diff --git a/.gitmodules b/.gitmodules index 432e45ba13..23833ffcc4 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,12 +1,16 @@ [submodule "eggroll"] path = eggroll url = https://github.com/WeBankFinTech/eggroll - branch = v2.5.0-alpha + branch = v3.0.0-beta [submodule "fate_client"] path = fate_client url = https://github.com/FederatedAI/FATE-Client - branch = v2.0.0-alpha + branch = v2.0.0-beta [submodule "fate_flow"] path = fate_flow url = https://github.com/FederatedAI/FATE-Flow - branch = v2.0.0-alpha + branch = v2.0.0-beta +[submodule "fate_test"] + path = fate_test + url = https://github.com/FederatedAI/FATE-Test.git + branch = v2.0.0-beta diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d45d6be974..bf91567c7b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pycqa/isort - rev: 5.5.2 + rev: 5.12.0 hooks: - id: isort - repo: https://github.com/psf/black diff --git a/Makefile b/Makefile index 39b65d374b..69dba8e2c9 100644 --- a/Makefile +++ b/Makefile @@ -3,16 +3,20 @@ mkfile_dir := $(dir $(mkfile_path)) .DEFAULT_GOAL:=help ##@ Dev -.PHONY: install-rust_paillier +.PHONY: install-rust install-rust_paillier: ## Install rust_paillier. - @echo "install rust_paillier" - @cd ${mkfile_dir} && maturin develop --release -m rust/tensor/rust_paillier/Cargo.toml + @echo "install fate_utils" + @cd ${mkfile_dir} && \ + . venv/bin/activate.sh && \ + maturin develop --release -m rust/fate_utils/Cargo.toml --target-dir build ##@ Build -.PHONY: build-rust_paillier -build-rust_paillier: ## Build rust_paillier. - @echo "build rust_paillier" - @cd ${mkfile_dir} && maturin build --release -m rust/tensor/rust_paillier/Cargo.toml --out dist --target-dir build +.PHONY: build-rust +build-rust: ## Build fate_utils. + @echo "build fate_utils" + @cd ${mkfile_dir} && \ + . ${mkfile_dir}venv/bin/activate && \ + maturin build --release -m rust/fate_utils/crates/fate_utils/Cargo.toml --out dist --target-dir build .PHONY: build-fate build-fate: ## Build fate diff --git a/README.md b/README.md index 8137eae25e..00d8306e91 100644 --- a/README.md +++ b/README.md @@ -25,8 +25,18 @@ FATE is an open source project hosted by Linux Foundation. The [Technical Charte ### Version < 2.0 Releases history can be found in [releases](https://github.com/FederatedAI/FATE/releases), deployment resources can be found on [wiki](https://github.com/FederatedAI/FATE/wiki/Download) +### Version == 2.0.0-beta +#### Standalone deployment +- Deploying FATE on a single node via PyPI, pre-built docker images or installers. It is for simple testing purposes. Refer to this [guide](./deploy/standalone-deploy/). + +### Cluster deployment +Deploying FATE to multiple nodes to achieve scalability, reliability and manageability. + +- [Cluster deployment by CLI](./deploy/cluster-deploy): Using CLI to deploy a FATE cluster. + ### Quick Start -- [Training Demo With Installing From Pypi](doc/2.0/quick_start.md) +- [Training Demo With Installing FATE AND FATE-Flow From Pypi](doc/2.0/quick_start.md) +- [Training Demo With Installing FATE Only From Pypi](doc/2.0/fate/ml) ## Related Repositories (Projects) - [KubeFATE](https://github.com/FederatedAI/KubeFATE): An operational tool for the FATE platform using cloud native technologies such as containers and Kubernetes. @@ -38,6 +48,7 @@ Releases history can be found in [releases](https://github.com/FederatedAI/FATE/ - [AnsibleFATE](https://github.com/FederatedAI/AnsibleFATE): A tool to optimize and automate the configuration and deployment operations via Ansible. - [FATE-Builder](https://github.com/FederatedAI/FATE-Builder): A tool to build package and docker image for FATE and KubeFATE. - [FATE-Client](https://github.com/FederatedAI/FATE-Client): A tool to enable fast federated modeling tasks for FATE. +- [FATE-Test](https://github.com/FederatedAI/FATE-Test): An automated testing tool for FATE, including tests and benchmark comparisons. ## Governance diff --git a/README_zh.md b/README_zh.md index 557c526d8f..2fee074553 100644 --- a/README_zh.md +++ b/README_zh.md @@ -23,8 +23,16 @@ FATE于2019年2月首次对外开源,并成立 ### 2.0以前的版本 FATE 2.0以前的版本在[发布页](https://github.com/FederatedAI/FATE/releases), 下载资源汇总页在[wiki](https://github.com/FederatedAI/FATE/wiki/Download) +### 2.0.0-beta 版本 +#### 单机版部署 +在单节点上部署FATE单机版,支持从 PyPI 直接安装,docker,主机安装包三种方式。 +- [单机版部署教程](./deploy/standalone-deploy) +#### 集群 +- [原生集群安装](./deploy/cluster-deploy): Using CLI to deploy a FATE cluster. + ### 快速开始 -- [从Pypi下载安装并启动训练任务示例](doc/2.0/quick_start.md) +- [从 PyPI 下载安装 FATE 和 FATE-Flow 并启动训练任务示例](doc/2.0/quick_start.md) +- [从 PyPI 下载安装 FATE,并启动训练任务示例](doc/2.0/fate/ml) ## 关联仓库 - [KubeFATE](https://github.com/FederatedAI/KubeFATE) @@ -37,6 +45,7 @@ FATE 2.0以前的版本在[发布页](https://github.com/FederatedAI/FATE/releas - [AnsibleFATE](https://github.com/FederatedAI/AnsibleFATE) - [FATE-Builder](https://github.com/FederatedAI/FATE-Builder) - [FATE-Client](https://github.com/FederatedAI/FATE-Client) +- [FATE-Test](https://github.com/FederatedAI/FATE-Test) ## 社区治理 diff --git a/RELEASE.md b/RELEASE.md index e30b90e326..968b82c613 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,51 @@ +## Release 2.0.0-beta +### Major Features and Improvements +> Arch 2.0:Building Unified and Standardized API for Heterogeneous Computing Engines Interconnection +* Framework: PSI-ECDH protocol support, single entry for histogram statistical computation +* Protocol: Support for ECDH, Secure Aggregation protocols +* Tensor: abstracted PHETensor, smooth switch between various underlying PHE implementations through standard interface +* DataFrame: New data block manager supports mixed-type columns & feature anonymization; added 30+ operator interfaces for statistics, including comparison, indexing, data binning, and transformation, etc. +* Enhanced workflow: Support for Cross Validation workflow + +> Components 2.0: Building Standardized Algorithm Components for different Scheduling Engines +* Input-Output: Further decoupling of FATE-Flow, providing standardized black-box calling processes +* Component Definition: Support for typing-based definition, automatic checking for component parameters, support for multiple types of data and model input and output, in addition to multiple inputs + +> ML:-2.0: Major functionality migration from FATE-v1.x, decoupling call hierarchy +* Data preprocessing: Added DataFrame Transformer, Union and DataSplit migration completed +* Feature Engineering: Migrated HeteroFederatedBinning, HeteroFeatureSelection, DataStatistics, Sampling, FeatureScale +3. Federated Training: Migrated HeteroSecureBoost, HomoNN, vertical CoordinatedLogisticRegression, and CoordinatedLinearRegression +4. Evaluation: Migrated Evaluation + +> OSX(Open Site Exchange) 1.0: Building Open Platform for Cross-Site Communication Interconnection +* Improved HTTP/1.X protocol support, support for GRPC-to-HTTP transmission +* Support for TLS secure transmission protocol +* Added routing table configuration interface +* Added routing table connectivity automatic check +* Improved transmission function in cluster mode +* Enhanced flow control in cluster mode +* Support for simple interface authentication + +> FATE Flow 2.0: Building Open and Standardized Scheduling Platform for Scheduling Interconnection +* Migrated functions: data upload/download, process scheduling, component output data/model/metric management, multi-storage adaptation for models, authentication, authorization, feature anonymization, multi-computing/storage/communication engine adaptation, and system high availability +* Optimized process scheduling, with scheduling separated and customizable, and added priority scheduling +* Optimized algorithm component scheduling, dividing execution steps into preprocessing, running, and post-processing +* Optimized multi-version algorithm component registration, supporting registration for mode of components +* Optimized client authentication logic, supporting permission management for multiple clients +* Optimized RESTful interface, making parameter fields and types, return fields, and status codes clearer +* Decoupling the system layer from the algorithm layer, with system configuration moved from the FATE repository to the Flow repository +* Published FATE Flow package to PyPI and added service-level CLI for service management + +> Fate-Client 2.0: Building Scalable Federated DSL for Application Layer Interconnection And Providing Tools For Fast Federated Modeling. +* Migrated Flow CLI and Flow SDK +* Updated federated DSL IR: enhance IR, add DataWarehouse and ModelWarehouse to load data and model from other sources +* Update component definitions to support Fate-v2.0.0-beta +* Flow CLI and PipeLine share configuration + +> Fate-Test: FATE Automated Testing Tool +* Migrated automated testing for functionality, performance, and correctness + + ## Release 2.0.0-alpha ### Feature Highlights > Arch 2.0:Building Unified and Standardized API for Heterogeneous Computing Engines Interconnection diff --git a/deploy/cluster-deploy/allinone/fate-allinone_deployment_guide.zh.md b/deploy/cluster-deploy/allinone/fate-allinone_deployment_guide.zh.md new file mode 100644 index 0000000000..46aa2ec13d --- /dev/null +++ b/deploy/cluster-deploy/allinone/fate-allinone_deployment_guide.zh.md @@ -0,0 +1,745 @@ +# FATE AllinOne部署指南 + +[English](./fate-allinone_deployment_guide.zh.md) +## 1. 服务器配置 + +| 服务器 | | +| :------: | ------------------------------------------------------------ | +| 数量 | 1 or 2 | +| 配置 | 8 core /16GB memory / 500GB硬盘/10M带宽 | +| 操作系统 | CentOS linux 7.2及以上/Ubuntu 18.04 | +| 依赖包 | (部署时自动安装) | +| 用户 | 用户:app,属主:apps(app用户需可以sudo su root而无需密码) | +| 文件系统 | 1. 500G硬盘挂载在 `/data` 目录下; 2.创建 `/data/projects` 目录,目录属主为:`app:apps` | + +## 2. 集群规划 + +| party | 主机名 | IP地址 | 操作系统 | 安装软件 | 服务 | +| ------ | ------------- | ----------- | ----------------------- | ------------------ | ------------------------------------------------------------ | +| PartyA | VM_0_1_centos | 192.168.0.1 | CentOS 7.2/Ubuntu 18.04 | fate,eggroll,mysql | fate_flow,fateboard,clustermanager,nodemanager,rollsite,mysql | +| PartyB | VM_0_2_centos | 192.168.0.2 | CentOS 7.2/Ubuntu 18.04 | fate,eggroll,mysql | fate_flow,fateboard,clustermanager,nodemanager,rollsite,mysql | + +架构图: + +|![](../../images/arch_zh.png)| +|:--:| +|架构图| + +## 3. 组件说明 + +| 软件产品 | 组件 | 端口 | 说明 | +| -------- | -------------- | --------- | ------------------------------------------------------------ | +| fate | fate_flow | 9360;9380 | 联合学习任务流水线管理模块 | +| fate | fateboard | 8080 | 联合学习过程可视化模块 | +| eggroll | clustermanager | 4670 | cluster manager管理集群 | +| eggroll | nodemanager | 4671 | node manager管理每台机器资源 | +| eggroll | rollsite | 9370 | 跨站点或者说跨party通讯组件,相当于以前版本的proxy+federation | +| mysql | mysql | 3306 | 数据存储,clustermanager和fateflow依赖 | + +## 4. 基础环境配置 + +### 4.1. hostname配置 + +**1)修改主机名** + +**在192.168.0.1 root用户下执行:** + +hostnamectl set-hostname VM_0_1_centos + +**在192.168.0.2 root用户下执行:** + +hostnamectl set-hostname VM_0_2_centos + +**2)加入主机映射** + +**在目标服务器(192.168.0.1 192.168.0.2)root用户下执行:** + +vim /etc/hosts + +192.168.0.1 VM_0_1_centos + +192.168.0.2 VM_0_2_centos + +### 4.2. 关闭 SELinux(不推荐) + +**在目标服务器(192.168.0.1 192.168.0.2)root用户下执行:** + +确认是否已安装selinux + +centos系统执行:rpm -qa | grep selinux + +ubuntu系统执行:apt list --installed | grep selinux + +如果已安装了selinux就执行:setenforce 0 + +### 4.3. 修改 Linux 系统参数 + +**在目标服务器(192.168.0.1 192.168.0.2)root用户下执行:** + +1)清理20-nproc.conf文件 + +cd /etc/security/limits.d + +ls -lrt 20-nproc.conf + +存在则:mv 20-nproc.conf 20-nproc.conf_bak + +2)vim /etc/security/limits.conf + +\* soft nofile 65535 + +\* hard nofile 65535 + +\* soft nproc 65535 + +\* hard nproc 65535 + +重新登陆,ulimit -a查看是否生效 + +### 4.4. 关闭防火墙(不推荐) + +**在目标服务器(192.168.0.1 192.168.0.2)root用户下执行** + +如果是Centos系统: + +systemctl disable firewalld.service + +systemctl stop firewalld.service + +systemctl status firewalld.service + +如果是Ubuntu系统: + +ufw disable + +ufw status + +### 4.5. 软件环境初始化 + +**1)创建用户** + +**在目标服务器(192.168.0.1 192.168.0.2)root用户下执行** + +``` +groupadd -g 6000 apps +useradd -s /bin/bash -g apps -d /home/app app +passwd app +``` + +**2)配置sudo** + +**在目标服务器(192.168.0.1 192.168.0.2)root用户下执行** + +vim /etc/sudoers.d/app + +app ALL=(ALL) ALL + +app ALL=(ALL) NOPASSWD: ALL + +Defaults !env_reset + +**3)配置ssh无密登录** + +**a. 在目标服务器(192.168.0.1 192.168.0.2)app用户下执行** + +su app + +ssh-keygen -t rsa + +cat \~/.ssh/id_rsa.pub \>\> /home/app/.ssh/authorized_keys + +chmod 600 \~/.ssh/authorized_keys + +**b.合并id_rsa_pub文件** + +拷贝192.168.0.1的authorized_keys 到192.168.0.2 +\~/.ssh目录下,追加到192.168.0.2的id_rsa.pub到authorized_keys,然后再拷贝到192.168.0.1 + +**在192.168.0.1 app用户下执行** + +scp \~/.ssh/authorized_keys app\@192.168.0.2:/home/app/.ssh + +输入密码 + +**在192.168.0.2 app用户下执行** + +cat \~/.ssh/id_rsa.pub \>\> /home/app/.ssh/authorized_keys + +scp \~/.ssh/authorized_keys app\@192.168.0.1:/home/app/.ssh + +覆盖之前的文件 + +**c. 在目标服务器(192.168.0.1 192.168.0.2)app用户下执行ssh 测试** + +ssh app\@192.168.0.1 + +ssh app\@192.168.0.2 + +### 4.6. 增加虚拟内存 + +**目标服务器(192.168.0.1 192.168.0.2 192.168.0.3)** + +生产环境使用时,因内存计算需要增加128G虚拟内存,执行前需检查存储空间是否足够。 + +手工创建,root用户执行: + +``` +cd /data +dd if=/dev/zero of=/data/swapfile128G bs=1024 count=134217728 +mkswap /data/swapfile128G +swapon /data/swapfile128G +cat /proc/swaps +echo '/data/swapfile128G swap swap defaults 0 0' >> /etc/fstab +``` + +或者使用5.1章节的代码包中的脚本创建,app用户执行: + +``` +bash /data/projects/fate_cluster_install_${version}_release/tools-install/makeVirtualDisk.sh +Waring: please make sure has enough space of your disk first!!! (请确认有足够的存储空间) +current user has sudo privilege(yes|no):yes (是否有sudo权限,输入yes,不能简写) +Enter store directory:/data (设置虚拟内存文件的存放路径,确保目录存在和不要设置在根目录) +Enter the size of virtual disk(such as 64G/128G):128G (设置虚拟内存文件的大小,32G的倍数,数字后要带单位G,一般设置为128G即可) +/data 32 1 +32768+0 records in +32768+0 records out +34359738368 bytes (34 GB) copied, 200.544 s, 171 MB/s +Setting up swapspace version 1, size = 33554428 KiB +no label, UUID=58ce153c-feac-4989-b684-c100e4edca0b +/data 32 2 +32768+0 records in +32768+0 records out +34359738368 bytes (34 GB) copied, 200.712 s, 171 MB/s +Setting up swapspace version 1, size = 33554428 KiB +no label, UUID=d44e27ed-966b-4477-b46e-fcda4e3057c2 +/data 32 3 +32768+0 records in +32768+0 records out +34359738368 bytes (34 GB) copied, 200.905 s, 171 MB/s +Setting up swapspace version 1, size = 33554428 KiB +no label, UUID=ab5db8d7-bc09-43fb-b23c-fc11aef1a3b6 +/data 32 4 +32768+0 records in +32768+0 records out +34359738368 bytes (34 GB) copied, 201.013 s, 171 MB/s +Setting up swapspace version 1, size = 33554428 KiB +no label, UUID=c125ede3-7ffd-4110-9dc8-ebdf4fab0fd1 +``` + +校验 + +``` +cat /proc/swaps + +Filename Type Size Used Priority +/data/swapfile32G_1 file 33554428 0 -1 +/data/swapfile32G_2 file 33554428 0 -2 +/data/swapfile32G_3 file 33554428 0 -3 +/data/swapfile32G_4 file 33554428 0 -4 + +free -m + total used free shared buff/cache available +Mem: 15715 6885 91 254 8739 8461 +Swap: 131071 0 131071 + +``` + +## 5. 项目部署 + + +注:此指导安装目录默认为/data/projects/,执行用户为app,安装时根据具体实际情况修改。 + +### 5.1. 获取项目 + +**在目标服务器(192.168.0.1 具备外网环境)app用户下执行** + +进入执行节点的/data/projects/目录,执行: + +备注:用具体FATE版本号替换${version},可在[release页面](https://github.com/FederatedAI/FATE/releases)上查看 + +``` +cd /data/projects/ +wget https://webank-ai-1251170195.cos.ap-guangzhou.myqcloud.com/fate/${version}/release/fate_cluster_install_${version}_release.tar.gz +tar xzf fate_cluster_install_${version}_release.tar.gz + +注意:version不带字符v,如fate_cluster_install_2.x.x_release.tar.gz +``` + +### 5.2. 部署前检查 + +**在目标服务器(192.168.0.1 192.168.0.2 )app用户下执行** + +把检查脚本fate_cluster_install_${version}_release/tools-install/check.sh从192.168.0.1拷贝到192.168.0.2 + +``` +#在192.168.0.1和192.168.0.2服务器上分别执行检查脚本 +bash ./check.sh + +#确认app用户已配置sudo +#虚拟内存,size不低于128G,如不满足需参考4.6章节重新设置 +#文件句柄数,不低于65535,如不满足需参考4.3章节重新设置 +#用户进程数,不低于64000,如不满足需参考4.3章节重新设置 +#确认部署前没有fate进程和端口冲突 +#确认/etc/my.cnf是否存在,存在需要mv;确认是否存在/data/projects/fate目录,存在需把fate目录mv备份。 +``` + +### 5.3. 配置文件修改和示例 + +**在目标服务器(192.168.0.1)app用户下执行** + +修改配置文件fate_cluster_install_${version}_release/allInone/conf/setup.conf. + +``` +vi fate_cluster_install_${version}_release/allInone/conf/setup.conf +``` + +配置文件setup.conf说明: + +| 配置项 | 配置项值 | 说明 | +| ------------------- | ----------------------------------------------------- | ------------------------------------------------------------ | +| roles | 默认:"host" "guest" | 部署的角色,有HOST端、GUEST端 | +| version | 默认:${version} | Fate 版本号 | +| pbase | 默认: /data/projects | 项目根目录 | +| pname | 默认:fate | 项目名称 | +| lbase | 默认:/data/logs | 保持默认不要修改 | +| ssh_user | 默认:app | ssh连接目标机器的用户,也是部署后文件的属主 | +| ssh_group | 默认:apps | ssh连接目标的用户的属组,也是部署后文件的属组 | +| ssh_port | 默认:22,根据实际情况修改 | ssh连接端口,部署前确认好端口,不然会报连接错误 | +| eggroll_dbname | 默认:eggroll_meta | eggroll连接的DB名字 | +| fate_flow_dbname | 默认:fate_flow | fate_flow、fateboard等连接的DB名字 | +| mysql_admin_pass | 默认 | mysql的管理员(root)密码 | +| redis_pass | 默认 | redis密码,暂未使用 | +| mysql_user | 默认:fate | msyql的应用连接账号 | +| mysql_port | 默认:3306,根据实际情况修改 | msql服务监听的端口 | +| host_id | 默认 : 10000,根据实施规划修改 | HOST端的party id。 | +| host_ip | 192.168.0.1 | HOST端的ip | +| host_mysql_ip | 默认和host_ip保持一致 | HOST端mysql的ip | +| host_mysql_pass | 默认 | HOST端msyql的应用连接账号 | +| guest_id | 默认 : 9999,根据实施规划修改 | GUEST端的party id | +| guest_ip | 192.168.0.2 | GUEST端的ip | +| guest_mysql_ip | 默认和guest_ip保持一致 | GUEST端mysql的ip | +| guest_mysql_pass | 默认 | GUEST端msyql的应用连接账号 | +| dbmodules | 默认:"mysql" | DB组件的部署模块列表,如mysql | +| basemodules | 默认:"tools" "base" "java" "python" "eggroll" "fate" | 非DB组件的部署模块列表,如 "tools" "base"、 "java"、 "python" 、"eggroll" 、"fate" | +| fateflow_grpc_port | 默认:9360 | fateflow grpc服务端口 | +| fateflow_http_port | 默认:9380 | fateflow http服务端口 | +| fateboard_port | 默认:8080 | fateboard服务端口 | +| rollsite_port | 默认:9370 | rollsite服务端口 | +| clustermanager_port | 默认:4670 | clustermanager服务端口 | +| nodemanager_port | 默认:4671 | nodemanager服务端口 | + +**1)两台主机partyA+partyB同时部署**** + +``` +#to install role +roles=( "host" "guest" ) + +version="${version}" +#project base +pbase="/data/projects" +#project name +pname="fate" + +#log directory +lbase="/data/logs" + +#user who connects dest machine by ssh +ssh_user="app" +ssh_group="apps" +#ssh port +ssh_port=22 + +#eggroll_db name +eggroll_dbname="eggroll_meta" +#fate_flow_db name +fate_flow_dbname="fate_flow" + +#mysql init root password +mysql_admin_pass="fate_dev" + +#redis passwd +redis_pass="" + +#mysql user +mysql_user="fate" +#mysql port +mysql_port="3306" + +#host party id +host_id="10000" +#host ip +host_ip="192.168.0.1" +#host mysql ip +host_mysql_ip="${host_ip}" +host_mysql_pass="fate_deV2999" + +#guest party id +guest_id="9999" +#guest ip +guest_ip="192.168.0.2" +#guest mysql ip +guest_mysql_ip="${guest_ip}" +guest_mysql_pass="fate_deV2999" + +#db module lists +dbmodules=( "mysql" ) + +#base module lists +basemodules=( "tools" "base" "java" "python" "eggroll" "fate" ) + +fateflow_grpc_port=9360 +fateflow_http_port=9380 +fateboard_port=8080 + +rollsite_port=9370 +clustermanager_port=4670 +nodemanager_port=4671 +``` + +**2)只部署一个party** + +``` +#to install role +roles=( "host" ) + +version="${version}" +#project base +pbase="/data/projects" +#project name +pname="fate" + +#log directory +lbase="/data/logs" + +#user who connects dest machine by ssh +ssh_user="app" +ssh_group="apps" +#ssh port +ssh_port=22 + +#eggroll_db name +eggroll_dbname="eggroll_meta" +#fate_flow_db name +fate_flow_dbname="fate_flow" + +#mysql init root password +mysql_admin_pass="fate_dev" + +#redis passwd +redis_pass="" + +#mysql user +mysql_user="fate" +#mysql port +mysql_port="3306" + +#host party id +host_id="10000" +#host ip +host_ip="192.168.0.1" +#host mysql ip +host_mysql_ip="${host_ip}" +host_mysql_pass="fate_deV2999" + +#guest party id +guest_id="" +#guest ip +guest_ip="" +#guest mysql ip +guest_mysql_ip="${guest_ip}" +guest_mysql_pass="" + +#db module lists +dbmodules=( "mysql" ) + +#base module lists +basemodules=( "tools" "base" "java" "python" "eggroll" "fate" ) + +fateflow_grpc_port=9360 +fateflow_http_port=9380 +fateboard_port=8080 + +rollsite_port=9370 +clustermanager_port=4670 +nodemanager_port=4671 +``` + +### 5.4. 部署 + +按照上述配置含义修改setup.conf文件对应的配置项后,然后在fate_cluster_install_${version}_release/allInone目录下执行部署脚本: + +``` +cd fate_cluster_install_${version}_release/allInone +nohup bash ./deploy.sh > logs/boot.log 2>&1 & +``` + +部署日志输出在fate_cluster_install_${version}_release/allInone/logs目录下,实时查看是否有报错: + +``` +tail -f ./logs/deploy.log (部署结束,查看一下即可) +tail -f ./logs/deploy-guest.log (实时打印GUEST端的部署情况) +tail -f ./logs/deploy-mysql-guest.log (实时打印GUEST端mysql的部署情况) +tail -f ./logs/deploy-host.log (实时打印HOST端的部署情况) +tail -f ./logs/deploy-mysql-host.log (实时打印HOST端mysql的部署情况) +``` + +### 5.5. 问题定位 + +1)eggroll日志 + + /data/projects/fate/eggroll/logs/eggroll/bootstrap.clustermanager.err + +/data/projects/fate/eggroll/logs/eggroll/clustermanager.jvm.err.log + +/data/projects/fate/eggroll/logs/eggroll/nodemanager.jvm.err.log + +/data/projects/fate/eggroll/logs/eggroll/bootstrap.nodemanager.err + +/data/projects/fate/eggroll/logs/eggroll/bootstrap.rollsite.err + +/data/projects/fate/eggroll/logs/eggroll/rollsite.jvm.err.log + +2)fateflow日志 + +/data/projects/fate/fate_flow/logs/fate_flow + +3)fateboard日志 + +/data/projects/fate/fateboard/logs + +## 6.测试 + +### 6.1. Toy_example部署验证 + +此测试您需要设置2个参数:gid(guest partyid),hid(host_partyid)。 + +#### 6.1.1. 单边测试 + +1)192.168.0.1上执行,gid和hid都设为10000: + +``` +source /data/projects/fate/fate_flow/bin/init_env.sh +flow test toy -gid 10000 -hid 10000 +``` + +类似如下结果表示成功: + +toy test job xxx is success + +提示:如出现max cores per job is 1, please modify job parameters报错提示,需要修改运行时参数task_cores为1,增加命令行参数 '--task-cores 1'. + +2)192.168.0.2上执行,gid和hid都设为9999: + +``` +source /data/projects/fate/fate_flow/bin/init_env.sh +flow test toy -gid 9999 -hid 9999 +``` + +类似如下结果表示成功: + +toy test job 202308291022025779790 is success + +#### 6.1.2 双边测试 + +选定9999为guest方,在192.168.0.2上执行: + +``` +source /data/projects/fate/fate_flow/bin/init_env.sh +flow test toy -gid 9999 -hid 10000 +``` + +类似如下结果表示成功: + +toy test job 202308291022025779790 is success + +## 7.系统运维 + +### 7.1. 服务管理 + +**在目标服务器(192.168.0.1 192.168.0.2)app用户下执行** + +#### 7.1.1. Mysql服务管理 + +启动/关闭/查看/重启mysql服务 + +```bash +cd /data/projects/fate/common/mysql/mysql-* +bash ./service.sh start|stop|status|restart +``` + +#### 7.1.2. Eggroll服务管理 + +```bash +source /data/projects/fate/fate_flow/bin/init_env.sh +cd /data/projects/fate/eggroll +``` + +启动/关闭/查看/重启所有: + +```bash +bash ./bin/eggroll.sh all start/stop/status/restart +``` + +启动/关闭/查看/重启单个模块(可选:clustermanager,nodemanager,rollsite): + +```bash +bash ./bin/eggroll.sh clustermanager start/stop/status/restart +``` + +#### 7.1.3. Fate服务管理 + +1) 启动/关闭/查看/重启fate_flow服务 + +```bash +source /data/projects/fate/fate_flow/bin/init_env.sh +cd /data/projects/fate/fate_flow/bin +bash service.sh start|stop|status|restart +``` + +如果逐个模块启动,需要先启动eggroll再启动fateflow,fateflow依赖eggroll的启动。 + +2) 启动/关闭/重启fateboard服务 + +```bash +cd /data/projects/fate/fateboard +bash service.sh start|stop|status|restart +``` + +3) 启动/关闭/重启osx服务 + +```bash +cd /data/projects/fate/fate/proxy/osx +bash service.sh start|stop|status|restart +``` + +如果需要启动rollsite,需要先停用osx再启动rollsite,默认启动osx。 + +### 7.2. 查看进程和端口 + +**在目标服务器(192.168.0.1 192.168.0.2 )app用户下执行** + +#### 7.2.1. 查看进程 + +根据部署规划查看进程是否启动 + +```bash +ps -ef | grep -i clustermanager +ps -ef | grep -i nodemanager +ps -ef | grep -i rollsite +ps -ef | grep -i fate_flow_server.py +ps -ef | grep -i fateboard +ps -ef | grep -i osx +``` + +#### 7.2.2. 查看进程端口 + +根据部署规划查看进程端口是否存在 + +```bash +#clustermanager +netstat -tlnp | grep 4670 +#nodemanager +netstat -tlnp | grep 4671 +#rollsite or osx +netstat -tlnp | grep 9370 +#fate_flow_server +netstat -tlnp | grep 9360 +#fateboard +netstat -tlnp | grep 8080 +``` + +### 7.3. 服务日志 + +| 服务 | 日志路径 | +| ------------------ | ----------------------------------------------- | +| eggroll | /data/projects/fate/eggroll/logs | +| fate_flow&任务日志 | /data/projects/fate/fate_flow/logs | +| fateboard | /data/projects/fate/fateboard/logs | +| mysql | /data/projects/fate/common/mysql/mysql-*/logs | +| osx | /data/projects/fate/fate/proxy/osx/logs/broker/ | + +### 7.4. 空间清理规则 + +#### 7.4.1. fate_flow作业日志 + +所在机器:fate_flow服务所在机器 + +目录:`/data/projects/fate/fate_flow/logs` + +保留期限:N=14天 + +规则:目录以 `$jobid` 开头,清理` $jobid`为 N天前的数据 + +```bash +find /data/projects/fate/fate_flow/logs/ -maxdepth 1 -mindepth 1 -mtime +N -type d ! -path "*/fate_flow" | xargs rm -rf +``` + +#### 7.4.2. fate_flow系统日志 + +所在机器:fate_flow服务所在机器 + +目录:`/data/projects/fate/fate_flow/logs/fate_flow` + +保留期限:N=14天 + +规则:以日期结尾,清理日期为 N天前的数据 + +```bash +find /data/projects/fate/fate_flow/logs/fate_flow/ -maxdepth 1 -mtime +N -name "*.log.*" | xargs rm -rf +``` + +#### 7.4.3. EggRoll Session日志 + +所在机器:eggroll node节点 + +目录:`/data/projects/fate/eggroll/logs/` + +保留期限:N=14天 + +规则:目录以 `$jobid` 开头,清理 `$jobid` 为 N天前的数据 + +```bash +find /data/projects/fate/eggroll/logs/ -maxdepth 1 -mindepth 1 -mtime +N -type d ! -path "*/eggroll" | xargs rm -rf +``` + +#### 7.4.4. EggRoll系统日志 + +所在机器:eggroll node节点 + +目录:`/data/projects/fate/eggroll/logs/eggroll` + +保留期限:N=14天 + +规则:以日期结尾和以年份建立的历史文件夹中文件,清理N天前的数据 + +```bash +find /data/projects/fate/eggroll/logs/eggroll/ -maxdepth 1 -mtime +N -name "*.log.*" | xargs rm -rf +``` + +#### 7.4.5. 计算临时数据 + +所在机器:eggroll node节点 + +目录:`/data/projects/fate/eggroll/data/IN_MEMORY` + +保留期限:N=7天 + +规则:namespace以 `$jobid` 开头,清理 `$jobid` 为 N天前的数据 + +```bash +find /data/projects/fate/eggroll/data/IN_MEMORY/ -maxdepth 1 -mindepth 1 -mtime +N -type d | xargs rm -rf +``` + +#### 7.4.6. 作业组件输出数据 + +所在机器:eggroll node节点 + +目录:/data/projects/fate/eggroll/data/LMDB + +保留期限:N=14天 + +规则:namespace以 `output_data_$jobid` 开头,清理 `$jobid` 为N天前的数据 + +```bash +find /data/projects/fate/eggroll/data/LMDB/ -maxdepth 1 -mindepth 1 -mtime +N -type d -name "output_data_*" | xargs rm -rf +``` diff --git a/deploy/cluster-deploy/allinone/fate-exchange_deployment_guide.zh.md b/deploy/cluster-deploy/allinone/fate-exchange_deployment_guide.zh.md new file mode 100644 index 0000000000..199e35f013 --- /dev/null +++ b/deploy/cluster-deploy/allinone/fate-exchange_deployment_guide.zh.md @@ -0,0 +1,529 @@ +# FATE exchange部署指南 +[English](./fate-exchange_deployment_guide.md) + +1.服务器配置 +============ + +| 服务器 | | +| :------: | ------------------------------------------------------------ | +| 数量 | 1(根据实际情况配置) | +| 配置 | 8 core /16GB memory / 500GB硬盘/10M带宽 | +| 操作系统 | CentOS linux 7.2及以上/Ubuntu 18.04 | +| 依赖包 | (参见4.5 软件环境初始化) | +| 用户 | 用户:app,属主:apps(app用户需可以sudo su root而无需密码) | +| 文件系统 | 1. 500G硬盘挂载在/ data目录下; 2.创建/ data / projects目录,目录属主为:app:apps | + +2.集群规划 +========== + +| party | partyid | 主机名 | IP地址 | 操作系统 | 安装软件 | 服务 | +| -------- | -------- | ------------- | ----------- | ----------------------- | -------- | ---- | +| exchange | exchange | VM_0_1_centos | 192.168.0.1 | CentOS 7.2/Ubuntu 18.04 | eggroll | osx | + + +# 3.组件说明 + +| 软件产品 | 组件 | 端口 | 说明 | +| -------- | -------- | ---- | ------------------------------------------------------------ | +| eggroll | rollsite | 9370 | 跨站点或者说跨party通讯组件,相当于proxy+federation,每个party只能有一个此服务 | + +4.基础环境配置 +============== + +4.1 hostname配置 +---------------- + +**1)修改主机名** + +**在192.168.0.1 root用户下执行:** + +hostnamectl set-hostname VM_0_1_centos + +**2)加入主机映射** + +**在目标服务器(192.168.0.1)root用户下执行:** + +vim /etc/hosts + +192.168.0.1 VM_0_1_centos + +4.2 关闭selinux +--------------- + +**在目标服务器(192.168.0.1)root用户下执行:** + +确认是否已安装selinux + +centos系统执行:rpm -qa | grep selinux + +ubuntu系统执行:apt list --installed | grep selinux + +如果已安装了selinux就执行:setenforce 0 + +4.3 修改Linux系统参数 +--------------------------- + +**在目标服务器(192.168.0.1)root用户下执行:** + +1)vim /etc/security/limits.conf + +\* soft nofile 65535 + +\* hard nofile 65535 + +2)vim /etc/security/limits.d/20-nproc.conf + +\* soft nproc unlimited + +4.4 关闭防火墙 +-------------- + +**在目标服务器(192.168.0.1)root用户下执行** + +如果是Centos系统: + +systemctl disable firewalld.service + +systemctl stop firewalld.service + +systemctl status firewalld.service + +如果是Ubuntu系统: + +ufw disable + +ufw status + +4.5 软件环境初始化 +------------------ + +**在目标服务器(192.168.0.1)root用户下执行** + +**1)创建用户** + +``` +groupadd -g 6000 apps +useradd -s /bin/bash -g apps -d /home/app app +passwd app +``` + +**2)创建目录** + +``` +mkdir -p /data/projects/fate +mkdir -p /data/projects/install +chown -R app:apps /data/projects +``` + +**3)安装依赖** + +``` +#centos +yum -y install gcc gcc-c++ make openssl-devel gmp-devel mpfr-devel libmpc-devel libaio numactl autoconf automake libtool libffi-devel snappy snappy-devel zlib zlib-devel bzip2 bzip2-devel lz4-devel libasan lsof sysstat telnet psmisc +#ubuntu +apt-get install -y gcc g++ make openssl supervisor libgmp-dev libmpfr-dev libmpc-dev libaio1 libaio-dev numactl autoconf automake libtool libffi-dev libssl1.0.0 libssl-dev liblz4-1 liblz4-dev liblz4-1-dbg liblz4-tool zlib1g zlib1g-dbg zlib1g-dev +cd /usr/lib/x86_64-linux-gnu +if [ ! -f "libssl.so.10" ];then + ln -s libssl.so.1.0.0 libssl.so.10 + ln -s libcrypto.so.1.0.0 libcrypto.so.10 +fi +``` + +5.项目部署 +========== + +注:此指导安装目录默认为/data/projects/install,执行用户为app,安装时根据具体实际情况修改。 + +5.1 获取安装包 +------------ + +在目标服务器(192.168.0.1 具备外网环境)app用户下执行: +备注:用具体FATE版本号替换${version} +``` +cd /data/projects/install +wget https://webank-ai-1251170195.cos.ap-guangzhou.myqcloud.com/resources/jdk-8u345.tar.xz +wget https://webank-ai-1251170195.cos.ap-guangzhou.myqcloud.com/fate/${version}/release/fate_install_${version}_release.tar.gz + +注意:version不带字符v,如fate_install_2.x.x_release.tar.gz +``` + +## 5.2 操作系统参数检查 + +**在目标服务器(192.168.0.1)app用户下执行** + +``` +#文件句柄数,不低于65535,如不满足需参考4.3章节重新设置 +ulimit -n +65535 + +#用户进程数,不低于64000,如不满足需参考4.3章节重新设置 +ulimit -u +65535 +``` + +## 5.3 部署jdk + +**在目标服务器(192.168.0.1)app用户下执行**: + +``` +#创建jdk安装目录 +mkdir -p /data/projects/fate/common/jdk +#解压缩 +cd /data/projects/install +tar -xvJf jdk-8u345.tar.xz -C /data/projects/fate/common/jdk +cd /data/projects/fate/common/jdk +``` + + +5.4 部署rollsite(rollsite和osx二选一) +-------- + +### **5.4.1软件部署** + +``` +#部署软件 +#在目标服务器(192.168.0.1)app用户下执行: +cd /data/projects/install +tar xf fate_install_*.tar.gz +cd fate_install_* +tar xvf eggroll.tar.gz -C /data/projects/fate +tar xvf bin.tar.gz -C /data/projects/fate + +#设置环境变量文件 +#在目标服务器(192.168.0.1)app用户下执行: +cat >/data/projects/fate/bin/init_env.sh < /data/projects/fate/eggroll/conf/eggroll.properties < /data/projects/fate/eggroll/conf/route_table.json << EOF +{ + "route_table": + { + "9999": + { + "default":[ + { + "port": 9370, + "ip": "192.168.0.2" + } + ] + }, + "10000": + { + "default":[ + { + "port": 9370, + "ip": "192.168.0.3" + } + ] + } + }, + "permission": + { + "default_allow": true + } +} +EOF +``` + +### 5.4.4 各party默认rollsite路由信息修改 + +**需要连接exchange的各party的rollsite模块,app用户修改** + +修改/data/projects/fate/eggroll/conf/route_table.json部分,默认路由信息指向部署好的exchange,不需要配置对端fateflow信息,修改后需重启rollsite: + +``` + "default": { + "default": [ + { + "ip": "192.168.0.1", + "port": 9370 + } + ] + } +``` + +## 5.5 部署osx(rollsite和osx二选一) + +### 5.5.1 软件部署 + +``` +#部署软件 +#在目标服务器(192.168.0.1)app用户下执行: +cd /data/projects/install +tar xf fate_install_*.tar.gz +cd fate_install_* +tar xvf bin.tar.gz -C /data/projects/fate +mkdir -p /data/projects/fate/fate/proxy +tar xvf osx.tar.gz -C /data/projects/fate/fate/proxy + +#设置环境变量文件 +#在目标服务器(192.168.0.1)app用户下执行: +cat >/data/projects/fate/bin/init_env.sh < /data/projects/fate/fate/proxy/osx/conf/broker/broker.properties < /data/projects/fate/fate/proxy/osx/conf/broker/route_table.json << EOF +{ + "route_table": + { + "9999": + { + "default":[ + { + "port": 9370, + "ip": "192.168.0.2" + } + ] + }, + "10000": + { + "default":[ + { + "port": 9370, + "ip": "192.168.0.3" + } + ] + } + }, + "permission": + { + "default_allow": true + } +} +EOF +``` + +### 5.5.4 各party默认osx路由信息修改 + +**需要连接exchange的各party的osx模块,app用户修改** + +修改/data/projects/fate/fate/proxy/osx/conf/broker/route_table.json部分,默认路由信息指向部署好的exchange,不需要配置对端fateflow信息,修改后需重启osx: + +``` + "default": { + "default": [ + { + "ip": "192.168.0.1", + "port": 9370 + } + ] + } +``` + + + +## 5.6 启动服务 + +**在目标服务器(192.168.0.1)app用户下执行** + +``` +eggroll和osx二选一 +#启动osx服务时(默认使用osx,需先停用rollsite) +source /data/projects/fate/bin/init_env.sh +cd /data/projects/fate/eggroll +bash ./bin/eggroll.sh rollsite stop +cd /data/projects/fate/fate/proxy/osx +bash service.sh start + +#启动eggroll服务(使用rollsite,需先停用osx) +source /data/projects/fate/bin/init_env.sh +cd /data/projects/fate/fate/proxy/osx +bash service.sh stop +cd /data/projects/fate/eggroll +bash ./bin/eggroll.sh rollsite start +``` + +## 5.7 验证和问题定位 + +1)跑一个双边toy测试,看是否可以测试通过,通过则表示配置无误,具体用例参考allinone部署文档。 + +2)查看exchange日志,看第1步用例涉及到的partyid是否有路由信息, + +​ 日志:/data/projects/fate/eggroll/logs/eggroll/rollsite.jvm.log + +3)rollsite错误日志 + +​ /data/projects/fate/eggroll/logs/eggroll/bootstrap.rollsite.err + +​ /data/projects/fate/eggroll/logs/eggroll/rollsite.jvm.err.log + +4)osx错误日志 + +​ /data/projects/fate/fate/proxy/osx/logs/broker/broker-error.log + +6.系统运维 +================ + +6.1 服务管理 +------------ + +**在目标服务器(192.168.0.1)app用户下执行** + +### 6.1.1 rollsite服务管理 + +``` +cd /data/projects/fate/eggroll +``` + +启动/关闭/查看/重启rollsite: + +``` +bash ./bin/eggroll.sh rollsite start/stop/status/restart +``` + +### 6.1.2 osx服务管理 + +``` +cd /data/projects/fate/fate/proxy/osx +``` + +启动/关闭/查看/重启osx: + +``` +bash service.sh start/stop/status/restart +``` + +## 6.2 查看进程和端口 + +**在目标服务器(192.168.0.1)app用户下执行** + +### 6.2.1 查看进程 + +``` +#查看osx进程是否启动 +ps -ef | grep -i osx + +#查看rollsite进程是否启动 +ps -ef | grep -i rollsite +``` + +### 6.2.2 查看进程端口 + +``` +#查看进程端口是否存在 +#rollsite or osx +netstat -tlnp | grep 9370 +``` + + + +## 6.3 服务日志 + +| 服务 | 日志路径 | +| -------- | ----------------------------------------------- | +| rollsite | /data/projects/fate/eggroll/logs | +| osx | /data/projects/fate/fate/proxy/osx/logs/broker/ | + diff --git a/deploy/standalone-deploy/README.md b/deploy/standalone-deploy/README.md new file mode 100644 index 0000000000..3c53f59908 --- /dev/null +++ b/deploy/standalone-deploy/README.md @@ -0,0 +1,209 @@ +# FATE Single-Node Deployment Guide + +[中文](./README.zh.md) + +## 1. Introduction + +**Server Configuration:** + +- **Quantity:** 1 +- **Configuration:** 8 cores / 16GB memory / 500GB hard disk +- **Operating System:** CentOS Linux release 7 +- **User:** User: app owner:apps + +The single-node version provides 3 deployment methods, which can be selected based on your needs: + +- Install FATE from PyPI +- Install FATE using Docker Images +- Install FATE on the host machine (using pre-compiled installation packages) + +## 2. Install FATE from PyPI (Recommended) + +### 2.1 Installing Python Environment +- Prepare and install [conda](https://docs.conda.io/projects/miniconda/en/latest/) environment. +- Create a virtual environment: +```shell +# FATE requires Python >= 3.8 +conda create -n fate_env python=3.8 +conda activate fate_env +``` + +### 2.2 Installing FATE +This section introduces two ways to installing FATE from pypi, with and without FATE-Flow service. + +#### 2.2.1 Installing FATE With FATE-FLow Service +FATE-Flow provides federated job life cycle management, includes scheduling, data management, model and metric management, etc. + +##### 2.2.1.1 Installing FATE, FATE-Flow, FATE-Client +```shell +pip install fate_client[fate,fate_flow]==2.0.0.b0 +``` +#### 2.2.1.2 Service Initialization +```shell +fate_flow init --ip 127.0.0.1 --port 9380 --home $HOME_DIR +pipeline init --ip 127.0.0.1 --port 9380 +``` +- `ip`: The IP address where the service runs. +- `port`: The HTTP port the service runs on. +- `home`: The data storage directory, including data, models, logs, job configurations, and SQLite databases. + +#### 2.2.1.3 Start Fate-Flow Service + +```shell +fate_flow start +``` + +#### 2.2.1.4 Testing + +- [Test Items](#5-Test-Items) + +### 2.2.2 Installing FATE Directly +FATE provides multiple federated algorithm and secure protocols, +users can directly import fate and use built-in algorithms and secure protocols directly. + +#### 2.2.2.1 Installing FATE +```shell +pip install pyfate==2.0.0.b0 +``` +#### 2.2.2.2 Using Guides +Refer to [examples](../../doc/2.0/fate/ml) + + +## 3. Install FATE using Docker Images + +**Note:** Replace `${version}` in the examples below with the actual version number. + +### 3.1 Pre-deployment Environment Check + +- The host machine should have access to the external network to pull installation packages and Docker images from public networks. +- Dependency on [Docker](https://download.docker.com/linux/). Docker version 18.09 is recommended. You can verify the Docker environment using the following command: `docker --version`. For Docker start/stop and other operations, refer to `docker --help`. +- Before execution, check if port 8080 is already occupied. If you need to re-execute, use Docker commands to delete previous containers and images. + +Set the necessary environment variables for deployment (note that environment variables set in this way are only valid for the current terminal session. If you open a new terminal session, such as logging in again or opening a new window, you will need to reset them). + +```bash +export version={FATE version number for this deployment, e.g., 2.0.0-beta} +``` + +Example: + +```bash +export version=2.0.0-beta +``` + +### 3.2 Pull Docker Images + +#### 3.2.1 Via Public Image Services + +```bash +# Docker Hub +docker pull federatedai/standalone_fate:${version} + +# Tencent Cloud Container Image +docker pull ccr.ccs.tencentyun.com/federatedai/standalone_fate:${version} +docker tag ccr.ccs.tencentyun.com/federatedai/standalone_fate:${version} federatedai/standalone_fate:${version} +``` + +#### 3.2.2 Via Image Package + +```bash +wget https://webank-ai-1251170195.cos.ap-guangzhou.myqcloud.com/fate/${version}/release/standalone_fate_docker_image_${version}_release.tar.gz +docker load -i standalone_fate_docker_image_${version}_release.tar.gz +docker images | grep federatedai/standalone_fate +``` + +If you see an image corresponding to `${version}`, it means the image download was successful. + +### 3.3 Start + +```bash +docker run -it --name standalone_fate -p 8080:8080 federatedai/standalone_fate:${version} +``` + +### 3.4 Testing + +```bash +source /data/projects/fate/fate_flow/bin/init_env.sh +``` + +- [Test Items](#5-Test-Items) + +## 4. Install FATE on the Host Machine (Using Pre-Compiled Installation Packages) + +**Note:** Replace `${version}` in the examples below with the actual version number. + +### 4.1 Pre-deployment Environment Check + +Check if local ports 8080, 9360, and 9380 are already occupied. + +```bash +netstat -apln|grep 8080; +netstat -apln|grep 9360; +netstat -apln|grep 9380 +``` + +Because operating system dependencies need to be installed, root privileges are required. You can execute the subsequent operations as the root user. If you don't use the root user, grant sudo privileges to the user you want to use: + +```bash +echo "{username to be used} ALL=(ALL) NOPASSWD:ALL" | tee /etc/sudoers.d/{username to be used} +``` + +### 4.2 Get Installation Package + +Download the installation package and unpack it. + +```bash +wget https://webank-ai-1251170195.cos.ap-guangzhou.myqcloud.com/fate/${version}/release/standalone_fate_install_${version}_release.tar.gz; +tar -xzvf standalone_fate_install_${version}_release.tar.gz +``` + +### 4.3 Installation + +Navigate to the unpacked directory and use `bin/init.sh` for installation. + +This script will automatically: + +- Install necessary operating system dependency packages +- Install Python 3.6 environment +- Install Python package dependencies +- Install JDK environment +- Configure FATE environment variable scripts +- Configure FateFlow +- Configure Fateboard +- Install FATE client + +```bash +cd standalone_fate_install_${version}_release; +bash bin/init.sh init +``` + +### 4.4 Start + +```bash +bash bin/init.sh status +bash bin/init.sh start +``` + +### 4.5 Testing + +- Load environment variables + +```bash +source bin/init_env.sh +``` + +- [Test Items](#4-Test-Items) + +## 5. Test Items + +### 5.1 Toy Test + +```bash +flow test toy -gid 10000 -hid 10000 +``` + +If successful, the screen will display statements similar to the following: + +```bash +toy test job xxx is success +``` diff --git a/deploy/standalone-deploy/README.zh.md b/deploy/standalone-deploy/README.zh.md new file mode 100644 index 0000000000..7ee53f9b25 --- /dev/null +++ b/deploy/standalone-deploy/README.zh.md @@ -0,0 +1,210 @@ +# FATE 单机部署指南 + +[English](README.md) + +## 1. 说明 + +**服务器配置:** + +- **数量:** 1 +- **配置:** 8 核 / 16GB 内存 / 500GB 硬盘 +- **操作系统:** CentOS Linux release 7 +- **用户:** User: app owner:apps + +单机版提供 3 种部署方式,可以根据实际情况选择: + +- 从 PyPI 安装 FATE +- 使用 Docker 镜像安装 FATE +- 在主机中安装 FATE (使用已编译的安装包) + +## 2. 从 PyPI 安装 FATE +### 2.1 安装Python环境 +- [conda](https://docs.conda.io/projects/miniconda/en/latest/) 环境准备及安装 +- 创建虚拟环境 +```shell +# fate的运行环境为python>=3.8 +conda create -n fate_env python=3.8 +conda activate fate_env +``` + +### 2.2 安装 FATE +本节介绍从 PyPI 安装携带FATE-Flow服务,以及无服务直接使用FATE包的两种 FATE 安装方法 + +#### 2.2.1 安装 FATE,同时携带 FATE-Flow 服务 +FATE-Flow提供了联邦作业生命周期管理,包括调度、数据管理、模型和指标管理等。 + +##### 2.2.1.1 安装FATE、FATE-Flow、FATE-Client +```shell +pip install fate_client[fate,fate_flow]==2.0.0.b0 +``` + +#### 2.2.1.2 服务初始化 +```shell +fate_flow init --ip 127.0.0.1 --port 9380 --home $HOME_DIR +pipeline --ip 127.0.0.1 --port 9380 +``` +- `ip`:服务运行的ip +- `port`:服务运行的 http 端口 +- `home`:数据存储目录。主要包括:数据/模型/日志/作业配置/sqlite.db等内容 + +#### 2.2.1.3 服务启动 +```shell +fate_flow start +``` + +#### 2.2.1.4 测试 + +- [测试项](#5-测试项) + +### 2.2.2 直接安装FATE +FATE提供多种联邦算法和安全协议, 用户可以在安装 FATE 后直接使用内置算法和安全协议。 + +#### 2.2.1.1 安装 FATE +```shell +pip install pyfate==2.0.0.b0 +``` + +#### 2.2.2.2 使用指引 +参考 [examples](../../doc/2.0/fate/ml) + + +## 3. 使用 Docker 镜像安装 FATE(推荐) + +建议使用 Docker 镜像,这样可以大大降低遇到问题的可能性。 + +**注意:** 请使用实际的版本号替换示例中的 `${version}`。 + +### 2.1 部署前环境检查 + +- 主机需要能够访问外部网络,从公共网络中拉取安装包和 Docker 镜像。 +- 依赖 [docker](https://download.docker.com/linux/),Docker 建议版本为 18.09。您可以使用以下命令验证 Docker 环境:`docker --version`。有关 Docker 的起停和其他操作,请参考 `docker --help`。 +- 在执行之前,请检查端口 8080 是否已被占用。如果要重新执行,请使用 Docker 命令删除以前的容器和镜像。 + +设置部署所需环境变量(注意,通过以下方式设置的环境变量仅在当前终端会话中有效。如果打开新的终端会话,例如重新登录或打开新窗口,请重新设置)。 + +```bash +export version={本次部署的 FATE 版本号, 如 2.0.0-beta} +``` + +示例: + +```bash +export version=2.0.0-beta +``` + +### 3.2 拉取镜像 + +#### 3.2.1 通过公共镜像服务 + +```bash +# Docker Hub +docker pull federatedai/standalone_fate:${version} + +# 腾讯云容器镜像 +docker pull ccr.ccs.tencentyun.com/federatedai/standalone_fate:${version} +docker tag ccr.ccs.tencentyun.com/federatedai/standalone_fate:${version} federatedai/standalone_fate:${version} +``` + +#### 3.2.2 通过镜像包 + +```bash +wget https://webank-ai-1251170195.cos.ap-guangzhou.myqcloud.com/fate/${version}/release/standalone_fate_docker_image_${version}_release.tar.gz +docker load -i standalone_fate_docker_image_${version}_release.tar.gz +docker images | grep federatedai/standalone_fate +``` + +如果您能看到对应 `${version}` 的镜像,则表示镜像下载成功。 + +### 3.3 启动 + +```bash +docker run -it --name standalone_fate -p 8080:8080 federatedai/standalone_fate:${version} +``` + +### 3.4 测试 + +```bash +source /data/projects/fate/fate_flow/bin/init_env.sh +``` + +- [测试项](#5-测试项) + +## 4. 在主机中安装 FATE(使用已编译的安装包) + +**注意:** 请使用实际的版本号替换示例中的 `${version}`。 + +### 4.1 部署前环境检查 + +检查本地端口 8080、9360 和 9380 是否被占用。 + +```bash +netstat -apln|grep 8080; +netstat -apln|grep 9360; +netstat -apln|grep 9380 +``` + +由于需要安装操作系统依赖包,所以需要 root 权限。您可以使用 root 用户执行后续操作,如果不使用 root 用户,请为要使用的用户分配 sudo 权限: + +```bash +echo "{要使用的用户名} ALL=(ALL) NOPASSWD:ALL" | tee /etc/sudoers.d/{要使用的用户名} +``` + +### 4.2 获取安装包 + +下载安装包并解压缩。 + +```bash +wget https://webank-ai-1251170195.cos.ap-guangzhou.myqcloud.com/fate/${version}/release/standalone_fate_install_${version}_release.tar.gz; +tar -xzvf standalone_fate_install_${version}_release.tar.gz +``` + +### 4.3 安装 + +进入解压后的目录并使用 `bin/init.sh` 进行安装。 + +该脚本将自动完成以下任务: + +- 安装必要的操作系统依赖包 +- 安装 Python 3.6 环境 +- 安装 Python 包依赖 +- 安装 JDK 环境 +- 配置 FATE 环境变量脚本 +- 配置 FateFlow +- 配置 Fateboard +- 安装 FATE 客户端 + +```bash +cd standalone_fate_install_${version}_release; +bash bin/init.sh init +``` + +### 4.4 启动 + +```bash +bash bin/init.sh status +bash bin/init.sh start +``` + +### 4.5 测试 + +- 加载环境变量 + +```bash +source bin/init_env.sh +``` + +5 [测试项](#5-测试项) + +## 5. 测试项 + +### 5.1 Toy 测试 + +```bash +flow test toy -gid 10000 -hid 10000 +``` + +如果成功,屏幕将显示类似下方的语句: + +```bash +toy test job xxx is success +``` diff --git a/doc/2.0/README.md b/doc/2.0/README.md index c0f069b844..5b0aa5341e 100644 --- a/doc/2.0/README.md +++ b/doc/2.0/README.md @@ -1,5 +1,5 @@ #### 文档索引 +- [FATE Flow V2.0 Quick Start](https://github.com/FederatedAI/FATE-Flow/blob/v2.0.0-beta/doc/quick_start.md) - [FATE V2.0 Quick Start](./quick_start.md) -- [FATE Flow V2.0 Quick Start](https://github.com/FederatedAI/FATE-Flow/blob/v2.0.0-alpha/doc/quick_start.md) -- [FATE FLOW V2.0 方案](https://github.com/FederatedAI/FATE-Flow/blob/v2.0.0-alpha/doc/2.0.0-alpha.md) -- [OSX方案](./osx/osx.md) \ No newline at end of file +- [OSX方案](./osx/osx.md) +- [FATE Components](./fate/components) \ No newline at end of file diff --git a/doc/2.0/fate/components/README.md b/doc/2.0/fate/components/README.md new file mode 100644 index 0000000000..aa21af877a --- /dev/null +++ b/doc/2.0/fate/components/README.md @@ -0,0 +1,41 @@ +# Federated Machine Learning + +[[中文](README.zh.md)] + +FATE-ML includes implementation of many common machine learning +algorithms on federated learning. All modules are developed in a +decoupling modular approach to enhance scalability. Specifically, we +provide: + +1. Federated Statistic: PSI, Union, Pearson Correlation, etc. +2. Federated Feature Engineering: Feature Sampling, Feature Binning, + Feature Selection, etc. +3. Federated Machine Learning Algorithms: LR, GBDT, DNN +4. Model Evaluation: Binary | Multiclass | Regression evaluation +5. Secure Protocol: Provides multiple security protocols for secure + multi-party computing and interaction between participants. + +## Algorithm List + +| Algorithm | Module Name | Description | Data Input | Data Output | Model Input | Model Output | +|--------------------------------------------------|------------------------|------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------|----------------------------------------------------------------------------|-------------------------------|--------------| +| [PSI](psi.md) | PSI | Compute intersect data set of multiple parties without leakage of difference set information. Mainly used in hetero scenario task. | input_data | output_data | | | +| [Sampling](sample.md) | Sample | Federated Sampling data so that its distribution become balance in each party.This module supports local and federation scenario. | input_data | output_data | | | +| [Data Split](data_split.md) | DataSplit | Split one data table into 3 tables by given ratio or count, this module supports local and federation scenario | input_data | train_output_data, validate_output_data, test_output_data | | | +| [Feature Scale](feature_scale.md) | FeatureScale | module for feature scaling and standardization. | train_data, test_data | train_output_data, test_output_data | input_model | output_model | +| [Data Statistics](statistics.md) | Statistics | This component will do some statistical work on the data, including statistical mean, maximum and minimum, median, etc. | input_data | output_data | | output_model | +| [Hetero Feature Binning](feature_binning.md) | HeteroFeatureBinning | With binning input data, calculates each column's iv and woe and transform data according to the binned information. | train_data, test_data | train_output_data, test_output_data | input_model | output_model | +| [Hetero Feature Selection](feature_selection.md) | HeteroFeatureSelection | Provide 3 types of filters. Each filters can select columns according to user config | train_data, test_data | train_output_data, test_output_data | input_models, input_model | output_model | +| [Coordinated-LR](logistic_regression.md) | CoordinatedLR | Build hetero logistic regression model through multiple parties. | train_data, validate_data, test_data, cv_data | train_output_data, validate_output_data, test_output_data, cv_output_datas | input_model, warm_start_model | output_model | +| [Coordinated-LinR](linear_regression.md) | CoordinatedLinR | Build hetero linear regression model through multiple parties. | train_data, validate_data, test_data, cv_data | train_output_data, validate_output_data, test_output_data, cv_output_datas | input_model, warm_start_model | output_model | +| [Homo-LR](logistic_regression.md) | HomoLR | Build homo logistic regression model through multiple parties. | train_data, validate_data, test_data, cv_data | train_output_data, validate_output_data, test_output_data, cv_output_datas | input_model, warm_start_model | output_model | +| [Homo-NN](homo_nn.md) | HomoNN | Build homo neural network model through multiple parties. | train_data, validate_data, test_data, cv_data | train_output_data, validate_output_data, test_output_data, cv_output_datas | input_model, warm_start_model | output_model | +| [Hetero Secure Boosting](ensemble.md) | HeteroSecureBoost | Build hetero secure boosting model through multiple parties | train_data, validate_data, test_data, cv_data | train_output_data, validate_output_data, test_output_data, cv_output_datas | input_model, warm_start_model | output_model | +| [Evaluation](evaluation.md) | Evaluation | Output the model evaluation metrics for user. | input_data | | | | +| [Union](union.md) | Union | Combine multiple data tables into one. | input_data_list | output_data | | | + +## Secure Protocol + +- [Encrypt](secure_protocol.md#encrypt) + - [Paillier encryption](secure_protocol.md#paillier-encryption) +- [Diffie Hellman Key Exchange](secure_protocol.md#diffie-hellman-key-exchange) diff --git a/doc/2.0/fate/components/README.zh.md b/doc/2.0/fate/components/README.zh.md new file mode 100644 index 0000000000..3e73d4b5b2 --- /dev/null +++ b/doc/2.0/fate/components/README.zh.md @@ -0,0 +1,34 @@ +# 联邦机器学习 + +Federatedml模块包括许多常见机器学习算法联邦化实现。所有模块均采用去耦的模块化方法开发,以增强模块的可扩展性。具体来说,我们提供: + +1. 联邦统计: 包括隐私交集计算,并集计算 +2. 联邦特征工程:包括联邦采样,联邦特征分箱,联邦特征选择等。 +3. 联邦机器学习算法:包括横向和纵向的联邦LR, GBDT, DNN等 +4. 模型评估:提供对二分类,多分类,回归评估 +5. 安全协议:提供了多种安全协议,以进行更安全的多方交互计算。 + +## 算法清单 + +| 算法 | 模块名 | 描述 | 数据输入 | 数据输出 | 模型输入 | 模型输出 | +|--------------------------------------------------|------------------------|--------------------------------------------|-----------------------------------------------|----------------------------------------------------------------------------|-------------------------------|--------------| +| [PSI](psi.md) | PSI | 计算两方的相交数据集,而不会泄漏任何差异数据集的信息。主要用于纵向任务。 | input_data | output_data | | | +| [Sampling](sample.md) | Sample | 对数据进行联邦采样,使得数据分布在各方之间变得平衡。这一模块同时支持本地和联邦场景。 | input_data | output_data | | | +| [Data Split](data_split.md) | DataSplit | 通过多方构建横向神经网络模块。 | input_data | train_output_data, validate_output_data, test_output_data | | | +| [Feature Scale](feature_scale.md) | FeatureScale | 特征归一化和标准化。 | train_data, test_data | train_output_data, test_output_data | input_model | output_model | +| [Data Statistics](statistics.md) | Statistics | 通过多方构建横向神经网络模块。 | input_data | output_data | | output_model | +| [Hetero Feature Binning](feature_binning.md) | HeteroFeatureBinning | 使用分箱的输入数据,计算每个列的iv和woe,并根据合并后的信息转换数据。 | train_data, test_data | train_output_data, test_output_data | input_model | output_model | +| [Hetero Feature Selection](feature_selection.md) | HeteroFeatureSelection | 提供多种类型的filter。每个filter都可以根据用户配置选择列。 | train_data, test_data | train_output_data, test_output_data | input_models, input_model | output_model | +| [Coordinated-LR](logistic_regression.md) | CoordinatedLR | 通过多方构建纵向逻辑回归模块。 | train_data, validate_data, test_data, cv_data | train_output_data, validate_output_data, test_output_data, cv_output_datas | input_model, warm_start_model | output_model | +| [Coordinated-LinR](linear_regression.md) | CoordinatedLinR | 通过多方建立纵向线性回归模块。 | train_data, validate_data, test_data, cv_data | train_output_data, validate_output_data, test_output_data, cv_output_datas | input_model, warm_start_model | output_model | +| [Homo-LR](logistic_regression.md) | HomoLR | 通过多方构建横向逻辑回归模块。 | train_data, validate_data, test_data, cv_data | train_output_data, validate_output_data, test_output_data, cv_output_datas | input_model, warm_start_model | output_model | +| [Homo-NN](homo_nn.md) | HomoNN | 通过多方构建横向神经网络模块。 | train_data, validate_data, test_data, cv_data | train_output_data, validate_output_data, test_output_data, cv_output_datas | input_model, warm_start_model | output_model | +| [Hetero Secure Boosting](ensemble.md) | HeteroSecureBoost | 通过多方构建横向神经网络模块。 | train_data, validate_data, test_data, cv_data | train_output_data, validate_output_data, test_output_data, cv_output_datas | input_model, warm_start_model | output_model | +| [Evaluation](evaluation.md) | Evaluation | O通过多方构建横向神经网络模块。 | input_data | | | | +| [Union](union.md) | Union | 将多个数据表合并成一个。 | input_data_list | output_data | | | + +## 安全协议 + +- [Encrypt](secure_protocol.md#encrypt) + - [Paillier encryption](secure_protocol.md#paillier-encryption) +- [Diffie Hellman Key Exchange](secure_protocol.md#diffie-hellman-key-exchange) diff --git a/doc/2.0/fate/components/data_split.md b/doc/2.0/fate/components/data_split.md new file mode 100644 index 0000000000..84a41bb478 --- /dev/null +++ b/doc/2.0/fate/components/data_split.md @@ -0,0 +1,32 @@ +# Data Split + +Data Split module splits data into train, test, and/or validate +sets of arbitrary sizes. The module is based on sampling method. + +# Use + +Data Split supports local(same as homogeneous) and heterogeneous (only Guest has y) mode. + +Here lists supported split modes and scenario. + +| Split Mode | Federated Heterogeneous | Federated Homogeneous(Local) | +|--------------|--------------------------------------------------------------------------------|--------------------------------------------------------------------------------| +| Random | [✓](../../../examples/pipeline/data_split/test_data_split.py) | [✓](../../../examples/pipeline/data_split/test_data_split_multi_host.py) | +| Stratified | [✓](../../../examples/pipeline/data_split/test_data_split_stratified.py) | [✓](../../../examples/pipeline/data_split/test_data_split_stratified.py) | + +Data Split module takes single data input as specified in job config file +and always outputs three tables (train, test, and validate +data sets). Each data ouput may be used as input of another module. Below are the +rules regarding set sizes: + +1. if all three set sizes are None, the + original data input will be split in the following ratio: 80% to train + set, 20% to validate set, and an empty test set; + +2. if only test size or + validate size is given, train size is set to be of complement given + size; + +3. only one of the three sizes is needed to split input data, but + all three may be specified. The module takes either int (instance count) + or float (fraction) value for set sizes, but mixed-type inputs are not accepted. diff --git a/doc/2.0/fate/components/feature_binning.md b/doc/2.0/fate/components/feature_binning.md new file mode 100644 index 0000000000..bc56d55c1f --- /dev/null +++ b/doc/2.0/fate/components/feature_binning.md @@ -0,0 +1,48 @@ +# Hetero Feature Binning + +Feature binning or data binning is a data pre-processing technique. It +can be used to reduce the effects of minor observation errors, calculate +information values and so on. + +Currently, we provide quantile binning and bucket binning methods. To +achieve quantile binning approach, we have used a special data structure +mentioned in this +[paper](https://www.researchgate.net/profile/Michael_Greenwald/publication/2854033_Space-Efficient_Online_Computation_of_Quantile_Summaries/links/0f317533ee009cd3f3000000/Space-Efficient-Online-Computation-of-Quantile-Summaries.pdf). +Feel free to check out the detail algorithm in the paper. + +As for calculating the federated iv and woe values, the following figure +can describe the principle properly. + +![Figure 1 (Federated Feature Binning +Principle)](../images/binning_principle.png) + +As the figure shows, B party which has the data labels encrypt its +labels with Addiction homomorphic encryption and then send to A. A +static each bin's label sum and send back. Then B can calculate woe and +iv base on the given information. + +For multiple hosts, it is similar with one host case. Guest sends its +encrypted label information to all hosts, and each of the hosts +calculates and sends back the static info. + +![Figure 2: Multi-Host Binning +Principle](../images/multiple_host_binning.png) + +## Features + +1. Support Quantile Binning based on quantile summary algorithm. +2. Support Bucket Binning. +3. Support calculating woe and iv values. +4. Support transforming data into bin indexes or woe value(guest only). +5. Support multiple-host binning. +6. Support asymmetric binning methods on Host & Guest sides. + +Below lists supported features with links to examples: + +| Cases | Scenario | +|--------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| Input Data with Categorical Features | [bucket binning](../../../examples/pipeline/hetero_feature_binning/test_feature_binning_bucket.py)
[quantile binning](../../../examples/pipeline/hetero_feature_binning/test_feature_binning_quantile.py) | +| Output Data Transformed | [bin index](../../../examples/pipeline/hetero_feature_binning/test_feature_binning_asymmetric.py)
[woe value(guest-only)](.../../../examples/pipeline/hetero_feature_binning/test_feature_binning_asymmetric.py) | +| Skip Metrics Calculation | [multi_host](../../../examples/pipeline/hetero_feature_binning/test_feature_binning_multi_host.py) | + + diff --git a/doc/2.0/fate/components/feature_scale.md b/doc/2.0/fate/components/feature_scale.md new file mode 100644 index 0000000000..6702e9da36 --- /dev/null +++ b/doc/2.0/fate/components/feature_scale.md @@ -0,0 +1,17 @@ +# Feature Scale + +Feature scale is a process that scales each feature along column. +Feature Scale module supports min-max scale and standard scale. + +1. min-max scale: this estimator scales and translates each feature + individually such that it is in the given range on the training set, + e.g. between min and max value of each feature. +2. standard scale: standardize features by removing the mean and + scaling to unit variance + +# Use + +| Scale Method | Federated Heterogeneous | +|--------------|------------------------------------------------------------------------| +| Min-Max | [✓](../../../examples/pipeline/sample/test_sample_unilateral.py) | +| Standard | [✓](../../../examples/pipeline/sample/test_sample_unilateral.py) | diff --git a/doc/2.0/fate/components/feature_selection.md b/doc/2.0/fate/components/feature_selection.md new file mode 100644 index 0000000000..dc40d88f2b --- /dev/null +++ b/doc/2.0/fate/components/feature_selection.md @@ -0,0 +1,57 @@ +# Hetero Feature Selection + +Feature selection is a process that selects a subset of features for +model construction. Taking advantage of feature selection can improve +model performance. + +In this version, we provide several filter methods for feature +selection. Note that module works in a cascade manner where +selected result of filter A will be input into next filter B. +User should pay attention to the order of listing when +supplying multiple filters to `filter_methods` param in job configuration. + +## Features + +Below lists available input models and their corresponding filter methods with links to examples: + +| Input Models | Filter Method | +|-------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| None | [manual](../../../examples/pipeline/hetero_feature_selection/test_feature_selection_manual.py) | +| Binning | [iv_filter(threshold)](../../../examples/pipeline/hetero_feature_selection/test_feature_selection_binning.py)
[iv_filter(top_k)](../../../examples/pipeline/hetero_feature_selection/test_feature_selection_multi_model.py)
[iv_filter(top_percentile)](../../../examples/pipeline/hetero_feature_selection/test_feature_selection_multi_host.py) | +| Statistic | [statistic_filter](../../../examples/pipeline/hetero_feature_selection/test_feature_selection_statistics.py) | + +Most of the filter methods above share the same set of configurable parameters. +Below lists their acceptable parameter values. + +| Filter Method | Parameter Name | metrics | filter_type | take_high | +|-------------------------------------|-------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------|--------------| +| IV Filter | filter_param | "iv" | "threshold", "top_k", "top_percentile" | True | +| Statistic Filter | statistic_param | "max", "min", "mean", "median", "std", "var", "coefficient_of_variance", "skewness", "kurtosis", "missing_count", "missing_ratio", quantile(e.g."95%") | "threshold", "top_k", "top_percentile" | True/False | + +1. + - iv\_filter: Use iv as criterion to selection features. Support + three mode: threshold value, top-k and top-percentile. + + - threshold value: Filter those columns whose iv is smaller + than threshold. You can also set different threshold for + each party. + - top-k: Sort features from larger iv to smaller and take top + k features in the sorted result. + - top-percentile. Sort features from larger to smaller and + take top percentile. + +2. statistic\_filter: Use statistic values calculate from DataStatistic + component. Support coefficient of variance, missing value, + percentile value etc. You can pick the columns with higher statistic + values or smaller values as you need. + +3. manually: Indicate features that need to be filtered or kept. + +Besides, we support multi-host federated feature selection for iv +filters. Starting in ver 2.0.0-beta, all data sets will obtain anonymous header +during transformation from local file. Guest use iv filters' logic to judge +whether a feature is left or not. Then guest sends result filter back to hosts. +During this selection process, guest will not know the real name of host(s)' features. + +![Figure 4: Multi-Host Selection +Principle\](../images/multi_host_selection.png) \ No newline at end of file diff --git a/doc/2.0/fate/components/linear_regression.md b/doc/2.0/fate/components/linear_regression.md new file mode 100644 index 0000000000..adce075e3f --- /dev/null +++ b/doc/2.0/fate/components/linear_regression.md @@ -0,0 +1,80 @@ +# Federated Linear Regression + +Linear Regression(LinR) is a simple statistic model widely used for +predicting continuous numbers. FATE provides Heterogeneous Linear +Regression(CoordinatedLinR). + +Below lists features of Coordinated LinR model: + +| Linear Model | Multi-Host | Cross Validation | Warm-Start | +|-------------------|--------------------------------------------------------------------------------|------------------------------------------------------------------------|---------------------------------------------------------------| +| Hetero LinR | [✓](../../../examples/pipeline/coordinated_linr/test_linr_multi_host.py) | [✓](../../../examples/pipeline/coordinated_linr/test_linr_cv.py) | [✓](../../../examples/pipeline/test_linr_warm_start.py) | + +## Coordinated LinR + +CoordinatedLinR also supports multi-Host training. + +Here we simplify participants of the federation process into three +parties. Party A represents Guest, party B represents Host. Party C, +which is also known as “Arbiter,” is a third party that works as +coordinator. Party C is responsible for generating private and public +keys. + +The process of HeteroLinR training is shown below: + +![Figure 1 (Federated HeteroLinR +Principle)](../images/HeteroLinR.png) + +A sample alignment process is conducted before training. The sample +alignment process identifies overlapping samples in databases of all +parties. The federated model is built based on the overlapping samples. +The whole sample alignment process is conducted in encryption mode, and +so confidential information (e.g. sample ids) will not be leaked. + +In the training process, party A and party B each compute the elements +needed for final gradients. Arbiter aggregates, calculates, and +transfers back the final gradients to corresponding parties. For more +details on the secure model-building process, please refer to this +[paper.](https://arxiv.org/pdf/1902.04885.pdf) + +## Features + +1. L1 & L2 regularization + +2. Mini-batch mechanism + +3. Weighted training + +4. Torch optimization methods: + + > - rmsprop: RMSProp + > - adadelta: AdaDelta + > - adagrad: AdaGrad + > - adam: Adam + > - adamw: AdamW + > - adamax: Adamax + > - asgd: ASGD + > - nadam: NAdam + > - radam: RAdam + +> - rprop: RProp + > - sgd: gradient descent with arbitrary batch sizegorithm details can refer + to [this paper](https://arxiv.org/abs/1912.00513v2). + +5. Torch Learning Rate Scheduler methods: + > - constant + > - step + > - linear +6. Three converge criteria: + + > - diff + > Use difference of loss between two iterations, not available + > for multi-host training + > + > - abs + > Use the absolute value of loss + > + > - weight\_diff + > Use difference of model weights + +5. Support multi-host modeling task. diff --git a/doc/2.0/fate/components/logistic_regression.md b/doc/2.0/fate/components/logistic_regression.md new file mode 100644 index 0000000000..4b95e282ef --- /dev/null +++ b/doc/2.0/fate/components/logistic_regression.md @@ -0,0 +1,112 @@ +# Federated Logistic Regression + +Logistic Regression(LR) is a widely used statistic model for +classification problems. FATE provided two modes of federated LR: +Homogeneous LR (HomoLR) and Heterogeneous LR (HeteroLR and Hetero_SSHE_LR). + +Below lists features of each LR models: + +| Linear Model | Multiclass(OVR) | Multi-Host | Cross Validation | Warm-Start | +|-----------------|-----------------------------------------------------------------------------|----------------------------------------------------------------------------|-------------------------------------------------------------------------|----------------------------------------------------------------------------| +| Hetero LR | [✓](../../../examples/pipeline/coordinated_lr/test_lr_multi_class.py) | [✓](../../../examples/pipeline/coordinated_lr/test_lr_multi_host.py) | [✓](../../../examples/pipeline/coordinated_lr/test_lr_cv.py) | [✓](../../../examples/pipeline/coordinated_lr/test_lr_warm_start.py) | +| Homo LR | [✓]() | [✓]() | [✓]() | [✓]() | + +We simplified the federation process into three parties. Party A +represents Guest, party B represents Host while party C, which also +known as "Arbiter", is a third party that holds a private key for each +party and work as a coordinator. + +## Heterogeneous LR + +The HeteroLR carries out the federated learning in a different way. As +shown in Figure 2, A sample alignment process is conducted before +training. This sample alignment process is to identify overlapping +samples stored in databases of the two involved parties. The federated +model is built based on those overlapping samples. The whole sample +alignment process will **not** leak confidential information (e.g., +sample ids) on the two parties since it is conducted in an encrypted +way. +ion) gradients to +arbiter. The arbiter aggregates these gradients to form a federated +gradient that will then be distributed to all parties for updating their +local models. Similar to traditional LR, the training process will stop +when the federated model converges or the whole training process reaches +a predefined max-iteration threshold. More details is available in this +[Practical Secure Aggregation for Privacy-Preserving Machine Learning](https://dl.acm.org/citation.cfm?id=3133982). + +## Coordinated LR + +The CoordinatedLR carries out the federated learning in a different way. As +shown in Figure 2, A sample alignment process is conducted before +training. This sample alignment process is to identify overlapping +samples stored in databases of the two involved parties. The federated +model is built based on those overlapping samples. The whole sample +alignment process will **not** leak confidential information (e.g., +sample ids) on the two parties since it is conducted in an encrypted +way. + +![Figure 1 (Federated HeteroLR Principle)](../images/HeteroLR.png) + +In the training process, party A and party B compute out the elements +needed for final gradients. Arbiter aggregate them and compute out the +gradient and then transfer back to each party. More details is available in +this: [Private federated learning on vertically partitioned data via entity resolution and additively homomorphic encryption](https://arxiv.org/abs/1711.10677). + +## Multi-host hetero-lr + +For multi-host scenario, the gradient computation still keep the same as +single-host case. However, we use the second-norm of the difference of +model weights between two consecutive iterations as the convergence +criterion. Since the arbiter can obtain the completed model weight, the +convergence decision is happening in Arbiter. + +![Figure 2 (Federated Multi-host HeteroLR +Principle)](../images/hetero_lr_multi_host.png) + +## Features + +- Both Homo-LR and Hetero-LR(CoordinatedLR) + +> 1. L1 & L2 regularization +> +> 2. Mini-batch mechanism +> +> 3. Weighted training +> +> 4. Torch optimization methods: + > + > > - rmsprop: RMSProp + > > - adadelta: AdaDelta + > > - adagrad: AdaGrad + > > - adam: Adam + > > - adamw: AdamW + > > - adamax: Adamax + > > - asgd: ASGD + > > - nadam: NAdam + > > - radam: RAdam + > > - rprop: RProp + > > - sgd: gradient descent with arbitrary batch size +> +> 5. Torch Learning Rate Scheduler methods: + > > - constant + > > - step + > > - linear +> +> 5. Three converge criteria: + > + > > - diff + > > Use difference of loss between two iterations, not available + > > for multi-host training; + > > + > > - abs + > > use the absolute value of loss; + > > + > > - weight\_diff + > > use difference of model weights +> +> 6. Support multi-host modeling task. + + +Hetero-LR extra features + +1. When modeling a multi-host task, "weight\_diff" converge criteria is supported only. diff --git a/doc/2.0/fate/components/psi.md b/doc/2.0/fate/components/psi.md new file mode 100644 index 0000000000..53f65a2874 --- /dev/null +++ b/doc/2.0/fate/components/psi.md @@ -0,0 +1,18 @@ +# PSI + +This module implements PSI(Private Set Intersection) +based on [elliptic curve Diffie-Hellman scheme](https://en.wikipedia.org/wiki/Elliptic-curve_Diffie–Hellman). +ECDH mode currently uses [Curve25519](https://en.wikipedia.org/wiki/Curve25519), +which offers 128 bits of security with key size of 256 bits. + +Below is an illustration of ECDH intersection. + +![Figure 1 (ECDH +PSI)](../images/ecdh_intersection.png) + +For details on how to hash value to given curve, +please refer [here](https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-hash-to-curve-10#section-6.7.1). + +Note that starting in ver 2.0.0-beta, data uploaded should always have sample id and match id; +for data sets that originally only contain id, please specify 'extend_sid' in upload config +as in this [example](../../../examples/pipeline/upload/test_upload.py). \ No newline at end of file diff --git a/doc/2.0/fate/components/sample.md b/doc/2.0/fate/components/sample.md new file mode 100644 index 0000000000..d33d0f8ba3 --- /dev/null +++ b/doc/2.0/fate/components/sample.md @@ -0,0 +1,11 @@ +# Federated Sampling + +Sample module supports random sampling and stratified sampling. + +- `hetero_sync` should be set to True for heterogeneous scenario; +- `replace` must be set to True if upsample is needed. + +| Sample Type | Federated Heterogeneous | Federated Homogeneous(Local) | +|-------------------|-------------------------------------------------------------|--------------------------------------------------------------------------------| +| By Fraction | [✓](../../../examples/pipeline/sample/test_sample.py) | [✓](../../../examples/pipeline/sample/test_data_split_multi_host.py) | +| By Exact Number | [✓](../../../examples/pipeline/sample/test_sample.py) | [✓](../../../examples/pipeline/data_split/test_data_split_stratified.py) | diff --git a/doc/2.0/fate/components/statistics.md b/doc/2.0/fate/components/statistics.md new file mode 100644 index 0000000000..76dad423d0 --- /dev/null +++ b/doc/2.0/fate/components/statistics.md @@ -0,0 +1,27 @@ +# Data Statistic + +This component will do some statistical work on the data, including +statistical mean, maximum and minimum, median, etc. + +The indicators for each column that can be statistic are list as follow. + +1. count: Number of data +2. sum: The sum of this column +3. mean: The mean of this column +4. variance/stddev: Variance and standard deviation of this column +5. median: Median of this column +6. min/max: Min and Max value of this column +7. coefficient of variance: The formula is abs(stddev / mean) +8. missing\_count/missing\_ratio: Number and ratio of missing value in + this column +9. skewness: definition may be found + [here](https://en.wikipedia.org/wiki/Skewness) +10. kurtosis: definition may be found + [here](https://en.wikipedia.org/wiki/Kurtosis) +11. percentile: The value of percentile. Accept 0% to 100% while the + number before the "%" should be integer. + +For examples of running statistics + +These statistic results can be used in feature selection as a criterion, +as in this [example](../../../examples/pipeline/hetero_feature_selection/test_feature_selection_statistics.py). diff --git a/doc/2.0/fate/components/union.md b/doc/2.0/fate/components/union.md new file mode 100644 index 0000000000..b3384f7805 --- /dev/null +++ b/doc/2.0/fate/components/union.md @@ -0,0 +1,13 @@ +# Union + +Union module combines given tables into one while keeping unique entry +ids. Union is a local module. This module can be run on the +side of Host or Guest, and running this module does not require any +interaction with outside parties. + +## Use + +Union currently supports concatenation along axis 0. + +For tables to be concatenated, their header, including sample id, match id, and label column (if label exists), +should match. Example of such a union task may be found [here](../../../examples/pipeline/union/test_union.py). diff --git a/doc/2.0/fate/ml/hetero_secureboost_tutorial.ipynb b/doc/2.0/fate/ml/hetero_secureboost_tutorial.ipynb new file mode 100644 index 0000000000..3f2476ca06 --- /dev/null +++ b/doc/2.0/fate/ml/hetero_secureboost_tutorial.ipynb @@ -0,0 +1,1114 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Hetero-SecureBoost Tutorial\n", + "\n", + "In a vertically partitioned data setting, multiple parties have different feature sets for the same common user samples. Federated learning enables these parties to collaboratively train a model without sharing their actual data. The model is trained locally at each party, and only model updates are shared, not the actual data. \n", + "SecureBoost is a specialized tree-boosting framework designed for vertical federated learning. It performs entity alignment under a privacy-preserving protocol and constructs boosting trees across multiple parties using an encryption strategy. It allows for high-quality, lossless model training without needing a trusted third party.\n", + "\n", + "In this tutorial, we will show you how to run a Hetero-SecureBoost task under FATE-2.0 without using a Pipeline. You can refer to this example for local model experimentation, algorithm modification, and testing, although we do not recommend using it directly in a production environment." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup Hetero-Secureboost Step by Step\n", + "\n", + "To run a Hetero-Secureboost task, several steps are needed:\n", + "1. Import required libraries and create fate context\n", + "2. Prepare tabular data and transform them into fate dataframe\n", + "3. guest&host run the python script, fit the Hetero-Secureboost model\n", + "\n", + "### Import Libs and Create Context\n", + "We import these lib from later use." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "from fate.arch.dataframe import PandasReader\n", + "import sys\n", + "from fate.ml.ensemble.algo.secureboost.hetero.guest import HeteroSecureBoostGuest\n", + "from fate.ml.ensemble.algo.secureboost.hetero.host import HeteroSecureBoostHost\n", + "from datetime import datetime" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will use 'create_ctx' to create fate context. When creating fate context, please make sure that every party's context is initialized with the same unique context_name. Here, we provide 'get_current_datetime_str' in order to get a unique context name\n", + "according to the current date and time. We are running a two party vertical federation task, so guest party id and host party id are:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "guest = (\"guest\", \"10000\")\n", + "host = (\"host\", \"9999\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "def get_current_datetime_str():\n", + " return datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n", + "\n", + "\n", + "def create_ctx(local, context_name):\n", + " from fate.arch import Context\n", + " from fate.arch.computing.standalone import CSession\n", + " from fate.arch.federation.standalone import StandaloneFederation\n", + " import logging\n", + "\n", + " # prepare log\n", + " logger = logging.getLogger()\n", + " logger.setLevel(logging.INFO)\n", + " console_handler = logging.StreamHandler()\n", + " console_handler.setLevel(logging.INFO)\n", + " formatter = logging.Formatter(\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\")\n", + " console_handler.setFormatter(formatter)\n", + " logger.addHandler(console_handler)\n", + " # init fate context\n", + " computing = CSession()\n", + " return Context(\n", + " computing=computing, federation=StandaloneFederation(computing, context_name, local, [guest, host])\n", + " )\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Prepare Data\n", + "\n", + "We can read a csv file and transform it into a Fate-DataFrame:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "# guest create context\n", + "guest_ctx = create_ctx(guest, get_current_datetime_str())\n", + "# read csv\n", + "df = pd.read_csv('./../../../../examples/data/breast_hetero_guest.csv')\n", + "# add sample_id column, sample id & match id are needed in FATE dataframe\n", + "df[\"sample_id\"] = [i for i in range(len(df))]\n", + "reader = PandasReader(sample_id_name=\"sample_id\", match_id_name=\"id\", label_name=\"y\", dtype=\"float32\") \n", + "data_guest = reader.to_frame(guest_ctx, df)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data_guest" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sample_ididx0x1x2x3x4x5x6x7...x10x11x12x13x14x15x16x17x18x19
00133.00.449512-1.2472260.4131780.303781-0.123848-0.184227-0.2190760.268537...-0.337360-0.728193-0.442587-0.272757-0.608018-0.577235-0.5011260.143371-0.466431-0.554102
15274.01.0800231.2078300.9568880.978402-0.555822-0.645696-0.399365-0.038153...0.0578480.392164-0.0500270.120414-0.532348-0.770613-0.519694-0.531097-0.769127-0.394858
2776.0-0.169639-1.943019-0.167192-0.2721502.3299370.006804-0.2514670.429234...0.017786-0.368046-0.105966-0.1691292.1197600.162743-0.672216-0.5770020.6269080.896114
39399.0-0.660984-0.472313-0.688248-0.634204-0.390718-0.796360-0.756680-0.839314...-0.221505-0.139439-0.317344-0.336122-0.526014-0.326291-0.368166-1.037840-0.698901-0.273818
411246.0-0.263364-0.432753-0.322891-0.322206-1.722935-1.120051-0.570489-0.976796...-0.8740500.696974-0.986625-0.589142-0.260004-0.547055-0.036596-1.040273-0.111671-0.584362
..................................................................
564499515.0-0.791630-0.158159-0.791224-0.7499590.607733-0.366730-0.574758-0.592724...-0.585313-0.375303-0.680696-0.4872740.512028-0.506814-0.361535-0.1177860.459820-0.975096
565504357.0-0.073075-0.716655-0.142066-0.174028-0.635527-0.936601-0.926297-0.723241...-0.5445290.265160-0.558919-0.431169-0.467679-0.980254-0.883294-0.933377-0.617779-0.646017
566509377.0-0.1895202.075826-0.250397-0.263902-1.508016-1.081769-0.955299-0.973701...-0.852755-0.121295-0.725744-0.559439-0.699688-0.751610-0.808558-1.073364-0.741279-0.798452
567510234.0-1.295188-0.786467-1.308161-1.067361-0.834079-1.202869-0.907465-0.831834...-0.685649-0.701704-0.817325-0.6093831.533069-0.842710-0.664258-0.3525040.398070-0.096418
568511341.0-1.284111-0.570050-1.249259-1.064801-0.821981-0.228573-0.057493-0.670622...-0.796813-0.497046-0.711388-0.621924-0.3623410.5159640.609634-0.533043-0.3683570.089304
\n", + "

569 rows × 22 columns

\n", + "
" + ], + "text/plain": [ + " sample_id id x0 x1 x2 x3 x4 \\\n", + "0 0 133.0 0.449512 -1.247226 0.413178 0.303781 -0.123848 \n", + "1 5 274.0 1.080023 1.207830 0.956888 0.978402 -0.555822 \n", + "2 7 76.0 -0.169639 -1.943019 -0.167192 -0.272150 2.329937 \n", + "3 9 399.0 -0.660984 -0.472313 -0.688248 -0.634204 -0.390718 \n", + "4 11 246.0 -0.263364 -0.432753 -0.322891 -0.322206 -1.722935 \n", + ".. ... ... ... ... ... ... ... \n", + "564 499 515.0 -0.791630 -0.158159 -0.791224 -0.749959 0.607733 \n", + "565 504 357.0 -0.073075 -0.716655 -0.142066 -0.174028 -0.635527 \n", + "566 509 377.0 -0.189520 2.075826 -0.250397 -0.263902 -1.508016 \n", + "567 510 234.0 -1.295188 -0.786467 -1.308161 -1.067361 -0.834079 \n", + "568 511 341.0 -1.284111 -0.570050 -1.249259 -1.064801 -0.821981 \n", + "\n", + " x5 x6 x7 ... x10 x11 x12 \\\n", + "0 -0.184227 -0.219076 0.268537 ... -0.337360 -0.728193 -0.442587 \n", + "1 -0.645696 -0.399365 -0.038153 ... 0.057848 0.392164 -0.050027 \n", + "2 0.006804 -0.251467 0.429234 ... 0.017786 -0.368046 -0.105966 \n", + "3 -0.796360 -0.756680 -0.839314 ... -0.221505 -0.139439 -0.317344 \n", + "4 -1.120051 -0.570489 -0.976796 ... -0.874050 0.696974 -0.986625 \n", + ".. ... ... ... ... ... ... ... \n", + "564 -0.366730 -0.574758 -0.592724 ... -0.585313 -0.375303 -0.680696 \n", + "565 -0.936601 -0.926297 -0.723241 ... -0.544529 0.265160 -0.558919 \n", + "566 -1.081769 -0.955299 -0.973701 ... -0.852755 -0.121295 -0.725744 \n", + "567 -1.202869 -0.907465 -0.831834 ... -0.685649 -0.701704 -0.817325 \n", + "568 -0.228573 -0.057493 -0.670622 ... -0.796813 -0.497046 -0.711388 \n", + "\n", + " x13 x14 x15 x16 x17 x18 x19 \n", + "0 -0.272757 -0.608018 -0.577235 -0.501126 0.143371 -0.466431 -0.554102 \n", + "1 0.120414 -0.532348 -0.770613 -0.519694 -0.531097 -0.769127 -0.394858 \n", + "2 -0.169129 2.119760 0.162743 -0.672216 -0.577002 0.626908 0.896114 \n", + "3 -0.336122 -0.526014 -0.326291 -0.368166 -1.037840 -0.698901 -0.273818 \n", + "4 -0.589142 -0.260004 -0.547055 -0.036596 -1.040273 -0.111671 -0.584362 \n", + ".. ... ... ... ... ... ... ... \n", + "564 -0.487274 0.512028 -0.506814 -0.361535 -0.117786 0.459820 -0.975096 \n", + "565 -0.431169 -0.467679 -0.980254 -0.883294 -0.933377 -0.617779 -0.646017 \n", + "566 -0.559439 -0.699688 -0.751610 -0.808558 -1.073364 -0.741279 -0.798452 \n", + "567 -0.609383 1.533069 -0.842710 -0.664258 -0.352504 0.398070 -0.096418 \n", + "568 -0.621924 -0.362341 0.515964 0.609634 -0.533043 -0.368357 0.089304 \n", + "\n", + "[569 rows x 22 columns]" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data_guest.as_pd_df()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For host side, creating the dataframe is the same as the guest side, except that the label_name is not needed." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "# host create context\n", + "guest_ctx = create_ctx(host, get_current_datetime_str())\n", + "# read csv\n", + "df = pd.read_csv('./../../../../examples/data/breast_hetero_guest.csv')\n", + "# add sample_id column, sample id & match id are needed in FATE dataframe\n", + "df[\"sample_id\"] = [i for i in range(len(df))]\n", + "reader = PandasReader(sample_id_name=\"sample_id\", match_id_name=\"id\", dtype=\"float32\") \n", + "data_host = reader.to_frame(guest_ctx, df)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sample_ididx0x1x2x3x4x5x6x7...x10x11x12x13x14x15x16x17x18x19
00133.00.449512-1.2472260.4131780.303781-0.123848-0.184227-0.2190760.268537...-0.337360-0.728193-0.442587-0.272757-0.608018-0.577235-0.5011260.143371-0.466431-0.554102
15274.01.0800231.2078300.9568880.978402-0.555822-0.645696-0.399365-0.038153...0.0578480.392164-0.0500270.120414-0.532348-0.770613-0.519694-0.531097-0.769127-0.394858
2776.0-0.169639-1.943019-0.167192-0.2721502.3299370.006804-0.2514670.429234...0.017786-0.368046-0.105966-0.1691292.1197600.162743-0.672216-0.5770020.6269080.896114
39399.0-0.660984-0.472313-0.688248-0.634204-0.390718-0.796360-0.756680-0.839314...-0.221505-0.139439-0.317344-0.336122-0.526014-0.326291-0.368166-1.037840-0.698901-0.273818
411246.0-0.263364-0.432753-0.322891-0.322206-1.722935-1.120051-0.570489-0.976796...-0.8740500.696974-0.986625-0.589142-0.260004-0.547055-0.036596-1.040273-0.111671-0.584362
..................................................................
564499515.0-0.791630-0.158159-0.791224-0.7499590.607733-0.366730-0.574758-0.592724...-0.585313-0.375303-0.680696-0.4872740.512028-0.506814-0.361535-0.1177860.459820-0.975096
565504357.0-0.073075-0.716655-0.142066-0.174028-0.635527-0.936601-0.926297-0.723241...-0.5445290.265160-0.558919-0.431169-0.467679-0.980254-0.883294-0.933377-0.617779-0.646017
566509377.0-0.1895202.075826-0.250397-0.263902-1.508016-1.081769-0.955299-0.973701...-0.852755-0.121295-0.725744-0.559439-0.699688-0.751610-0.808558-1.073364-0.741279-0.798452
567510234.0-1.295188-0.786467-1.308161-1.067361-0.834079-1.202869-0.907465-0.831834...-0.685649-0.701704-0.817325-0.6093831.533069-0.842710-0.664258-0.3525040.398070-0.096418
568511341.0-1.284111-0.570050-1.249259-1.064801-0.821981-0.228573-0.057493-0.670622...-0.796813-0.497046-0.711388-0.621924-0.3623410.5159640.609634-0.533043-0.3683570.089304
\n", + "

569 rows × 22 columns

\n", + "
" + ], + "text/plain": [ + " sample_id id x0 x1 x2 x3 x4 \\\n", + "0 0 133.0 0.449512 -1.247226 0.413178 0.303781 -0.123848 \n", + "1 5 274.0 1.080023 1.207830 0.956888 0.978402 -0.555822 \n", + "2 7 76.0 -0.169639 -1.943019 -0.167192 -0.272150 2.329937 \n", + "3 9 399.0 -0.660984 -0.472313 -0.688248 -0.634204 -0.390718 \n", + "4 11 246.0 -0.263364 -0.432753 -0.322891 -0.322206 -1.722935 \n", + ".. ... ... ... ... ... ... ... \n", + "564 499 515.0 -0.791630 -0.158159 -0.791224 -0.749959 0.607733 \n", + "565 504 357.0 -0.073075 -0.716655 -0.142066 -0.174028 -0.635527 \n", + "566 509 377.0 -0.189520 2.075826 -0.250397 -0.263902 -1.508016 \n", + "567 510 234.0 -1.295188 -0.786467 -1.308161 -1.067361 -0.834079 \n", + "568 511 341.0 -1.284111 -0.570050 -1.249259 -1.064801 -0.821981 \n", + "\n", + " x5 x6 x7 ... x10 x11 x12 \\\n", + "0 -0.184227 -0.219076 0.268537 ... -0.337360 -0.728193 -0.442587 \n", + "1 -0.645696 -0.399365 -0.038153 ... 0.057848 0.392164 -0.050027 \n", + "2 0.006804 -0.251467 0.429234 ... 0.017786 -0.368046 -0.105966 \n", + "3 -0.796360 -0.756680 -0.839314 ... -0.221505 -0.139439 -0.317344 \n", + "4 -1.120051 -0.570489 -0.976796 ... -0.874050 0.696974 -0.986625 \n", + ".. ... ... ... ... ... ... ... \n", + "564 -0.366730 -0.574758 -0.592724 ... -0.585313 -0.375303 -0.680696 \n", + "565 -0.936601 -0.926297 -0.723241 ... -0.544529 0.265160 -0.558919 \n", + "566 -1.081769 -0.955299 -0.973701 ... -0.852755 -0.121295 -0.725744 \n", + "567 -1.202869 -0.907465 -0.831834 ... -0.685649 -0.701704 -0.817325 \n", + "568 -0.228573 -0.057493 -0.670622 ... -0.796813 -0.497046 -0.711388 \n", + "\n", + " x13 x14 x15 x16 x17 x18 x19 \n", + "0 -0.272757 -0.608018 -0.577235 -0.501126 0.143371 -0.466431 -0.554102 \n", + "1 0.120414 -0.532348 -0.770613 -0.519694 -0.531097 -0.769127 -0.394858 \n", + "2 -0.169129 2.119760 0.162743 -0.672216 -0.577002 0.626908 0.896114 \n", + "3 -0.336122 -0.526014 -0.326291 -0.368166 -1.037840 -0.698901 -0.273818 \n", + "4 -0.589142 -0.260004 -0.547055 -0.036596 -1.040273 -0.111671 -0.584362 \n", + ".. ... ... ... ... ... ... ... \n", + "564 -0.487274 0.512028 -0.506814 -0.361535 -0.117786 0.459820 -0.975096 \n", + "565 -0.431169 -0.467679 -0.980254 -0.883294 -0.933377 -0.617779 -0.646017 \n", + "566 -0.559439 -0.699688 -0.751610 -0.808558 -1.073364 -0.741279 -0.798452 \n", + "567 -0.609383 1.533069 -0.842710 -0.664258 -0.352504 0.398070 -0.096418 \n", + "568 -0.621924 -0.362341 0.515964 0.609634 -0.533043 -0.368357 0.089304 \n", + "\n", + "[569 rows x 22 columns]" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data_host.as_pd_df()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run the Hetero-Secureboost Script\n", + "\n", + "Once contexts are prepared and data are loaded, we can initialize Secureboost instances and fit models.Here we show you the complete python script for running a Hetero-Secureboost task. In this example, we will not use PSI (Private Set Intersection) for data intersection; instead, we will train the tree model directly with aligned data." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "from fate.arch.dataframe import PandasReader\n", + "import sys\n", + "from fate.ml.ensemble.algo.secureboost.hetero.guest import HeteroSecureBoostGuest\n", + "from fate.ml.ensemble.algo.secureboost.hetero.host import HeteroSecureBoostHost\n", + "from datetime import datetime\n", + "\n", + "guest = (\"guest\", \"10000\")\n", + "host = (\"host\", \"9999\")\n", + "\n", + "def get_current_datetime_str():\n", + " return datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n", + "\n", + "\n", + "def create_ctx(local, context_name):\n", + " from fate.arch import Context\n", + " from fate.arch.computing.standalone import CSession\n", + " from fate.arch.federation.standalone import StandaloneFederation\n", + " import logging\n", + "\n", + " # prepare log\n", + " logger = logging.getLogger()\n", + " logger.setLevel(logging.INFO)\n", + " console_handler = logging.StreamHandler()\n", + " console_handler.setLevel(logging.INFO)\n", + " formatter = logging.Formatter(\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\")\n", + " console_handler.setFormatter(formatter)\n", + " logger.addHandler(console_handler)\n", + " # init fate context\n", + " computing = CSession()\n", + " return Context(\n", + " computing=computing, federation=StandaloneFederation(computing, context_name, local, [guest, host])\n", + " )\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + "\n", + " party = sys.argv[1]\n", + " max_depth = 2\n", + " num_tree = 2\n", + "\n", + " if party == \"guest\":\n", + " ctx = create_ctx(guest, get_current_datetime_str())\n", + " df = pd.read_csv(\"./../../../../examples/data/breast_hetero_guest.csv\")\n", + " df[\"sample_id\"] = [i for i in range(len(df))]\n", + "\n", + " reader = PandasReader(sample_id_name=\"sample_id\", match_id_name=\"id\", label_name=\"y\", dtype=\"float32\")\n", + " data_guest = reader.to_frame(ctx, df)\n", + "\n", + " trees = HeteroSecureBoostGuest(num_tree, max_depth=max_depth)\n", + " trees.fit(ctx, data_guest)\n", + " pred = trees.get_train_predict().as_pd_df()\n", + "\n", + " print(pred)\n", + " from sklearn.metrics import roc_auc_score\n", + " print('auc is {}'.format(roc_auc_score(pred['label'], pred['predict_score'])))\n", + "\n", + " # save model\n", + " import json\n", + " with open('./guest_tree.json', 'w') as f:\n", + " f.write(json.dumps(trees.get_model(), indent=4))\n", + "\n", + " # load model\n", + " model_dict = json.load(open('./guest_tree.json'))\n", + "\n", + " elif party == \"host\":\n", + " ctx = create_ctx(host, get_current_datetime_str())\n", + " df_host = pd.read_csv(\"./../../../../examples/data/breast_hetero_host.csv\")\n", + " df_host[\"sample_id\"] = [i for i in range(len(df_host))]\n", + "\n", + " reader_host = PandasReader(sample_id_name=\"sample_id\", match_id_name=\"id\", dtype=\"float32\")\n", + " data_host = reader_host.to_frame(ctx, df_host)\n", + "\n", + " trees = HeteroSecureBoostHost(num_tree, max_depth=max_depth)\n", + " trees.fit(ctx, data_host)\n", + "\n", + " # save model\n", + " import json\n", + " with open('./host_tree.json', 'w') as f:\n", + " f.write(json.dumps(trees.get_model(), indent=4))\n", + "\n", + " # load model\n", + " model_dict = json.load(open('./host_tree.json'))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We save this script to a file named 'run_hetero_sbt.py' and execute it simultaneously in two terminals. The guest party terminal command is:\n", + "\n", + "```\n", + "python -i ./run_hetero_sbt.py guest\n", + "```\n", + "\n", + "The host party terminal command is:\n", + "\n", + "```\n", + "python -i ./run_hetero_sbt.py host\n", + "```\n", + "\n", + "We add -i option so that you can check the result of the script in the terminal." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Geust Terminal Outputs:\n", + "```\n", + "2023-09-13 16:42:47,053 - fate.ml.ensemble.algo.secureboost.hetero.guest - INFO - start to fit a guest tree\n", + "2023-09-13 16:42:47,583 - fate.ml.ensemble.learner.decision_tree.hetero.guest - INFO - encrypt kit setup through setter\n", + "2023-09-13 16:42:48,214 - fate.ml.ensemble.learner.decision_tree.hetero.guest - INFO - gh are packed\n", + "2023-09-13 16:42:50,978 - fate.ml.ensemble.learner.decision_tree.tree_core.decision_tree - INFO - drop leaf samples, new sample count is 569, 0 samples dropped\n", + "2023-09-13 16:42:51,067 - fate.ml.ensemble.learner.decision_tree.hetero.guest - INFO - layer 0 done: next layer will split 2 nodes, active samples num 569\n", + "2023-09-13 16:42:53,802 - fate.ml.ensemble.learner.decision_tree.tree_core.decision_tree - INFO - drop leaf samples, new sample count is 569, 0 samples dropped\n", + "2023-09-13 16:42:53,979 - fate.ml.ensemble.learner.decision_tree.hetero.guest - INFO - layer 1 done: next layer will split 4 nodes, active samples num 569\n", + "2023-09-13 16:42:54,769 - fate.ml.ensemble.algo.secureboost.hetero.guest - INFO - fitting guest decision tree 0 done\n", + "2023-09-13 16:42:54,769 - fate.ml.ensemble.algo.secureboost.hetero.guest - INFO - start to fit a guest tree\n", + "2023-09-13 16:42:55,419 - fate.ml.ensemble.learner.decision_tree.hetero.guest - INFO - encrypt kit setup through setter\n", + "2023-09-13 16:42:56,075 - fate.ml.ensemble.learner.decision_tree.hetero.guest - INFO - gh are packed\n", + "2023-09-13 16:42:58,780 - fate.ml.ensemble.learner.decision_tree.tree_core.decision_tree - INFO - drop leaf samples, new sample count is 569, 0 samples dropped\n", + "2023-09-13 16:42:58,875 - fate.ml.ensemble.learner.decision_tree.hetero.guest - INFO - layer 0 done: next layer will split 2 nodes, active samples num 569\n", + "2023-09-13 16:43:01,620 - fate.ml.ensemble.learner.decision_tree.tree_core.decision_tree - INFO - drop leaf samples, new sample count is 569, 0 samples dropped\n", + "2023-09-13 16:43:01,779 - fate.ml.ensemble.learner.decision_tree.hetero.guest - INFO - layer 1 done: next layer will split 4 nodes, active samples num 569\n", + "2023-09-13 16:43:02,564 - fate.ml.ensemble.algo.secureboost.hetero.guest - INFO - fitting guest decision tree 1 done\n", + " sample_id id label predict_score predict_result predict_detail\n", + "0 0 133.0 1 0.620374 1 \"{'0': 0.3796257972717285, '1': 0.620374202728...\n", + "1 5 274.0 0 0.288331 0 \"{'0': 0.7116693258285522, '1': 0.288330674171...\n", + "2 7 76.0 1 0.730982 1 \"{'0': 0.26901811361312866, '1': 0.73098188638...\n", + "3 9 399.0 1 0.730982 1 \"{'0': 0.26901811361312866, '1': 0.73098188638...\n", + "4 11 246.0 1 0.730982 1 \"{'0': 0.26901811361312866, '1': 0.73098188638...\n", + ".. ... ... ... ... ... ...\n", + "564 499 515.0 1 0.730982 1 \"{'0': 0.26901811361312866, '1': 0.73098188638...\n", + "565 504 357.0 1 0.730982 1 \"{'0': 0.26901811361312866, '1': 0.73098188638...\n", + "566 509 377.0 1 0.730982 1 \"{'0': 0.26901811361312866, '1': 0.73098188638...\n", + "567 510 234.0 1 0.730982 1 \"{'0': 0.26901811361312866, '1': 0.73098188638...\n", + "568 511 341.0 1 0.730982 1 \"{'0': 0.26901811361312866, '1': 0.73098188638...\n", + "\n", + "[569 rows x 6 columns]\n", + "auc is 0.9778024417314095\n", + "```\n", + "\n", + "Host Terminal Outputs:\n", + "\n", + "```\n", + "2023-09-13 16:42:45,399 - fate.ml.ensemble.algo.secureboost.hetero.host - INFO - data binning done\n", + "2023-09-13 16:42:45,399 - fate.ml.ensemble.algo.secureboost.hetero.host - INFO - start to fit a host tree\n", + "2023-09-13 16:42:49,417 - fate.ml.ensemble.learner.decision_tree.hetero.host - INFO - cur layer node num: 1, next layer node num: 2\n", + "2023-09-13 16:42:50,792 - fate.ml.ensemble.learner.decision_tree.tree_core.decision_tree - INFO - drop leaf samples, new sample count is 569, 0 samples dropped\n", + "2023-09-13 16:42:50,933 - fate.ml.ensemble.learner.decision_tree.hetero.host - INFO - layer 0 done: next layer will split 2 nodes, active samples num 569\n", + "2023-09-13 16:42:52,058 - fate.ml.ensemble.learner.decision_tree.hetero.host - INFO - cur layer node num: 2, next layer node num: 4\n", + "2023-09-13 16:42:53,503 - fate.ml.ensemble.learner.decision_tree.tree_core.decision_tree - INFO - drop leaf samples, new sample count is 569, 0 samples dropped\n", + "2023-09-13 16:42:53,728 - fate.ml.ensemble.learner.decision_tree.hetero.host - INFO - layer 1 done: next layer will split 4 nodes, active samples num 569\n", + "2023-09-13 16:42:53,967 - fate.ml.ensemble.algo.secureboost.hetero.host - INFO - fitting host decision tree 0 done\n", + "2023-09-13 16:42:53,967 - fate.ml.ensemble.algo.secureboost.hetero.host - INFO - start to fit a host tree\n", + "2023-09-13 16:42:57,265 - fate.ml.ensemble.learner.decision_tree.hetero.host - INFO - cur layer node num: 1, next layer node num: 2\n", + "2023-09-13 16:42:58,649 - fate.ml.ensemble.learner.decision_tree.tree_core.decision_tree - INFO - drop leaf samples, new sample count is 569, 0 samples dropped\n", + "2023-09-13 16:42:58,787 - fate.ml.ensemble.learner.decision_tree.hetero.host - INFO - layer 0 done: next layer will split 2 nodes, active samples num 569\n", + "2023-09-13 16:43:00,027 - fate.ml.ensemble.learner.decision_tree.hetero.host - INFO - cur layer node num: 2, next layer node num: 4\n", + "2023-09-13 16:43:01,377 - fate.ml.ensemble.learner.decision_tree.tree_core.decision_tree - INFO - drop leaf samples, new sample count is 569, 0 samples dropped\n", + "2023-09-13 16:43:01,598 - fate.ml.ensemble.learner.decision_tree.hetero.host - INFO - layer 1 done: next layer will split 4 nodes, active samples num 569\n", + "2023-09-13 16:43:01,839 - fate.ml.ensemble.algo.secureboost.hetero.host - INFO - fitting host decision tree 1 done\n", + "```" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "fate-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/doc/2.0/fate/ml/homo_nn_tutorial.ipynb b/doc/2.0/fate/ml/homo_nn_tutorial.ipynb new file mode 100644 index 0000000000..946ac272ff --- /dev/null +++ b/doc/2.0/fate/ml/homo_nn_tutorial.ipynb @@ -0,0 +1,725 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Homo-NN Tutorial\n", + "\n", + "The Homo(horizontal) federated learning allows parties to collaboratively train a neural network model without sharing their actual data. In a horizontally partitioned data setting, multiple parties have the same feature set but different user samples. In this scenario, each party trains the model locally on its own subset of data and only shares the model updates.\n", + "\n", + "The Homo-NN algorithm is designed for horizontal federated neural network training. In this tutorial, we demonstrate how to run a Homo-NN task under FATE-2.0 without using a Pipeline. This is particularly useful for local experimentation, model/training setting modifications and testing. However, it is not recommended for direct usage in a production environment. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup Homo-NN Setup by Step\n", + "\n", + "To run a Hetero-Secureboost task, several steps are needed:\n", + "1. Import required libraries and create fate context\n", + "2. Prepare datasets, models, loss and optimizers\n", + "3. Configure clients(guest&hosts) parameters, trainers; configure arbiter parameters and trainer\n", + "4. Run the training script\n", + "\n", + "### Import Libs and Create Context\n", + "We import these lib from later use. The FedAVGCLient is the trainer on the client sides and the FedAVGServer is the trainer on the server side. In TrainingArgument we can set training parameters(the same as the transformer trainer) while in the FeAVGArguments we can adjust options for federation.\n", + "\n", + "If fate is not installed, run:\n", + "```\n", + "pip install fate==2.0.0b0\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from fate.ml.nn.algo.homo.fedavg import FedAVGCLient, FedAVGServer, TrainingArguments, FedAVGArguments\n", + "import torch as t\n", + "import pandas as pd\n", + "from fate.ml.nn.dataset.table import TableDataset\n", + "import sys\n", + "from datetime import datetime" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will use 'create_ctx' to create fate context. When creating fate context, please make sure that every party's context is initialized with the same unique context_name. Here, we provide 'get_current_datetime_str' in order to get a unique context name\n", + "according to the current date and time. We are running a three party homo-federation task, so parties are:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "arbiter = (\"arbiter\", 10000)\n", + "guest = (\"guest\", 10000)\n", + "host = (\"host\", 9999)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def get_current_datetime_str():\n", + " return datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n", + "\n", + "\n", + "def create_ctx(local, context_name):\n", + " from fate.arch import Context\n", + " from fate.arch.computing.standalone import CSession\n", + " from fate.arch.federation.standalone import StandaloneFederation\n", + "\n", + " # init fate context\n", + " computing = CSession()\n", + " return Context(\n", + " computing=computing, federation=StandaloneFederation(computing, context_name, local, [guest, host])\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prepare\n", + "\n", + "Using FATE's Homo-NN framework is very similar to directly using PyTorch's nn.Module, Dataset, and Optimizer, as well as Transformers. Developed based on HuggingFace's Transformers Trainer, FATE-2.0 NN Trainer enables seamless integration for a variety of neural network tasks. Whether you're dealing with tabular data, natural language processing (NLP), or computer vision (CV), our Homo-NN framework easily adapts to PyTorch's components. Therefore, you can prepare your project just as you would when using PyTorch's Module, Dataset, and Optimizer. Here, we adopt a binary classification task as the example.\n", + "\n", + "### Read Csv and Prepare Dataset\n", + "\n", + "We read the csv data of guest side, and put it into the TensorDataset." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "guest_data_path = './../FATE/examples/data/breast_homo_guest.csv'" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch.utils.data import TensorDataset\n", + "\n", + "idx = 'id'\n", + "label_name = 'y'\n", + "guest_data = pd.read_csv(guest_data_path, index_col=idx)\n", + "\n", + "X = guest_data.drop(columns=[label_name])\n", + "y = guest_data.pop(label_name).values\n", + "\n", + "X = torch.tensor(X.values, dtype=torch.float32)\n", + "y = torch.tensor(y, dtype=torch.float32).view(-1, 1)\n", + "dataset = TensorDataset(X, y)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Or you can use our built-in TableDataset" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from fate.ml.nn.dataset.table import TableDataset\n", + "\n", + "dataset_ = TableDataset(to_tensor=True)\n", + "dataset_.load(guest_data_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(tensor([ 0.2549, -1.0466, 0.2097, 0.0742, -0.4414, -0.3776, -0.4859, 0.3471,\n", + " -0.2876, -0.7335, 0.4495, -1.2472, 0.4132, 0.3038, -0.1238, -0.1842,\n", + " -0.2191, 0.2685, 0.0160, -0.7893, -0.3374, -0.7282, -0.4426, -0.2728,\n", + " -0.6080, -0.5772, -0.5011, 0.1434, -0.4664, -0.5541]),\n", + " tensor([1.]))" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([30])" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset[0][0].shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Prepare Model, Loss and Optimizer\n", + "\n", + "We initialize the pytorch model, loss and optimizer." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "model = torch.nn.Sequential(\n", + " torch.nn.Linear(30, 16),\n", + " torch.nn.ReLU(),\n", + " torch.nn.Linear(16, 1),\n", + " torch.nn.Sigmoid()\n", + ")\n", + "\n", + "loss_fn = torch.nn.BCELoss()\n", + "opt = torch.optim.Adam(model.parameters(), lr=0.01)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run Trainer without Federation\n", + "\n", + "Here, we first perform a non-federated local training using FedAVGClient to verify whether our model, dataset, optimizer, and loss function can work properly. We can enable FedAVG to operate in local mode by calling 'set_local_mode()'.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "ctx = create_ctx(guest, get_current_datetime_str()) # create a guest context for local testing\n", + "train_arg = TrainingArguments(num_train_epochs=4, per_device_train_batch_size=128, disable_tqdm=False)\n", + "trainer = FedAVGCLient(ctx, model, train_arg, FedAVGArguments(), train_set=dataset, loss_fn=loss_fn, optimizer=opt)\n", + "trainer.set_local_mode()" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [8/8 00:00, Epoch 4/4]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
20.067400
40.060600
60.053600
80.049600

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "TrainOutput(global_step=8, training_loss=0.057804910466074944, metrics={'train_runtime': 0.0137, 'train_samples_per_second': 66308.488, 'train_steps_per_second': 584.216, 'train_loss': 0.057804910466074944, 'epoch': 4.0})" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Excellent! Everything is functioning as expected. Now, we can proceed to develop a script that contains the code for both the clients (guest & host) and the server (arbiter).\n", + "\n", + " Once the script is ready, we can execute our task using three separate terminals." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run the Homo-NN Script\n", + "\n", + "Once model, dataset, optimizer and loss are prepared, we can initialize FedAVGClient instances and fit models. The server side is responsible for model aggreation, so we only need to intialize a FedAVGServer on arbiter side, and call its train function.\n", + "\n", + "**Please note** that when configuring `FedAVGArguments`, **make sure that the number of aggregation rounds is consistent across all parties**. We will **check the aggregation round count**, and **if there is any inconsistency, the task will fail**.\n", + "\n", + "Here is the complete python script:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from fate.ml.nn.algo.homo.fedavg import FedAVGCLient, FedArguments, TrainingArguments, FedAVGServer\n", + "import torch as t\n", + "import pandas as pd\n", + "from fate.ml.nn.dataset.table import TableDataset\n", + "import sys\n", + "from datetime import datetime\n", + "\n", + "\n", + "arbiter = (\"arbiter\", 10000)\n", + "guest = (\"guest\", 10000)\n", + "host = (\"host\", 9999)\n", + "\n", + "def get_current_datetime_str():\n", + " return datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n", + "\n", + "\n", + "def create_ctx(local):\n", + " from fate.arch import Context\n", + " from fate.arch.computing.standalone import CSession\n", + " from fate.arch.federation.standalone import StandaloneFederation\n", + " import logging\n", + "\n", + " logger = logging.getLogger()\n", + " logger.setLevel(logging.DEBUG)\n", + "\n", + " console_handler = logging.StreamHandler()\n", + " console_handler.setLevel(logging.DEBUG)\n", + "\n", + " formatter = logging.Formatter(\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\")\n", + " console_handler.setFormatter(formatter)\n", + "\n", + " logger.addHandler(console_handler)\n", + " computing = CSession()\n", + " return Context(\n", + " computing=computing, federation=StandaloneFederation(computing, get_current_datetime_str(), local, [guest, host, arbiter])\n", + " )\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + "\n", + " model = t.nn.Sequential(t.nn.Linear(30, 1), t.nn.Sigmoid())\n", + "\n", + " if sys.argv[1] == \"guest\":\n", + "\n", + " ds = TableDataset(return_dict=False, to_tensor=True)\n", + " ds.load(\"../../../../../../examples/data/breast_homo_guest.csv\")\n", + " ctx = create_ctx(guest)\n", + " fed_args = FedArguments(aggregate_strategy=\"epochs\", aggregate_freq=1, aggregator=\"secure_aggregate\")\n", + " args = TrainingArguments(\n", + " num_train_epochs=5, per_device_train_batch_size=16, logging_strategy=\"steps\", logging_steps=5\n", + " )\n", + " trainer = FedAVGCLient(\n", + " ctx=ctx,\n", + " model=model,\n", + " fed_args=fed_args,\n", + " training_args=args,\n", + " loss_fn=t.nn.BCELoss(),\n", + " optimizer=t.optim.SGD(model.parameters(), lr=0.01),\n", + " train_set=ds,\n", + " )\n", + " trainer.train()\n", + "\n", + " elif sys.argv[1] == \"host\":\n", + "\n", + " ds = TableDataset(return_dict=False, to_tensor=True)\n", + " ds.load(\"../../../../../../examples/data/breast_homo_host.csv\")\n", + " ctx = create_ctx(host)\n", + " fed_args = FedArguments(aggregate_strategy=\"epochs\", aggregate_freq=1, aggregator=\"secure_aggregate\")\n", + " args = TrainingArguments(num_train_epochs=5, per_device_train_batch_size=16)\n", + " trainer = FedAVGCLient(\n", + " ctx=ctx,\n", + " model=model,\n", + " fed_args=fed_args,\n", + " training_args=args,\n", + " loss_fn=t.nn.BCELoss(),\n", + " optimizer=t.optim.SGD(model.parameters(), lr=0.01),\n", + " train_set=ds,\n", + " )\n", + " trainer.train()\n", + "\n", + " else:\n", + " ctx = create_ctx(arbiter)\n", + " trainer = FedAVGServer(ctx)\n", + " trainer.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We save this script to a file named 'run_homo_nn.py' and execute it simultaneously in three terminals. The guest party & host party terminals' commands is:\n", + "\n", + "```\n", + "python -i ./run_homo_nn.py guest\n", + "```\n", + "\n", + "```\n", + "python -i ./run_homo_nn.py host\n", + "```\n", + "\n", + "\n", + "The host party terminal command is:\n", + "\n", + "```\n", + "python -i ./run_homo_nn.py arbiter\n", + "```\n", + "\n", + "We add -i option so that you can check the result of the script in the terminal." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Guest Terminal Outputs:\n", + "\n", + "```\n", + "2023-09-13 20:36:13,319 - fate.arch.federation.standalone._federation - DEBUG - [federation.standalone]init federation: standalone_session=, federation_session_id=2023-09-13-20-36, party=('guest', 10000)\n", + "2023-09-13 20:36:13,319 - fate.arch.federation.standalone._federation - DEBUG - [federation.standalone]init federation context done\n", + "2023-09-13 20:36:13,400 - fate.ml.nn.algo.homo.fedavg - INFO - Using secure_aggregate aggregator\n", + "2023-09-13 20:36:13,400 - fate.arch._standalone - DEBUG - [federation.standalone.remote.agg_type.default]remote data, type=\n", + "2023-09-13 20:36:13,400 - fate.arch._standalone - DEBUG - [federation.standalone.remote.agg_type.default]remote object with type: \n", + "2023-09-13 20:36:13,413 - fate.ml.aggregator.base - INFO - computing weights\n", + "2023-09-13 20:36:13,413 - fate.arch._standalone - DEBUG - [federation.standalone.remote.local_weight_fedavg_0.default]remote data, type=\n", + "2023-09-13 20:36:13,413 - fate.arch._standalone - DEBUG - [federation.standalone.remote.local_weight_fedavg_0.default]remote object with type: \n", + "2023-09-13 20:36:13,428 - fate.arch._standalone - DEBUG - [federation.standalone.get.agg_weight_fedavg_0.default]\n", + "2023-09-13 20:36:13,529 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-agg_weight_fedavg_0-default-arbiter-10000-guest-10000 type Object\n", + "2023-09-13 20:36:13,536 - fate.arch._standalone - DEBUG - [federation.standalone.get.agg_weight_fedavg_0.default] got object with type: \n", + "2023-09-13 20:36:13,545 - fate.ml.aggregator.base - INFO - aggregate weight is 0.4989010989010989\n", + "2023-09-13 20:36:13,545 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_dh_pubkey.default]remote data, type=\n", + "2023-09-13 20:36:13,545 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_dh_pubkey.default]remote object with type: \n", + "2023-09-13 20:36:13,556 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_dh_pubkey.default]\n", + "2023-09-13 20:36:13,556 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-fedavg_model_dh_pubkey-default-host-9999-guest-10000 type Object\n", + "2023-09-13 20:36:13,561 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_dh_pubkey.default] got object with type: \n", + "2023-09-13 20:36:13,566 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_loss_dh_pubkey.default]remote data, type=\n", + "2023-09-13 20:36:13,566 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_loss_dh_pubkey.default]remote object with type: \n", + "2023-09-13 20:36:13,576 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_loss_dh_pubkey.default]\n", + "2023-09-13 20:36:13,677 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-fedavg_loss_dh_pubkey-default-host-9999-guest-10000 type Object\n", + "2023-09-13 20:36:13,683 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_loss_dh_pubkey.default] got object with type: \n", + "2023-09-13 20:36:13,690 - fate.ml.nn.trainer.trainer_base - INFO - computed max_aggregation is 5\n", + "2023-09-13 20:36:13,690 - fate.ml.nn.trainer.trainer_base - INFO - parameters is {'num_train_epochs': 5, 'max_steps': 75, 'num_update_steps_per_epoch': 15, 'epochs_trained': 0, 'steps_trained_in_current_epoch': 0, 'max_aggregation': 5, 'aggregate_freq': 1, 'aggregation_strategy': 'epochs'}\n", + "2023-09-13 20:36:13,690 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fed_para_0.default]remote data, type=\n", + "2023-09-13 20:36:13,690 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fed_para_0.default]remote object with type: \n", + "2023-09-13 20:36:13,722 - fate.ml.nn.trainer.trainer_base - INFO - {'loss': 0.7378, 'learning_rate': 0.01, 'epoch': 0.33}\n", + "2023-09-13 20:36:13,729 - fate.ml.nn.trainer.trainer_base - INFO - {'loss': 0.741, 'learning_rate': 0.01, 'epoch': 0.67}\n", + "2023-09-13 20:36:13,736 - fate.ml.nn.trainer.trainer_base - INFO - {'loss': 0.6012, 'learning_rate': 0.01, 'epoch': 1.0}\n", + "2023-09-13 20:36:13,736 - fate.ml.nn.trainer.trainer_base - INFO - aggregation on epoch end\n", + "2023-09-13 20:36:13,737 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_mixed_client_values.default.aggregation-0]remote data, type=\n", + "2023-09-13 20:36:13,737 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_mixed_client_values.default.aggregation-0]remote object with type: \n", + "2023-09-13 20:36:13,746 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_aggregated_values.default.aggregation-0]\n", + "2023-09-13 20:36:13,846 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-fedavg_model_aggregated_values-default.aggregation-0-arbiter-10000-guest-10000 type Object\n", + "2023-09-13 20:36:13,851 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_aggregated_values.default.aggregation-0] got object with type: \n", + "2023-09-13 20:36:13,856 - fate.ml.nn.trainer.trainer_base - INFO - Aggregation count: 1 / 5\n", + "2023-09-13 20:36:13,863 - fate.ml.nn.trainer.trainer_base - INFO - {'loss': 0.5073, 'learning_rate': 0.01, 'epoch': 1.33}\n", + "2023-09-13 20:36:13,871 - fate.ml.nn.trainer.trainer_base - INFO - {'loss': 0.4734, 'learning_rate': 0.01, 'epoch': 1.67}\n", + "2023-09-13 20:36:13,877 - fate.ml.nn.trainer.trainer_base - INFO - {'loss': 0.4241, 'learning_rate': 0.01, 'epoch': 2.0}\n", + "2023-09-13 20:36:13,878 - fate.ml.nn.trainer.trainer_base - INFO - aggregation on epoch end\n", + "2023-09-13 20:36:13,878 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_mixed_client_values.default.aggregation-1]remote data, type=\n", + "2023-09-13 20:36:13,878 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_mixed_client_values.default.aggregation-1]remote object with type: \n", + "2023-09-13 20:36:13,888 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_aggregated_values.default.aggregation-1]\n", + "2023-09-13 20:36:13,989 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-fedavg_model_aggregated_values-default.aggregation-1-arbiter-10000-guest-10000 type Object\n", + "2023-09-13 20:36:13,993 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_aggregated_values.default.aggregation-1] got object with type: \n", + "2023-09-13 20:36:13,998 - fate.ml.nn.trainer.trainer_base - INFO - Aggregation count: 2 / 5\n", + "2023-09-13 20:36:14,006 - fate.ml.nn.trainer.trainer_base - INFO - {'loss': 0.3753, 'learning_rate': 0.01, 'epoch': 2.33}\n", + "2023-09-13 20:36:14,014 - fate.ml.nn.trainer.trainer_base - INFO - {'loss': 0.4512, 'learning_rate': 0.01, 'epoch': 2.67}\n", + "2023-09-13 20:36:14,021 - fate.ml.nn.trainer.trainer_base - INFO - {'loss': 0.4011, 'learning_rate': 0.01, 'epoch': 3.0}\n", + "2023-09-13 20:36:14,021 - fate.ml.nn.trainer.trainer_base - INFO - aggregation on epoch end\n", + "2023-09-13 20:36:14,022 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_mixed_client_values.default.aggregation-2]remote data, type=\n", + "2023-09-13 20:36:14,022 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_mixed_client_values.default.aggregation-2]remote object with type: \n", + "2023-09-13 20:36:14,031 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_aggregated_values.default.aggregation-2]\n", + "2023-09-13 20:36:14,131 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-fedavg_model_aggregated_values-default.aggregation-2-arbiter-10000-guest-10000 type Object\n", + "2023-09-13 20:36:14,136 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_aggregated_values.default.aggregation-2] got object with type: \n", + "2023-09-13 20:36:14,140 - fate.ml.nn.trainer.trainer_base - INFO - Aggregation count: 3 / 5\n", + "2023-09-13 20:36:14,148 - fate.ml.nn.trainer.trainer_base - INFO - {'loss': 0.3772, 'learning_rate': 0.01, 'epoch': 3.33}\n", + "2023-09-13 20:36:14,155 - fate.ml.nn.trainer.trainer_base - INFO - {'loss': 0.3531, 'learning_rate': 0.01, 'epoch': 3.67}\n", + "2023-09-13 20:36:14,162 - fate.ml.nn.trainer.trainer_base - INFO - {'loss': 0.338, 'learning_rate': 0.01, 'epoch': 4.0}\n", + "2023-09-13 20:36:14,162 - fate.ml.nn.trainer.trainer_base - INFO - aggregation on epoch end\n", + "2023-09-13 20:36:14,162 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_mixed_client_values.default.aggregation-3]remote data, type=\n", + "2023-09-13 20:36:14,162 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_mixed_client_values.default.aggregation-3]remote object with type: \n", + "2023-09-13 20:36:14,171 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_aggregated_values.default.aggregation-3]\n", + "2023-09-13 20:36:14,272 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-fedavg_model_aggregated_values-default.aggregation-3-arbiter-10000-guest-10000 type Object\n", + "2023-09-13 20:36:14,277 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_aggregated_values.default.aggregation-3] got object with type: \n", + "2023-09-13 20:36:14,282 - fate.ml.nn.trainer.trainer_base - INFO - Aggregation count: 4 / 5\n", + "2023-09-13 20:36:14,290 - fate.ml.nn.trainer.trainer_base - INFO - {'loss': 0.32, 'learning_rate': 0.01, 'epoch': 4.33}\n", + "2023-09-13 20:36:14,297 - fate.ml.nn.trainer.trainer_base - INFO - {'loss': 0.3118, 'learning_rate': 0.01, 'epoch': 4.67}\n", + "2023-09-13 20:36:14,303 - fate.ml.nn.trainer.trainer_base - INFO - {'loss': 0.3319, 'learning_rate': 0.01, 'epoch': 5.0}\n", + "2023-09-13 20:36:14,304 - fate.ml.nn.trainer.trainer_base - INFO - aggregation on epoch end\n", + "2023-09-13 20:36:14,304 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_mixed_client_values.default.aggregation-4]remote data, type=\n", + "2023-09-13 20:36:14,304 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_mixed_client_values.default.aggregation-4]remote object with type: \n", + "2023-09-13 20:36:14,314 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_aggregated_values.default.aggregation-4]\n", + "2023-09-13 20:36:14,415 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-fedavg_model_aggregated_values-default.aggregation-4-arbiter-10000-guest-10000 type Object\n", + "2023-09-13 20:36:14,419 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_aggregated_values.default.aggregation-4] got object with type: \n", + "2023-09-13 20:36:14,424 - fate.ml.nn.trainer.trainer_base - INFO - Aggregation count: 5 / 5\n", + "2023-09-13 20:36:14,424 - fate.ml.nn.trainer.trainer_base - INFO - {'train_runtime': 1.024, 'train_samples_per_second': 1108.351, 'train_steps_per_second': 73.239, 'train_loss': 0.4496373208363851, 'epoch': 5.0}\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Host Terminal Outputs:\n", + "\n", + "```\n", + "2023-09-13 20:36:12,803 - fate.arch.federation.standalone._federation - DEBUG - [federation.standalone]init federation: standalone_session=, federation_session_id=2023-09-13-20-36, party=('host', 9999)\n", + "2023-09-13 20:36:12,803 - fate.arch.federation.standalone._federation - DEBUG - [federation.standalone]init federation context done\n", + "2023-09-13 20:36:12,888 - fate.ml.nn.algo.homo.fedavg - INFO - Using secure_aggregate aggregator\n", + "2023-09-13 20:36:12,888 - fate.arch._standalone - DEBUG - [federation.standalone.remote.agg_type.default]remote data, type=\n", + "2023-09-13 20:36:12,888 - fate.arch._standalone - DEBUG - [federation.standalone.remote.agg_type.default]remote object with type: \n", + "2023-09-13 20:36:12,899 - fate.ml.aggregator.base - INFO - computing weights\n", + "2023-09-13 20:36:12,899 - fate.arch._standalone - DEBUG - [federation.standalone.remote.local_weight_fedavg_0.default]remote data, type=\n", + "2023-09-13 20:36:12,899 - fate.arch._standalone - DEBUG - [federation.standalone.remote.local_weight_fedavg_0.default]remote object with type: \n", + "2023-09-13 20:36:12,909 - fate.arch._standalone - DEBUG - [federation.standalone.get.agg_weight_fedavg_0.default]\n", + "2023-09-13 20:36:13,512 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-agg_weight_fedavg_0-default-arbiter-10000-host-9999 type Object\n", + "2023-09-13 20:36:13,518 - fate.arch._standalone - DEBUG - [federation.standalone.get.agg_weight_fedavg_0.default] got object with type: \n", + "2023-09-13 20:36:13,523 - fate.ml.aggregator.base - INFO - aggregate weight is 0.5010989010989011\n", + "2023-09-13 20:36:13,524 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_dh_pubkey.default]remote data, type=\n", + "2023-09-13 20:36:13,524 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_dh_pubkey.default]remote object with type: \n", + "2023-09-13 20:36:13,539 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_dh_pubkey.default]\n", + "2023-09-13 20:36:13,640 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-fedavg_model_dh_pubkey-default-guest-10000-host-9999 type Object\n", + "2023-09-13 20:36:13,646 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_dh_pubkey.default] got object with type: \n", + "2023-09-13 20:36:13,652 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_loss_dh_pubkey.default]remote data, type=\n", + "2023-09-13 20:36:13,653 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_loss_dh_pubkey.default]remote object with type: \n", + "2023-09-13 20:36:13,661 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_loss_dh_pubkey.default]\n", + "2023-09-13 20:36:13,661 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-fedavg_loss_dh_pubkey-default-guest-10000-host-9999 type Object\n", + "2023-09-13 20:36:13,666 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_loss_dh_pubkey.default] got object with type: \n", + "2023-09-13 20:36:13,671 - fate.ml.nn.trainer.trainer_base - INFO - computed max_aggregation is 5\n", + "2023-09-13 20:36:13,671 - fate.ml.nn.trainer.trainer_base - INFO - parameters is {'num_train_epochs': 5, 'max_steps': 75, 'num_update_steps_per_epoch': 15, 'epochs_trained': 0, 'steps_trained_in_current_epoch': 0, 'max_aggregation': 5, 'aggregate_freq': 1, 'aggregation_strategy': 'epochs'}\n", + "2023-09-13 20:36:13,671 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fed_para_0.default]remote data, type=\n", + "2023-09-13 20:36:13,671 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fed_para_0.default]remote object with type: \n", + "2023-09-13 20:36:13,688 - fate.ml.nn.trainer.trainer_base - INFO - aggregation on epoch end\n", + "2023-09-13 20:36:13,689 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_mixed_client_values.default.aggregation-0]remote data, type=\n", + "2023-09-13 20:36:13,689 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_mixed_client_values.default.aggregation-0]remote object with type: \n", + "2023-09-13 20:36:13,708 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_aggregated_values.default.aggregation-0]\n", + "2023-09-13 20:36:13,909 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-fedavg_model_aggregated_values-default.aggregation-0-arbiter-10000-host-9999 type Object\n", + "2023-09-13 20:36:13,914 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_aggregated_values.default.aggregation-0] got object with type: \n", + "2023-09-13 20:36:13,919 - fate.ml.nn.trainer.trainer_base - INFO - Aggregation count: 1 / 5\n", + "2023-09-13 20:36:13,919 - fate.ml.nn.trainer.trainer_base - INFO - {'loss': 0.5443, 'learning_rate': 0.01, 'epoch': 1.0}\n", + "2023-09-13 20:36:13,940 - fate.ml.nn.trainer.trainer_base - INFO - aggregation on epoch end\n", + "2023-09-13 20:36:13,940 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_mixed_client_values.default.aggregation-1]remote data, type=\n", + "2023-09-13 20:36:13,941 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_mixed_client_values.default.aggregation-1]remote object with type: \n", + "2023-09-13 20:36:13,949 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_aggregated_values.default.aggregation-1]\n", + "2023-09-13 20:36:14,050 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-fedavg_model_aggregated_values-default.aggregation-1-arbiter-10000-host-9999 type Object\n", + "2023-09-13 20:36:14,055 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_aggregated_values.default.aggregation-1] got object with type: \n", + "2023-09-13 20:36:14,060 - fate.ml.nn.trainer.trainer_base - INFO - Aggregation count: 2 / 5\n", + "2023-09-13 20:36:14,060 - fate.ml.nn.trainer.trainer_base - INFO - {'loss': 0.4891, 'learning_rate': 0.01, 'epoch': 2.0}\n", + "2023-09-13 20:36:14,081 - fate.ml.nn.trainer.trainer_base - INFO - aggregation on epoch end\n", + "2023-09-13 20:36:14,081 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_mixed_client_values.default.aggregation-2]remote data, type=\n", + "2023-09-13 20:36:14,081 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_mixed_client_values.default.aggregation-2]remote object with type: \n", + "2023-09-13 20:36:14,090 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_aggregated_values.default.aggregation-2]\n", + "2023-09-13 20:36:14,191 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-fedavg_model_aggregated_values-default.aggregation-2-arbiter-10000-host-9999 type Object\n", + "2023-09-13 20:36:14,195 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_aggregated_values.default.aggregation-2] got object with type: \n", + "2023-09-13 20:36:14,200 - fate.ml.nn.trainer.trainer_base - INFO - Aggregation count: 3 / 5\n", + "2023-09-13 20:36:14,200 - fate.ml.nn.trainer.trainer_base - INFO - {'loss': 0.4143, 'learning_rate': 0.01, 'epoch': 3.0}\n", + "2023-09-13 20:36:14,221 - fate.ml.nn.trainer.trainer_base - INFO - aggregation on epoch end\n", + "2023-09-13 20:36:14,222 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_mixed_client_values.default.aggregation-3]remote data, type=\n", + "2023-09-13 20:36:14,222 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_mixed_client_values.default.aggregation-3]remote object with type: \n", + "2023-09-13 20:36:14,234 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_aggregated_values.default.aggregation-3]\n", + "2023-09-13 20:36:14,334 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-fedavg_model_aggregated_values-default.aggregation-3-arbiter-10000-host-9999 type Object\n", + "2023-09-13 20:36:14,339 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_aggregated_values.default.aggregation-3] got object with type: \n", + "2023-09-13 20:36:14,344 - fate.ml.nn.trainer.trainer_base - INFO - Aggregation count: 4 / 5\n", + "2023-09-13 20:36:14,344 - fate.ml.nn.trainer.trainer_base - INFO - {'loss': 0.3502, 'learning_rate': 0.01, 'epoch': 4.0}\n", + "2023-09-13 20:36:14,365 - fate.ml.nn.trainer.trainer_base - INFO - aggregation on epoch end\n", + "2023-09-13 20:36:14,365 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_mixed_client_values.default.aggregation-4]remote data, type=\n", + "2023-09-13 20:36:14,365 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_mixed_client_values.default.aggregation-4]remote object with type: \n", + "2023-09-13 20:36:14,375 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_aggregated_values.default.aggregation-4]\n", + "2023-09-13 20:36:14,476 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-fedavg_model_aggregated_values-default.aggregation-4-arbiter-10000-host-9999 type Object\n", + "2023-09-13 20:36:14,480 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_aggregated_values.default.aggregation-4] got object with type: \n", + "2023-09-13 20:36:14,485 - fate.ml.nn.trainer.trainer_base - INFO - Aggregation count: 5 / 5\n", + "2023-09-13 20:36:14,486 - fate.ml.nn.trainer.trainer_base - INFO - {'loss': 0.3157, 'learning_rate': 0.01, 'epoch': 5.0}\n", + "2023-09-13 20:36:14,486 - fate.ml.nn.trainer.trainer_base - INFO - {'train_runtime': 1.5982, 'train_samples_per_second': 713.303, 'train_steps_per_second': 46.928, 'train_loss': 0.4227139218648275, 'epoch': 5.0}\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Arbiter Terminal Outputs:\n", + "\n", + "```\n", + "2023-09-13 20:36:12,315 - fate.arch.federation.standalone._federation - DEBUG - [federation.standalone]init federation: standalone_session=, federation_session_id=2023-09-13-20-36, party=('arbiter', 10000)\n", + "2023-09-13 20:36:12,316 - fate.arch.federation.standalone._federation - DEBUG - [federation.standalone]init federation context done\n", + "2023-09-13 20:36:12,316 - fate.arch._standalone - DEBUG - [federation.standalone.get.agg_type.default]\n", + "2023-09-13 20:36:13,418 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-agg_type-default-guest-10000-arbiter-10000 type Object\n", + "2023-09-13 20:36:13,425 - fate.arch._standalone - DEBUG - [federation.standalone.get.agg_type.default] got object with type: \n", + "2023-09-13 20:36:13,432 - fate.arch._standalone - DEBUG - [federation.standalone.get.agg_type.default]\n", + "2023-09-13 20:36:13,432 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-agg_type-default-host-9999-arbiter-10000 type Object\n", + "2023-09-13 20:36:13,436 - fate.arch._standalone - DEBUG - [federation.standalone.get.agg_type.default] got object with type: \n", + "2023-09-13 20:36:13,441 - fate.ml.nn.algo.homo.fedavg - INFO - Using secure_aggregate aggregator\n", + "2023-09-13 20:36:13,441 - fate.arch._standalone - DEBUG - [federation.standalone.get.local_weight_fedavg_0.default]\n", + "2023-09-13 20:36:13,441 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-local_weight_fedavg_0-default-guest-10000-arbiter-10000 type Object\n", + "2023-09-13 20:36:13,446 - fate.arch._standalone - DEBUG - [federation.standalone.get.local_weight_fedavg_0.default] got object with type: \n", + "2023-09-13 20:36:13,450 - fate.arch._standalone - DEBUG - [federation.standalone.get.local_weight_fedavg_0.default]\n", + "2023-09-13 20:36:13,450 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-local_weight_fedavg_0-default-host-9999-arbiter-10000 type Object\n", + "2023-09-13 20:36:13,454 - fate.arch._standalone - DEBUG - [federation.standalone.get.local_weight_fedavg_0.default] got object with type: \n", + "2023-09-13 20:36:13,459 - fate.arch._standalone - DEBUG - [federation.standalone.remote.agg_weight_fedavg_0.default]remote data, type=\n", + "2023-09-13 20:36:13,459 - fate.arch._standalone - DEBUG - [federation.standalone.remote.agg_weight_fedavg_0.default]remote object with type: \n", + "2023-09-13 20:36:13,470 - fate.arch._standalone - DEBUG - [federation.standalone.remote.agg_weight_fedavg_0.default]remote data, type=\n", + "2023-09-13 20:36:13,470 - fate.arch._standalone - DEBUG - [federation.standalone.remote.agg_weight_fedavg_0.default]remote object with type: \n", + "2023-09-13 20:36:13,482 - fate.ml.nn.trainer.trainer_base - INFO - Initialized aggregator Done: \n", + "2023-09-13 20:36:13,482 - fate.arch._standalone - DEBUG - [federation.standalone.get.fed_para_0.default]\n", + "2023-09-13 20:36:13,682 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-fed_para_0-default-host-9999-arbiter-10000 type Object\n", + "2023-09-13 20:36:13,692 - fate.arch._standalone - DEBUG - [federation.standalone.get.fed_para_0.default] got object with type: \n", + "2023-09-13 20:36:13,700 - fate.arch._standalone - DEBUG - [federation.standalone.get.fed_para_0.default]\n", + "2023-09-13 20:36:13,800 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-fed_para_0-default-guest-10000-arbiter-10000 type Object\n", + "2023-09-13 20:36:13,805 - fate.arch._standalone - DEBUG - [federation.standalone.get.fed_para_0.default] got object with type: \n", + "2023-09-13 20:36:13,809 - fate.ml.nn.trainer.trainer_base - INFO - checked parameters are {'max_aggregation': 5}\n", + "2023-09-13 20:36:13,810 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_mixed_client_values.default.aggregation-0]\n", + "2023-09-13 20:36:13,810 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-fedavg_model_mixed_client_values-default.aggregation-0-guest-10000-arbiter-10000 type Object\n", + "2023-09-13 20:36:13,815 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_mixed_client_values.default.aggregation-0] got object with type: \n", + "2023-09-13 20:36:13,819 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_mixed_client_values.default.aggregation-0]\n", + "2023-09-13 20:36:13,819 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-fedavg_model_mixed_client_values-default.aggregation-0-host-9999-arbiter-10000 type Object\n", + "2023-09-13 20:36:13,824 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_mixed_client_values.default.aggregation-0] got object with type: \n", + "2023-09-13 20:36:13,829 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_aggregated_values.default.aggregation-0]remote data, type=\n", + "2023-09-13 20:36:13,829 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_aggregated_values.default.aggregation-0]remote object with type: \n", + "2023-09-13 20:36:13,838 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_aggregated_values.default.aggregation-0]remote data, type=\n", + "2023-09-13 20:36:13,838 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_aggregated_values.default.aggregation-0]remote object with type: \n", + "2023-09-13 20:36:13,847 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_mixed_client_values.default.aggregation-1]\n", + "2023-09-13 20:36:13,947 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-fedavg_model_mixed_client_values-default.aggregation-1-guest-10000-arbiter-10000 type Object\n", + "2023-09-13 20:36:13,954 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_mixed_client_values.default.aggregation-1] got object with type: \n", + "2023-09-13 20:36:13,958 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_mixed_client_values.default.aggregation-1]\n", + "2023-09-13 20:36:13,958 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-fedavg_model_mixed_client_values-default.aggregation-1-host-9999-arbiter-10000 type Object\n", + "2023-09-13 20:36:13,963 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_mixed_client_values.default.aggregation-1] got object with type: \n", + "2023-09-13 20:36:13,968 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_aggregated_values.default.aggregation-1]remote data, type=\n", + "2023-09-13 20:36:13,968 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_aggregated_values.default.aggregation-1]remote object with type: \n", + "2023-09-13 20:36:13,977 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_aggregated_values.default.aggregation-1]remote data, type=\n", + "2023-09-13 20:36:13,977 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_aggregated_values.default.aggregation-1]remote object with type: \n", + "2023-09-13 20:36:13,986 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_mixed_client_values.default.aggregation-2]\n", + "2023-09-13 20:36:14,086 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-fedavg_model_mixed_client_values-default.aggregation-2-guest-10000-arbiter-10000 type Object\n", + "2023-09-13 20:36:14,094 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_mixed_client_values.default.aggregation-2] got object with type: \n", + "2023-09-13 20:36:14,098 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_mixed_client_values.default.aggregation-2]\n", + "2023-09-13 20:36:14,098 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-fedavg_model_mixed_client_values-default.aggregation-2-host-9999-arbiter-10000 type Object\n", + "2023-09-13 20:36:14,103 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_mixed_client_values.default.aggregation-2] got object with type: \n", + "2023-09-13 20:36:14,107 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_aggregated_values.default.aggregation-2]remote data, type=\n", + "2023-09-13 20:36:14,107 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_aggregated_values.default.aggregation-2]remote object with type: \n", + "2023-09-13 20:36:14,116 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_aggregated_values.default.aggregation-2]remote data, type=\n", + "2023-09-13 20:36:14,116 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_aggregated_values.default.aggregation-2]remote object with type: \n", + "2023-09-13 20:36:14,125 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_mixed_client_values.default.aggregation-3]\n", + "2023-09-13 20:36:14,225 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-fedavg_model_mixed_client_values-default.aggregation-3-guest-10000-arbiter-10000 type Object\n", + "2023-09-13 20:36:14,232 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_mixed_client_values.default.aggregation-3] got object with type: \n", + "2023-09-13 20:36:14,239 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_mixed_client_values.default.aggregation-3]\n", + "2023-09-13 20:36:14,239 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-fedavg_model_mixed_client_values-default.aggregation-3-host-9999-arbiter-10000 type Object\n", + "2023-09-13 20:36:14,243 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_mixed_client_values.default.aggregation-3] got object with type: \n", + "2023-09-13 20:36:14,247 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_aggregated_values.default.aggregation-3]remote data, type=\n", + "2023-09-13 20:36:14,248 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_aggregated_values.default.aggregation-3]remote object with type: \n", + "2023-09-13 20:36:14,257 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_aggregated_values.default.aggregation-3]remote data, type=\n", + "2023-09-13 20:36:14,257 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_aggregated_values.default.aggregation-3]remote object with type: \n", + "2023-09-13 20:36:14,266 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_mixed_client_values.default.aggregation-4]\n", + "2023-09-13 20:36:14,366 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-fedavg_model_mixed_client_values-default.aggregation-4-guest-10000-arbiter-10000 type Object\n", + "2023-09-13 20:36:14,377 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_mixed_client_values.default.aggregation-4] got object with type: \n", + "2023-09-13 20:36:14,382 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_mixed_client_values.default.aggregation-4]\n", + "2023-09-13 20:36:14,382 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-fedavg_model_mixed_client_values-default.aggregation-4-host-9999-arbiter-10000 type Object\n", + "2023-09-13 20:36:14,386 - fate.arch._standalone - DEBUG - [federation.standalone.get.fedavg_model_mixed_client_values.default.aggregation-4] got object with type: \n", + "2023-09-13 20:36:14,391 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_aggregated_values.default.aggregation-4]remote data, type=\n", + "2023-09-13 20:36:14,391 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_aggregated_values.default.aggregation-4]remote object with type: \n", + "2023-09-13 20:36:14,400 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_aggregated_values.default.aggregation-4]remote data, type=\n", + "2023-09-13 20:36:14,400 - fate.arch._standalone - DEBUG - [federation.standalone.remote.fedavg_model_aggregated_values.default.aggregation-4]remote object with type: \n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "fate-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/doc/2.0/images/HeteroLR.png b/doc/2.0/images/HeteroLR.png new file mode 100644 index 0000000000..775d99af3a Binary files /dev/null and b/doc/2.0/images/HeteroLR.png differ diff --git a/doc/2.0/images/HeteroLinR.png b/doc/2.0/images/HeteroLinR.png new file mode 100644 index 0000000000..842405493e Binary files /dev/null and b/doc/2.0/images/HeteroLinR.png differ diff --git a/doc/2.0/images/binning_principle.png b/doc/2.0/images/binning_principle.png new file mode 100644 index 0000000000..40aa7fcd4e Binary files /dev/null and b/doc/2.0/images/binning_principle.png differ diff --git a/doc/2.0/images/ecdh_intersection.png b/doc/2.0/images/ecdh_intersection.png new file mode 100644 index 0000000000..7cecca182f Binary files /dev/null and b/doc/2.0/images/ecdh_intersection.png differ diff --git a/doc/2.0/images/hetero_lr_multi_host.png b/doc/2.0/images/hetero_lr_multi_host.png new file mode 100644 index 0000000000..67da2e3ee3 Binary files /dev/null and b/doc/2.0/images/hetero_lr_multi_host.png differ diff --git a/doc/2.0/images/multi_host_selection.png b/doc/2.0/images/multi_host_selection.png new file mode 100644 index 0000000000..8d68839e87 Binary files /dev/null and b/doc/2.0/images/multi_host_selection.png differ diff --git a/doc/2.0/images/multiple_host_binning.png b/doc/2.0/images/multiple_host_binning.png new file mode 100644 index 0000000000..0c8d4127ea Binary files /dev/null and b/doc/2.0/images/multiple_host_binning.png differ diff --git a/doc/2.0/quick_start.md b/doc/2.0/quick_start.md index eab38a25a6..f4945b525d 100644 --- a/doc/2.0/quick_start.md +++ b/doc/2.0/quick_start.md @@ -3,8 +3,16 @@ 1. install `fate_client` with extra package `fate` ```sh -python -m pip install -U pip && python -m pip install fate_client[fate]==2.0.0a0 +python -m pip install -U pip && python -m pip install fate_client[fate,fate_flow]==2.0.0b0 ``` +after installing packages successfully, initialize fate_flow service and fate_client + +```sh +mkdir fate_workspace +fate_flow init --ip 127.0.0.1 --port 9380 --home $(pwd)/fate_workspace +pipeline init --ip 127.0.0.1 --port 9380 +``` + 2. download example data @@ -13,90 +21,96 @@ wget https://raw.githubusercontent.com/wiki/FederatedAI/FATE/example/data/breast wget https://raw.githubusercontent.com/wiki/FederatedAI/FATE/example/data/breast_hetero_host.csv ``` -3. run example with fate_client - +3. transform example data to dataframe using in fate ```python import os +from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline import StandalonePipeline -from fate_client.pipeline.components.fate import ( - Evaluation, - FeatureScale, - HeteroLR, - Intersection, - Reader, -) base_path = os.path.abspath(os.path.join(__file__, os.path.pardir)) guest_data_path = os.path.join(base_path, "breast_hetero_guest.csv") host_data_path = os.path.join(base_path, "breast_hetero_host.csv") -# create pipeline -pipeline = StandalonePipeline().set_roles(guest="9999", host="10000", arbiter="10001") - -# create reader component -reader_0 = Reader(name="reader_0") -reader_0.guest.component_param( - path=f"file://${guest_data_path}", - format="csv", - id_name="id", - delimiter=",", - label_name="y", - label_type="float32", - dtype="float32", -) -reader_0.hosts[0].component_param( - path=f"file://${host_data_path}", - format="csv", - id_name="id", - delimiter=",", - label_name=None, - dtype="float32", +data_pipeline = FateFlowPipeline().set_roles(local="0") +guest_meta = { + "delimiter": ",", "dtype": "float64", "label_type": "int64","label_name": "y", "match_id_name": "id" +} +host_meta = { + "delimiter": ",", "input_format": "dense", "match_id_name": "id" +} +data_pipeline.transform_local_file_to_dataframe(file=guest_data_path, namespace="experiment", name="breast_hetero_guest", + meta=guest_meta, head=True, extend_sid=True) +data_pipeline.transform_local_file_to_dataframe(file=host_data_path, namespace="experiment", name="breast_hetero_host", + meta=host_meta, head=True, extend_sid=True) +``` +4. run example + +```python +from fate_client.pipeline.components.fate import ( + HeteroSecureBoost, + PSI, + Evaluation ) +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.interface import DataWarehouseChannel + -# create intersection component -intersection_0 = Intersection(name="intersection_0", method="raw", input_data=reader_0.outputs["output_data"]) -intersection_1 = Intersection(name="intersection_1", method="raw", input_data=reader_0.outputs["output_data"]) +# create pipeline for training +pipeline = FateFlowPipeline().set_roles(guest="9999", host="10000") -# create feature scale component -feature_scale_0 = FeatureScale( - name="feature_scale_0", method="standard", train_data=intersection_0.outputs["output_data"] +# create psi component_desc +psi_0 = PSI("psi_0") +psi_0.guest.component_setting( + input_data=DataWarehouseChannel(name="breast_hetero_guest", namespace="experiment") ) -feature_scale_1 = FeatureScale( - name="feature_scale_1", - test_data=intersection_1.outputs["output_data"], - input_model=feature_scale_0.outputs["output_model"], +psi_0.hosts[0].component_setting( + input_data=DataWarehouseChannel(name="breast_hetero_host", namespace="experiment") ) -# create lr component -lr_0 = HeteroLR( - name="lr_0", - train_data=feature_scale_0.outputs["train_output_data"], - validate_data=feature_scale_1.outputs["test_output_data"], - max_iter=100, - learning_rate=0.03, - batch_size=-1, +# create hetero secure_boost component_desc +hetero_secureboost_0 = HeteroSecureBoost( + 'hetero_secureboost_0', num_trees=1, max_depth=5, + train_data=psi_0.outputs['output_data'], + validate_data=psi_0.outputs["output_data"] ) -# create evaluation component -evaluation_0 = Evaluation(name="evaluation_0", runtime_roles="guest", input_data=lr_0.outputs["train_output_data"]) +# create evaluation component_desc +evaluation_0 = Evaluation( + 'evaluation_0', runtime_roles=['guest'], metrics=['auc'], input_data=[hetero_secureboost_0.outputs['train_data_output']] +) -# add components -pipeline.add_task(reader_0) -pipeline.add_task(feature_scale_0) -pipeline.add_task(feature_scale_1) -pipeline.add_task(intersection_0) -pipeline.add_task(intersection_1) -pipeline.add_task(lr_0) +# add training task +pipeline.add_task(psi_0) +pipeline.add_task(hetero_secureboost_0) pipeline.add_task(evaluation_0) -# train +# compile and train pipeline.compile() -print(pipeline.get_dag()) pipeline.fit() -print(pipeline.get_task_info("feature_scale_0").get_output_model()) -print(pipeline.get_task_info("lr_0").get_output_model()) -print(pipeline.get_task_info("lr_0").get_output_data()) -print(pipeline.get_task_info("evaluation_0").get_output_metrics()) -print(pipeline.deploy([intersection_0, feature_scale_0, lr_0])) + +# print metric and model info +print (pipeline.get_task_info("hetero_secureboost_0").get_output_model()) +print (pipeline.get_task_info("evaluation_0").get_output_metric()) + +# deploy task for inference +pipeline.deploy([psi_0, hetero_secureboost_0]) + +# create pipeline for predicting +predict_pipeline = FateFlowPipeline() + +# add input to deployed_pipeline +deployed_pipeline = pipeline.get_deployed_pipeline() +deployed_pipeline.psi_0.guest.component_setting( + input_data=DataWarehouseChannel(name="breast_hetero_guest", namespace=f"experiment") +) +deployed_pipeline.psi_0.hosts[0].component_setting( + input_data=DataWarehouseChannel(name="breast_hetero_host", namespace=f"experiment") +) + +# add task to predict pipeline +predict_pipeline.add_task(deployed_pipeline) + +# compile and predict +predict_pipeline.compile() +predict_pipeline.predict() ``` diff --git a/doc/images/FATE_logo.png b/doc/images/FATE_logo.png new file mode 100644 index 0000000000..672e286a8c Binary files /dev/null and b/doc/images/FATE_logo.png differ diff --git a/doc/tutorial/pipeline_tutorial_hetero.ipynb b/doc/tutorial/pipeline_tutorial_hetero.ipynb new file mode 100644 index 0000000000..8a392cc17b --- /dev/null +++ b/doc/tutorial/pipeline_tutorial_hetero.ipynb @@ -0,0 +1,688 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Pipeline Tutorial with Hetero Components" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### install" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`Pipeline` is distributed along with [fate_client](https://pypi.org/project/fate-client/).\n", + "\n", + "```bash\n", + "pip install fate_client\n", + "```\n", + "\n", + "To use Pipeline, we need to first specify which `FATE Flow Service` to connect to. Once `fate_client` installed, one can find an cmd enterpoint name `pipeline`:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Usage: pipeline [OPTIONS] COMMAND [ARGS]...\n", + "\n", + "Options:\n", + " --help Show this message and exit.\n", + "\n", + "Commands:\n", + " init pipeline init\n", + " show - DESCRIPTION: Show pipeline config details for Flow server.\n", + " site-info pipeline site info\n" + ] + } + ], + "source": [ + "!pipeline --help" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Assume we have a `FATE Flow Service` in 127.0.0.1:9380(defaults in standalone), then exec" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Pipeline configuration succeeded.\n" + ] + } + ], + "source": [ + "!pipeline init --ip 127.0.0.1 --port 9380" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Hetero Example" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " Before start a modeling task, data to be used should be transformed into dataframe. Please refer to this [guide](./pipeline_tutorial_transform_local_file_to_dataframe.ipynb)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `pipeline` package provides components to compose a `FATE pipeline`." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from fate_client.pipeline import FateFlowPipeline\n", + "from fate_client.pipeline.components.fate import PSI, CoordinatedLR, Evaluation\n", + "from fate_client.pipeline.interface import DataWarehouseChannel" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Make a `pipeline` instance:\n", + "\n", + " - initiator: \n", + " * role: guest\n", + " * party: 9999\n", + " - roles:\n", + " * guest: 9999\n", + " * host: 10000\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "pipeline = FateFlowPipeline().set_roles(guest='9999', host='10000', arbiter='10000')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Add `PSI` component to perform PSI for hetero-scenario. Since this is the first component, specify input data frame from `DataWarehouseChannel`." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "psi_0 = PSI(\"psi_0\")\n", + "psi_0.guest.component_setting(input_data=DataWarehouseChannel(name=\"breast_hetero_guest\",\n", + " namespace=\"experiment\"))\n", + "psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name=\"breast_hetero_host\",\n", + " namespace=\"experiment\"))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we add training component CoordinatedLR and another LR component that predicts with model from previous component. Here we show how to feed output data and model from one component to another." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "lr_0 = CoordinatedLR(\"lr_0\",\n", + " epochs=5,\n", + " batch_size=None,\n", + " optimizer={\"method\": \"SGD\", \"optimizer_params\": {\"lr\": 0.1}, \"penalty\": \"l2\", \"alpha\": 0.001},\n", + " init_param={\"fit_intercept\": True, \"method\": \"zeros\"},\n", + " train_data=psi_0.outputs[\"output_data\"],\n", + " learning_rate_scheduler={\"method\": \"linear\", \"scheduler_params\": {\"start_factor\": 0.7,\n", + " \"total_iters\": 100}})\n", + "lr_1 = CoordinatedLR(\"lr_1\", \n", + " input_model=lr_0.outputs[\"output_model\"],\n", + " test_data=psi_0.outputs[\"output_data\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To show the evaluation result, an \"Evaluation\" component is needed." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "evaluation_0 = Evaluation(\"evaluation_0\",\n", + " runtime_roles=[\"guest\"],\n", + " default_eval_setting=\"binary\",\n", + " input_data=lr_0.outputs[\"train_output_data\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Add components to pipeline, in order of execution:\n", + "\n", + " - `psi_0` is responsible for finding overlapping match id\n", + " - `lr_0` trains Coordinated LR on data output by `psi_0`\n", + " - `lr_1` predicts with model from `lr_0`\n", + " - `evaluation_0` consumes `lr_0`'s prediciton result on training data\n", + "\n", + "Then compile our pipeline to make it ready for submission." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "pipeline.add_task(psi_0)\n", + "pipeline.add_task(lr_0)\n", + "pipeline.add_task(lr_1)\n", + "pipeline.add_task(evaluation_0)\n", + "\n", + "pipeline.compile();" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, submit(fit) our pipeline:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Job id is 202308311051324015890\n", + "\n", + "\u001b[80D\u001b[1A\u001b[KJob is waiting, time elapse: 0:00:00\n", + "\u001b[80D\u001b[1A\u001b[KJob is waiting, time elapse: 0:00:01\n", + "\n", + "\u001b[80D\u001b[1A\u001b[KRunning task psi_0, time elapse: 0:00:02\n", + "\u001b[80D\u001b[1A\u001b[KRunning task psi_0, time elapse: 0:00:03\n", + "\u001b[80D\u001b[1A\u001b[KRunning task psi_0, time elapse: 0:00:04\n", + "\u001b[80D\u001b[1A\u001b[KRunning task psi_0, time elapse: 0:00:05\n", + "\u001b[80D\u001b[1A\u001b[KRunning task psi_0, time elapse: 0:00:06\n", + "\u001b[80D\u001b[1A\u001b[KRunning task psi_0, time elapse: 0:00:07\n", + "\u001b[80D\u001b[1A\u001b[KRunning task psi_0, time elapse: 0:00:08\n", + "\u001b[80D\u001b[1A\u001b[KRunning task psi_0, time elapse: 0:00:09\n", + "\u001b[80D\u001b[1A\u001b[KRunning task psi_0, time elapse: 0:00:10\n", + "\u001b[80D\u001b[1A\u001b[KRunning task psi_0, time elapse: 0:00:11\n", + "\u001b[80D\u001b[1A\u001b[KRunning task psi_0, time elapse: 0:00:12\n", + "\u001b[80D\u001b[1A\u001b[KRunning task psi_0, time elapse: 0:00:13\n", + "\u001b[80D\u001b[1A\u001b[KRunning task psi_0, time elapse: 0:00:14\n", + "\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:15\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:16\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:17\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:18\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:19\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:20\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:21\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:22\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:23\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:24\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:25\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:26\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:27\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:28\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:29\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:30\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:31\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:32\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:33\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:34\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:35\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:36\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:37\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:38\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:39\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:40\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:41\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:42\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:43\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:44\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:45\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:46\n", + "\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_1, time elapse: 0:00:47\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_1, time elapse: 0:00:48\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_1, time elapse: 0:00:49\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_1, time elapse: 0:00:50\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_1, time elapse: 0:00:51\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_1, time elapse: 0:00:52\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_1, time elapse: 0:00:53\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_1, time elapse: 0:00:54\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_1, time elapse: 0:00:55\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_1, time elapse: 0:00:56\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_1, time elapse: 0:00:57\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_1, time elapse: 0:00:58\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_1, time elapse: 0:01:00\n", + "\n", + "\u001b[80D\u001b[1A\u001b[KRunning task evaluation_0, time elapse: 0:01:01\n", + "\u001b[80D\u001b[1A\u001b[KRunning task evaluation_0, time elapse: 0:01:02\n", + "\u001b[80D\u001b[1A\u001b[KRunning task evaluation_0, time elapse: 0:01:03\n", + "\u001b[80D\u001b[1A\u001b[KRunning task evaluation_0, time elapse: 0:01:04\n", + "\u001b[80D\u001b[1A\u001b[KRunning task evaluation_0, time elapse: 0:01:05\n", + "\u001b[80D\u001b[1A\u001b[KRunning task evaluation_0, time elapse: 0:01:06\n", + "\u001b[80D\u001b[1A\u001b[KRunning task evaluation_0, time elapse: 0:01:07\n", + "\u001b[80D\u001b[1A\u001b[KRunning task evaluation_0, time elapse: 0:01:08\n", + "Job is success!!! Job id is 202308311051324015890, response_data={'apply_resource_time': 1693450294353, 'cores': 4, 'create_time': 1693450292411, 'dag': {'dag': {'conf': {'auto_retries': 0, 'computing_partitions': 8, 'cores': None, 'engine': None, 'inheritance': None, 'initiator_party_id': '9999', 'model_id': '202308311051324015890', 'model_version': '0', 'model_warehouse': None, 'priority': None, 'scheduler_party_id': '9999', 'sync_type': 'callback', 'task': None, 'task_cores': None}, 'parties': [{'party_id': ['9999'], 'role': 'guest'}, {'party_id': ['10000'], 'role': 'host'}, {'party_id': ['10000'], 'role': 'arbiter'}], 'party_tasks': {'guest_9999': {'conf': None, 'parties': [{'party_id': ['9999'], 'role': 'guest'}], 'tasks': {'psi_0': {'conf': None, 'inputs': {'data': {'input_data': {'data_warehouse': {'job_id': None, 'name': 'breast_hetero_guest', 'namespace': 'experiment', 'output_artifact_key': None, 'producer_task': None, 'roles': ['guest']}}}, 'model': None}, 'parameters': None}}}, 'host_10000': {'conf': None, 'parties': [{'party_id': ['10000'], 'role': 'host'}], 'tasks': {'psi_0': {'conf': None, 'inputs': {'data': {'input_data': {'data_warehouse': {'job_id': None, 'name': 'breast_hetero_host', 'namespace': 'experiment', 'output_artifact_key': None, 'producer_task': None, 'roles': ['host']}}}, 'model': None}, 'parameters': None}}}}, 'stage': 'train', 'tasks': {'evaluation_0': {'component_ref': 'evaluation', 'conf': None, 'dependent_tasks': ['lr_0'], 'inputs': {'data': {'input_data': {'task_output_artifact': [{'output_artifact_key': 'train_output_data', 'producer_task': 'lr_0', 'roles': ['guest']}]}}, 'model': None}, 'parameters': {'default_eval_setting': 'binary', 'label_column_name': None, 'metrics': None, 'predict_column_name': None}, 'parties': [{'party_id': ['9999'], 'role': 'guest'}], 'stage': 'default'}, 'lr_0': {'component_ref': 'coordinated_lr', 'conf': None, 'dependent_tasks': ['psi_0'], 'inputs': {'data': {'train_data': {'task_output_artifact': {'output_artifact_key': 'output_data', 'producer_task': 'psi_0', 'roles': ['guest', 'host']}}}, 'model': {}}, 'parameters': {'batch_size': None, 'early_stop': 'diff', 'epochs': 5, 'init_param': {'fit_intercept': True, 'method': 'zeros'}, 'learning_rate_scheduler': {'method': 'linear', 'scheduler_params': {'start_factor': 0.7, 'total_iters': 100}}, 'optimizer': {'alpha': 0.001, 'method': 'SGD', 'optimizer_params': {'lr': 0.1}, 'penalty': 'l2'}, 'output_cv_data': True, 'threshold': 0.5, 'tol': 0.0001}, 'parties': None, 'stage': None}, 'lr_1': {'component_ref': 'coordinated_lr', 'conf': None, 'dependent_tasks': ['lr_0', 'psi_0'], 'inputs': {'data': {'test_data': {'task_output_artifact': {'output_artifact_key': 'output_data', 'producer_task': 'psi_0', 'roles': ['guest', 'host']}}}, 'model': {'input_model': {'task_output_artifact': {'output_artifact_key': 'output_model', 'producer_task': 'lr_0', 'roles': ['guest', 'host']}}}}, 'parameters': {'batch_size': None, 'early_stop': 'diff', 'epochs': 20, 'output_cv_data': True, 'threshold': 0.5, 'tol': 0.0001}, 'parties': None, 'stage': 'predict'}, 'psi_0': {'component_ref': 'psi', 'conf': None, 'dependent_tasks': None, 'inputs': {'data': {}, 'model': None}, 'parameters': {}, 'parties': [{'party_id': ['9999'], 'role': 'guest'}, {'party_id': ['10000'], 'role': 'host'}], 'stage': 'default'}}}, 'schema_version': '2.0.0.alpha'}, 'description': '', 'elapsed': 66740, 'end_time': 1693450361127, 'engine_name': 'standalone', 'inheritance': {}, 'initiator_party_id': '9999', 'job_id': '202308311051324015890', 'memory': 0, 'model_id': '202308311051324015890', 'model_version': '0', 'parties': [{'party_id': ['9999'], 'role': 'guest'}, {'party_id': ['10000'], 'role': 'host'}, {'party_id': ['10000'], 'role': 'arbiter'}], 'party_id': '9999', 'progress': 100, 'remaining_cores': 4, 'remaining_memory': 0, 'resource_in_use': False, 'return_resource_time': 1693450361094, 'role': 'guest', 'scheduler_party_id': '9999', 'start_time': 1693450294387, 'status': 'success', 'status_code': None, 'tag': 'job_end', 'update_time': 1693450361127, 'user_name': ''}\n", + "Total time: 0:01:09\n" + ] + } + ], + "source": [ + "pipeline.fit();" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once training is done, data and model output from trained components may be queried through pipeline api. " + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "

\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
extend_sididlabelpredict_scorepredict_resultpredict_detailtype
0a41979464da4e859ce5f594b3da91582013310.54536363775301791{'0': 0.4546363622469821, '1': 0.5453636377530...train_set
1a41979464da4e859ce5f594b3da915822226200.285892600379450360{'0': 0.7141073996205496, '1': 0.2858926003794...train_set
2a41979464da4e859ce5f594b3da915827611610.75894020809434491{'0': 0.24105979190565507, '1': 0.758940208094...train_set
3a41979464da4e859ce5f594b3da9158211514010.8379348211028451{'0': 0.162065178897155, '1': 0.837934821102845}train_set
4a41979464da4e859ce5f594b3da9158216017410.8197902484828751{'0': 0.18020975151712504, '1': 0.819790248482...train_set
\n", + "
" + ], + "text/plain": [ + " extend_sid id label predict_score \\\n", + "0 a41979464da4e859ce5f594b3da915820 133 1 0.5453636377530179 \n", + "1 a41979464da4e859ce5f594b3da9158222 262 0 0.28589260037945036 \n", + "2 a41979464da4e859ce5f594b3da9158276 116 1 0.7589402080943449 \n", + "3 a41979464da4e859ce5f594b3da91582115 140 1 0.837934821102845 \n", + "4 a41979464da4e859ce5f594b3da91582160 174 1 0.819790248482875 \n", + "\n", + " predict_result predict_detail type \n", + "0 1 {'0': 0.4546363622469821, '1': 0.5453636377530... train_set \n", + "1 0 {'0': 0.7141073996205496, '1': 0.2858926003794... train_set \n", + "2 1 {'0': 0.24105979190565507, '1': 0.758940208094... train_set \n", + "3 1 {'0': 0.162065178897155, '1': 0.837934821102845} train_set \n", + "4 1 {'0': 0.18020975151712504, '1': 0.819790248482... train_set " + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "lr_0_data = pipeline.get_task_info(\"lr_0\").get_output_data()[\"train_output_data\"]\n", + "import pandas as pd\n", + "pd.DataFrame(lr_0_data).head()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'output_model': {'data': {'estimator': {'end_epoch': 5,\n", + " 'fit_intercept': True,\n", + " 'is_converged': False,\n", + " 'lr_scheduler': {'lr_params': {'start_factor': 0.7, 'total_iters': 100},\n", + " 'lr_scheduler': {'_get_lr_called_within_step': False,\n", + " '_last_lr': [0.07119999999999999],\n", + " '_step_count': 5,\n", + " 'base_lrs': [0.1],\n", + " 'end_factor': 1.0,\n", + " 'last_epoch': 4,\n", + " 'start_factor': 0.7,\n", + " 'total_iters': 100,\n", + " 'verbose': False},\n", + " 'method': 'linear'},\n", + " 'optimizer': {'alpha': 0.001,\n", + " 'l1_penalty': False,\n", + " 'l2_penalty': True,\n", + " 'method': 'sgd',\n", + " 'model_parameter': [[0.0],\n", + " [0.0],\n", + " [0.0],\n", + " [0.0],\n", + " [0.0],\n", + " [0.0],\n", + " [0.0],\n", + " [0.0],\n", + " [0.0],\n", + " [0.0],\n", + " [0.0]],\n", + " 'model_parameter_dtype': 'float32',\n", + " 'optim_param': {'lr': 0.1},\n", + " 'optimizer': {'param_groups': [{'dampening': 0,\n", + " 'differentiable': False,\n", + " 'foreach': None,\n", + " 'initial_lr': 0.1,\n", + " 'lr': 0.07119999999999999,\n", + " 'maximize': False,\n", + " 'momentum': 0,\n", + " 'nesterov': False,\n", + " 'params': [0],\n", + " 'weight_decay': 0}],\n", + " 'state': {}}},\n", + " 'param': {'coef_': [[-0.0878903686629256],\n", + " [-0.05677242358584973],\n", + " [-0.08771869885341368],\n", + " [-0.08136158941522312],\n", + " [-0.04950030091279235],\n", + " [-0.06369604907729508],\n", + " [-0.07172871180928618],\n", + " [-0.08904661502230068],\n", + " [-0.04913537990226004],\n", + " [-0.03418310333218406]],\n", + " 'dtype': 'float64',\n", + " 'intercept_': [0.04341809752512136]}}},\n", + " 'meta': {'batch_size': None,\n", + " 'epochs': 5,\n", + " 'init_param': {'fill_val': 0.0,\n", + " 'fit_intercept': True,\n", + " 'method': 'zeros',\n", + " 'random_state': None},\n", + " 'labels': [0, 1],\n", + " 'learning_rate_param': {'method': 'linear',\n", + " 'scheduler_params': {'start_factor': 0.7, 'total_iters': 100}},\n", + " 'optimizer_param': {'alpha': 0.001,\n", + " 'method': 'sgd',\n", + " 'optimizer_params': {'lr': 0.1},\n", + " 'penalty': 'l2'},\n", + " 'ovr': False,\n", + " 'threshold': 0.5}}}" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "lr_0_model = pipeline.get_task_info(\"lr_0\").get_output_model()\n", + "lr_0_model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To run prediction, trained components should first be deployed." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "pipeline.deploy([psi_0, lr_0]);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then, get deployed pipeline." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "deployed_pipeline = pipeline.get_deployed_pipeline()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Specify data input for predict pipeline." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "deployed_pipeline.psi_0.guest.component_setting(input_data=DataWarehouseChannel(name=\"breast_hetero_guest\",\n", + " namespace=\"experiment\"))\n", + "deployed_pipeline.psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name=\"breast_hetero_host\",\n", + " namespace=\"experiment\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Add components to predict pipeline in order of execution:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "predict_pipeline = FateFlowPipeline()\n", + "predict_pipeline.add_task(deployed_pipeline)\n", + "predict_pipeline.compile();" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then, run prediction job" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Job id is 202308311054193818250\n", + "\n", + "\u001b[80D\u001b[1A\u001b[KJob is waiting, time elapse: 0:00:00\n", + "\n", + "\u001b[80D\u001b[1A\u001b[KRunning task psi_0, time elapse: 0:00:01\n", + "\u001b[80D\u001b[1A\u001b[KRunning task psi_0, time elapse: 0:00:02\n", + "\u001b[80D\u001b[1A\u001b[KRunning task psi_0, time elapse: 0:00:03\n", + "\u001b[80D\u001b[1A\u001b[KRunning task psi_0, time elapse: 0:00:04\n", + "\u001b[80D\u001b[1A\u001b[KRunning task psi_0, time elapse: 0:00:05\n", + "\u001b[80D\u001b[1A\u001b[KRunning task psi_0, time elapse: 0:00:06\n", + "\u001b[80D\u001b[1A\u001b[KRunning task psi_0, time elapse: 0:00:07\n", + "\u001b[80D\u001b[1A\u001b[KRunning task psi_0, time elapse: 0:00:08\n", + "\u001b[80D\u001b[1A\u001b[KRunning task psi_0, time elapse: 0:00:09\n", + "\u001b[80D\u001b[1A\u001b[KRunning task psi_0, time elapse: 0:00:10\n", + "\u001b[80D\u001b[1A\u001b[KRunning task psi_0, time elapse: 0:00:11\n", + "\u001b[80D\u001b[1A\u001b[KRunning task psi_0, time elapse: 0:00:12\n", + "\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:13\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:14\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:15\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:16\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:17\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:18\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:19\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:20\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:21\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:22\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:23\n", + "\u001b[80D\u001b[1A\u001b[KRunning task lr_0, time elapse: 0:00:24\n", + "Job is success!!! Job id is 202308311054193818250, response_data={'apply_resource_time': 1693450459527, 'cores': 4, 'create_time': 1693450459392, 'dag': {'dag': {'conf': {'auto_retries': 0, 'computing_partitions': 8, 'cores': None, 'engine': None, 'inheritance': None, 'initiator_party_id': '9999', 'model_id': '202308311054193818250', 'model_version': '0', 'model_warehouse': {'model_id': '202308311051324015890', 'model_version': '0'}, 'priority': None, 'scheduler_party_id': '9999', 'sync_type': 'callback', 'task': None, 'task_cores': None}, 'parties': [{'party_id': ['9999'], 'role': 'guest'}, {'party_id': ['10000'], 'role': 'host'}, {'party_id': ['10000'], 'role': 'arbiter'}], 'party_tasks': {'guest_9999': {'conf': None, 'parties': [{'party_id': ['9999'], 'role': 'guest'}], 'tasks': {'psi_0': {'conf': None, 'inputs': {'data': {'input_data': {'data_warehouse': {'job_id': None, 'name': 'breast_hetero_guest', 'namespace': 'experiment', 'output_artifact_key': None, 'producer_task': None, 'roles': ['guest']}}}, 'model': None}, 'parameters': None}}}, 'host_10000': {'conf': None, 'parties': [{'party_id': ['10000'], 'role': 'host'}], 'tasks': {'psi_0': {'conf': None, 'inputs': {'data': {'input_data': {'data_warehouse': {'job_id': None, 'name': 'breast_hetero_host', 'namespace': 'experiment', 'output_artifact_key': None, 'producer_task': None, 'roles': ['host']}}}, 'model': None}, 'parameters': None}}}}, 'stage': 'predict', 'tasks': {'lr_0': {'component_ref': 'coordinated_lr', 'conf': None, 'dependent_tasks': ['psi_0'], 'inputs': {'data': {'test_data': {'task_output_artifact': {'output_artifact_key': 'output_data', 'producer_task': 'psi_0', 'roles': ['guest', 'host']}}}, 'model': {'input_model': {'model_warehouse': {'output_artifact_key': 'output_model', 'producer_task': 'lr_0', 'roles': ['guest', 'host']}}}}, 'parameters': {'batch_size': None, 'early_stop': 'diff', 'epochs': 5, 'init_param': {'fit_intercept': True, 'method': 'zeros'}, 'learning_rate_scheduler': {'method': 'linear', 'scheduler_params': {'start_factor': 0.7, 'total_iters': 100}}, 'optimizer': {'alpha': 0.001, 'method': 'SGD', 'optimizer_params': {'lr': 0.1}, 'penalty': 'l2'}, 'output_cv_data': True, 'threshold': 0.5, 'tol': 0.0001}, 'parties': None, 'stage': None}, 'psi_0': {'component_ref': 'psi', 'conf': None, 'dependent_tasks': None, 'inputs': {'data': {}, 'model': None}, 'parameters': {}, 'parties': [{'party_id': ['9999'], 'role': 'guest'}, {'party_id': ['10000'], 'role': 'host'}], 'stage': 'default'}}}, 'schema_version': '2.0.0.alpha'}, 'description': '', 'elapsed': 25170, 'end_time': 1693450484709, 'engine_name': 'standalone', 'inheritance': {}, 'initiator_party_id': '9999', 'job_id': '202308311054193818250', 'memory': 0, 'model_id': '202308311054193818250', 'model_version': '0', 'parties': [{'party_id': ['9999'], 'role': 'guest'}, {'party_id': ['10000'], 'role': 'host'}, {'party_id': ['10000'], 'role': 'arbiter'}], 'party_id': '9999', 'progress': 100, 'remaining_cores': 4, 'remaining_memory': 0, 'resource_in_use': False, 'return_resource_time': 1693450484675, 'role': 'guest', 'scheduler_party_id': '9999', 'start_time': 1693450459539, 'status': 'success', 'status_code': None, 'tag': 'job_end', 'update_time': 1693450484709, 'user_name': ''}\n", + "Total time: 0:00:25\n" + ] + } + ], + "source": [ + "predict_pipeline.predict();" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "ad4309918fa4cd1705b305e369b2f64d901b1851e9144aef7b9b07ea3efcb1bb" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/doc/tutorial/pipeline_tutorial_transform_local_file_to_dataframe.ipynb b/doc/tutorial/pipeline_tutorial_transform_local_file_to_dataframe.ipynb new file mode 100644 index 0000000000..cdad85c8ed --- /dev/null +++ b/doc/tutorial/pipeline_tutorial_transform_local_file_to_dataframe.ipynb @@ -0,0 +1,222 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pipeline Data Transform Tutorial " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### install" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`Pipeline` is distributed along with [fate_client](https://pypi.org/project/fate-client/).\n", + "\n", + "```bash\n", + "pip install fate_client\n", + "```\n", + "\n", + "To use Pipeline, we need to first specify which `FATE Flow Service` to connect to. Once `fate_client` installed, one can find an cmd enterpoint name `pipeline`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pipeline --help" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Assume we have a `FATE Flow Service` in 127.0.0.1:9380(defaults in standalone), then exec" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Pipeline configuration succeeded.\n" + ] + } + ], + "source": [ + "!pipeline init --ip 127.0.0.1 --port 9380" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### transform data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " Before start a modeling task, local data to be used should be transformed into dataframe. \n", + " Typically, a party is usually a cluster which include multiple nodes. Thus, when we upload these data, the data will be allocated to those nodes." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from fate_client.pipeline import FateFlowPipeline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Make a `pipeline` instance:\n", + "\n", + " - initiator: \n", + " * role: local\n", + " * party: 0\n", + " - roles:\n", + " * local: 0\n", + "\n", + "note that for\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "pipeline = FateFlowPipeline().set_roles(local=\"0\")\n", + "pipeline.set_site_role(\"local\")\n", + "pipeline.set_site_party_id(\"0\");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Define partitions for data storage" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "meta = {'delimiter': ',',\n", + " 'dtype': 'float32',\n", + " 'input_format': 'dense',\n", + " 'label_type': 'int32',\n", + " 'label_name': 'y',\n", + " 'match_id_name': 'id',\n", + " 'match_id_range': 0,\n", + " 'tag_value_delimiter': ':',\n", + " 'tag_with_value': False,\n", + " 'weight_type': 'float32'}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Define data meta, note that for local file comes with sample id, `sample_id_name` should also be specified." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "partitions = 4" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Define table name and namespace, which will be used in FATE job configuration" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "data_guest = {\"name\": \"breast_hetero_guest\", \"namespace\": f\"experiment\"}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we add data to be uploaded" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can then start transforming data. Function `transform_local_file_to_dataframe` will initialize two tasks: the first task uploads local data, while the second transforms uploaded data into dataframe usable by FATE tasks." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "data_base = \"work_space/FATE/\"\n", + "pipeline.transform_local_file_to_dataframe(file=os.path.join(data_base, \"examples/data/breast_hetero_guest.csv\"),\n", + " meta=meta, head=True, extend_sid=True,\n", + " name=data_guest[\"name\"], # name\n", + " namespace=data_guest[\"namespace\"], # namespace\n", + " partitions=partitions)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/eggroll b/eggroll index a06f9a15f7..0f8e33dcdb 160000 --- a/eggroll +++ b/eggroll @@ -1 +1 @@ -Subproject commit a06f9a15f742c85d84aed0a86631c19f54d382db +Subproject commit 0f8e33dcdbfe0a5184d43b4bf7a7f14e2d635bdd diff --git a/examples/benchmark_performance/coordinated_lr/breast_config.yaml b/examples/benchmark_performance/coordinated_lr/breast_config.yaml new file mode 100644 index 0000000000..d827c47236 --- /dev/null +++ b/examples/benchmark_performance/coordinated_lr/breast_config.yaml @@ -0,0 +1,24 @@ +data_guest: "breast_hetero_guest" +data_host: "breast_hetero_host" +idx: "id" +label_name: "y" +epochs: 20 +init_param: + fit_intercept: True + method: "random_uniform" + random_state: 42 +learning_rate_scheduler: + method: "constant" + scheduler_params: + factor: 1.0 + total_iters: 100 +optimizer: + method: "rmsprop" + penalty: "L2" + optimizer_params: + lr: 0.05 + alpha: 0.1 +batch_size: null +early_stop: "diff" +task_cores: 4 +timeout: 3600 \ No newline at end of file diff --git a/examples/benchmark_performance/coordinated_lr/config.yaml b/examples/benchmark_performance/coordinated_lr/config.yaml new file mode 100644 index 0000000000..1c021a7223 --- /dev/null +++ b/examples/benchmark_performance/coordinated_lr/config.yaml @@ -0,0 +1,11 @@ +parties: # parties default id + guest: + - 9999 + host: + - 9998 + - 9999 + arbiter: + - 9998 + +data_base_dir: "" # path to project base where data is located +timeout: 3600 \ No newline at end of file diff --git a/examples/benchmark_performance/coordinated_lr/coordinated_lr_performance.yaml b/examples/benchmark_performance/coordinated_lr/coordinated_lr_performance.yaml new file mode 100644 index 0000000000..81afb73e56 --- /dev/null +++ b/examples/benchmark_performance/coordinated_lr/coordinated_lr_performance.yaml @@ -0,0 +1,39 @@ +data: + - file: examples/data/breast_hetero_guest.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + label_type: int64 + label_name: y + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_guest + namespace: experiment + role: guest_0 + - file: examples/data/breast_hetero_host.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_host + namespace: experiment + role: host_0 +tasks: + normal-lr: + script: test_lr_sid.py + conf: "./breast_config.yaml" diff --git a/examples/benchmark_performance/coordinated_lr/test_lr_sid.py b/examples/benchmark_performance/coordinated_lr/test_lr_sid.py new file mode 100644 index 0000000000..c58721f4d7 --- /dev/null +++ b/examples/benchmark_performance/coordinated_lr/test_lr_sid.py @@ -0,0 +1,102 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import CoordinatedLR, PSI +from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../../config.yaml", param="./lr_config.yaml", namespace=""): + # obtain config + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + arbiter = parties.arbiter[0] + + if isinstance(param, str): + param = test_utils.JobConfig.load_from_file(param) + + assert isinstance(param, dict) + + guest_data_table = param.get("data_guest") + host_data_table = param.get("data_host") + + guest_train_data = {"name": guest_data_table, "namespace": f"experiment{namespace}"} + host_train_data = {"name": host_data_table, "namespace": f"experiment{namespace}"} + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name=guest_train_data["name"], + namespace=guest_train_data["namespace"])) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name=host_train_data["name"], + namespace=host_train_data["namespace"])) + + lr_param = { + } + + config_param = { + "epochs": param["epochs"], + "learning_rate_scheduler": param["learning_rate_scheduler"], + "optimizer": param["optimizer"], + "batch_size": param["batch_size"], + "early_stop": param["early_stop"], + "init_param": param["init_param"], + "tol": 1e-5 + } + lr_param.update(config_param) + lr_0 = CoordinatedLR("lr_0", + train_data=psi_0.outputs["output_data"], + **lr_param) + lr_1 = CoordinatedLR("lr_1", + test_data=psi_0.outputs["output_data"], + input_model=lr_0.outputs["output_model"]) + + evaluation_0 = Evaluation("evaluation_0", + runtime_roles=["guest"], + metrics=["auc", "binary_precision", "binary_accuracy", "binary_recall"], + input_data=lr_0.outputs["train_output_data"]) + + pipeline.add_task(psi_0) + pipeline.add_task(lr_0) + pipeline.add_task(lr_1) + pipeline.add_task(evaluation_0) + + pipeline.compile() + print(pipeline.get_dag()) + pipeline.fit() + + job_id = pipeline.model_info.job_id + return job_id + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("BENCHMARK-QUALITY PIPELINE JOB") + parser.add_argument("-c", "--config", type=str, + help="config file", default="../../config.yaml") + parser.add_argument("-p", "--param", type=str, + help="config file for params", default="./breast_config.yaml") + args = parser.parse_args() + main(args.config, args.param) diff --git a/examples/benchmark_performance/hetero_secureboost/default_credit_config.yaml b/examples/benchmark_performance/hetero_secureboost/default_credit_config.yaml new file mode 100644 index 0000000000..03425c6618 --- /dev/null +++ b/examples/benchmark_performance/hetero_secureboost/default_credit_config.yaml @@ -0,0 +1,7 @@ +data_guest: "default_credit_hetero_guest" +data_host: "default_credit_hetero_host" +idx: "id" +label_name: "y" +num_trees: 3 +max_depth: 3 +max_bin: 32 \ No newline at end of file diff --git a/examples/benchmark_performance/hetero_secureboost/hetero_secureboost_performance.yaml b/examples/benchmark_performance/hetero_secureboost/hetero_secureboost_performance.yaml new file mode 100644 index 0000000000..2a6a15d687 --- /dev/null +++ b/examples/benchmark_performance/hetero_secureboost/hetero_secureboost_performance.yaml @@ -0,0 +1,39 @@ +data: + - file: examples/data/default_credit_hetero_guest.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + label_type: int64 + label_name: y + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: default_credit_hetero_guest + namespace: experiment + role: guest_0 + - file: examples/data/default_credit_hetero_host.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: default_credit_hetero_host + namespace: experiment + role: host_0 +tasks: + hetero-sbt: + script: test_sbt.py + conf: "./default_credit_config.yaml" diff --git a/examples/benchmark_performance/hetero_secureboost/test_sbt.py b/examples/benchmark_performance/hetero_secureboost/test_sbt.py new file mode 100644 index 0000000000..b8bcb7ab2c --- /dev/null +++ b/examples/benchmark_performance/hetero_secureboost/test_sbt.py @@ -0,0 +1,113 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +from fate_test.utils import parse_summary_result +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils +from fate_client.pipeline.components.fate import HeteroSecureBoost, PSI +from fate_client.pipeline.components.fate.evaluation import Evaluation +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.interface import DataWarehouseChannel + + +def main(config="../../config.yaml", param="./sbt_breast_config.yaml", namespace=""): + # obtain config + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + arbiter = parties.arbiter[0] + + if isinstance(param, str): + param = test_utils.JobConfig.load_from_file(param) + + assert isinstance(param, dict) + + guest_data_table = param.get("data_guest") + host_data_table = param.get("data_host") + + guest_train_data = {"name": guest_data_table, "namespace": f"experiment{namespace}"} + host_train_data = {"name": host_data_table, "namespace": f"experiment{namespace}"} + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name=guest_train_data["name"], + namespace=guest_train_data["namespace"])) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name=host_train_data["name"], + namespace=host_train_data["namespace"])) + config_param = { + "num_trees": param["num_trees"], + "max_depth": param["max_depth"], + "max_bin": param["max_bin"], + "objective": param.get("objective", "binary:bce"), + } + hetero_sbt_0 = HeteroSecureBoost('sbt_0', train_data=psi_0.outputs['output_data'], num_trees=config_param["num_trees"], + max_bin=config_param["max_bin"], max_depth=config_param["max_depth"], he_param={'kind': 'paillier', 'key_length': 1024}, + objective=config_param["objective"]) + + hetero_sbt_1 = HeteroSecureBoost('sbt_1', + test_data=psi_0.outputs['output_data'], + predict_model_input=hetero_sbt_0.outputs['train_model_output'], + ) + + if config_param['objective'] == 'regression:l2': + evaluation_0 = Evaluation( + 'eval_0', + runtime_roles=['guest'], + input_data=[hetero_sbt_0.outputs['train_data_output']], + default_eval_setting='regression', + ) + + + else: + evaluation_0 = Evaluation( + 'eval_0', + runtime_roles=['guest'], + metrics=['auc'], + input_data=[hetero_sbt_0.outputs['train_data_output']] + ) + + pipeline.add_task(psi_0) + pipeline.add_task(hetero_sbt_0) + pipeline.add_task(hetero_sbt_1) + pipeline.add_task(evaluation_0) + + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + pipeline.compile() + pipeline.fit() + + result_summary = parse_summary_result(pipeline.get_task_info("eval_0").get_output_metric()[0]["data"]) + print(f"result_summary: {result_summary}") + + return pipeline.model_info.job_id + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("BENCHMARK-QUALITY PIPELINE JOB") + parser.add_argument("-c", "--config", type=str, + help="config file", default="../../config.yaml") + parser.add_argument("-p", "--param", type=str, + help="config file for params", default="./sbt_breast_config.yaml") + args = parser.parse_args() + main(args.config, args.param) \ No newline at end of file diff --git a/examples/benchmark_quality/hetero_secureboost/hetero_sbt_benchmark.yaml b/examples/benchmark_quality/hetero_secureboost/hetero_sbt_benchmark.yaml new file mode 100644 index 0000000000..a84fdddf4c --- /dev/null +++ b/examples/benchmark_quality/hetero_secureboost/hetero_sbt_benchmark.yaml @@ -0,0 +1,132 @@ +data: + - file: "examples/data/breast_hetero_guest.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + label_type: int64 + label_name: y + match_id_name: "id" + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_guest + namespace: experiment + role: guest_0 + - file: "examples/data/breast_hetero_host.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: "id" + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_host + namespace: experiment + role: host_0 + - file: "examples/data/default_credit_hetero_guest.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: "id" + match_id_range: 0 + label_type: int64 + label_name: y + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: default_credit_hetero_guest + namespace: experiment + role: guest_0 + - file: "examples/data/default_credit_hetero_host.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: "id" + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: default_credit_hetero_host + namespace: experiment + role: host_0 + - file: "examples/data/student_hetero_guest.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: "id" + match_id_range: 0 + label_type: float64 + label_name: y + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + head: true + partitions: 4 + extend_sid: true + table_name: student_hetero_guest + namespace: experiment + role: guest_0 + - file: "examples/data/student_hetero_host.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: "id" + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + head: true + partition: 4 + extend_sid: true + table_name: student_hetero_host + namespace: experiment + role: host_0 +hetero_sbt-binary-0-breast: + local: + script: "./xgb.py" + conf: "./xgb_breast_config.yaml" + FATE-hetero-sbt: + script: "./pipeline_hetero_sbt.py" + conf: "./sbt_breast_config.yaml" + compare_setting: + relative_tol: 0.01 +hetero_sbt-regression: + local: + script: "./xgb.py" + conf: "./xgb_student_config.yaml" + FATE-hetero-sbt: + script: "./pipeline_hetero_sbt.py" + conf: "./sbt_student_config.yaml" + compare_setting: + relative_tol: 0.05 +hetero_sbt-binary-1-default-credit: + local: + script: "./xgb.py" + conf: "./xgb_default_credit_config.yaml" + FATE-hetero-sbt: + script: "./pipeline_hetero_sbt.py" + conf: "./sbt_default_credit_config.yaml" + compare_setting: + relative_tol: 0.01 + + diff --git a/examples/benchmark_quality/hetero_secureboost/pipeline_hetero_sbt.py b/examples/benchmark_quality/hetero_secureboost/pipeline_hetero_sbt.py new file mode 100644 index 0000000000..963072b847 --- /dev/null +++ b/examples/benchmark_quality/hetero_secureboost/pipeline_hetero_sbt.py @@ -0,0 +1,112 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +from fate_test.utils import parse_summary_result +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import CoordinatedLR, PSI +from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils +from fate_client.pipeline.components.fate import HeteroSecureBoost +from fate_client.pipeline.components.fate import PSI +from fate_client.pipeline.components.fate.evaluation import Evaluation +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.interface import DataWarehouseChannel + + +def main(config="../../config.yaml", param="./sbt_breast_config.yaml", namespace=""): + # obtain config + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + arbiter = parties.arbiter[0] + + if isinstance(param, str): + param = test_utils.JobConfig.load_from_file(param) + + assert isinstance(param, dict) + + guest_data_table = param.get("data_guest") + host_data_table = param.get("data_host") + + guest_train_data = {"name": guest_data_table, "namespace": f"experiment{namespace}"} + host_train_data = {"name": host_data_table, "namespace": f"experiment{namespace}"} + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name=guest_train_data["name"], + namespace=guest_train_data["namespace"])) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name=host_train_data["name"], + namespace=host_train_data["namespace"])) + config_param = { + "num_trees": param["num_trees"], + "max_depth": param["max_depth"], + "max_bin": param["max_bin"], + "objective": param.get("objective", "binary:bce"), + } + hetero_sbt_0 = HeteroSecureBoost('sbt_0', train_data=psi_0.outputs['output_data'], num_trees=config_param["num_trees"], + max_bin=config_param["max_bin"], max_depth=config_param["max_depth"], he_param={'kind': 'paillier', 'key_length': 1024}, + objective=config_param["objective"]) + + if config_param['objective'] == 'regression:l2': + evaluation_0 = Evaluation( + 'eval_0', + runtime_roles=['guest'], + input_data=[hetero_sbt_0.outputs['train_data_output']], + default_eval_setting='regression', + ) + + + else: + evaluation_0 = Evaluation( + 'eval_0', + runtime_roles=['guest'], + metrics=['auc'], + input_data=[hetero_sbt_0.outputs['train_data_output']] + ) + + pipeline.add_task(psi_0) + pipeline.add_task(hetero_sbt_0) + pipeline.add_task(evaluation_0) + + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + pipeline.compile() + pipeline.fit() + + result_summary = parse_summary_result(pipeline.get_task_info("eval_0").get_output_metric()[0]["data"]) + print(f"result_summary: {result_summary}") + + data_summary = {"train": {"guest": guest_train_data["name"], "host": host_train_data["name"]}, + "test": {"guest": guest_train_data["name"], "host": host_train_data["name"]} + } + + return data_summary, result_summary + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("BENCHMARK-QUALITY PIPELINE JOB") + parser.add_argument("-c", "--config", type=str, + help="config file", default="../../config.yaml") + parser.add_argument("-p", "--param", type=str, + help="config file for params", default="./sbt_breast_config.yaml") + args = parser.parse_args() + main(args.config, args.param) \ No newline at end of file diff --git a/examples/benchmark_quality/hetero_secureboost/sbt_breast_config.yaml b/examples/benchmark_quality/hetero_secureboost/sbt_breast_config.yaml new file mode 100644 index 0000000000..f1c35101fa --- /dev/null +++ b/examples/benchmark_quality/hetero_secureboost/sbt_breast_config.yaml @@ -0,0 +1,7 @@ +data_guest: "breast_hetero_guest" +data_host: "breast_hetero_host" +idx: "id" +label_name: "y" +num_trees: 10 +max_depth: 3 +max_bin: 32 \ No newline at end of file diff --git a/examples/benchmark_quality/hetero_secureboost/sbt_default_credit_config.yaml b/examples/benchmark_quality/hetero_secureboost/sbt_default_credit_config.yaml new file mode 100644 index 0000000000..c8a8369898 --- /dev/null +++ b/examples/benchmark_quality/hetero_secureboost/sbt_default_credit_config.yaml @@ -0,0 +1,7 @@ +data_guest: "default_credit_hetero_guest" +data_host: "default_credit_hetero_host" +idx: "id" +label_name: "y" +num_trees: 30 +max_depth: 3 +max_bin: 32 \ No newline at end of file diff --git a/examples/benchmark_quality/hetero_secureboost/sbt_student_config.yaml b/examples/benchmark_quality/hetero_secureboost/sbt_student_config.yaml new file mode 100644 index 0000000000..fbb77a06f0 --- /dev/null +++ b/examples/benchmark_quality/hetero_secureboost/sbt_student_config.yaml @@ -0,0 +1,8 @@ +data_guest: "student_hetero_guest" +data_host: "student_hetero_host" +idx: "id" +label_name: "y" +num_trees: 50 +max_depth: 3 +max_bin: 32 +objective: 'regression:l2' \ No newline at end of file diff --git a/examples/benchmark_quality/hetero_secureboost/xgb.py b/examples/benchmark_quality/hetero_secureboost/xgb.py new file mode 100644 index 0000000000..69f732f41d --- /dev/null +++ b/examples/benchmark_quality/hetero_secureboost/xgb.py @@ -0,0 +1,89 @@ +import argparse +import os +import xgboost as xgb +import pandas as pd +import math +from sklearn.metrics import roc_auc_score, precision_score, accuracy_score, recall_score, roc_curve, mean_absolute_error, mean_squared_error +from fate_client.pipeline.utils.test_utils import JobConfig + + +def main(config="../../config.yaml", param="./xgb_breast_config.yaml"): + + # obtain config + if isinstance(param, str): + param = JobConfig.load_from_file(param) + assert isinstance(param, dict) + data_guest = param["data_guest"] + data_host = param["data_host"] + idx = param["idx"] + label_name = param["label_name"] + max_depth = param['max_depth'] + max_bin = param['max_bin'] + + if isinstance(config, str): + config = JobConfig.load_from_file(config) + print(f"config: {config}") + data_base_dir = config["data_base_dir"] + else: + data_base_dir = config.data_base_dir + + config_param = { + "objective": param.get('objective', 'binary:logistic'), + "learning_rate": param["learning_rate"], + "n_estimators": param["n_estimators"], + "max_bin": max_bin, + "max_depth": max_depth, + "tree_method": "hist" + } + + # prepare data + df_guest = pd.read_csv(os.path.join(data_base_dir, data_guest), index_col=idx) + df_host = pd.read_csv(os.path.join(data_base_dir, data_host), index_col=idx) + df = df_guest.join(df_host, rsuffix="host") + print('data shape is {}'.format(df.shape)) + y = df[label_name] + X = df.drop(label_name, axis=1) + + x_train, x_test, y_train, y_test = X, X, y, y # no split here + + # Train the model + if config_param['objective'] == "reg:squarederror": + model = xgb.XGBRegressor(**config_param) + model.fit(x_train, y_train) + y_pred = model.predict(x_test) + + # compute mse rmse mae + mse = mean_squared_error(y_test, y_pred) + rmse = math.sqrt(mse) + mae = mean_absolute_error(y_test, y_pred) + return {}, {"rmse": rmse} + else: + model = xgb.XGBClassifier(**config_param) + model.fit(x_train, y_train) + y_pred = model.predict(x_test) + y_prob = model.predict_proba(x_test)[:, 1] + + try: + auc_score = roc_auc_score(y_test, y_prob) + except BaseException: + print("no auc score available") + return + + recall = recall_score(y_test, y_pred, average="macro") + pr = precision_score(y_test, y_pred, average="macro") + acc = accuracy_score(y_test, y_pred) + fpr, tpr, thresholds = roc_curve(y_test, y_prob) + + ks = max(tpr - fpr) + result = {"auc": auc_score, "recall": recall, "precision": pr, "accuracy": acc} + print(result) + return {}, result + +if __name__ == "__main__": + parser = argparse.ArgumentParser("BENCHMARK-QUALITY XGBoost JOB") + parser.add_argument("-c", "--config", type=str, + help="config file", default="../../config.yaml") + parser.add_argument("-p", "--param", type=str, + help="config file for params", default="./xgb_breast_config.yaml") + args = parser.parse_args() + main(args.config, args.param) diff --git a/examples/benchmark_quality/hetero_secureboost/xgb_breast_config.yaml b/examples/benchmark_quality/hetero_secureboost/xgb_breast_config.yaml new file mode 100644 index 0000000000..6e99c53364 --- /dev/null +++ b/examples/benchmark_quality/hetero_secureboost/xgb_breast_config.yaml @@ -0,0 +1,8 @@ +data_guest: "examples/data/breast_hetero_guest.csv" +data_host: "examples/data/breast_hetero_host.csv" +idx: "id" +label_name: "y" +max_depth: 3 +max_bin: 32 +learning_rate: 0.3 +n_estimators: 30 diff --git a/examples/benchmark_quality/hetero_secureboost/xgb_default_credit_config.yaml b/examples/benchmark_quality/hetero_secureboost/xgb_default_credit_config.yaml new file mode 100644 index 0000000000..587b7c1ac6 --- /dev/null +++ b/examples/benchmark_quality/hetero_secureboost/xgb_default_credit_config.yaml @@ -0,0 +1,8 @@ +data_guest: "examples/data/default_credit_hetero_guest.csv" +data_host: "examples/data/default_credit_hetero_host.csv" +idx: "id" +label_name: "y" +max_depth: 3 +max_bin: 32 +learning_rate: 0.3 +n_estimators: 30 diff --git a/examples/benchmark_quality/hetero_secureboost/xgb_student_config.yaml b/examples/benchmark_quality/hetero_secureboost/xgb_student_config.yaml new file mode 100644 index 0000000000..f0923182ce --- /dev/null +++ b/examples/benchmark_quality/hetero_secureboost/xgb_student_config.yaml @@ -0,0 +1,9 @@ +data_guest: "examples/data/student_hetero_guest.csv" +data_host: "examples/data/student_hetero_host.csv" +idx: "id" +label_name: "y" +max_depth: 3 +max_bin: 32 +learning_rate: 0.3 +n_estimators: 50 +objective: "reg:squarederror" diff --git a/examples/benchmark_quality/homo_nn/fed_nn_breast_config.yaml b/examples/benchmark_quality/homo_nn/fed_nn_breast_config.yaml new file mode 100644 index 0000000000..44c86553bd --- /dev/null +++ b/examples/benchmark_quality/homo_nn/fed_nn_breast_config.yaml @@ -0,0 +1,10 @@ +data_guest: "breast_homo_guest" +data_host: "breast_homo_host" +data_test: "breast_homo_test" +idx: "id" +label_name: "y" +in_feat: 30 +out_feat: 10 +lr: 0.01 +batch_size: 32 +epochs: 50 \ No newline at end of file diff --git a/examples/benchmark_quality/homo_nn/fed_nn_default_credit_config.yaml b/examples/benchmark_quality/homo_nn/fed_nn_default_credit_config.yaml new file mode 100644 index 0000000000..cc88f6f683 --- /dev/null +++ b/examples/benchmark_quality/homo_nn/fed_nn_default_credit_config.yaml @@ -0,0 +1,10 @@ +data_guest: "default_credit_homo_guest" +data_host: "default_credit_homo_host" +data_test: "default_credit_homo_test" +idx: "id" +label_name: "y" +in_feat: 23 +out_feat: 10 +lr: 0.01 +batch_size: 32 +epochs: 50 \ No newline at end of file diff --git a/examples/benchmark_quality/homo_nn/fed_nn_give_credit_config.yaml b/examples/benchmark_quality/homo_nn/fed_nn_give_credit_config.yaml new file mode 100644 index 0000000000..1a0a6cb915 --- /dev/null +++ b/examples/benchmark_quality/homo_nn/fed_nn_give_credit_config.yaml @@ -0,0 +1,10 @@ +data_guest: "give_credit_homo_guest" +data_host: "give_credit_homo_host" +data_test: "give_credit_homo_test" +idx: "id" +label_name: "y" +in_feat: 10 +out_feat: 4 +lr: 0.001 +batch_size: 128 +epochs: 100 \ No newline at end of file diff --git a/examples/benchmark_quality/homo_nn/fed_nn_student_config.yaml b/examples/benchmark_quality/homo_nn/fed_nn_student_config.yaml new file mode 100644 index 0000000000..b234ac3e6f --- /dev/null +++ b/examples/benchmark_quality/homo_nn/fed_nn_student_config.yaml @@ -0,0 +1,10 @@ +data_guest: "student_homo_guest" +data_host: "student_homo_host" +data_test: "student_homo_test" +idx: "id" +label_name: "y" +in_feat: 13 +out_feat: 8 +lr: 0.01 +batch_size: 1024 +epochs: 60 \ No newline at end of file diff --git a/examples/benchmark_quality/homo_nn/fed_nn_vehicle_config.yaml b/examples/benchmark_quality/homo_nn/fed_nn_vehicle_config.yaml new file mode 100644 index 0000000000..01417a7bec --- /dev/null +++ b/examples/benchmark_quality/homo_nn/fed_nn_vehicle_config.yaml @@ -0,0 +1,11 @@ +data_guest: "vehicle_scale_homo_guest" +data_host: "vehicle_scale_homo_host" +data_test: "vehicle_scale_homo_test" +idx: "id" +label_name: "y" +in_feat: 18 +out_feat: 8 +lr: 0.01 +batch_size: 32 +epochs: 110 +class_num: 4 \ No newline at end of file diff --git a/examples/benchmark_quality/homo_nn/homo_nn_benchmark.yaml b/examples/benchmark_quality/homo_nn/homo_nn_benchmark.yaml new file mode 100644 index 0000000000..ecd2c3de8c --- /dev/null +++ b/examples/benchmark_quality/homo_nn/homo_nn_benchmark.yaml @@ -0,0 +1,216 @@ +data: + - file: "examples/data/breast_homo_guest.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + label_type: int64 + label_name: y + match_id_name: "id" + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_homo_guest + namespace: experiment + role: guest_0 + - file: "examples/data/breast_homo_host.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: "id" + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_homo_host + namespace: experiment + role: host_0 + - file: "examples/data/default_credit_homo_guest.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: "id" + match_id_range: 0 + label_type: int64 + label_name: y + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: default_credit_homo_guest + namespace: experiment + role: guest_0 + - file: "examples/data/default_credit_homo_host_1.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: "id" + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: default_credit_homo_host + namespace: experiment + role: host_0 + - file: "examples/data/give_credit_homo_guest.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: "id" + match_id_range: 0 + label_type: int64 + label_name: y + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: give_credit_homo_guest + namespace: experiment + role: guest_0 + - file: "examples/data/give_credit_homo_host.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: "id" + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + head: true + partitions: 4 + extend_sid: true + table_name: give_credit_homo_host + namespace: experiment + role: host_0 + - file: "examples/data/vehicle_scale_homo_guest.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: "id" + match_id_range: 0 + label_type: int64 + label_name: y + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + head: true + partitions: 4 + extend_sid: true + table_name: vehicle_scale_homo_guest + namespace: experiment + role: guest_0 + - file: "examples/data/vehicle_scale_homo_host.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: "id" + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + head: true + partition: 4 + extend_sid: true + table_name: vehicle_scale_homo_host + namespace: experiment + role: host_0 + - file: "examples/data/student_homo_guest.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: "id" + match_id_range: 0 + label_type: float64 + label_name: y + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + head: true + partitions: 4 + extend_sid: true + table_name: student_homo_guest + namespace: experiment + role: guest_0 + - file: "examples/data/student_homo_host.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: "id" + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + head: true + partition: 4 + extend_sid: true + table_name: student_homo_host + namespace: experiment + role: host_0 +homo_nn-binary-0-breast: + local: + script: "./local_nn_binary.py" + conf: "./local_nn_breast_config.yaml" + FATE-homo-nn: + script: "./pipeline_nn_binary.py" + conf: "./fed_nn_breast_config.yaml" + compare_setting: + relative_tol: 0.01 +homo_nn-binary-1-default-credit: + local: + script: "./local_nn_binary.py" + conf: "./local_nn_default_credit_config.yaml" + FATE-homo-nn: + script: "./pipeline_nn_binary.py" + conf: "./fed_nn_default_credit_config.yaml" + compare_setting: + relative_tol: 0.01 +homo_nn-binary-2-give-credit: + local: + script: "./local_nn_binary.py" + conf: "./local_nn_give_credit_config.yaml" + FATE-homo-nn: + script: "./pipeline_nn_binary.py" + conf: "./fed_nn_give_credit_config.yaml" + compare_setting: + relative_tol: 0.01 +homo_nn-regression-student: + local: + script: "./local_nn_regression.py" + conf: "./local_nn_student_config.yaml" + FATE-homo-nn: + script: "./pipeline_nn_regression.py" + conf: "./fed_nn_student_config.yaml" + compare_setting: + relative_tol: 0.05 +homo_nn-multi-vehicle: + local: + script: "./local_nn_multi.py" + conf: "./local_nn_vehicle_config.yaml" + FATE-homo-nn: + script: "./pipeline_nn_multi.py" + conf: "./fed_nn_vehicle_config.yaml" + compare_setting: + relative_tol: 0.05 \ No newline at end of file diff --git a/examples/benchmark_quality/homo_nn/local_nn_binary.py b/examples/benchmark_quality/homo_nn/local_nn_binary.py new file mode 100644 index 0000000000..7ac08241a0 --- /dev/null +++ b/examples/benchmark_quality/homo_nn/local_nn_binary.py @@ -0,0 +1,103 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import torch +from torch import nn, optim +from torch.nn import Sequential +import pandas as pd +import argparse +from fate_client.pipeline.utils import test_utils +from torch.utils.data import DataLoader, TensorDataset +from fate_client.pipeline.utils.test_utils import JobConfig +from sklearn.metrics import roc_auc_score +import tqdm + + +seed = 114514 +torch.manual_seed(seed) +torch.cuda.manual_seed_all(seed) + +def main(config="../../config.yaml", param="./local_nn_breast_config.yaml", namespace=""): + # obtain config + if isinstance(config, str): + config = test_utils.load_job_config(config) + + if isinstance(param, str): + param = test_utils.JobConfig.load_from_file(param) + + if isinstance(config, str): + config = JobConfig.load_from_file(config) + print(f"config: {config}") + data_base_dir = config["data_base_dir"] + else: + data_base_dir = config.data_base_dir + + assert isinstance(param, dict) + + epochs = param.get('epochs') + batch_size = param.get('batch_size') + in_feat = param.get('in_feat') + out_feat = param.get('out_feat') + lr = param.get('lr') + idx = param.get('idx') + label_name = param.get('label_name') + + guest_data_table = param.get("data_guest") + host_data_table = param.get("data_host") + + guest_data = pd.read_csv(os.path.join(data_base_dir, guest_data_table), index_col=idx) + host_data = pd.read_csv(os.path.join(data_base_dir, host_data_table), index_col=idx) + + X = pd.concat([guest_data, host_data], ignore_index=True) + y = X.pop(label_name).values + + X = torch.tensor(X.values, dtype=torch.float32) + y = torch.tensor(y, dtype=torch.float32).view(-1, 1) + dataset = TensorDataset(X, y) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + model = Sequential( + nn.Linear(in_feat, out_feat), + nn.ReLU(), + nn.Linear(out_feat ,1), + nn.Sigmoid() + ) + criterion = nn.BCELoss() + optimizer = optim.Adam(model.parameters(), lr=lr) + + for epoch in tqdm.tqdm(range(epochs)): + for batch_X, batch_y in dataloader: + optimizer.zero_grad() + outputs = model(batch_X) + loss = criterion(outputs, batch_y) + loss.backward() + optimizer.step() + + with torch.no_grad(): + y_train_pred = model(X).numpy() + + auc_train = roc_auc_score(y.numpy(), y_train_pred) + print('auc is {}'.format(auc_train)) + + return {}, {'auc': auc_train} + +if __name__ == "__main__": + parser = argparse.ArgumentParser("BENCHMARK-QUALITY PIPELINE JOB") + parser.add_argument("-c", "--config", type=str, + help="config file", default="../../config.yaml") + parser.add_argument("-p", "--param", type=str, + help="config file for params", default="./local_nn_breast_config.yaml") + args = parser.parse_args() + main(args.config, args.param) \ No newline at end of file diff --git a/examples/benchmark_quality/homo_nn/local_nn_breast_config.yaml b/examples/benchmark_quality/homo_nn/local_nn_breast_config.yaml new file mode 100644 index 0000000000..c5a7f09b68 --- /dev/null +++ b/examples/benchmark_quality/homo_nn/local_nn_breast_config.yaml @@ -0,0 +1,10 @@ +data_guest: "examples/data/breast_homo_guest.csv" +data_host: "examples/data/breast_homo_host.csv" +data_test: "examples/data/breast_homo_test.csv" +idx: "id" +label_name: "y" +in_feat: 30 +out_feat: 10 +lr: 0.01 +batch_size: 32 +epochs: 50 \ No newline at end of file diff --git a/examples/benchmark_quality/homo_nn/local_nn_default_credit_config.yaml b/examples/benchmark_quality/homo_nn/local_nn_default_credit_config.yaml new file mode 100644 index 0000000000..4bbdd7059c --- /dev/null +++ b/examples/benchmark_quality/homo_nn/local_nn_default_credit_config.yaml @@ -0,0 +1,10 @@ +data_guest: "examples/data/default_credit_homo_guest.csv" +data_host: "examples/data/default_credit_homo_host_1.csv" +data_test: "examples/data/default_credit_homo_test.csv" +idx: "id" +label_name: "y" +in_feat: 23 +out_feat: 10 +lr: 0.01 +batch_size: 32 +epochs: 50 \ No newline at end of file diff --git a/examples/benchmark_quality/homo_nn/local_nn_give_credit_config.yaml b/examples/benchmark_quality/homo_nn/local_nn_give_credit_config.yaml new file mode 100644 index 0000000000..7bdbd994a6 --- /dev/null +++ b/examples/benchmark_quality/homo_nn/local_nn_give_credit_config.yaml @@ -0,0 +1,10 @@ +data_guest: "examples/data/give_credit_homo_guest.csv" +data_host: "examples/data/give_credit_homo_host.csv" +data_test: "examples/data/give_credit_homo_test.csv" +idx: "id" +label_name: "y" +in_feat: 10 +out_feat: 4 +lr: 0.001 +batch_size: 128 +epochs: 100 \ No newline at end of file diff --git a/examples/benchmark_quality/homo_nn/local_nn_multi.py b/examples/benchmark_quality/homo_nn/local_nn_multi.py new file mode 100644 index 0000000000..4b83d50051 --- /dev/null +++ b/examples/benchmark_quality/homo_nn/local_nn_multi.py @@ -0,0 +1,103 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import torch +from torch import nn, optim +import pandas as pd +import argparse +from fate_client.pipeline.utils import test_utils +from torch.utils.data import DataLoader, TensorDataset +from fate_client.pipeline.utils.test_utils import JobConfig +from fate.ml.nn.model_zoo.multi_model import Multi +from fate.ml.evaluation.classification import MultiAccuracy, MultiPrecision, MultiRecall +import tqdm + +seed = 114514 +torch.manual_seed(seed) +torch.cuda.manual_seed_all(seed) + +def main(config="../../config.yaml", param="", namespace=""): + # obtain config + if isinstance(config, str): + config = test_utils.load_job_config(config) + + if isinstance(param, str): + param = test_utils.JobConfig.load_from_file(param) + + if isinstance(config, str): + config = JobConfig.load_from_file(config) + print(f"config: {config}") + data_base_dir = config["data_base_dir"] + else: + data_base_dir = config.data_base_dir + + assert isinstance(param, dict) + + epochs = param.get('epochs') + batch_size = param.get('batch_size') + in_feat = param.get('in_feat') + out_feat = param.get('out_feat') + lr = param.get('lr') + idx = param.get('idx') + label_name = param.get('label_name') + + guest_data_table = param.get("data_guest") + host_data_table = param.get("data_host") + + guest_data = pd.read_csv(os.path.join(data_base_dir, guest_data_table), index_col=idx) + host_data = pd.read_csv(os.path.join(data_base_dir, host_data_table), index_col=idx) + + X = pd.concat([guest_data, host_data], ignore_index=True) + y = X.pop(label_name).values + + X = torch.tensor(X.values, dtype=torch.float32) + y = torch.tensor(y, dtype=torch.float32).flatten().long() + dataset = TensorDataset(X, y) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + model = Multi(in_feat, class_num=out_feat) + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=lr) + + for epoch in tqdm.tqdm(range(epochs)): + for batch_X, batch_y in dataloader: + optimizer.zero_grad() + outputs = model(batch_X) + loss = criterion(outputs, batch_y) + loss.backward() + optimizer.step() + + with torch.no_grad(): + y_train_pred = model(X).numpy() + + # compute accuracy + acc = MultiAccuracy()(y_train_pred, y).get_raw_data() + # compute precision + precision = MultiPrecision()(y_train_pred, y).get_raw_data() + # compute recall + recall = MultiRecall()(y_train_pred, y).get_raw_data() + + result = {"multi_accuracy": float(acc)} + + return {}, result + +if __name__ == "__main__": + parser = argparse.ArgumentParser("BENCHMARK-QUALITY PIPELINE JOB") + parser.add_argument("-c", "--config", type=str, + help="config file", default="../../config.yaml") + parser.add_argument("-p", "--param", type=str, + help="config file for params", default="") + args = parser.parse_args() + main(args.config, args.param) \ No newline at end of file diff --git a/examples/benchmark_quality/homo_nn/local_nn_regression.py b/examples/benchmark_quality/homo_nn/local_nn_regression.py new file mode 100644 index 0000000000..90748ca5bc --- /dev/null +++ b/examples/benchmark_quality/homo_nn/local_nn_regression.py @@ -0,0 +1,101 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import torch +from torch import nn, optim +from torch.nn import Sequential +import pandas as pd +import argparse +from fate_client.pipeline.utils import test_utils +from torch.utils.data import DataLoader, TensorDataset +from fate_client.pipeline.utils.test_utils import JobConfig +from sklearn.metrics import mean_squared_error, mean_absolute_error +import tqdm + +seed = 114514 +torch.manual_seed(seed) +torch.cuda.manual_seed_all(seed) + + +def main(config="../../config.yaml", param="", namespace=""): + # obtain config + if isinstance(config, str): + config = test_utils.load_job_config(config) + + if isinstance(param, str): + param = test_utils.JobConfig.load_from_file(param) + + if isinstance(config, str): + config = JobConfig.load_from_file(config) + print(f"config: {config}") + data_base_dir = config["data_base_dir"] + else: + data_base_dir = config.data_base_dir + + assert isinstance(param, dict) + + epochs = param.get('epochs') + batch_size = param.get('batch_size') + in_feat = param.get('in_feat') + out_feat = param.get('out_feat') + lr = param.get('lr') + idx = param.get('idx') + label_name = param.get('label_name') + + guest_data_table = param.get("data_guest") + host_data_table = param.get("data_host") + + guest_data = pd.read_csv(os.path.join(data_base_dir, guest_data_table), index_col=idx) + host_data = pd.read_csv(os.path.join(data_base_dir, host_data_table), index_col=idx) + + X = pd.concat([guest_data, host_data], ignore_index=True) + y = X.pop(label_name).values + + X = torch.tensor(X.values, dtype=torch.float32) + y = torch.tensor(y, dtype=torch.float32).view(-1, 1) + dataset = TensorDataset(X, y) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + model = Sequential( + nn.Linear(in_feat, out_feat), + nn.ReLU(), + nn.Linear(out_feat ,1) + ) + criterion = nn.MSELoss() + optimizer = optim.Adam(model.parameters(), lr=lr) + + for epoch in tqdm.tqdm(range(epochs)): + for batch_X, batch_y in dataloader: + optimizer.zero_grad() + outputs = model(batch_X) + loss = criterion(outputs, batch_y) + loss.backward() + optimizer.step() + + with torch.no_grad(): + y_train_pred = model(X).numpy() + + mse = mean_squared_error(y, y_train_pred) + rmse = mse ** 0.5 + return {}, {'rmse': rmse} + +if __name__ == "__main__": + parser = argparse.ArgumentParser("BENCHMARK-QUALITY PIPELINE JOB") + parser.add_argument("-c", "--config", type=str, + help="config file", default="../../config.yaml") + parser.add_argument("-p", "--param", type=str, + help="config file for params", default="") + args = parser.parse_args() + main(args.config, args.param) \ No newline at end of file diff --git a/examples/benchmark_quality/homo_nn/local_nn_student_config.yaml b/examples/benchmark_quality/homo_nn/local_nn_student_config.yaml new file mode 100644 index 0000000000..538ab69fb2 --- /dev/null +++ b/examples/benchmark_quality/homo_nn/local_nn_student_config.yaml @@ -0,0 +1,10 @@ +data_guest: "examples/data/student_homo_guest.csv" +data_host: "examples/data/student_homo_host.csv" +data_test: "examples/data/student_homo_test.csv" +idx: "id" +label_name: "y" +in_feat: 13 +out_feat: 8 +lr: 0.01 +batch_size: 1024 +epochs: 100 \ No newline at end of file diff --git a/examples/benchmark_quality/homo_nn/local_nn_vehicle_config.yaml b/examples/benchmark_quality/homo_nn/local_nn_vehicle_config.yaml new file mode 100644 index 0000000000..4c53e4aa6e --- /dev/null +++ b/examples/benchmark_quality/homo_nn/local_nn_vehicle_config.yaml @@ -0,0 +1,11 @@ +data_guest: "examples/data/vehicle_scale_homo_guest.csv" +data_host: "examples/data/vehicle_scale_homo_host.csv" +data_test: "examples/data/vehicle_scale_homo_test.csv" +idx: "id" +label_name: "y" +in_feat: 18 +out_feat: 8 +lr: 0.01 +batch_size: 32 +epochs: 100 +class_num: 4 \ No newline at end of file diff --git a/examples/benchmark_quality/homo_nn/pipeline_nn_binary.py b/examples/benchmark_quality/homo_nn/pipeline_nn_binary.py new file mode 100644 index 0000000000..87692baa16 --- /dev/null +++ b/examples/benchmark_quality/homo_nn/pipeline_nn_binary.py @@ -0,0 +1,128 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +from fate_test.utils import parse_summary_result +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils +from fate_client.pipeline.components.fate.evaluation import Evaluation +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate.nn.torch import nn, optim +from fate_client.pipeline.components.fate.nn.torch.base import Sequential +from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_default_runner +from fate_client.pipeline.components.fate.nn.algo_params import TrainingArguments, FedAVGArguments + + + +def main(config="../../config.yaml", param="./fed_nn_breast_config.yaml", namespace=""): + # obtain config + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + arbiter = parties.arbiter[0] + + if isinstance(param, str): + param = test_utils.JobConfig.load_from_file(param) + + assert isinstance(param, dict) + + epochs = param.get('epochs') + batch_size = param.get('batch_size') + in_feat = param.get('in_feat') + out_feat = param.get('out_feat') + lr = param.get('lr') + + guest_data_table = param.get("data_guest") + host_data_table = param.get("data_host") + test_data_table = param.get("data_test") + + guest_train_data = {"name": guest_data_table, "namespace": f"experiment{namespace}"} + host_train_data = {"name": host_data_table, "namespace": f"experiment{namespace}"} + test_data = {"name": test_data_table, "namespace": f"experiment{namespace}"} + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + + conf = get_config_of_default_runner( + algo='fedavg', + model=Sequential( + nn.Linear(in_feat, out_feat), + nn.ReLU(), + nn.Linear(out_feat ,1), + nn.Sigmoid() + ), + loss=nn.BCELoss(), + optimizer=optim.Adam(lr=lr), + training_args=TrainingArguments(num_train_epochs=epochs, per_device_train_batch_size=batch_size, seed=114514), + fed_args=FedAVGArguments(), + task_type='binary' + ) + + + homo_nn_0 = HomoNN( + 'nn_0', + runner_conf=conf + ) + + homo_nn_1 = HomoNN( + 'nn_1', + test_data=DataWarehouseChannel(name=test_data["name"], namespace=test_data["namespace"]), + predict_model_input=homo_nn_0.outputs['train_model_output'] + ) + + homo_nn_0.guest.component_setting(train_data=DataWarehouseChannel(name=guest_train_data["name"], namespace=guest_train_data["namespace"])) + homo_nn_0.hosts[0].component_setting(train_data=DataWarehouseChannel(name=host_train_data["name"], namespace=host_train_data["namespace"])) + + evaluation_0 = Evaluation( + 'eval_0', + runtime_roles=['guest'], + metrics=['auc'], + input_data=[homo_nn_1.outputs['predict_data_output'], homo_nn_0.outputs['train_data_output']] + ) + + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + pipeline.add_task(homo_nn_0) + pipeline.add_task(homo_nn_1) + pipeline.add_task(evaluation_0) + + pipeline.compile() + pipeline.fit() + + print(pipeline.get_task_info("eval_0").get_output_metric()) + result_summary = parse_summary_result(pipeline.get_task_info("eval_0").get_output_metric()[0]["data"]) + print(f"result_summary: {result_summary}") + + data_summary = {"train": {"guest": guest_train_data["name"], "host": host_train_data["name"]}, + "test": {"guest": guest_train_data["name"], "host": host_train_data["name"]} + } + + return data_summary, result_summary + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("BENCHMARK-QUALITY PIPELINE JOB") + parser.add_argument("-c", "--config", type=str, + help="config file", default="../../config.yaml") + parser.add_argument("-p", "--param", type=str, + help="config file for params", default="./fed_nn_breast_config.yaml") + args = parser.parse_args() + main(args.config, args.param) \ No newline at end of file diff --git a/examples/benchmark_quality/homo_nn/pipeline_nn_multi.py b/examples/benchmark_quality/homo_nn/pipeline_nn_multi.py new file mode 100644 index 0000000000..342e237990 --- /dev/null +++ b/examples/benchmark_quality/homo_nn/pipeline_nn_multi.py @@ -0,0 +1,121 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +from fate_test.utils import parse_summary_result +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils +from fate_client.pipeline.components.fate.evaluation import Evaluation +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate.nn.torch import nn, optim +from fate_client.pipeline.components.fate.nn.loader import ModelLoader +from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_default_runner +from fate_client.pipeline.components.fate.nn.algo_params import TrainingArguments, FedAVGArguments + + + +def main(config="../../config.yaml", param="", namespace=""): + # obtain config + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + arbiter = parties.arbiter[0] + + if isinstance(param, str): + param = test_utils.JobConfig.load_from_file(param) + + assert isinstance(param, dict) + + epochs = param.get('epochs') + batch_size = param.get('batch_size') + in_feat = param.get('in_feat') + out_feat = param.get('out_feat') + class_num = param.get('class_num') + lr = param.get('lr') + + guest_data_table = param.get("data_guest") + host_data_table = param.get("data_host") + test_data_table = param.get("data_test") + + guest_train_data = {"name": guest_data_table, "namespace": f"experiment{namespace}"} + host_train_data = {"name": host_data_table, "namespace": f"experiment{namespace}"} + test_data = {"name": test_data_table, "namespace": f"experiment{namespace}"} + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + + conf = get_config_of_default_runner( + algo='fedavg', + model=ModelLoader('multi_model', 'Multi', feat=in_feat, class_num=class_num), + loss=nn.CrossEntropyLoss(), + optimizer=optim.Adam(lr=lr), + training_args=TrainingArguments(num_train_epochs=epochs, per_device_train_batch_size=batch_size, seed=114514), + fed_args=FedAVGArguments(), + task_type='multi' + ) + + + homo_nn_0 = HomoNN( + 'nn_0', + runner_conf=conf + ) + + homo_nn_1 = HomoNN( + 'nn_1', + test_data=DataWarehouseChannel(name=test_data["name"], namespace=test_data["namespace"]), + predict_model_input=homo_nn_0.outputs['train_model_output'] + ) + + homo_nn_0.guest.component_setting(train_data=DataWarehouseChannel(name=guest_train_data["name"], namespace=guest_train_data["namespace"])) + homo_nn_0.hosts[0].component_setting(train_data=DataWarehouseChannel(name=host_train_data["name"], namespace=host_train_data["namespace"])) + + evaluation_0 = Evaluation( + 'eval_0', + default_eval_setting='multi', + runtime_roles=['guest'], + input_data=[homo_nn_1.outputs['predict_data_output'], homo_nn_0.outputs['train_data_output']] + ) + + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + pipeline.add_task(homo_nn_0) + pipeline.add_task(homo_nn_1) + pipeline.add_task(evaluation_0) + + pipeline.compile() + pipeline.fit() + + result_summary = parse_summary_result(pipeline.get_task_info("eval_0").get_output_metric()[0]["data"]) + data_summary = {"train": {"guest": guest_train_data["name"], "host": host_train_data["name"]}, + "test": {"guest": guest_train_data["name"], "host": host_train_data["name"]} + } + + return data_summary, result_summary + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("BENCHMARK-QUALITY PIPELINE JOB") + parser.add_argument("-c", "--config", type=str, + help="config file", default="../../config.yaml") + parser.add_argument("-p", "--param", type=str, + help="config file for params", default="") + args = parser.parse_args() + main(args.config, args.param) \ No newline at end of file diff --git a/examples/benchmark_quality/homo_nn/pipeline_nn_regression.py b/examples/benchmark_quality/homo_nn/pipeline_nn_regression.py new file mode 100644 index 0000000000..32e17f41de --- /dev/null +++ b/examples/benchmark_quality/homo_nn/pipeline_nn_regression.py @@ -0,0 +1,127 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +from fate_test.utils import parse_summary_result +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils +from fate_client.pipeline.components.fate.evaluation import Evaluation +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate.nn.torch import nn, optim +from fate_client.pipeline.components.fate.nn.torch.base import Sequential +from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_default_runner +from fate_client.pipeline.components.fate.nn.algo_params import TrainingArguments, FedAVGArguments + + + +def main(config="../../config.yaml", param="", namespace=""): + # obtain config + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + arbiter = parties.arbiter[0] + + if isinstance(param, str): + param = test_utils.JobConfig.load_from_file(param) + + assert isinstance(param, dict) + + epochs = param.get('epochs') + batch_size = param.get('batch_size') + in_feat = param.get('in_feat') + out_feat = param.get('out_feat') + lr = param.get('lr') + + guest_data_table = param.get("data_guest") + host_data_table = param.get("data_host") + test_data_table = param.get("data_test") + + guest_train_data = {"name": guest_data_table, "namespace": f"experiment{namespace}"} + host_train_data = {"name": host_data_table, "namespace": f"experiment{namespace}"} + test_data = {"name": test_data_table, "namespace": f"experiment{namespace}"} + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + + conf = get_config_of_default_runner( + algo='fedavg', + model=Sequential( + nn.Linear(in_feat, out_feat), + nn.ReLU(), + nn.Linear(out_feat ,1) + ), + loss=nn.MSELoss(), + optimizer=optim.Adam(lr=lr), + training_args=TrainingArguments(num_train_epochs=epochs, per_device_train_batch_size=batch_size, seed=114514), + fed_args=FedAVGArguments(), + task_type='regression' + ) + + + homo_nn_0 = HomoNN( + 'nn_0', + runner_conf=conf + ) + + homo_nn_1 = HomoNN( + 'nn_1', + test_data=DataWarehouseChannel(name=test_data["name"], namespace=test_data["namespace"]), + predict_model_input=homo_nn_0.outputs['train_model_output'] + ) + + homo_nn_0.guest.component_setting(train_data=DataWarehouseChannel(name=guest_train_data["name"], namespace=guest_train_data["namespace"])) + homo_nn_0.hosts[0].component_setting(train_data=DataWarehouseChannel(name=host_train_data["name"], namespace=host_train_data["namespace"])) + + evaluation_0 = Evaluation( + 'eval_0', + default_eval_setting='regression', + runtime_roles=['guest'], + input_data=[homo_nn_1.outputs['predict_data_output'], homo_nn_0.outputs['train_data_output']] + ) + + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + pipeline.add_task(homo_nn_0) + pipeline.add_task(homo_nn_1) + pipeline.add_task(evaluation_0) + + pipeline.compile() + pipeline.fit() + + print(pipeline.get_task_info("eval_0").get_output_metric()) + result_summary = parse_summary_result(pipeline.get_task_info("eval_0").get_output_metric()[0]["data"]) + print(f"result_summary: {result_summary}") + + data_summary = {"train": {"guest": guest_train_data["name"], "host": host_train_data["name"]}, + "test": {"guest": guest_train_data["name"], "host": host_train_data["name"]} + } + + return data_summary, result_summary + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("BENCHMARK-QUALITY PIPELINE JOB") + parser.add_argument("-c", "--config", type=str, + help="config file", default="../../config.yaml") + parser.add_argument("-p", "--param", type=str, + help="config file for params", default="") + args = parser.parse_args() + main(args.config, args.param) \ No newline at end of file diff --git a/examples/benchmark_quality/linr/fate-linr.py b/examples/benchmark_quality/linr/fate-linr.py new file mode 100644 index 0000000000..d35c19ad38 --- /dev/null +++ b/examples/benchmark_quality/linr/fate-linr.py @@ -0,0 +1,119 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import CoordinatedLinR, PSI +from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils +from fate_test.utils import parse_summary_result + + +def main(config="../../config.yaml", param="./linr_config.yaml", namespace=""): + # obtain config + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + arbiter = parties.arbiter[0] + + if isinstance(param, str): + param = test_utils.JobConfig.load_from_file(param) + + assert isinstance(param, dict) + + guest_train_data = {"name": "motor_hetero_guest", "namespace": f"experiment{namespace}"} + host_train_data = {"name": "motor_hetero_host", "namespace": f"experiment{namespace}"} + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name=guest_train_data["name"], + namespace=guest_train_data["namespace"])) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name=host_train_data["name"], + namespace=host_train_data["namespace"])) + + linr_param = { + } + + config_param = { + "epochs": param["epochs"], + "learning_rate_scheduler": param["learning_rate_scheduler"], + "optimizer": param["optimizer"], + "batch_size": param["batch_size"], + "early_stop": param["early_stop"], + "init_param": param["init_param"], + "tol": 1e-5 + } + linr_param.update(config_param) + linr_0 = CoordinatedLinR("linr_0", + train_data=psi_0.outputs["output_data"], + **config_param) + """linr_1 = CoordinatedLinR("linr_1", + test_data=psi_0.outputs["output_data"], + input_model=linr_0.outputs["output_model"])""" + + evaluation_0 = Evaluation("evaluation_0", + runtime_roles=["guest"], + metrics=["r2_score", + "mse", + "rmse"], + input_data=linr_0.outputs["train_output_data"]) + + pipeline.add_task(psi_0) + pipeline.add_task(linr_0) + # pipeline.add_task(linr_1) + pipeline.add_task(evaluation_0) + + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + pipeline.compile() + print(pipeline.get_dag()) + pipeline.fit() + + """linr_0_data = pipeline.get_task_info("linr_0").get_output_data()["train_output_data"] + linr_1_data = pipeline.get_task_info("linr_1").get_output_data()["test_output_data"] + linr_0_score = extract_data(linr_0_data, "predict_result") + linr_0_label = extract_data(linr_0_data, "motor_speed") + linr_1_score = extract_data(linr_1_data, "predict_result") + linr_1_label = extract_data(linr_1_data, "motor_speed") + linr_0_score_label = extract_data(linr_0_data, "predict_result", keep_id=True) + linr_1_score_label = extract_data(linr_1_data, "predict_result", keep_id=True)""" + + result_summary = parse_summary_result(pipeline.get_task_info("evaluation_0").get_output_metric()[0]["data"]) + print(f"result_summary") + + data_summary = {"train": {"guest": guest_train_data["name"], "host": host_train_data["name"]}, + "test": {"guest": guest_train_data["name"], "host": host_train_data["name"]} + } + + return data_summary, result_summary + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("BENCHMARK-QUALITY PIPELINE JOB") + parser.add_argument("-c", "--config", type=str, + help="config file", default="../../config.yaml") + parser.add_argument("-p", "--param", type=str, + help="config file for params", default="./breast_config.yaml") + args = parser.parse_args() + main(args.config, args.param) diff --git a/examples/benchmark_quality/linr/hetero_linr_benchmark.yaml b/examples/benchmark_quality/linr/hetero_linr_benchmark.yaml new file mode 100644 index 0000000000..d3089cc26b --- /dev/null +++ b/examples/benchmark_quality/linr/hetero_linr_benchmark.yaml @@ -0,0 +1,45 @@ +data: + - file: examples/data/motor_hetero_guest.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + label_type: float64 + label_name: motor_speed + match_id_name: "idx" + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: motor_hetero_guest + namespace: experiment + role: guest_0 + - file: examples/data/motor_hetero_host.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: "idx" + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: motor_hetero_host + namespace: experiment + role: host_0 + +hetero_linr: + local: + script: "./local-linr.py" + conf: "./linr_sklearn_config.yaml" + FATE-hetero-linr: + script: "./fate-linr.py" + conf: "./linr_config.yaml" + compare_setting: + relative_tol: 0.01 diff --git a/examples/benchmark_quality/linr/linr_config.yaml b/examples/benchmark_quality/linr/linr_config.yaml new file mode 100644 index 0000000000..13f5199e90 --- /dev/null +++ b/examples/benchmark_quality/linr/linr_config.yaml @@ -0,0 +1,22 @@ +data_guest: "examples/data/motor_hetero_guest.csv" +data_host: "examples/data/motor_hetero_host.csv" +label_name: "motor_speed" +penalty: "L2" +epochs: 10 +init_param: + fit_intercept: True + method: "zeros" + random_state: 42 +learning_rate_scheduler: + method: "constant" + scheduler_params: + factor: 1.0 + total_iters: 100 +optimizer: + method: "sgd" + penalty: "L2" + optimizer_params: + lr: 0.13 + alpha: 0.01 +batch_size: 100 +early_stop: "diff" diff --git a/examples/benchmark_quality/linr/linr_sklearn_config.yaml b/examples/benchmark_quality/linr/linr_sklearn_config.yaml new file mode 100644 index 0000000000..38a15edc00 --- /dev/null +++ b/examples/benchmark_quality/linr/linr_sklearn_config.yaml @@ -0,0 +1,11 @@ +data_guest: "examples/data/motor_hetero_guest.csv" +data_host: "examples/data/motor_hetero_host.csv" +label_name: "motor_speed" +penalty: "L2" +idx: "idx" +epochs: 20 +fit_intercept: True +method: "rmsprop" +eta0: 0.1 +alpha: 0.5 +batch_size: 5000 diff --git a/examples/benchmark_quality/linr/local-linr.py b/examples/benchmark_quality/linr/local-linr.py new file mode 100644 index 0000000000..bffafbb524 --- /dev/null +++ b/examples/benchmark_quality/linr/local-linr.py @@ -0,0 +1,72 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +import os + +import numpy as np +import pandas +from fate_client.pipeline.utils.test_utils import JobConfig +from sklearn.linear_model import SGDRegressor +from sklearn.metrics import mean_squared_error, r2_score, explained_variance_score + + +def main(config="../../config.yaml", param="./linr_sklearn_config.yaml"): + # obtain config + if isinstance(param, str): + param = JobConfig.load_from_file(param) + data_guest = param["data_guest"] + data_host = param["data_host"] + idx = param["idx"] + label_name = param["label_name"] + + if isinstance(config, str): + config = JobConfig.load_from_file(config) + print(f"config: {config}") + data_base_dir = config["data_base_dir"] + else: + data_base_dir = config.data_base_dir + + # prepare data + df_guest = pandas.read_csv(os.path.join(data_base_dir, data_guest), index_col=idx) + df_host = pandas.read_csv(os.path.join(data_base_dir, data_host), index_col=idx) + df = df_guest.join(df_host, rsuffix="host") + y = df[label_name] + X = df.drop(label_name, axis=1) + lm = SGDRegressor(loss="squared_error", penalty=param["penalty"], random_state=42, + fit_intercept=True, max_iter=param["epochs"], average=param["batch_size"]) + lm_fit = lm.fit(X, y) + y_pred = lm_fit.predict(X) + + mse = mean_squared_error(y, y_pred) + rmse = np.sqrt(mse) + r2 = r2_score(y, y_pred) + explained_var = explained_variance_score(y, y_pred) + metric_summary = {"r2_score": r2, + "mse": mse, + "rmse": rmse} + data_summary = {} + return data_summary, metric_summary + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("BENCHMARK-QUALITY LOCAL JOB") + parser.add_argument("-c", "--config", type=str, + help="config file", default="../../config.yaml") + parser.add_argument("-p", "--param", type=str, + help="config file for params", default="./linr_sklearn_config.yaml") + args = parser.parse_args() + main(args.config, args.param) diff --git a/examples/benchmark_quality/lr/breast_config.yaml b/examples/benchmark_quality/lr/breast_config.yaml new file mode 100644 index 0000000000..3d1747cc04 --- /dev/null +++ b/examples/benchmark_quality/lr/breast_config.yaml @@ -0,0 +1,22 @@ +data_guest: "breast_hetero_guest" +data_host: "breast_hetero_host" +idx: "id" +label_name: "y" +epochs: 20 +init_param: + fit_intercept: True + method: "random_uniform" + random_state: 42 +learning_rate_scheduler: + method: "constant" + scheduler_params: + factor: 0.5 + total_iters: 5 +optimizer: + method: "rmsprop" + penalty: "l2" + optimizer_params: + lr: 0.15 + alpha: 0.01 +batch_size: 240 +early_stop: "diff" \ No newline at end of file diff --git a/examples/benchmark_quality/lr/breast_lr_sklearn_config.yaml b/examples/benchmark_quality/lr/breast_lr_sklearn_config.yaml new file mode 100644 index 0000000000..e7fc0c17d4 --- /dev/null +++ b/examples/benchmark_quality/lr/breast_lr_sklearn_config.yaml @@ -0,0 +1,11 @@ +data_guest: "examples/data/breast_hetero_guest.csv" +data_host: "examples/data/breast_hetero_host.csv" +idx: "id" +label_name: "y" +epochs: 30 +fit_intercept: True +method: "rmsprop" +penalty: "L2" +eta0: 0.1 +alpha: 0.05 +batch_size: 5000 \ No newline at end of file diff --git a/examples/benchmark_quality/lr/default_credit_config.yaml b/examples/benchmark_quality/lr/default_credit_config.yaml new file mode 100644 index 0000000000..3bfbd67760 --- /dev/null +++ b/examples/benchmark_quality/lr/default_credit_config.yaml @@ -0,0 +1,22 @@ +data_guest: "default_credit_hetero_guest" +data_host: "default_credit_hetero_host" +idx: "id" +label_name: "y" +epochs: 16 +init_param: + fit_intercept: True + method: "zeros" + random_state: 42 +learning_rate_scheduler: + method: "linear" + scheduler_params: + start_factor: 0.7 + total_iters: 1000 +optimizer: + method: "rmsprop" + penalty: "L2" + alpha: 0.1 + optimizer_params: + lr: 0.12 +batch_size: 10000 +early_stop: "diff" \ No newline at end of file diff --git a/examples/benchmark_quality/lr/default_credit_lr_sklearn_config.yaml b/examples/benchmark_quality/lr/default_credit_lr_sklearn_config.yaml new file mode 100644 index 0000000000..df9387503d --- /dev/null +++ b/examples/benchmark_quality/lr/default_credit_lr_sklearn_config.yaml @@ -0,0 +1,11 @@ +data_guest: "examples/data/default_credit_hetero_guest.csv" +data_host: "examples/data/default_credit_hetero_host.csv" +idx: "id" +label_name: "y" +epochs: 30 +fit_intercept: True +method: "rmsprop" +penalty: "L2" +eta0: 0.1 +alpha: 0.1 +batch_size: 5000 \ No newline at end of file diff --git a/examples/benchmark_quality/lr/epsilon_5k_config.yaml b/examples/benchmark_quality/lr/epsilon_5k_config.yaml new file mode 100644 index 0000000000..034d61378c --- /dev/null +++ b/examples/benchmark_quality/lr/epsilon_5k_config.yaml @@ -0,0 +1,22 @@ +data_guest: "epsilon_5k_hetero_guest" +data_host: "epsilon_5k_hetero_host" +idx: "id" +label_name: "y" +epochs: 8 +batch_size: 2200 +init_param: + fit_intercept: True + method: "random" + random_state: 42 +learning_rate_scheduler: + method: "linear" + scheduler_params: + start_factor: 0.7 + total_iters: 1000 +optimizer: + method: "adam" + penalty: "L2" + alpha: 0.0001 + optimizer_params: + lr: 0.43 +early_stop: "diff" \ No newline at end of file diff --git a/examples/benchmark_quality/lr/epsilon_5k_lr_sklearn_config.yaml b/examples/benchmark_quality/lr/epsilon_5k_lr_sklearn_config.yaml new file mode 100644 index 0000000000..7559f0bfa6 --- /dev/null +++ b/examples/benchmark_quality/lr/epsilon_5k_lr_sklearn_config.yaml @@ -0,0 +1,11 @@ +data_guest: "examples/data/epsilon_5k_hetero_guest.csv" +data_host: "examples/data/epsilon_5k_hetero_host.csv" +idx: "id" +label_name: "y" +epochs: 10 +fit_intercept: True +method: "rmsprop" +penalty: "L2" +eta0: 0.1 +alpha: 0.001 +batch_size: 5000 \ No newline at end of file diff --git a/examples/benchmark_quality/lr/give_credit_config.yaml b/examples/benchmark_quality/lr/give_credit_config.yaml new file mode 100644 index 0000000000..6f8656132b --- /dev/null +++ b/examples/benchmark_quality/lr/give_credit_config.yaml @@ -0,0 +1,21 @@ +data_guest: "give_credit_hetero_guest" +data_host: "give_credit_hetero_host" +idx: "id" +label_name: "y" +epochs: 16 +init_param: + fit_intercept: True + method: "zeros" +learning_rate_scheduler: + method: "linear" + scheduler_params: + start_factor: 0.71 + total_iters: 1000 +optimizer: + method: "rmsprop" + penalty: "L1" + alpha: 0.01 + optimizer_params: + lr: 0.25 +batch_size: null +early_stop: "diff" \ No newline at end of file diff --git a/examples/benchmark_quality/lr/give_credit_lr_sklearn_config.yaml b/examples/benchmark_quality/lr/give_credit_lr_sklearn_config.yaml new file mode 100644 index 0000000000..4dcb136b99 --- /dev/null +++ b/examples/benchmark_quality/lr/give_credit_lr_sklearn_config.yaml @@ -0,0 +1,11 @@ +data_guest: "examples/data/give_credit_hetero_guest.csv" +data_host: "examples/data/give_credit_hetero_host.csv" +idx: "id" +label_name: "y" +epochs: 30 +fit_intercept: True +method: "rmsprop" +penalty: "L2" +eta0: 0.1 +alpha: 0.5 +batch_size: 5000 \ No newline at end of file diff --git a/examples/benchmark_quality/lr/lr_benchmark.yaml b/examples/benchmark_quality/lr/lr_benchmark.yaml new file mode 100644 index 0000000000..294c9264c1 --- /dev/null +++ b/examples/benchmark_quality/lr/lr_benchmark.yaml @@ -0,0 +1,217 @@ +data: + - file: examples/data/breast_hetero_guest.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + label_type: int64 + label_name: y + match_id_name: "id" + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_guest + namespace: experiment + role: guest_0 + - file: examples/data/breast_hetero_host.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: "id" + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_host + namespace: experiment + role: host_0 + - file: "examples/data/default_credit_hetero_guest.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: "id" + match_id_range: 0 + label_type: int64 + label_name: y + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: default_credit_hetero_guest + namespace: experiment + role: guest_0 + - file: "examples/data/default_credit_hetero_host.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: "id" + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: default_credit_hetero_host + namespace: experiment + role: host_0 + - file: "examples/data/give_credit_hetero_guest.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: "id" + match_id_range: 0 + label_type: int64 + label_name: y + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: give_credit_hetero_guest + namespace: experiment + role: guest_0 + - file: "examples/data/give_credit_hetero_host.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: "id" + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + head: true + partition: 4 + extend_sid: true + table_name: give_credit_hetero_host + namespace: experiment + role: host_0 + - file: "examples/data/epsilon_5k_hetero_guest.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: "id" + match_id_range: 0 + label_type: int64 + label_name: y + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + head: true + partition: 4 + extend_sid: true + table_name: epsilon_5k_hetero_guest + namespace: experiment + role: guest_0 + - file: "examples/data/epsilon_5k_hetero_host.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: "id" + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + head: true + partition: 4 + extend_sid: true + table_name: epsilon_5k_hetero_host + namespace: experiment + role: host_0 + - file: "examples/data/vehicle_scale_hetero_guest.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: "id" + match_id_range: 0 + label_type: int64 + label_name: y + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + head: true + partition: 4 + extend_sid: true + table_name: vehicle_scale_hetero_guest + namespace: experiment + role: guest_0 + - file: "examples/data/vehicle_scale_hetero_host.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: "id" + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + head: true + partition: 4 + extend_sid: true + table_name: vehicle_scale_hetero_host + namespace: experiment + role: host_0 +hetero_lr-binary-0-breast: + local: + script: "./sklearn-lr-binary.py" + conf: "./breast_lr_sklearn_config.yaml" + FATE-hetero-lr: + script: "./pipeline-lr-binary.py" + conf: "./breast_config.yaml" + compare_setting: + relative_tol: 0.01 +hetero_lr-binary-1-default-credit: + local: + script: "./sklearn-lr-binary.py" + conf: "./default_credit_lr_sklearn_config.yaml" + FATE-hetero-lr: + script: "./pipeline-lr-binary.py" + conf: "./default_credit_config.yaml" + compare_setting: + relative_tol: 0.01 +hetero_lr-binary-2-epsilon-5k: + local: + script: "./sklearn-lr-binary.py" + conf: "./epsilon_5k_lr_sklearn_config.yaml" + FATE-hetero-lr: + script: "./pipeline-lr-binary.py" + conf: "./epsilon_5k_config.yaml" + compare_setting: + relative_tol: 0.01 +hetero_lr-binary-3-give-credit: + local: + script: "./sklearn-lr-binary.py" + conf: "./give_credit_lr_sklearn_config.yaml" + FATE-hetero-lr: + script: "./pipeline-lr-binary.py" + conf: "./give_credit_config.yaml" + compare_setting: + relative_tol: 0.01 +multi-vehicle: + local: + script: "./sklearn-lr-multi.py" + conf: "./vehicle_lr_sklearn_config.yaml" + FATE-hetero-lr: + script: "./pipeline-lr-multi.py" + conf: "./vehicle_config.yaml" + compare_setting: + relative_tol: 0.01 + diff --git a/examples/benchmark_quality/lr/pipeline-lr-binary.py b/examples/benchmark_quality/lr/pipeline-lr-binary.py new file mode 100644 index 0000000000..6afb78b25c --- /dev/null +++ b/examples/benchmark_quality/lr/pipeline-lr-binary.py @@ -0,0 +1,118 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse + +from fate_test.utils import parse_summary_result + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import CoordinatedLR, PSI +from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../../config.yaml", param="./breast_config.yaml", namespace=""): + # obtain config + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + arbiter = parties.arbiter[0] + + if isinstance(param, str): + param = test_utils.JobConfig.load_from_file(param) + + assert isinstance(param, dict) + + guest_data_table = param.get("data_guest") + host_data_table = param.get("data_host") + + guest_train_data = {"name": guest_data_table, "namespace": f"experiment{namespace}"} + host_train_data = {"name": host_data_table, "namespace": f"experiment{namespace}"} + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name=guest_train_data["name"], + namespace=guest_train_data["namespace"])) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name=host_train_data["name"], + namespace=host_train_data["namespace"])) + + lr_param = { + } + + config_param = { + "epochs": param["epochs"], + "learning_rate_scheduler": param["learning_rate_scheduler"], + "optimizer": param["optimizer"], + "batch_size": param["batch_size"], + "early_stop": param["early_stop"], + "init_param": param["init_param"], + "tol": 1e-5 + } + lr_param.update(config_param) + lr_0 = CoordinatedLR("lr_0", + train_data=psi_0.outputs["output_data"], + **lr_param) + lr_1 = CoordinatedLR("lr_1", + test_data=psi_0.outputs["output_data"], + input_model=lr_0.outputs["output_model"]) + + evaluation_0 = Evaluation("evaluation_0", + runtime_roles=["guest"], + metrics=["auc", "binary_precision", "binary_accuracy", "binary_recall"], + input_data=lr_0.outputs["train_output_data"]) + + pipeline.add_task(psi_0) + pipeline.add_task(lr_0) + pipeline.add_task(lr_1) + pipeline.add_task(evaluation_0) + + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + pipeline.compile() + pipeline.fit() + + """lr_0_data = pipeline.get_task_info("lr_0").get_output_data()["train_output_data"] + lr_1_data = pipeline.get_task_info("lr_1").get_output_data()["test_output_data"] + lr_0_score = extract_data(lr_0_data, "predict_result") + lr_0_label = extract_data(lr_0_data, "y") + lr_1_score = extract_data(lr_1_data, "predict_result") + lr_1_label = extract_data(lr_1_data, "y") + lr_0_score_label = extract_data(lr_0_data, "predict_result", keep_id=True) + lr_1_score_label = extract_data(lr_1_data, "predict_result", keep_id=True)""" + + result_summary = parse_summary_result(pipeline.get_task_info("evaluation_0").get_output_metric()[0]["data"]) + print(f"result_summary: {result_summary}") + + data_summary = {"train": {"guest": guest_train_data["name"], "host": host_train_data["name"]}, + "test": {"guest": guest_train_data["name"], "host": host_train_data["name"]} + } + + return data_summary, result_summary + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("BENCHMARK-QUALITY PIPELINE JOB") + parser.add_argument("-c", "--config", type=str, + help="config file", default="../../config.yaml") + parser.add_argument("-p", "--param", type=str, + help="config file for params", default="./breast_config.yaml") + args = parser.parse_args() + main(args.config, args.param) diff --git a/examples/benchmark_quality/lr/pipeline-lr-multi.py b/examples/benchmark_quality/lr/pipeline-lr-multi.py new file mode 100644 index 0000000000..0741ef0e06 --- /dev/null +++ b/examples/benchmark_quality/lr/pipeline-lr-multi.py @@ -0,0 +1,114 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import CoordinatedLR, PSI +from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils +from fate_test.utils import extract_data, parse_summary_result + + +def main(config="../../config.yaml", param="./vehicle_config.yaml", namespace=""): + # obtain config + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + arbiter = parties.arbiter[0] + + if isinstance(param, str): + param = test_utils.JobConfig.load_from_file(param) + + assert isinstance(param, dict) + guest_data_table = param.get("data_guest") + host_data_table = param.get("data_host") + + guest_train_data = {"name": guest_data_table, "namespace": f"experiment{namespace}"} + host_train_data = {"name": host_data_table, "namespace": f"experiment{namespace}"} + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name=guest_train_data["name"], + namespace=guest_train_data["namespace"])) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name=host_train_data["name"], + namespace=host_train_data["namespace"])) + + lr_param = { + } + + config_param = { + "epochs": param["epochs"], + "learning_rate_scheduler": param["learning_rate_scheduler"], + "optimizer": param["optimizer"], + "batch_size": param["batch_size"], + "early_stop": param["early_stop"], + "init_param": param["init_param"], + "tol": 1e-5, + } + lr_param.update(config_param) + lr_0 = CoordinatedLR("lr_0", + train_data=psi_0.outputs["output_data"], + **config_param) + lr_1 = CoordinatedLR("lr_1", + test_data=psi_0.outputs["output_data"], + input_model=lr_0.outputs["output_model"]) + + evaluation_0 = Evaluation('evaluation_0', + runtime_roles=['guest'], + input_data=lr_0.outputs["train_output_data"], + predict_column_name='predict_result', + metrics=['multi_recall', 'multi_accuracy', 'multi_precision']) + pipeline.add_task(psi_0) + pipeline.add_task(lr_0) + pipeline.add_task(lr_1) + pipeline.add_task(evaluation_0) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + pipeline.compile() + pipeline.fit() + + lr_0_data = pipeline.get_task_info("lr_0").get_output_data()["train_output_data"] + lr_1_data = pipeline.get_task_info("lr_1").get_output_data()["test_output_data"] + + result_summary = parse_summary_result(pipeline.get_task_info("evaluation_0").get_output_metric()[0]["data"]) + lr_0_score_label = extract_data(lr_0_data, "predict_result", keep_id=True) + lr_1_score_label = extract_data(lr_1_data, "predict_result", keep_id=True) + + data_summary = {"train": {"guest": guest_train_data["name"], "host": host_train_data["name"]}, + "test": {"guest": guest_train_data["name"], "host": host_train_data["name"]} + } + return data_summary, result_summary + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("BENCHMARK-QUALITY PIPELINE JOB") + parser.add_argument("-c", "--config", type=str, + help="config file", default="../../config.yaml") + parser.add_argument("-p", "--param", type=str, + help="config file for params", default="./vehicle_config.yaml") + + args = parser.parse_args() + if args.config is not None: + main(args.config, args.param) + else: + main() diff --git a/examples/benchmark_quality/lr/sklearn-lr-binary.py b/examples/benchmark_quality/lr/sklearn-lr-binary.py new file mode 100644 index 0000000000..51e463df94 --- /dev/null +++ b/examples/benchmark_quality/lr/sklearn-lr-binary.py @@ -0,0 +1,92 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +import os + +import pandas +from fate_client.pipeline.utils.test_utils import JobConfig +from sklearn.linear_model import SGDClassifier +from sklearn.metrics import roc_auc_score, precision_score, accuracy_score, recall_score, roc_curve + + +def main(config="../../config.yaml", param="./breast_lr_sklearn_config.yaml"): + # obtain config + if isinstance(param, str): + param = JobConfig.load_from_file(param) + assert isinstance(param, dict) + data_guest = param["data_guest"] + data_host = param["data_host"] + idx = param["idx"] + label_name = param["label_name"] + + if isinstance(config, str): + config = JobConfig.load_from_file(config) + print(f"config: {config}") + data_base_dir = config["data_base_dir"] + else: + data_base_dir = config.data_base_dir + + config_param = { + "penalty": param["penalty"], + "max_iter": param["epochs"], + "alpha": param["alpha"], + "learning_rate": "optimal", + "eta0": param["eta0"], + "random_state": 105 + } + + # prepare data + df_guest = pandas.read_csv(os.path.join(data_base_dir, data_guest), index_col=idx) + df_host = pandas.read_csv(os.path.join(data_base_dir, data_host), index_col=idx) + df = df_guest.join(df_host, rsuffix="host") + y = df[label_name] + X = df.drop(label_name, axis=1) + + # x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=0) + x_train, x_test, y_train, y_test = X, X, y, y + + # lm = LogisticRegression(max_iter=20) + lm = SGDClassifier(loss="log", **config_param) + lm_fit = lm.fit(x_train, y_train) + y_pred = lm_fit.predict(x_test) + y_prob = lm_fit.predict_proba(x_test)[:, 1] + try: + auc_score = roc_auc_score(y_test, y_prob) + except BaseException: + print(f"no auc score available") + return + recall = recall_score(y_test, y_pred, average="macro") + pr = precision_score(y_test, y_pred, average="macro") + acc = accuracy_score(y_test, y_pred) + # y_predict_proba = est.predict_proba(X_test)[:, 1] + fpr, tpr, thresholds = roc_curve(y_test, y_prob) + + ks = max(tpr - fpr) + result = {"auc": auc_score, "recall": recall, "precision": pr, "accuracy": acc} + print(result) + # print(f"coef_: {lm_fit.coef_}, intercept_: {lm_fit.intercept_}, n_iter: {lm_fit.n_iter_}") + return {}, result + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("BENCHMARK-QUALITY SKLEARN JOB") + parser.add_argument("-c", "--config", type=str, + help="config file", default="../../config.yaml") + parser.add_argument("-p", "--param", type=str, + help="config file for params", default="./breast_lr_sklearn_config.yaml") + args = parser.parse_args() + main(args.config, args.param) diff --git a/examples/benchmark_quality/lr/sklearn-lr-multi.py b/examples/benchmark_quality/lr/sklearn-lr-multi.py new file mode 100644 index 0000000000..b56fc80dce --- /dev/null +++ b/examples/benchmark_quality/lr/sklearn-lr-multi.py @@ -0,0 +1,82 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +import os + +import pandas +from fate_client.pipeline.utils.test_utils import JobConfig +from sklearn.linear_model import SGDClassifier +from sklearn.metrics import precision_score, accuracy_score, recall_score + + +def main(config="../../config.yaml", param="./vehicle_lr_sklearn_config.yaml"): + # obtain config + if isinstance(param, str): + param = JobConfig.load_from_file(param) + assert isinstance(param, dict) + data_guest = param["data_guest"] + data_host = param["data_host"] + + idx = param["idx"] + label_name = param["label_name"] + + if isinstance(config, str): + config = JobConfig.load_from_file(config) + data_base_dir = config["data_base_dir"] + else: + data_base_dir = config.data_base_dir + + config_param = { + "penalty": param["penalty"], + "max_iter": param["epochs"], + "alpha": param["alpha"], + "learning_rate": "optimal", + "eta0": param["eta0"], + "random_state": 105 + } + + # prepare data + df_guest = pandas.read_csv(os.path.join(data_base_dir, data_guest), index_col=idx) + df_host = pandas.read_csv(os.path.join(data_base_dir, data_host), index_col=idx) + + df = df_guest.join(df_host, rsuffix="host") + y = df[label_name] + X = df.drop(label_name, axis=1) + # lm = LogisticRegression(max_iter=20) + lm = SGDClassifier(loss="log", **config_param, shuffle=False) + lm_fit = lm.fit(X, y) + y_pred = lm_fit.predict(X) + + recall = recall_score(y, y_pred, average="macro") + pr = precision_score(y, y_pred, average="macro") + acc = accuracy_score(y, y_pred) + + result = {"multi_accuracy": acc, + "multi_precision": pr, + "multi_recall": recall} + print(result) + return {}, result + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("BENCHMARK-QUALITY SKLEARN JOB") + parser.add_argument("-c", "--config", type=str, + help="config file", default="../../config.yaml") + parser.add_argument("-p", "--param", type=str, + help="config file for params", default="./vehicle_lr_sklearn_config.yaml") + args = parser.parse_args() + main(args.config, args.param) diff --git a/examples/benchmark_quality/lr/vehicle_config.yaml b/examples/benchmark_quality/lr/vehicle_config.yaml new file mode 100644 index 0000000000..fdbfadf47d --- /dev/null +++ b/examples/benchmark_quality/lr/vehicle_config.yaml @@ -0,0 +1,24 @@ +data_guest: "vehicle_scale_hetero_guest" +data_host: "vehicle_scale_hetero_host" +idx: "id" +label_name: "y" +epochs: 10 +init_param: + fit_intercept: True + method: "random_uniform" + random_state: 42 +learning_rate_scheduler: + method: "linear" + scheduler_params: + start_factor: 0.7 + total_iters: 800 +optimizer: + method: "adam" + penalty: "L2" + alpha: 0.000001 + optimizer_params: + lr: 0.3 +batch_size: 18 +early_stop: "diff" +task_cores: null +timeout: 3600 \ No newline at end of file diff --git a/examples/benchmark_quality/lr/vehicle_lr_sklearn_config.yaml b/examples/benchmark_quality/lr/vehicle_lr_sklearn_config.yaml new file mode 100644 index 0000000000..bf3001f665 --- /dev/null +++ b/examples/benchmark_quality/lr/vehicle_lr_sklearn_config.yaml @@ -0,0 +1,11 @@ +data_guest: "examples/data/vehicle_scale_hetero_guest.csv" +data_host: "examples/data/vehicle_scale_hetero_host.csv" +idx: "id" +label_name: "y" +epochs: 30 +fit_intercept: True +method: "rmsprop" +penalty: "L2" +eta0: 0.1 +alpha: 0.001 +batch_size: 5000 \ No newline at end of file diff --git a/examples/config.yaml b/examples/config.yaml index cd8658ab09..d6aba4704a 100644 --- a/examples/config.yaml +++ b/examples/config.yaml @@ -7,6 +7,4 @@ parties: # parties default id arbiter: - 10000 -work_mode: 0 # 0 for standalone, or 1 for cluster - -data_base_dir: "/data/projects/fate" # pa th to project base where data is located \ No newline at end of file +data_base_dir: "~/FATE/FATE-2.0-pure/FATE/" \ No newline at end of file diff --git a/examples/pipeline/config.yaml b/examples/pipeline/config.yaml new file mode 100644 index 0000000000..394a5b7802 --- /dev/null +++ b/examples/pipeline/config.yaml @@ -0,0 +1,10 @@ +parties: # parties default id + guest: + - 9999 + host: + - 9998 + - 9999 + arbiter: + - 9998 + +data_base_dir: "" # path to project base where data is located \ No newline at end of file diff --git a/examples/pipeline/coordinated_linr/coordinated_linr_testsuite.yaml b/examples/pipeline/coordinated_linr/coordinated_linr_testsuite.yaml new file mode 100644 index 0000000000..7f190c4345 --- /dev/null +++ b/examples/pipeline/coordinated_linr/coordinated_linr_testsuite.yaml @@ -0,0 +1,60 @@ +data: + - file: examples/data/motor_hetero_guest.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + label_type: float64 + label_name: motor_speed + match_id_name: idx + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: motor_hetero_guest + namespace: experiment + role: guest_0 + - file: examples/data/motor_hetero_host.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: idx + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: motor_hetero_host + namespace: experiment + role: host_0 + - file: examples/data/motor_hetero_host.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: idx + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: motor_hetero_host + namespace: experiment + role: host_1 +tasks: + normal-linr: + script: test_linr.py + linr-cv: + script: test_linr_cv.py + linr-warm-start: + script: test_linr_warm_start.py + linr-multi-host: + script: test_linr_multi_host.py diff --git a/examples/pipeline/coordinated_linr/test_linr.py b/examples/pipeline/coordinated_linr/test_linr.py new file mode 100644 index 0000000000..5d388a5d10 --- /dev/null +++ b/examples/pipeline/coordinated_linr/test_linr.py @@ -0,0 +1,87 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import CoordinatedLinR, PSI, Evaluation +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + arbiter = parties.arbiter[0] + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="motor_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="motor_hetero_host", + namespace=f"experiment{namespace}")) + linr_0 = CoordinatedLinR("linr_0", + epochs=10, + batch_size=100, + optimizer={"method": "rmsprop", "optimizer_params": {"lr": 0.01}, + "alpha": 0.001}, + init_param={"fit_intercept": True}, + train_data=psi_0.outputs["output_data"]) + evaluation_0 = Evaluation("evaluation_0", + runtime_roles=["guest"], + default_eval_setting="regression", + input_data=linr_0.outputs["train_output_data"]) + + pipeline.add_task(psi_0) + pipeline.add_task(linr_0) + pipeline.add_task(evaluation_0) + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + pipeline.deploy([psi_0, linr_0]) + + predict_pipeline = FateFlowPipeline() + + deployed_pipeline = pipeline.get_deployed_pipeline() + deployed_pipeline.psi_0.guest.component_setting( + input_data=DataWarehouseChannel(name="motor_hetero_guest", + namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.hosts[0].component_setting( + input_data=DataWarehouseChannel(name="motor_hetero_host", + namespace=f"experiment{namespace}")) + + predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.compile() + # print("\n\n\n") + # print(predict_pipeline.compile().get_dag()) + predict_pipeline.predict() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/coordinated_linr/test_linr_cv.py b/examples/pipeline/coordinated_linr/test_linr_cv.py new file mode 100644 index 0000000000..2b1f4e87fb --- /dev/null +++ b/examples/pipeline/coordinated_linr/test_linr_cv.py @@ -0,0 +1,65 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import CoordinatedLinR, PSI +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + arbiter = parties.arbiter[0] + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="motor_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="motor_hetero_host", + namespace=f"experiment{namespace}")) + linr_0 = CoordinatedLinR("linr_0", + epochs=10, + batch_size=None, + optimizer={"method": "sgd", "optimizer_params": {"lr": 0.01}, + "alpha": 0.001}, + init_param={"fit_intercept": True}, + cv_data=psi_0.outputs["output_data"], + cv_param={"n_splits": 3}) + + pipeline.add_task(psi_0) + pipeline.add_task(linr_0) + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/coordinated_linr/test_linr_multi_host.py b/examples/pipeline/coordinated_linr/test_linr_multi_host.py new file mode 100644 index 0000000000..425adc5113 --- /dev/null +++ b/examples/pipeline/coordinated_linr/test_linr_multi_host.py @@ -0,0 +1,93 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import CoordinatedLinR, PSI +from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host + arbiter = parties.arbiter[0] + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="motor_hetero_guest", + namespace=f"{namespace}experiment")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="motor_hetero_host", + namespace=f"{namespace}experiment")) + psi_0.hosts[1].component_setting(input_data=DataWarehouseChannel(name="motor_hetero_host", + namespace=f"{namespace}experiment")) + linr_0 = CoordinatedLinR("linr_0", + epochs=5, + batch_size=None, + early_stop="weight_diff", + optimizer={"method": "SGD", "optimizer_params": {"lr": 0.1}, + "alpha": 0.001}, + init_param={"fit_intercept": True, "method": "random_uniform"}, + train_data=psi_0.outputs["output_data"], + learning_rate_scheduler={"method": "constant", "scheduler_params": {"factor": 1.0, + "total_iters": 100}}) + + evaluation_0 = Evaluation("evaluation_0", + runtime_roles=["guest"], + default_eval_setting="regression", + input_data=linr_0.outputs["train_output_data"]) + + pipeline.add_task(psi_0) + pipeline.add_task(linr_0) + pipeline.add_task(evaluation_0) + + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + pipeline.deploy([psi_0, linr_0]) + + predict_pipeline = FateFlowPipeline() + + deployed_pipeline = pipeline.get_deployed_pipeline() + deployed_pipeline.psi_0.guest.component_setting( + input_data=DataWarehouseChannel(name="motor_hetero_guest", + namespace=f"{namespace}experiment")) + deployed_pipeline.psi_0.hosts[[0, 1]].component_setting( + input_data=DataWarehouseChannel(name="motor_hetero_host", + namespace=f"{namespace}experiment")) + + predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.compile() + # print("\n\n\n") + # print(predict_pipeline.compile().get_dag()) + predict_pipeline.predict() + # print(f"predict linr_0 data: {pipeline.get_task_info('linr_0').get_output_data()}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/coordinated_linr/test_linr_warm_start.py b/examples/pipeline/coordinated_linr/test_linr_warm_start.py new file mode 100644 index 0000000000..198dcf6341 --- /dev/null +++ b/examples/pipeline/coordinated_linr/test_linr_warm_start.py @@ -0,0 +1,97 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import CoordinatedLinR, PSI +from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + arbiter = parties.arbiter[0] + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="motor_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="motor_hetero_host", + namespace=f"experiment{namespace}")) + linr_0 = CoordinatedLinR("linr_0", + epochs=4, + batch_size=None, + optimizer={"method": "SGD", "optimizer_params": {"lr": 0.01}, + "alpha": 0.001}, + init_param={"fit_intercept": True, "method": "zeros"}, + train_data=psi_0.outputs["output_data"], + learning_rate_scheduler={"method": "constant", "scheduler_params": {"factor": 1.0, + "total_iters": 100}}) + linr_1 = CoordinatedLinR("linr_1", train_data=psi_0.outputs["output_data"], + warm_start_model=linr_0.outputs["output_model"], + epochs=2, + batch_size=None, + optimizer={"method": "SGD", "optimizer_params": {"lr": 0.01}, + "alpha": 0.001}, + ) + + linr_2 = CoordinatedLinR("linr_2", epochs=6, + batch_size=None, + optimizer={"method": "SGD", "optimizer_params": {"lr": 0.01}, + "alpha": 0.001}, + init_param={"fit_intercept": True, "method": "zeros"}, + train_data=psi_0.outputs["output_data"], + learning_rate_scheduler={"method": "constant", "scheduler_params": {"factor": 1.0, + "total_iters": 100}}) + + evaluation_0 = Evaluation("evaluation_0", + runtime_roles=["guest"], + default_eval_setting="regression", + input_data=[linr_1.outputs["train_output_data"], linr_2.outputs["train_output_data"]]) + + pipeline.add_task(psi_0) + pipeline.add_task(linr_0) + pipeline.add_task(linr_1) + pipeline.add_task(linr_2) + pipeline.add_task(evaluation_0) + + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + # print(f"linr_1 model: {pipeline.get_task_info('linr_1').get_output_model()}") + # print(f"train linr_1 data: {pipeline.get_task_info('linr_1').get_output_data()}") + + # print(f"linr_2 model: {pipeline.get_task_info('linr_2').get_output_model()}") + # print(f"train linr_2 data: {pipeline.get_task_info('linr_2').get_output_data()}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/coordinated_lr/coordinated_lr_testsuite.yaml b/examples/pipeline/coordinated_lr/coordinated_lr_testsuite.yaml new file mode 100644 index 0000000000..224880a88b --- /dev/null +++ b/examples/pipeline/coordinated_lr/coordinated_lr_testsuite.yaml @@ -0,0 +1,98 @@ +data: + - file: examples/data/breast_hetero_guest.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + label_type: int64 + label_name: y + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_guest + namespace: experiment + role: guest_0 + - file: examples/data/breast_hetero_host.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_host + namespace: experiment + role: host_0 + - file: examples/data/breast_hetero_host.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_host + namespace: experiment + role: host_1 + - file: examples/data/vehicle_scale_hetero_guest.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: id + match_id_range: 0 + label_type: int64 + label_name: y + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + head: true + partitions: 4 + extend_sid: true + table_name: vehicle_scale_hetero_guest + namespace: experiment + role: guest_0 + - file: examples/data/vehicle_scale_hetero_host.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + head: true + partitions: 4 + extend_sid: true + table_name: vehicle_scale_hetero_host + namespace: experiment + role: host_0 +tasks: + normal-lr: + script: test_lr.py + lr-cv: + script: test_lr_cv.py + lr-validate: + script: test_lr_validate.py + lr-warm-start: + script: test_lr_warm_start.py + lr-multi-class: + script: test_lr_multi_class.py + lr-multi-host: + script: test_lr_multi_host.py diff --git a/examples/pipeline/coordinated_lr/test_lr.py b/examples/pipeline/coordinated_lr/test_lr.py new file mode 100644 index 0000000000..5159422457 --- /dev/null +++ b/examples/pipeline/coordinated_lr/test_lr.py @@ -0,0 +1,92 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import CoordinatedLR, PSI +from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + arbiter = parties.arbiter[0] + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + lr_0 = CoordinatedLR("lr_0", + epochs=10, + batch_size=300, + optimizer={"method": "SGD", "optimizer_params": {"lr": 0.1}, "penalty": "l2", "alpha": 0.001}, + init_param={"fit_intercept": True, "method": "zeros"}, + train_data=psi_0.outputs["output_data"], + learning_rate_scheduler={"method": "linear", "scheduler_params": {"start_factor": 0.7, + "total_iters": 100}}) + + evaluation_0 = Evaluation("evaluation_0", + runtime_roles=["guest"], + default_eval_setting="binary", + input_data=lr_0.outputs["train_output_data"]) + + pipeline.add_task(psi_0) + pipeline.add_task(lr_0) + pipeline.add_task(evaluation_0) + + pipeline.compile() + pipeline.fit() + + pipeline.deploy([psi_0, lr_0]) + + predict_pipeline = FateFlowPipeline() + + deployed_pipeline = pipeline.get_deployed_pipeline() + deployed_pipeline.psi_0.guest.component_setting( + input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.hosts[0].component_setting( + input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.compile() + # print("\n\n\n") + # print(predict_pipeline.compile().get_dag()) + predict_pipeline.predict() + # print(f"predict lr_0 data: {pipeline.get_task_info('lr_0').get_output_data()}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/coordinated_lr/test_lr_cv.py b/examples/pipeline/coordinated_lr/test_lr_cv.py new file mode 100644 index 0000000000..456241fd10 --- /dev/null +++ b/examples/pipeline/coordinated_lr/test_lr_cv.py @@ -0,0 +1,65 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import CoordinatedLR, PSI +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + arbiter = parties.arbiter[0] + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + lr_0 = CoordinatedLR("lr_0", + epochs=2, + batch_size=None, + optimizer={"method": "sgd", "optimizer_params": {"lr": 0.01}, + "alpha": 0.001}, + init_param={"fit_intercept": True}, + cv_data=psi_0.outputs["output_data"], + cv_param={"n_splits": 3}) + + pipeline.add_task(psi_0) + pipeline.add_task(lr_0) + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/coordinated_lr/test_lr_multi_class.py b/examples/pipeline/coordinated_lr/test_lr_multi_class.py new file mode 100644 index 0000000000..c161feec50 --- /dev/null +++ b/examples/pipeline/coordinated_lr/test_lr_multi_class.py @@ -0,0 +1,95 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import CoordinatedLR, PSI +from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + arbiter = parties.arbiter[0] + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="vehicle_scale_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="vehicle_scale_hetero_host", + namespace=f"experiment{namespace}")) + lr_0 = CoordinatedLR("lr_0", + epochs=10, + batch_size=None, + optimizer={"method": "SGD", "optimizer_params": {"lr": 0.21}, "penalty": "L1", + "alpha": 0.001}, + init_param={"fit_intercept": True, "method": "random_uniform"}, + train_data=psi_0.outputs["output_data"], + learning_rate_scheduler={"method": "linear", "scheduler_params": {"start_factor": 0.7, + "total_iters": 100}}) + + evaluation_0 = Evaluation("evaluation_0", + runtime_roles=["guest"], + default_eval_setting="multi", + predict_column_name='predict_result', + input_data=lr_0.outputs["train_output_data"]) + + pipeline.add_task(psi_0) + pipeline.add_task(lr_0) + pipeline.add_task(evaluation_0) + + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + pipeline.deploy([psi_0, lr_0]) + + predict_pipeline = FateFlowPipeline() + + deployed_pipeline = pipeline.get_deployed_pipeline() + deployed_pipeline.psi_0.guest.component_setting( + input_data=DataWarehouseChannel(name="vehicle_scale_hetero_guest", + namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.hosts[0].component_setting( + input_data=DataWarehouseChannel(name="vehicle_scale_hetero_host", + namespace=f"experiment{namespace}")) + + predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.compile() + # print("\n\n\n") + # print(predict_pipeline.compile().get_dag()) + predict_pipeline.predict() + # print(f"predict lr_0 data: {pipeline.get_task_info('lr_0').get_output_data()}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/coordinated_lr/test_lr_multi_host.py b/examples/pipeline/coordinated_lr/test_lr_multi_host.py new file mode 100644 index 0000000000..4eabff90d9 --- /dev/null +++ b/examples/pipeline/coordinated_lr/test_lr_multi_host.py @@ -0,0 +1,93 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import CoordinatedLR, PSI +from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host + arbiter = parties.arbiter[0] + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"{namespace}experiment")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"{namespace}experiment")) + psi_0.hosts[1].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"{namespace}experiment")) + lr_0 = CoordinatedLR("lr_0", + epochs=5, + batch_size=None, + early_stop="weight_diff", + optimizer={"method": "SGD", "optimizer_params": {"lr": 0.1}, + "alpha": 0.001}, + init_param={"fit_intercept": True, "method": "random_uniform"}, + train_data=psi_0.outputs["output_data"], + learning_rate_scheduler={"method": "constant", "scheduler_params": {"factor": 1.0, + "total_iters": 100}}) + + evaluation_0 = Evaluation("evaluation_0", + runtime_roles=["guest"], + default_eval_setting="binary", + input_data=lr_0.outputs["train_output_data"]) + + pipeline.add_task(psi_0) + pipeline.add_task(lr_0) + pipeline.add_task(evaluation_0) + + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + pipeline.deploy([psi_0, lr_0]) + + predict_pipeline = FateFlowPipeline() + + deployed_pipeline = pipeline.get_deployed_pipeline() + deployed_pipeline.psi_0.guest.component_setting( + input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"{namespace}experiment")) + deployed_pipeline.psi_0.hosts[[0, 1]].component_setting( + input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"{namespace}experiment")) + + predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.compile() + # print("\n\n\n") + # print(predict_pipeline.compile().get_dag()) + predict_pipeline.predict() + # print(f"predict lr_0 data: {pipeline.get_task_info('lr_0').get_output_data()}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/coordinated_lr/test_lr_validate.py b/examples/pipeline/coordinated_lr/test_lr_validate.py new file mode 100644 index 0000000000..75a5bc7d4c --- /dev/null +++ b/examples/pipeline/coordinated_lr/test_lr_validate.py @@ -0,0 +1,81 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import CoordinatedLR, PSI, DataSplit +from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + arbiter = parties.arbiter[0] + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + data_split_0 = DataSplit("data_split_0", + train_size=0.8, + validate_size=0.2, + input_data=psi_0.outputs["output_data"]) + lr_0 = CoordinatedLR("lr_0", + epochs=10, + batch_size=300, + optimizer={"method": "SGD", "optimizer_params": {"lr": 0.21}, + "alpha": 0.001}, + init_param={"fit_intercept": True, "method": "random_uniform"}, + train_data=data_split_0.outputs["train_output_data"], + validate_data=data_split_0.outputs["validate_output_data"], + learning_rate_scheduler={"method": "linear", "scheduler_params": {"start_factor": 0.7, + "total_iters": 100}}) + + evaluation_0 = Evaluation("evaluation_0", + runtime_roles=["guest"], + default_eval_setting="binary", + input_data=lr_0.outputs["train_output_data"]) + + pipeline.add_task(psi_0) + pipeline.add_task(data_split_0) + pipeline.add_task(lr_0) + pipeline.add_task(evaluation_0) + + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/coordinated_lr/test_lr_warm_start.py b/examples/pipeline/coordinated_lr/test_lr_warm_start.py new file mode 100644 index 0000000000..674df045db --- /dev/null +++ b/examples/pipeline/coordinated_lr/test_lr_warm_start.py @@ -0,0 +1,95 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import CoordinatedLR, PSI +from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + arbiter = parties.arbiter[0] + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + lr_0 = CoordinatedLR("lr_0", + epochs=4, + batch_size=None, + optimizer={"method": "SGD", "optimizer_params": {"lr": 0.01}, + "alpha": 0.001}, + init_param={"fit_intercept": True, "method": "zeros"}, + train_data=psi_0.outputs["output_data"], + learning_rate_scheduler={"method": "constant", "scheduler_params": {"factor": 1.0, + "total_iters": 100}}) + lr_1 = CoordinatedLR("lr_1", train_data=psi_0.outputs["output_data"], + warm_start_model=lr_0.outputs["output_model"], + epochs=2, + batch_size=None, + optimizer={"method": "SGD", "optimizer_params": {"lr": 0.01}}, + ) + + lr_2 = CoordinatedLR("lr_2", epochs=6, + batch_size=None, + optimizer={"method": "SGD", "optimizer_params": {"lr": 0.01}}, + init_param={"fit_intercept": True, "method": "zeros"}, + train_data=psi_0.outputs["output_data"], + learning_rate_scheduler={"method": "constant", "scheduler_params": {"factor": 1.0, + "total_iters": 100}}) + + evaluation_0 = Evaluation("evaluation_0", + runtime_roles=["guest"], + default_eval_setting="binary", + input_data=[lr_1.outputs["train_output_data"], lr_2.outputs["train_output_data"]]) + + pipeline.add_task(psi_0) + pipeline.add_task(lr_0) + pipeline.add_task(lr_1) + pipeline.add_task(lr_2) + pipeline.add_task(evaluation_0) + + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + # print(f"lr_1 model: {pipeline.get_task_info('lr_1').get_output_model()}") + # print(f"train lr_1 data: {pipeline.get_task_info('lr_1').get_output_data()}") + + # print(f"lr_2 model: {pipeline.get_task_info('lr_2').get_output_model()}") + # print(f"train lr_2 data: {pipeline.get_task_info('lr_2').get_output_data()}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/data_split/data_split_lr_testsuite.yaml b/examples/pipeline/data_split/data_split_lr_testsuite.yaml new file mode 100644 index 0000000000..b88247f5db --- /dev/null +++ b/examples/pipeline/data_split/data_split_lr_testsuite.yaml @@ -0,0 +1,58 @@ +data: + - file: examples/data/breast_hetero_guest.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + label_type: int64 + label_name: y + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_guest + namespace: experiment + role: guest_0 + - file: examples/data/breast_hetero_host.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_host + namespace: experiment + role: host_0 + - file: examples/data/breast_hetero_host.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_host + namespace: experiment + role: host_1 +tasks: + data-split: + script: test_data_split.py + data-split-stratified: + script: test_data_split_stratified.py + data-split-multi-host: + script: test_data_split_multi_host.py diff --git a/examples/pipeline/data_split/test_data_split.py b/examples/pipeline/data_split/test_data_split.py new file mode 100644 index 0000000000..c2913953a5 --- /dev/null +++ b/examples/pipeline/data_split/test_data_split.py @@ -0,0 +1,91 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import DataSplit, PSI +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + psi_1 = PSI("psi_1") + psi_1.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_1.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + data_split_0 = DataSplit("data_split_0", + train_size=0.6, + validate_size=0.1, + test_size=None, + stratified=True, + input_data=psi_0.outputs["output_data"]) + + data_split_1 = DataSplit("data_split_1", + train_size=200, + test_size=50, + input_data=psi_0.outputs["output_data"] + ) + + pipeline.add_task(psi_0) + pipeline.add_task(psi_1) + pipeline.add_task(data_split_0) + pipeline.add_task(data_split_1) + + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + # print(pipeline.get_task_info("data_split_0").get_output_data()) + """output_data = pipeline.get_task_info("data_split_0").get_output_data() + import pandas as pd + + print(f"data split 0 train size: {pd.DataFrame(output_data['train_output_data']).shape};" + f"validate size: {pd.DataFrame(output_data['validate_output_data']).shape}" + f"test size: {pd.DataFrame(output_data['test_output_data']).shape}") + output_data = pipeline.get_task_info("data_split_1").get_output_data() + print(f"data split 1train size: {pd.DataFrame(output_data['train_output_data']).shape};" + f"validate size: {pd.DataFrame(output_data['validate_output_data']).shape}" + f"test size: {pd.DataFrame(output_data['test_output_data']).shape}")""" + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/data_split/test_data_split_multi_host.py b/examples/pipeline/data_split/test_data_split_multi_host.py new file mode 100644 index 0000000000..3a565cbfc6 --- /dev/null +++ b/examples/pipeline/data_split/test_data_split_multi_host.py @@ -0,0 +1,81 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import DataSplit, PSI +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + psi_0.hosts[1].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + psi_1 = PSI("psi_1") + psi_1.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_1.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + psi_1.hosts[1].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + data_split_0 = DataSplit("data_split_0", + train_size=0.6, + validate_size=0.1, + test_size=None, + hetero_sync=False, + input_data=psi_0.outputs["output_data"]) + + data_split_1 = DataSplit("data_split_1", + train_size=200, + test_size=50, + input_data=psi_0.outputs["output_data"] + ) + + pipeline.add_task(psi_0) + pipeline.add_task(psi_1) + pipeline.add_task(data_split_0) + pipeline.add_task(data_split_1) + + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/data_split/test_data_split_stratified.py b/examples/pipeline/data_split/test_data_split_stratified.py new file mode 100644 index 0000000000..cb0075e4ab --- /dev/null +++ b/examples/pipeline/data_split/test_data_split_stratified.py @@ -0,0 +1,81 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import DataSplit, PSI +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + psi_1 = PSI("psi_1") + psi_1.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_1.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + data_split_0 = DataSplit("data_split_0", + train_size=0.6, + validate_size=0.0, + test_size=0.4, + stratified=True, + input_data=psi_0.outputs["output_data"]) + + data_split_1 = DataSplit("data_split_1", + train_size=200, + test_size=50, + stratified=True, + hetero_sync=True, + input_data=psi_1.outputs["output_data"] + ) + + pipeline.add_task(psi_0) + pipeline.add_task(psi_1) + pipeline.add_task(data_split_0) + pipeline.add_task(data_split_1) + + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/feature_scale/scale_testsuite.yaml b/examples/pipeline/feature_scale/scale_testsuite.yaml new file mode 100644 index 0000000000..dfb9771821 --- /dev/null +++ b/examples/pipeline/feature_scale/scale_testsuite.yaml @@ -0,0 +1,42 @@ +data: + - file: examples/data/breast_hetero_guest.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + label_type: int64 + label_name: y + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_guest + namespace: experiment + role: guest_0 + - file: examples/data/breast_hetero_host.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_host + namespace: experiment + role: host_0 +tasks: + scale-min-max: + script: test_scale_min_max.py + scale-standard: + script: test_scale_standard.py + scale-with-lr: + script: test_scale_w_lr.py \ No newline at end of file diff --git a/examples/pipeline/feature_scale/test_scale_min_max.py b/examples/pipeline/feature_scale/test_scale_min_max.py new file mode 100644 index 0000000000..71f12abab5 --- /dev/null +++ b/examples/pipeline/feature_scale/test_scale_min_max.py @@ -0,0 +1,99 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import PSI, FeatureScale, Statistics +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + psi_1 = PSI("psi_1") + psi_1.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_1.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + feature_scale_0 = FeatureScale("feature_scale_0", + method="min_max", + feature_range={"x0": [-1, 1]}, + scale_col=["x0", "x1", "x3"], + train_data=psi_0.outputs["output_data"]) + + feature_scale_1 = FeatureScale("feature_scale_1", + test_data=psi_1.outputs["output_data"], + input_model=feature_scale_0.outputs["output_model"]) + + statistics_0 = Statistics("statistics_0", + metrics=["max", "min", "mean", "std"], + input_data=feature_scale_1.outputs["test_output_data"]) + + pipeline.add_task(psi_0) + pipeline.add_task(psi_1) + pipeline.add_task(feature_scale_0) + pipeline.add_task(feature_scale_1) + pipeline.add_task(statistics_0) + + # pipeline.add_task(hetero_feature_binning_0) + pipeline.compile() + print(pipeline.get_dag()) + pipeline.fit() + + print(pipeline.get_task_info("statistics_0").get_output_model()) + + pipeline.deploy([psi_0, feature_scale_0]) + + predict_pipeline = FateFlowPipeline() + + deployed_pipeline = pipeline.get_deployed_pipeline() + deployed_pipeline.psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.compile() + # print("\n\n\n") + # print(predict_pipeline.compile().get_dag()) + predict_pipeline.predict() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/feature_scale/test_scale_standard.py b/examples/pipeline/feature_scale/test_scale_standard.py new file mode 100644 index 0000000000..008e7c2a75 --- /dev/null +++ b/examples/pipeline/feature_scale/test_scale_standard.py @@ -0,0 +1,94 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import PSI, FeatureScale +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + arbiter = parties.arbiter[0] + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + psi_1 = PSI("psi_1") + psi_1.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_1.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + feature_scale_0 = FeatureScale("feature_scale_0", + method="standard", + train_data=psi_0.outputs["output_data"]) + + feature_scale_1 = FeatureScale("feature_scale_1", + test_data=psi_1.outputs["output_data"], + input_model=feature_scale_0.outputs["output_model"]) + + pipeline.add_task(psi_0) + pipeline.add_task(psi_1) + pipeline.add_task(feature_scale_0) + pipeline.add_task(feature_scale_1) + + # pipeline.add_task(hetero_feature_binning_0) + pipeline.compile() + print(pipeline.get_dag()) + pipeline.fit() + + print(pipeline.get_task_info("feature_scale_0").get_output_model()) + # print(pipeline.get_task_info("feature_scale_1").get_output_model()) + + pipeline.deploy([psi_0, feature_scale_0]) + + predict_pipeline = FateFlowPipeline() + + deployed_pipeline = pipeline.get_deployed_pipeline() + deployed_pipeline.psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.compile() + # print("\n\n\n") + # print(predict_pipeline.compile().get_dag()) + predict_pipeline.predict() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/feature_scale/test_scale_w_lr.py b/examples/pipeline/feature_scale/test_scale_w_lr.py new file mode 100644 index 0000000000..fa12cb8935 --- /dev/null +++ b/examples/pipeline/feature_scale/test_scale_w_lr.py @@ -0,0 +1,102 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import CoordinatedLR, PSI, FeatureScale, Evaluation +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + arbiter = parties.arbiter[0] + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + psi_1 = PSI("psi_1") + psi_1.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_1.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + feature_scale_0 = FeatureScale("feature_scale_0", + method="standard", + train_data=psi_0.outputs["output_data"]) + + lr_0 = CoordinatedLR("lr_0", + epochs=10, + batch_size=None, + optimizer={"method": "SGD", "optimizer_params": {"lr": 0.21}}, + init_param={"fit_intercept": True, "method": "random_uniform"}, + train_data=feature_scale_0.outputs["train_output_data"], + learning_rate_scheduler={"method": "linear", "scheduler_params": {"start_factor": 0.7, + "total_iters": 100}}) + + evaluation_0 = Evaluation("evaluation_0", + runtime_roles=["guest"], + default_eval_setting="binary", + input_data=lr_0.outputs["train_output_data"]) + + pipeline.add_task(psi_0) + pipeline.add_task(psi_1) + pipeline.add_task(feature_scale_0) + pipeline.add_task(lr_0) + pipeline.add_task(evaluation_0) + + # pipeline.add_task(hetero_feature_binning_0) + pipeline.compile() + print(pipeline.get_dag()) + pipeline.fit() + + pipeline.deploy([psi_0, feature_scale_0, lr_0]) + + predict_pipeline = FateFlowPipeline() + + deployed_pipeline = pipeline.get_deployed_pipeline() + deployed_pipeline.psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.compile() + # print("\n\n\n") + # print(predict_pipeline.compile().get_dag()) + predict_pipeline.predict() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/hetero_feature_binning/binning_testsuite.yaml b/examples/pipeline/hetero_feature_binning/binning_testsuite.yaml new file mode 100644 index 0000000000..13d472dea6 --- /dev/null +++ b/examples/pipeline/hetero_feature_binning/binning_testsuite.yaml @@ -0,0 +1,60 @@ +data: + - file: examples/data/breast_hetero_guest.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + label_type: int64 + label_name: y + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_guest + namespace: experiment + role: guest_0 + - file: examples/data/breast_hetero_host.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_host + namespace: experiment + role: host_0 + - file: examples/data/breast_hetero_host.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_host + namespace: experiment + role: host_1 +tasks: + binning-bucket: + script: test_feature_binning_bucket.py + binning-quantile: + script: test_feature_binning_quantile.py + binning-asymmetric: + script: test_feature_binning_asymmetric.py + binning-multi-host: + script: test_feature_binning_multi_host.py \ No newline at end of file diff --git a/examples/pipeline/hetero_feature_binning/test_feature_binning_asymmetric.py b/examples/pipeline/hetero_feature_binning/test_feature_binning_asymmetric.py new file mode 100644 index 0000000000..43f70ae59d --- /dev/null +++ b/examples/pipeline/hetero_feature_binning/test_feature_binning_asymmetric.py @@ -0,0 +1,88 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import PSI, HeteroFeatureBinning +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + binning_0 = HeteroFeatureBinning("binning_0", + method="quantile", + n_bins=10, + train_data=psi_0.outputs["output_data"], + local_only=True + ) + binning_0.guest.component_setting(bin_col=["x0"], transform_method="bin_idx") + + binning_1 = HeteroFeatureBinning("binning_1", + transform_method="bin_idx", + method="quantile", + train_data=binning_0.outputs["train_output_data"]) + binning_1.guest.component_setting(category_col=["x0"], transform_method="woe") + + pipeline.add_task(psi_0) + pipeline.add_task(binning_0) + pipeline.add_task(binning_1) + + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + pipeline.deploy([psi_0, binning_0]) + + predict_pipeline = FateFlowPipeline() + + deployed_pipeline = pipeline.get_deployed_pipeline() + deployed_pipeline.psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.compile() + # print("\n\n\n") + # print(predict_pipeline.compile().get_dag()) + predict_pipeline.predict() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/hetero_feature_binning/test_feature_binning_bucket.py b/examples/pipeline/hetero_feature_binning/test_feature_binning_bucket.py new file mode 100644 index 0000000000..1ff22c34e9 --- /dev/null +++ b/examples/pipeline/hetero_feature_binning/test_feature_binning_bucket.py @@ -0,0 +1,97 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import PSI, HeteroFeatureBinning +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + psi_1 = PSI("psi_1") + psi_1.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_1.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + binning_0 = HeteroFeatureBinning("binning_0", + method="bucket", + n_bins=10, + transform_method="bin_idx", + skip_metrics=True, + train_data=psi_0.outputs["output_data"] + ) + binning_1 = HeteroFeatureBinning("binning_1", + transform_method="bin_idx", + input_model=binning_0.outputs["output_model"], + test_data=psi_1.outputs["output_data"]) + + pipeline.add_task(psi_0) + pipeline.add_task(psi_1) + pipeline.add_task(binning_0) + pipeline.add_task(binning_1) + + # pipeline.add_task(hetero_feature_binning_0) + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + # print(pipeline.get_task_info("binning_0").get_output_model()) + # print(pipeline.get_task_info("feature_scale_1").get_output_model()) + + pipeline.deploy([psi_0, binning_0]) + + predict_pipeline = FateFlowPipeline() + + deployed_pipeline = pipeline.get_deployed_pipeline() + deployed_pipeline.psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.compile() + # print("\n\n\n") + # print(predict_pipeline.compile().get_dag()) + predict_pipeline.predict() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/hetero_feature_binning/test_feature_binning_multi_host.py b/examples/pipeline/hetero_feature_binning/test_feature_binning_multi_host.py new file mode 100644 index 0000000000..197dc9dcae --- /dev/null +++ b/examples/pipeline/hetero_feature_binning/test_feature_binning_multi_host.py @@ -0,0 +1,100 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import PSI, HeteroFeatureBinning +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + psi_0.hosts[1].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + psi_1 = PSI("psi_1") + psi_1.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_1.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + psi_1.hosts[1].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + binning_0 = HeteroFeatureBinning("binning_0", + method="bucket", + n_bins=10, + skip_metrics=True, + transform_method="bin_idx", + train_data=psi_0.outputs["output_data"] + ) + binning_1 = HeteroFeatureBinning("binning_1", + transform_method="bin_idx", + input_model=binning_0.outputs["output_model"], + test_data=psi_1.outputs["output_data"]) + + pipeline.add_task(psi_0) + pipeline.add_task(psi_1) + pipeline.add_task(binning_0) + pipeline.add_task(binning_1) + + # pipeline.add_task(hetero_feature_binning_0) + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + # print(pipeline.get_task_info("binning_0").get_output_model()) + + pipeline.deploy([psi_0, binning_0]) + + predict_pipeline = FateFlowPipeline() + + deployed_pipeline = pipeline.get_deployed_pipeline() + deployed_pipeline.psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.hosts[1].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.compile() + # print("\n\n\n") + # print(predict_pipeline.compile().get_dag()) + predict_pipeline.predict() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/hetero_feature_binning/test_feature_binning_quantile.py b/examples/pipeline/hetero_feature_binning/test_feature_binning_quantile.py new file mode 100644 index 0000000000..b98031b4b2 --- /dev/null +++ b/examples/pipeline/hetero_feature_binning/test_feature_binning_quantile.py @@ -0,0 +1,89 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import PSI, HeteroFeatureBinning +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + binning_0 = HeteroFeatureBinning("binning_0", + method="quantile", + n_bins=10, + transform_method="bin_idx", + train_data=psi_0.outputs["output_data"] + ) + binning_0.hosts[0].component_setting(bin_idx=[1]) + binning_0.guest.component_setting(bin_col=["x0"]) + binning_1 = HeteroFeatureBinning("binning_1", + transform_method="bin_idx", + method="quantile", + train_data=binning_0.outputs["train_output_data"]) + binning_1.hosts[0].component_setting(category_idx=[1]) + binning_1.guest.component_setting(category_col=["x0"]) + + pipeline.add_task(psi_0) + pipeline.add_task(binning_0) + pipeline.add_task(binning_1) + + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + pipeline.deploy([psi_0, binning_0]) + + predict_pipeline = FateFlowPipeline() + + deployed_pipeline = pipeline.get_deployed_pipeline() + deployed_pipeline.psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.compile() + # print("\n\n\n") + # print(predict_pipeline.compile().get_dag()) + predict_pipeline.predict() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/hetero_feature_selection/selection_testsuite.yaml b/examples/pipeline/hetero_feature_selection/selection_testsuite.yaml new file mode 100644 index 0000000000..75a3031b66 --- /dev/null +++ b/examples/pipeline/hetero_feature_selection/selection_testsuite.yaml @@ -0,0 +1,62 @@ +data: + - file: examples/data/breast_hetero_guest.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + label_type: int64 + label_name: y + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_guest + namespace: experiment + role: guest_0 + - file: examples/data/breast_hetero_host.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_host + namespace: experiment + role: host_0 + - file: examples/data/breast_hetero_host.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_host + namespace: experiment + role: host_1 +tasks: + selection-binning: + script: test_feature_selection_binning.py + selection-manual: + script: test_feature_selection_manual.py + selection-statistics: + script: test_feature_selection_statistics.py + selection-multi-model: + script: test_feature_selection_multi_model.py + selection-multi-host: + script: test_feature_selection_multi_host.py \ No newline at end of file diff --git a/examples/pipeline/hetero_feature_selection/test_feature_selection_binning.py b/examples/pipeline/hetero_feature_selection/test_feature_selection_binning.py new file mode 100644 index 0000000000..d639fc63eb --- /dev/null +++ b/examples/pipeline/hetero_feature_selection/test_feature_selection_binning.py @@ -0,0 +1,87 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import PSI, HeteroFeatureSelection, HeteroFeatureBinning +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config=".../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + binning_0 = HeteroFeatureBinning("binning_0", + method="quantile", + n_bins=10, + transform_method="bin_idx", + train_data=psi_0.outputs["output_data"] + ) + selection_0 = HeteroFeatureSelection("selection_0", + method=["iv"], + train_data=psi_0.outputs["output_data"], + input_models=[binning_0.outputs["output_model"]], + iv_param={"metrics": "iv", "filter_type": "threshold", "threshold": 0.1}) + + pipeline.add_task(psi_0) + pipeline.add_task(binning_0) + pipeline.add_task(selection_0) + + # pipeline.add_task(hetero_feature_binning_0) + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + # print(pipeline.get_task_info("feature_scale_1").get_output_model()) + + pipeline.deploy([psi_0, selection_0]) + + predict_pipeline = FateFlowPipeline() + + deployed_pipeline = pipeline.get_deployed_pipeline() + deployed_pipeline.psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.compile() + predict_pipeline.predict() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/hetero_feature_selection/test_feature_selection_manual.py b/examples/pipeline/hetero_feature_selection/test_feature_selection_manual.py new file mode 100644 index 0000000000..a278387dca --- /dev/null +++ b/examples/pipeline/hetero_feature_selection/test_feature_selection_manual.py @@ -0,0 +1,80 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import PSI, HeteroFeatureSelection +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config=".../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + selection_0 = HeteroFeatureSelection("selection_0", + method=["manual"], + train_data=psi_0.outputs["output_data"]) + selection_0.guest.component_setting(manual_param={"keep_col": ["x0", "x1"]}) + selection_0.hosts[0].component_setting(manual_param={"filter_out_col": ["x0", "x1"]}) + + pipeline.add_task(psi_0) + pipeline.add_task(selection_0) + + # pipeline.add_task(hetero_feature_binning_0) + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + # print(pipeline.get_task_info("feature_scale_1").get_output_model()) + + pipeline.deploy([psi_0, selection_0]) + + predict_pipeline = FateFlowPipeline() + + deployed_pipeline = pipeline.get_deployed_pipeline() + deployed_pipeline.psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.compile() + predict_pipeline.predict() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/hetero_feature_selection/test_feature_selection_multi_host.py b/examples/pipeline/hetero_feature_selection/test_feature_selection_multi_host.py new file mode 100644 index 0000000000..eea2c62578 --- /dev/null +++ b/examples/pipeline/hetero_feature_selection/test_feature_selection_multi_host.py @@ -0,0 +1,95 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import PSI, HeteroFeatureSelection, HeteroFeatureBinning, Statistics +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config=".../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + psi_0.hosts[1].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + binning_0 = HeteroFeatureBinning("binning_0", + method="quantile", + n_bins=10, + transform_method="bin_idx", + train_data=psi_0.outputs["output_data"] + ) + statistics_0 = Statistics("statistics_0", input_data=psi_0.outputs["output_data"]) + selection_0 = HeteroFeatureSelection("selection_0", + method=["iv", "statistics", "manual"], + train_data=psi_0.outputs["output_data"], + input_models=[binning_0.outputs["output_model"], + statistics_0.outputs["output_model"]], + iv_param={"metrics": "iv", "filter_type": "top_percentile", "threshold": 0.8}, + statistic_param={"metrics": ["max", "mean"], + "filter_type": "top_k", "threshold": 5}, + manual_param={"keep_col": ["x0", "x1"]} + ) + + pipeline.add_task(psi_0) + pipeline.add_task(binning_0) + pipeline.add_task(statistics_0) + pipeline.add_task(selection_0) + + pipeline.compile() + pipeline.fit() + + # print(pipeline.get_task_info("selection_0").get_output_model()) + + pipeline.deploy([psi_0, selection_0]) + + predict_pipeline = FateFlowPipeline() + + deployed_pipeline = pipeline.get_deployed_pipeline() + deployed_pipeline.psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.hosts[1].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.compile() + predict_pipeline.predict() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/hetero_feature_selection/test_feature_selection_multi_model.py b/examples/pipeline/hetero_feature_selection/test_feature_selection_multi_model.py new file mode 100644 index 0000000000..4a3dcb4050 --- /dev/null +++ b/examples/pipeline/hetero_feature_selection/test_feature_selection_multi_model.py @@ -0,0 +1,94 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import PSI, HeteroFeatureSelection, HeteroFeatureBinning, Statistics +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config=".../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + binning_0 = HeteroFeatureBinning("binning_0", + method="quantile", + n_bins=10, + transform_method="bin_idx", + train_data=psi_0.outputs["output_data"] + ) + statistics_0 = Statistics("statistics_0", input_data=psi_0.outputs["output_data"]) + selection_0 = HeteroFeatureSelection("selection_0", + method=["iv", "statistics", "manual"], + train_data=psi_0.outputs["output_data"], + input_models=[binning_0.outputs["output_model"], + statistics_0.outputs["output_model"]], + iv_param={"metrics": "iv", "filter_type": "top_k", "threshold": 6}, + statistic_param={"metrics": ["max", "mean"], + "filter_type": "top_k", "threshold": 5, "take_high": False}, + manual_param={"keep_col": ["x0", "x1"]} + ) + + pipeline.add_task(psi_0) + pipeline.add_task(binning_0) + pipeline.add_task(statistics_0) + pipeline.add_task(selection_0) + + # pipeline.add_task(hetero_feature_binning_0) + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + # print(pipeline.get_task_info("selection_0").get_output_model()) + + pipeline.deploy([psi_0, selection_0]) + + predict_pipeline = FateFlowPipeline() + + deployed_pipeline = pipeline.get_deployed_pipeline() + deployed_pipeline.psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.compile() + predict_pipeline.predict() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/hetero_feature_selection/test_feature_selection_statistics.py b/examples/pipeline/hetero_feature_selection/test_feature_selection_statistics.py new file mode 100644 index 0000000000..a21d4470cc --- /dev/null +++ b/examples/pipeline/hetero_feature_selection/test_feature_selection_statistics.py @@ -0,0 +1,84 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import PSI, HeteroFeatureSelection, Statistics +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config=".../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + statistics_0 = Statistics("statistics_0", input_data=psi_0.outputs["output_data"], + metrics=["min", "max", "25%", "mean", "median"]) + selection_0 = HeteroFeatureSelection("selection_0", + method=["statistics"], + train_data=psi_0.outputs["output_data"], + input_models=[statistics_0.outputs["output_model"]], + statistic_param={"metrics": ["max", "mean", "25%"], + "filter_type": "top_k", "threshold": 5}) + + pipeline.add_task(psi_0) + pipeline.add_task(statistics_0) + pipeline.add_task(selection_0) + + # pipeline.add_task(hetero_feature_binning_0) + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + # print(pipeline.get_task_info("feature_scale_1").get_output_model()) + + pipeline.deploy([psi_0, selection_0]) + + predict_pipeline = FateFlowPipeline() + + deployed_pipeline = pipeline.get_deployed_pipeline() + deployed_pipeline.psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.compile() + predict_pipeline.predict() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/hetero_secureboost/hetero_secureboost_testsuite.yaml b/examples/pipeline/hetero_secureboost/hetero_secureboost_testsuite.yaml new file mode 100644 index 0000000000..6cb3ad88b6 --- /dev/null +++ b/examples/pipeline/hetero_secureboost/hetero_secureboost_testsuite.yaml @@ -0,0 +1,74 @@ +data: + - file: examples/data/breast_hetero_guest.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + label_type: int64 + label_name: y + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_guest + namespace: experiment + role: guest_0 + - file: examples/data/breast_hetero_host.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_host + namespace: experiment + role: host_0 + - file: "examples/data/student_hetero_guest.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: "id" + match_id_range: 0 + label_type: float64 + label_name: y + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + head: true + partitions: 4 + extend_sid: true + table_name: student_hetero_guest + namespace: experiment + role: guest_0 + - file: "examples/data/student_hetero_host.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: "id" + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + head: true + partition: 4 + extend_sid: true + table_name: student_hetero_host + namespace: experiment + role: host_0 +tasks: + hetero-sbt-binary: + script: test_hetero_sbt_binary.py + hetero-sbt-regression: + script: test_hetero_sbt_regression.py diff --git a/examples/pipeline/hetero_secureboost/test_hetero_sbt_binary.py b/examples/pipeline/hetero_secureboost/test_hetero_sbt_binary.py new file mode 100644 index 0000000000..193e5ce3fc --- /dev/null +++ b/examples/pipeline/hetero_secureboost/test_hetero_sbt_binary.py @@ -0,0 +1,47 @@ +import argparse +from fate_client.pipeline.components.fate import HeteroSecureBoost, PSI, Evaluation +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + arbiter = parties.arbiter[0] + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace="experiment")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace="experiment")) + + hetero_sbt_0 = HeteroSecureBoost('sbt_0', num_trees=3, max_bin=32, max_depth=3, + he_param={'kind': 'paillier', 'key_length': 1024}, train_data=psi_0.outputs['output_data'],) + evaluation_0 = Evaluation( + 'eval_0', + runtime_roles=['guest'], + metrics=['auc'], + input_data=[hetero_sbt_0.outputs['train_data_output']] + ) + + pipeline.add_task(psi_0) + pipeline.add_task(hetero_sbt_0) + pipeline.add_task(evaluation_0) + pipeline.compile() + pipeline.fit() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) \ No newline at end of file diff --git a/examples/pipeline/hetero_secureboost/test_hetero_sbt_regression.py b/examples/pipeline/hetero_secureboost/test_hetero_sbt_regression.py new file mode 100644 index 0000000000..5dca31377a --- /dev/null +++ b/examples/pipeline/hetero_secureboost/test_hetero_sbt_regression.py @@ -0,0 +1,47 @@ +import argparse +from fate_client.pipeline.components.fate import HeteroSecureBoost, PSI, Evaluation +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + arbiter = parties.arbiter[0] + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="student_hetero_guest", + namespace="experiment")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="student_hetero_host", + namespace="experiment")) + + hetero_sbt_0 = HeteroSecureBoost('sbt_0', num_trees=3, max_bin=32, max_depth=3, objective='regression:l2', + he_param={'kind': 'paillier', 'key_length': 1024}, train_data=psi_0.outputs['output_data'],) + evaluation_0 = Evaluation( + 'eval_0', + runtime_roles=['guest'], + metrics=['rmse'], + input_data=[hetero_sbt_0.outputs['train_data_output']] + ) + + pipeline.add_task(psi_0) + pipeline.add_task(hetero_sbt_0) + pipeline.add_task(evaluation_0) + pipeline.compile() + pipeline.fit() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) \ No newline at end of file diff --git a/examples/pipeline/homo_lr/homo_lr_testsuite.yaml b/examples/pipeline/homo_lr/homo_lr_testsuite.yaml new file mode 100644 index 0000000000..38d66baaeb --- /dev/null +++ b/examples/pipeline/homo_lr/homo_lr_testsuite.yaml @@ -0,0 +1,40 @@ +data: + - file: "examples/data/breast_homo_guest.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + label_type: int64 + label_name: y + match_id_name: "id" + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_homo_guest + namespace: experiment + role: guest_0 + - file: "examples/data/breast_homo_host.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + label_type: int64 + label_name: y + match_id_name: "id" + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_homo_host + namespace: experiment + role: host_0 +tasks: + homo-lr-binary: + script: test_homo_lr_binary.py \ No newline at end of file diff --git a/examples/pipeline/homo_lr/test_homo_lr_binary.py b/examples/pipeline/homo_lr/test_homo_lr_binary.py new file mode 100644 index 0000000000..6fa768335d --- /dev/null +++ b/examples/pipeline/homo_lr/test_homo_lr_binary.py @@ -0,0 +1,46 @@ +import argparse +from fate_client.pipeline.components.fate import HomoLR, Evaluation +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + arbiter = parties.arbiter[0] + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + + homo_lr_0 = HomoLR( + "homo_lr_0", + epochs=10, + batch_size=16 + ) + + homo_lr_0.guest.component_setting(train_data=DataWarehouseChannel(name="breast_homo_guest", namespace="experiment")) + homo_lr_0.hosts[0].component_setting(train_data=DataWarehouseChannel(name="breast_homo_host", namespace="experiment")) + evaluation_0 = Evaluation( + 'eval_0', + metrics=['auc'], + input_data=[homo_lr_0.outputs['train_output_data']] + ) + + + pipeline.add_task(homo_lr_0) + pipeline.add_task(evaluation_0) + pipeline.compile() + pipeline.fit() + print (pipeline.get_task_info("homo_lr_0").get_output_data()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) \ No newline at end of file diff --git a/examples/pipeline/homo_nn/homo_nn_testsuite.yaml b/examples/pipeline/homo_nn/homo_nn_testsuite.yaml new file mode 100644 index 0000000000..2f61429d96 --- /dev/null +++ b/examples/pipeline/homo_nn/homo_nn_testsuite.yaml @@ -0,0 +1,116 @@ +data: + - file: "examples/data/breast_homo_guest.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + label_type: int64 + label_name: y + match_id_name: "id" + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_homo_guest + namespace: experiment + role: guest_0 + - file: "examples/data/breast_homo_host.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + label_type: int64 + label_name: y + match_id_name: "id" + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_homo_host + namespace: experiment + role: host_0 + - file: "examples/data/vehicle_scale_homo_guest.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: "id" + match_id_range: 0 + label_type: int64 + label_name: y + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + head: true + partitions: 4 + extend_sid: true + table_name: vehicle_scale_homo_guest + namespace: experiment + role: guest_0 + - file: "examples/data/vehicle_scale_homo_host.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + label_type: int64 + label_name: y + match_id_name: "id" + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + head: true + partition: 4 + extend_sid: true + table_name: vehicle_scale_homo_host + namespace: experiment + role: host_0 + - file: "examples/data/student_homo_guest.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: "id" + match_id_range: 0 + label_type: float64 + label_name: y + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + head: true + partitions: 4 + extend_sid: true + table_name: student_homo_guest + namespace: experiment + role: guest_0 + - file: "examples/data/student_homo_host.csv" + meta: + delimiter: "," + dtype: float64 + input_format: dense + label_type: float64 + label_name: y + match_id_name: "id" + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + head: true + partition: 4 + extend_sid: true + table_name: student_homo_host + namespace: experiment + role: host_0 +tasks: + homo-nn-binary: + script: test_nn_binary.py + homo-nn-regression: + script: test_nn_regression.py + homo-nn-multi: + script: test_nn_multi.py \ No newline at end of file diff --git a/examples/pipeline/homo_nn/test_nn_binary.py b/examples/pipeline/homo_nn/test_nn_binary.py new file mode 100644 index 0000000000..8fc2b8b4f9 --- /dev/null +++ b/examples/pipeline/homo_nn/test_nn_binary.py @@ -0,0 +1,98 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +from fate_test.utils import parse_summary_result +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils +from fate_client.pipeline.components.fate.evaluation import Evaluation +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate.nn.torch import nn, optim +from fate_client.pipeline.components.fate.nn.torch.base import Sequential +from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_default_runner +from fate_client.pipeline.components.fate.nn.algo_params import TrainingArguments, FedAVGArguments + + +def main(config="../../config.yaml", namespace=""): + # obtain config + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + arbiter = parties.arbiter[0] + + epochs = 10 + batch_size = 64 + in_feat = 30 + out_feat = 16 + lr = 0.01 + + guest_train_data = {"name": "breast_homo_guest", "namespace": f"experiment{namespace}"} + host_train_data = {"name": "breast_homo_host", "namespace": f"experiment{namespace}"} + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + + conf = get_config_of_default_runner( + algo='fedavg', + model=Sequential( + nn.Linear(in_feat, out_feat), + nn.ReLU(), + nn.Linear(out_feat ,1), + nn.Sigmoid() + ), + loss=nn.BCELoss(), + optimizer=optim.Adam(lr=lr), + training_args=TrainingArguments(num_train_epochs=epochs, per_device_train_batch_size=batch_size, seed=114514), + fed_args=FedAVGArguments(), + task_type='binary' + ) + + + homo_nn_0 = HomoNN( + 'nn_0', + runner_conf=conf + ) + + homo_nn_0.guest.component_setting(train_data=DataWarehouseChannel(name=guest_train_data["name"], namespace=guest_train_data["namespace"])) + homo_nn_0.hosts[0].component_setting(train_data=DataWarehouseChannel(name=host_train_data["name"], namespace=host_train_data["namespace"])) + + evaluation_0 = Evaluation( + 'eval_0', + runtime_roles=['guest'], + metrics=['auc'], + input_data=[homo_nn_0.outputs['train_data_output']] + ) + + + pipeline.add_task(homo_nn_0) + pipeline.add_task(evaluation_0) + pipeline.compile() + pipeline.fit() + + result_summary = parse_summary_result(pipeline.get_task_info("eval_0").get_output_metric()[0]["data"]) + print(f"result_summary: {result_summary}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) \ No newline at end of file diff --git a/examples/pipeline/homo_nn/test_nn_multi.py b/examples/pipeline/homo_nn/test_nn_multi.py new file mode 100644 index 0000000000..edd0d3a328 --- /dev/null +++ b/examples/pipeline/homo_nn/test_nn_multi.py @@ -0,0 +1,94 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +from fate_test.utils import parse_summary_result +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils +from fate_client.pipeline.components.fate.evaluation import Evaluation +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate.nn.torch import nn, optim +from fate_client.pipeline.components.fate.nn.torch.base import Sequential +from fate_client.pipeline.components.fate.nn.loader import ModelLoader +from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_default_runner +from fate_client.pipeline.components.fate.nn.algo_params import TrainingArguments, FedAVGArguments + + +def main(config="../../config.yaml", namespace=""): + # obtain config + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + arbiter = parties.arbiter[0] + + epochs = 10 + batch_size = 64 + in_feat = 18 + lr = 0.01 + class_num=4 + + guest_train_data = {"name": "vehicle_scale_homo_guest", "namespace": f"experiment{namespace}"} + host_train_data = {"name": "vehicle_scale_homo_host", "namespace": f"experiment{namespace}"} + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + + conf = get_config_of_default_runner( + algo='fedavg', + model=ModelLoader('multi_model', 'Multi', feat=in_feat, class_num=class_num), + loss=nn.CrossEntropyLoss(), + optimizer=optim.Adam(lr=lr), + training_args=TrainingArguments(num_train_epochs=epochs, per_device_train_batch_size=batch_size, seed=114514), + fed_args=FedAVGArguments(), + task_type='multi' + ) + + + homo_nn_0 = HomoNN( + 'nn_0', + runner_conf=conf + ) + + homo_nn_0.guest.component_setting(train_data=DataWarehouseChannel(name=guest_train_data["name"], namespace=guest_train_data["namespace"])) + homo_nn_0.hosts[0].component_setting(train_data=DataWarehouseChannel(name=host_train_data["name"], namespace=host_train_data["namespace"])) + + evaluation_0 = Evaluation( + 'eval_0', + runtime_roles=['guest'], + default_eval_setting='multi', + input_data=[homo_nn_0.outputs['train_data_output']] + ) + + + pipeline.add_task(homo_nn_0) + pipeline.add_task(evaluation_0) + pipeline.compile() + pipeline.fit() + + result_summary = parse_summary_result(pipeline.get_task_info("eval_0").get_output_metric()[0]["data"]) + print(f"result_summary: {result_summary}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) \ No newline at end of file diff --git a/examples/pipeline/homo_nn/test_nn_regression.py b/examples/pipeline/homo_nn/test_nn_regression.py new file mode 100644 index 0000000000..9f214c9ca2 --- /dev/null +++ b/examples/pipeline/homo_nn/test_nn_regression.py @@ -0,0 +1,97 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +from fate_test.utils import parse_summary_result +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils +from fate_client.pipeline.components.fate.evaluation import Evaluation +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate.nn.torch import nn, optim +from fate_client.pipeline.components.fate.nn.torch.base import Sequential +from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_default_runner +from fate_client.pipeline.components.fate.nn.algo_params import TrainingArguments, FedAVGArguments + + +def main(config="../../config.yaml", namespace=""): + # obtain config + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + arbiter = parties.arbiter[0] + + epochs = 10 + batch_size = 64 + in_feat = 13 + out_feat = 10 + lr = 0.01 + + guest_train_data = {"name": "student_homo_guest", "namespace": f"experiment{namespace}"} + host_train_data = {"name": "student_homo_host", "namespace": f"experiment{namespace}"} + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + + conf = get_config_of_default_runner( + algo='fedavg', + model=Sequential( + nn.Linear(in_feat, out_feat), + nn.ReLU(), + nn.Linear(out_feat ,1) + ), + loss=nn.MSELoss(), + optimizer=optim.Adam(lr=lr), + training_args=TrainingArguments(num_train_epochs=epochs, per_device_train_batch_size=batch_size, seed=114514), + fed_args=FedAVGArguments(), + task_type='regression' + ) + + + homo_nn_0 = HomoNN( + 'nn_0', + runner_conf=conf + ) + + homo_nn_0.guest.component_setting(train_data=DataWarehouseChannel(name=guest_train_data["name"], namespace=guest_train_data["namespace"])) + homo_nn_0.hosts[0].component_setting(train_data=DataWarehouseChannel(name=host_train_data["name"], namespace=host_train_data["namespace"])) + + evaluation_0 = Evaluation( + 'eval_0', + runtime_roles=['guest'], + metrics=['rmse'], + input_data=[homo_nn_0.outputs['train_data_output']] + ) + + + pipeline.add_task(homo_nn_0) + pipeline.add_task(evaluation_0) + pipeline.compile() + pipeline.fit() + + result_summary = parse_summary_result(pipeline.get_task_info("eval_0").get_output_metric()[0]["data"]) + print(f"result_summary: {result_summary}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) \ No newline at end of file diff --git a/examples/pipeline/multi_model/test_multi.py b/examples/pipeline/multi_model/test_multi.py new file mode 100644 index 0000000000..ef7ca37bd4 --- /dev/null +++ b/examples/pipeline/multi_model/test_multi.py @@ -0,0 +1,130 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import PSI, HeteroFeatureSelection, HeteroFeatureBinning, \ + FeatureScale, Union, DataSplit, CoordinatedLR, CoordinatedLinR, Statistics, Sample, Evaluation +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + arbiter = parties.arbiter[0] + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + data_split_0 = DataSplit("data_split_0", input_data=psi_0.outputs["output_data"], + train_size=0.8, test_size=0.2, random_state=42) + union_0 = Union("union_0", input_data_list=[data_split_0.outputs["train_output_data"], + data_split_0.outputs["test_output_data"]]) + sample_0 = Sample("sample_0", input_data=data_split_0.outputs["train_output_data"], + n=800, replace=True, hetero_sync=True) + + binning_0 = HeteroFeatureBinning("binning_0", + method="quantile", + n_bins=10, + train_data=sample_0.outputs["output_data"] + ) + statistics_0 = Statistics("statistics_0", + input_data=psi_0.outputs["output_data"]) + selection_0 = HeteroFeatureSelection("selection_0", + method=["iv", "statistics"], + train_data=sample_0.outputs["output_data"], + input_models=[binning_0.outputs["output_model"], + statistics_0.outputs["output_model"]], + iv_param={"metrics": "iv", "filter_type": "threshold", "value": 0.1}, + statistic_param={"metrics": ["max", "min"], "filter_type": "top_k", + "threshold": 5}) + + selection_1 = HeteroFeatureSelection("selection_1", + input_model=selection_0.outputs["train_output_model"], + test_data=data_split_0.outputs["test_output_data"]) + + scale_0 = FeatureScale("scale_0", method="min_max", + train_data=selection_0.outputs["train_output_data"], ) + + lr_0 = CoordinatedLR("lr_0", train_data=selection_0.outputs["train_output_data"], + validate_data=selection_1.outputs["test_output_data"], epochs=3) + linr_0 = CoordinatedLinR("linr_0", train_data=selection_0.outputs["train_output_data"], + validate_data=selection_1.outputs["test_output_data"], epochs=3) + + evaluation_0 = Evaluation("evaluation_0", input_data=lr_0.outputs["train_output_data"], + default_eval_setting="binary", + runtime_roles=["guest"]) + evaluation_1 = Evaluation("evaluation_1", input_data=linr_0.outputs["train_output_data"], + default_eval_setting="regression", + runtime_roles=["guest"]) + pipeline.add_task(psi_0) + pipeline.add_task(data_split_0) + pipeline.add_task(union_0) + pipeline.add_task(sample_0) + pipeline.add_task(binning_0) + pipeline.add_task(statistics_0) + pipeline.add_task(selection_0) + pipeline.add_task(scale_0) + pipeline.add_task(selection_1) + pipeline.add_task(lr_0) + pipeline.add_task(linr_0) + pipeline.add_task(evaluation_0) + pipeline.add_task(evaluation_1) + + # pipeline.add_task(hetero_feature_binning_0) + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + # print(pipeline.get_task_info("feature_scale_1").get_output_model()) + + pipeline.deploy([psi_0, selection_0]) + + predict_pipeline = FateFlowPipeline() + + deployed_pipeline = pipeline.get_deployed_pipeline() + deployed_pipeline.psi_0.guest.component_setting( + input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.hosts[0].component_setting( + input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.compile() + predict_pipeline.predict() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/multi_model/test_multi_preprocessing.py b/examples/pipeline/multi_model/test_multi_preprocessing.py new file mode 100644 index 0000000000..c7a9e77711 --- /dev/null +++ b/examples/pipeline/multi_model/test_multi_preprocessing.py @@ -0,0 +1,113 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import DataSplit, PSI, Sample, FeatureScale +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + psi_1 = PSI("psi_1") + psi_1.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_1.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + data_split_0 = DataSplit("data_split_0", + train_size=0.6, + validate_size=0.0, + test_size=0.4, + stratified=True, + input_data=psi_0.outputs["output_data"]) + + data_split_1 = DataSplit("data_split_1", + train_size=200, + test_size=50, + stratified=True, + input_data=psi_0.outputs["output_data"] + ) + + sample_0 = Sample("sample_0", + frac={0: 0.5}, + replace=False, + hetero_sync=True, + input_data=psi_0.outputs["output_data"]) + + sample_1 = Sample("sample_1", + n=100, + replace=False, + hetero_sync=True, + input_data=psi_0.outputs["output_data"] + ) + feature_scale_0 = FeatureScale("feature_scale_0", + method="min_max", + feature_range={"x0": [-1, 1]}, + scale_col=["x0", "x1", "x3"], + train_data=psi_0.outputs["output_data"]) + pipeline.add_task(psi_0) + pipeline.add_task(psi_1) + pipeline.add_task(data_split_0) + pipeline.add_task(data_split_1) + pipeline.add_task(sample_0) + pipeline.add_task(sample_1) + pipeline.add_task(feature_scale_0) + + # pipeline.add_task(hetero_feature_binning_0) + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + # print(pipeline.get_task_info("data_split_0").get_output_data()) + """output_data = pipeline.get_task_info("data_split_0").get_output_data() + import pandas as pd + + print(f"data split 0 train size: {pd.DataFrame(output_data['train_output_data']).shape};" + f"validate size: {pd.DataFrame(output_data['validate_output_data']).shape}" + f"test size: {pd.DataFrame(output_data['test_output_data']).shape}") + output_data = pipeline.get_task_info("data_split_1").get_output_data() + print(f"data split 1train size: {pd.DataFrame(output_data['train_output_data']).shape};" + f"validate size: {pd.DataFrame(output_data['validate_output_data']).shape}" + f"test size: {pd.DataFrame(output_data['test_output_data']).shape}")""" + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/multi_model/test_multi_w_predict.py b/examples/pipeline/multi_model/test_multi_w_predict.py new file mode 100644 index 0000000000..21f5e4460c --- /dev/null +++ b/examples/pipeline/multi_model/test_multi_w_predict.py @@ -0,0 +1,103 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import PSI, CoordinatedLR, Evaluation, \ + HeteroFeatureBinning, HeteroFeatureSelection +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + arbiter = parties.arbiter[0] + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host, arbiter=arbiter) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + binning_0 = HeteroFeatureBinning("binning_0", + method="quantile", + n_bins=10, + train_data=psi_0.outputs["output_data"] + ) + selection_0 = HeteroFeatureSelection("selection_0", + method=["iv"], + train_data=psi_0.outputs["output_data"], + input_models=[binning_0.outputs["output_model"]], + iv_param={"metrics": "iv", "filter_type": "threshold", "threshold": 0.1}) + + lr_0 = CoordinatedLR("lr_0", + epochs=10, + batch_size=None, + optimizer={"method": "SGD", "optimizer_params": {"lr": 0.1}, "penalty": "l2", "alpha": 0.001}, + init_param={"fit_intercept": True, "method": "zeros"}, + train_data=selection_0.outputs["train_output_data"], + learning_rate_scheduler={"method": "linear", "scheduler_params": {"start_factor": 0.7, + "total_iters": 100}}) + + evaluation_0 = Evaluation("evaluation_0", + runtime_roles=["guest"], + default_eval_setting="binary", + input_data=lr_0.outputs["train_output_data"]) + + pipeline.add_task(psi_0) + pipeline.add_task(binning_0) + pipeline.add_task(selection_0) + pipeline.add_task(lr_0) + pipeline.add_task(evaluation_0) + + # pipeline.add_task(hetero_feature_binning_0) + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + pipeline.deploy([psi_0, selection_0, lr_0]) + + predict_pipeline = FateFlowPipeline() + + deployed_pipeline = pipeline.get_deployed_pipeline() + deployed_pipeline.psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.compile() + predict_pipeline.predict() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/sample/sample_testsuite.yaml b/examples/pipeline/sample/sample_testsuite.yaml new file mode 100644 index 0000000000..3c1da90fbc --- /dev/null +++ b/examples/pipeline/sample/sample_testsuite.yaml @@ -0,0 +1,58 @@ +data: + - file: examples/data/breast_hetero_guest.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + label_type: int64 + label_name: y + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_guest + namespace: experiment + role: guest_0 + - file: examples/data/breast_hetero_host.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_host + namespace: experiment + role: host_0 + - file: examples/data/breast_hetero_host.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_host + namespace: experiment + role: host_1 +tasks: + sample: + script: test_sample.py + sample-unilateral: + script: test_sample_unilateral.py + sample-multi-host: + script: test_sample_multi_host.py diff --git a/examples/pipeline/sample/test_sample.py b/examples/pipeline/sample/test_sample.py new file mode 100644 index 0000000000..0cbea77bbe --- /dev/null +++ b/examples/pipeline/sample/test_sample.py @@ -0,0 +1,79 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import Sample, PSI +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + psi_1 = PSI("psi_1") + psi_1.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_1.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + sample_0 = Sample("sample_0", + frac={0: 0.5}, + replace=False, + hetero_sync=True, + input_data=psi_0.outputs["output_data"]) + + sample_1 = Sample("sample_1", + n=100, + replace=False, + hetero_sync=True, + input_data=psi_0.outputs["output_data"] + ) + + pipeline.add_task(psi_0) + pipeline.add_task(psi_1) + pipeline.add_task(sample_0) + pipeline.add_task(sample_1) + + # pipeline.add_task(hetero_feature_binning_0) + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/sample/test_sample_multi_host.py b/examples/pipeline/sample/test_sample_multi_host.py new file mode 100644 index 0000000000..f0d4056761 --- /dev/null +++ b/examples/pipeline/sample/test_sample_multi_host.py @@ -0,0 +1,83 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import Sample, PSI +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + psi_0.hosts[1].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + psi_1 = PSI("psi_1") + psi_1.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_1.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + psi_1.hosts[1].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + sample_0 = Sample("sample_0", + frac={0: 0.8, 1: 0.5}, + replace=False, + hetero_sync=True, + input_data=psi_0.outputs["output_data"]) + + sample_1 = Sample("sample_1", + n=800, + replace=True, + hetero_sync=True, + input_data=psi_0.outputs["output_data"] + ) + + pipeline.add_task(psi_0) + pipeline.add_task(psi_1) + pipeline.add_task(sample_0) + pipeline.add_task(sample_1) + + # pipeline.add_task(hetero_feature_binning_0) + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/sample/test_sample_unilateral.py b/examples/pipeline/sample/test_sample_unilateral.py new file mode 100644 index 0000000000..643a14e60f --- /dev/null +++ b/examples/pipeline/sample/test_sample_unilateral.py @@ -0,0 +1,80 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import Sample, PSI +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + psi_1 = PSI("psi_1") + psi_1.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_1.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + sample_0 = Sample("sample_0", + runtime_roles=["guest"], + frac={0: 0.5}, + replace=False, + hetero_sync=False, + input_data=psi_0.outputs["output_data"]) + + sample_1 = Sample("sample_1", + runtime_roles=["host"], + n=1000, + replace=True, + hetero_sync=False, + input_data=psi_0.outputs["output_data"] + ) + + pipeline.add_task(psi_0) + pipeline.add_task(psi_1) + pipeline.add_task(sample_0) + pipeline.add_task(sample_1) + + # pipeline.add_task(hetero_feature_binning_0) + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/statistics/statistics_testsuite.yaml b/examples/pipeline/statistics/statistics_testsuite.yaml new file mode 100644 index 0000000000..99b5cea2d2 --- /dev/null +++ b/examples/pipeline/statistics/statistics_testsuite.yaml @@ -0,0 +1,40 @@ +data: + - file: examples/data/breast_hetero_guest.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + label_type: int64 + label_name: y + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_guest + namespace: experiment + role: guest_0 + - file: examples/data/breast_hetero_host.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_host + namespace: experiment + role: host_0 +tasks: + statistics: + script: test_statistics.py + statistics-default: + script: test_statistics_default.py \ No newline at end of file diff --git a/examples/pipeline/statistics/test_statistics.py b/examples/pipeline/statistics/test_statistics.py new file mode 100644 index 0000000000..d89fcb1197 --- /dev/null +++ b/examples/pipeline/statistics/test_statistics.py @@ -0,0 +1,62 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import PSI, Statistics +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config=".../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + statistics_0 = Statistics("statistics_0", input_data=psi_0.outputs["output_data"], + metrics=["mean", "std", "0%", "25%", "median", "75%", "100%", + "missing_ratio"]) + + pipeline.add_task(psi_0) + pipeline.add_task(statistics_0) + + # pipeline.add_task(hetero_feature_binning_0) + pipeline.compile() + pipeline.fit() + # print(f"statistics_0 output model: {pipeline.get_task_info('statistics_0').get_output_model()}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/statistics/test_statistics_default.py b/examples/pipeline/statistics/test_statistics_default.py new file mode 100644 index 0000000000..1add7f97aa --- /dev/null +++ b/examples/pipeline/statistics/test_statistics_default.py @@ -0,0 +1,61 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import PSI, Statistics +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config=".../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + statistics_0 = Statistics("statistics_0", + skip_col=["x0", "x3"], + input_data=psi_0.outputs["output_data"]) + + pipeline.add_task(psi_0) + pipeline.add_task(statistics_0) + + pipeline.compile() + pipeline.fit() + print(f"statistics_0 output model: {pipeline.get_task_info('statistics_0').get_output_model()}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/test_dag.py b/examples/pipeline/test_dag.py deleted file mode 100644 index fddd3994da..0000000000 --- a/examples/pipeline/test_dag.py +++ /dev/null @@ -1,116 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from fate_client.pipeline.components.fate import HeteroLR -from fate_client.pipeline.components.fate import Reader -from fate_client.pipeline.components.fate import FeatureScale -from fate_client.pipeline.components.fate import Intersection -from fate_client.pipeline.components.fate import Evaluation -from fate_client.pipeline import StandalonePipeline - - -pipeline = StandalonePipeline().set_roles( - guest='9999', host='10000', arbiter='10001') -reader_0 = Reader(name="reader_0") -reader_0.guest.component_param(path="file://${abs_path_of_data_guest}", - # path="file:///data/projects/fate/examples/data/breast_hetero_guest.csv", - format="csv", - id_name="id", - delimiter=",", - label_name="y", - label_type="float32", - dtype="float32") - -reader_0.hosts[0].component_param(path="file://${abs_path_of_data_host}", - # path="file:///data/projects/fate/examples/data/breast_hetero_host.csv", - format="csv", - id_name="id", - delimiter=",", - label_name=None, - dtype="float32") - -intersection_0 = Intersection(name="intersection_0", - method="raw", - input_data=reader_0.outputs["output_data"]) - -intersection_1 = Intersection(name="intersection_1", - method="raw", - input_data=reader_0.outputs["output_data"]) - -feature_scale_0 = FeatureScale(name="feature_scale_0", - method="standard", - train_data=intersection_0.outputs["output_data"]) - -feature_scale_1 = FeatureScale(name="feature_scale_1", - test_data=intersection_1.outputs["output_data"], - input_model=feature_scale_0.outputs["output_model"]) - -lr_0 = HeteroLR(name="lr_0", - train_data=feature_scale_0.outputs["train_output_data"], - validate_data=feature_scale_1.outputs["test_output_data"], - max_iter=100, - learning_rate=0.03, - batch_size=-1) - -evaluation_0 = Evaluation(name="evaluation_0", - runtime_roles="guest", - input_data=lr_0.outputs["train_output_data"]) - -pipeline.add_task(reader_0) -pipeline.add_task(feature_scale_0) -pipeline.add_task(feature_scale_1) -pipeline.add_task(intersection_0) -pipeline.add_task(intersection_1) -pipeline.add_task(lr_0) -pipeline.add_task(evaluation_0) - -pipeline.compile() -print(pipeline.get_dag()) -pipeline.fit() -print(pipeline.get_task_info("feature_scale_0").get_output_model()) -print(pipeline.get_task_info("lr_0").get_output_model()) -print(pipeline.get_task_info("lr_0").get_output_data()) -print(pipeline.get_task_info("evaluation_0").get_output_metrics()) -print(pipeline.deploy([intersection_0, feature_scale_0, lr_0])) - - -predict_pipeline = StandalonePipeline() -reader_1 = Reader(name="reader_1") -reader_1.guest.component_param(path="file://${abs_path_of_data_guest}", - # path="file:///data/projects/fate/examples/data/breast_hetero_guest.csv", - format="csv", - id_name="id", - delimiter=",", - label_name="y", - label_type="float32", - dtype="float32") - -reader_1.hosts[0].component_param(path="file://${abs_path_of_data_host}", - # path="file:///data/projects/fate/examples/data/breast_hetero_host.csv", - format="csv", - id_name="id", - delimiter=",", - label_name=None, - dtype="float32") - - -deployed_pipeline = pipeline.get_deployed_pipeline() -deployed_pipeline.intersection_0.input_data = reader_1.outputs["output_data"] - -predict_pipeline.add_task(deployed_pipeline) -predict_pipeline.add_task(reader_1) - -print("\n\n\n") -print(predict_pipeline.compile().get_dag()) -predict_pipeline.predict() diff --git a/examples/pipeline/test_dag_flow.py b/examples/pipeline/test_dag_flow.py deleted file mode 100644 index 29d5c1d621..0000000000 --- a/examples/pipeline/test_dag_flow.py +++ /dev/null @@ -1,114 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from fate_client.pipeline.components.fate import HeteroLR -from fate_client.pipeline.components.fate import Reader -from fate_client.pipeline.components.fate import FeatureScale -from fate_client.pipeline.components.fate import Intersection -from fate_client.pipeline.components.fate import Evaluation -from fate_client.pipeline import FateFlowPipeline - - -pipeline = FateFlowPipeline().set_roles( - guest='9999', host='9999', arbiter='9999') -reader_0 = Reader(name="reader_0") -reader_0.guest.component_param(path="file://${abs_path_of_data_guest}", - # path="file:///data/projects/fate/examples/data/breast_hetero_guest.csv", - format="csv", - id_name="id", - delimiter=",", - label_name="y", - label_type="float32", - dtype="float32") - -reader_0.hosts[0].component_param(path="file://${abs_path_of_data_host}", - # path="file:///data/projects/fate/examples/data/breast_hetero_host.csv", - format="csv", - id_name="id", - delimiter=",", - label_name=None, - dtype="float32") - -intersection_0 = Intersection(name="intersection_0", - method="raw", - input_data=reader_0.outputs["output_data"]) - -intersection_1 = Intersection(name="intersection_1", - method="raw", - input_data=reader_0.outputs["output_data"]) - -feature_scale_0 = FeatureScale(name="feature_scale_0", - method="standard", - train_data=intersection_0.outputs["output_data"]) - -feature_scale_1 = FeatureScale(name="feature_scale_1", - test_data=intersection_1.outputs["output_data"], - input_model=feature_scale_0.outputs["output_model"]) - -lr_0 = HeteroLR(name="lr_0", - train_data=feature_scale_0.outputs["train_output_data"], - validate_data=feature_scale_1.outputs["test_output_data"], - max_iter=1, - learning_rate=0.01, - batch_size=100) - -evaluation_0 = Evaluation(name="evaluation_0", - runtime_roles="guest", - input_data=lr_0.outputs["train_output_data"]) - -pipeline.add_task(reader_0) -pipeline.add_task(feature_scale_0) -pipeline.add_task(feature_scale_1) -pipeline.add_task(intersection_0) -pipeline.add_task(intersection_1) -pipeline.add_task(lr_0) -pipeline.add_task(evaluation_0) - -pipeline.compile() -print(pipeline.get_dag()) -pipeline.fit() -print(pipeline.get_task_info("lr_0").get_output_model()) -print(pipeline.get_task_info("evaluation_0").get_output_metrics()) -print(pipeline.deploy([intersection_0, feature_scale_0, lr_0])) - - -predict_pipeline = FateFlowPipeline() -reader_1 = Reader(name="reader_1") - -reader_1.guest.component_param(path="file://${abs_path_of_data_guest}", - # path="file:///data/projects/fate/examples/data/breast_hetero_guest.csv", - format="csv", - id_name="id", - delimiter=",", - label_name="y", - label_type="float32", - dtype="float32") - -reader_1.hosts[0].component_param(path="file://${abs_path_of_data_host}", - # path="file:///data/projects/fate/examples/data/breast_hetero_host.csv", - format="csv", - id_name="id", - delimiter=",", - label_name=None, - dtype="float32") - -deployed_pipeline = pipeline.get_deployed_pipeline() -deployed_pipeline.intersection_0.input_data = reader_1.outputs["output_data"] - -predict_pipeline.add_task(deployed_pipeline) -predict_pipeline.add_task(reader_1) - -print("\n\n\n") -print(predict_pipeline.compile().get_dag()) -predict_pipeline.predict() diff --git a/examples/pipeline/test_dag_flow_eggroll.py b/examples/pipeline/test_dag_flow_eggroll.py deleted file mode 100644 index a7c3077746..0000000000 --- a/examples/pipeline/test_dag_flow_eggroll.py +++ /dev/null @@ -1,92 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from fate_client.pipeline.components.fate import HeteroLR -from fate_client.pipeline.components.fate import Reader -from fate_client.pipeline.components.fate import FeatureScale -from fate_client.pipeline.components.fate import Intersection -from fate_client.pipeline.components.fate import Evaluation -from fate_client.pipeline import FateFlowPipeline - - -pipeline = FateFlowPipeline().set_roles( - guest='9999', host='9999', arbiter='9999') -reader_0 = Reader(name="reader_0") -reader_0.guest.component_param(path="eggroll:///${guest_table_namespace}/${guest_table_name}", - format="raw_table") - -reader_0.hosts[0].component_param(path="eggroll:///${host_table_namespace}/${host_table_name}", - format="raw_table") - -intersection_0 = Intersection(name="intersection_0", - method="raw", - input_data=reader_0.outputs["output_data"]) - -intersection_1 = Intersection(name="intersection_1", - method="raw", - input_data=reader_0.outputs["output_data"]) - -feature_scale_0 = FeatureScale(name="feature_scale_0", - method="standard", - train_data=intersection_0.outputs["output_data"]) - -feature_scale_1 = FeatureScale(name="feature_scale_1", - test_data=intersection_1.outputs["output_data"], - input_model=feature_scale_0.outputs["output_model"]) - -lr_0 = HeteroLR(name="lr_0", - train_data=feature_scale_0.outputs["train_output_data"], - validate_data=feature_scale_1.outputs["test_output_data"], - max_iter=1, - learning_rate=0.01, - batch_size=100) - -evaluation_0 = Evaluation(name="evaluation_0", - runtime_roles="guest", - input_data=lr_0.outputs["train_output_data"]) - -pipeline.add_task(reader_0) -pipeline.add_task(feature_scale_0) -pipeline.add_task(feature_scale_1) -pipeline.add_task(intersection_0) -pipeline.add_task(intersection_1) -pipeline.add_task(lr_0) -pipeline.add_task(evaluation_0) - -pipeline.compile() -print(pipeline.get_dag()) -pipeline.fit() -print(pipeline.get_task_info("lr_0").get_output_model()) -print(pipeline.get_task_info("evaluation_0").get_output_metrics()) -print(pipeline.deploy([intersection_0, feature_scale_0, lr_0])) - - -predict_pipeline = FateFlowPipeline() -reader_1 = Reader(name="reader_1") -reader_1.guest.component_param(path="eggroll:///${guest_table_namespace}/${guest_table_name}", - format="raw_table") - -reader_1.hosts[0].component_param(path="eggroll:///${host_table_namespace}/${host_table_name}", - format="raw_table") - - -deployed_pipeline = pipeline.get_deployed_pipeline() -deployed_pipeline.intersection_0.input_data = reader_1.outputs["output_data"] - -predict_pipeline.add_task(deployed_pipeline) -predict_pipeline.add_task(reader_1) - -print("\n\n\n") -print(predict_pipeline.compile().get_dag()) -predict_pipeline.predict() diff --git a/examples/pipeline/test_upload.py b/examples/pipeline/test_upload.py deleted file mode 100644 index 6f72f0afba..0000000000 --- a/examples/pipeline/test_upload.py +++ /dev/null @@ -1,40 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from fate_client.pipeline import FateFlowPipeline - -pipeline = FateFlowPipeline() -pipeline.upload(file="${abs_path_of_data_guest}", - # file="/data/projects/fate/examples/data/breast_hetero_guest.csv", - head=1, - partitions=4, - namespace="experiment", - name="breast_hetero_guest", - meta={ - "label_name": "y", - "label_type": "float32", - "dtype": "float32" - }) - -pipeline = FateFlowPipeline() -pipeline.upload(file="${abs_path_of_data_host}", - # file="/data/projects/fate/examples/data/breast_hetero_host.csv", - head=1, - partitions=4, - namespace="experiment", - name="breast_hetero_host", - meta={ - "label_name": None, - "dtype": "float32" - }) \ No newline at end of file diff --git a/examples/pipeline/union/test_union.py b/examples/pipeline/union/test_union.py new file mode 100644 index 0000000000..a1138117e1 --- /dev/null +++ b/examples/pipeline/union/test_union.py @@ -0,0 +1,81 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import DataSplit, PSI, Union +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + + pipeline = FateFlowPipeline().set_roles(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + psi_1 = PSI("psi_1") + psi_1.guest.component_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_1.hosts[0].component_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + data_split_0 = DataSplit("data_split_0", + train_size=0.6, + validate_size=0.1, + input_data=psi_0.outputs["output_data"]) + + data_split_1 = DataSplit("data_split_1", + train_size=200, + test_size=50, + input_data=psi_0.outputs["output_data"] + ) + + union_0 = Union("union_0", input_data_list=[data_split_0.outputs["train_output_data"], + data_split_0.outputs["test_output_data"]]) + pipeline.add_task(psi_0) + pipeline.add_task(psi_1) + pipeline.add_task(data_split_0) + pipeline.add_task(data_split_1) + pipeline.add_task(union_0) + + # pipeline.add_task(hetero_feature_binning_0) + pipeline.compile() + print(pipeline.get_dag()) + pipeline.fit() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/union/union_testsuite.yaml b/examples/pipeline/union/union_testsuite.yaml new file mode 100644 index 0000000000..b5eab53a5b --- /dev/null +++ b/examples/pipeline/union/union_testsuite.yaml @@ -0,0 +1,38 @@ +data: + - file: examples/data/breast_hetero_guest.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + label_type: int64 + label_name: y + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_guest + namespace: experiment + role: guest_0 + - file: examples/data/breast_hetero_host.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_host + namespace: experiment + role: host_0 +tasks: + union: + script: test_union.py diff --git a/examples/pipeline/upload/test_upload.py b/examples/pipeline/upload/test_upload.py new file mode 100644 index 0000000000..bdf10c7ebb --- /dev/null +++ b/examples/pipeline/upload/test_upload.py @@ -0,0 +1,55 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from fate_client.pipeline import FateFlowPipeline + +pipeline = FateFlowPipeline().set_roles( + local="0") +pipeline.set_site_role("local") +pipeline.set_site_party_id("0") +meta = {'delimiter': ',', + 'dtype': 'float32', + 'input_format': 'dense', + 'label_type': 'int32', + 'label_name': 'y', + 'match_id_name': 'id', + 'match_id_range': 0, + 'tag_value_delimiter': ':', + 'tag_with_value': False, + 'weight_type': 'float32'} + +pipeline.transform_local_file_to_dataframe( # file="${abs_path_of_data_guest}", + meta=meta, head=True, extend_sid=True, + namespace="experiment", + name="breast_hetero_guest") + +meta = {'delimiter': ',', + 'dtype': 'float32', + 'input_format': 'dense', + 'label_type': 'int', + 'match_id_name': 'id', + 'match_id_range': 0, + 'tag_value_delimiter': ':', + 'tag_with_value': False, + 'weight_type': 'float32'} + +pipeline = FateFlowPipeline().set_roles( + local="0") +pipeline.set_site_role("local") +pipeline.set_site_party_id("0") + +pipeline.transform_local_file_to_dataframe( # file="${abs_path_of_data_host}", + meta=meta, head=True, extend_sid=True, + namespace="experiment", + name="breast_hetero_host") diff --git a/examples/pipeline/upload/test_upload_sid.py b/examples/pipeline/upload/test_upload_sid.py new file mode 100644 index 0000000000..3eb51d1490 --- /dev/null +++ b/examples/pipeline/upload/test_upload_sid.py @@ -0,0 +1,57 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from fate_client.pipeline import FateFlowPipeline + +pipeline = FateFlowPipeline().set_roles( + local="0") +pipeline.set_site_role("local") +pipeline.set_site_party_id("0") +meta = {'delimiter': ',', + 'dtype': 'float32', + 'input_format': 'dense', + 'label_type': 'int32', + 'label_name': 'y', + 'match_id_name': 'id', + 'match_id_range': 0, + 'sample_id_name': 'id', + 'tag_value_delimiter': ':', + 'tag_with_value': False, + 'weight_type': 'float32'} + +pipeline.transform_local_file_to_dataframe( # file="${abs_path_of_data_guest}", + meta=meta, head=True, extend_sid=False, + namespace="experiment", + name="breast_hetero_guest") + +meta = {'delimiter': ',', + 'dtype': 'float32', + 'input_format': 'dense', + 'label_type': 'int', + 'match_id_name': 'id', + 'match_id_range': 0, + 'sample_id_name': 'id', + 'tag_value_delimiter': ':', + 'tag_with_value': False, + 'weight_type': 'float32'} + +pipeline = FateFlowPipeline().set_roles( + local="0") +pipeline.set_site_role("local") +pipeline.set_site_party_id("0") + +pipeline.transform_local_file_to_dataframe( # file="${abs_path_of_data_host}", + meta=meta, head=True, extend_sid=False, + namespace="experiment", + name="breast_hetero_host") diff --git a/fate_client b/fate_client index 94185a8593..1af3d2271c 160000 --- a/fate_client +++ b/fate_client @@ -1 +1 @@ -Subproject commit 94185a859367229ac94be171f7ad0042a5b5c9bb +Subproject commit 1af3d2271c51609637635f2c33576d702f7365f9 diff --git a/fate_flow b/fate_flow index cb9d772f54..13fd3dbfb6 160000 --- a/fate_flow +++ b/fate_flow @@ -1 +1 @@ -Subproject commit cb9d772f544f8e69116dd8cc114d8d76a2a2ef96 +Subproject commit 13fd3dbfb67c4c23b25e38630899f34c8446c38f diff --git a/fate_test b/fate_test new file mode 160000 index 0000000000..80e737027f --- /dev/null +++ b/fate_test @@ -0,0 +1 @@ +Subproject commit 80e737027f93fdd214f78c60afda840e6ce09757 diff --git a/java/osx/README.md b/java/osx/README.md index 964392f623..055f994d4e 100644 --- a/java/osx/README.md +++ b/java/osx/README.md @@ -1,7 +1 @@ -# Release 1.0.0-alpha -## Major Features and Improvements -* Support grpc synchronous transmission and streaming transmission. Compatible with eggroll interface and can replace FATE1. x rollsite component -* Support asynchronous message transmission, which can replace rabbitmq and pulsar components in FATE1. x -* Support HTTP1. X protocol transmission -* Support cluster deployment and inter-site traffic control -* Support networking as an Exchange component \ No newline at end of file +OSX: Open Site Exchange \ No newline at end of file diff --git a/java/osx/RELEASE.md b/java/osx/RELEASE.md new file mode 100644 index 0000000000..c95e476954 --- /dev/null +++ b/java/osx/RELEASE.md @@ -0,0 +1,24 @@ +# Release 1.0.0-beta +## Major Features and Improvements +* Improve HTTP/1. X protocol support, support GRPC to HTTP1.X transmission +* Support TLS secure transmission protocol +* Add a routing table configuration interface +* Automatic connectivity check for routing tables +* Improve the transmission function in cluster mode +* Improve cluster mode flow control function +* Support for simple interface authentication + +# Release 1.0.0-alpha +## Major Features and Improvements +* Support grpc synchronous transmission and streaming transmission. Compatible with eggroll interface and can replace FATE1. x rollsite component +* Support asynchronous message transmission, which can replace rabbitmq and pulsar components in FATE1. x +* Support HTTP1. X protocol transmission +* Support cluster deployment and inter-site traffic control +* Support networking as an Exchange component + + + + + + + diff --git a/java/osx/bin/common.sh b/java/osx/bin/common.sh index 3101efaa39..23bc782274 100644 --- a/java/osx/bin/common.sh +++ b/java/osx/bin/common.sh @@ -14,9 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# -error_exit () -{ +error_exit (){ echo "ERROR: $1 !!" exit 1 } @@ -74,6 +72,8 @@ choose_gc_options() JAVA_OPT="${JAVA_OPT} -XX:+UseG1GC -XX:G1HeapRegionSize=16m -XX:G1ReservePercent=25 -XX:InitiatingHeapOccupancyPercent=30 -XX:SoftRefLRUPolicyMSPerMB=0" JAVA_OPT="${JAVA_OPT} -Xlog:gc*:file=${GC_LOG_DIR}/rmq_srv_gc_%p_%t.log:time,tags:filecount=5,filesize=30M" fi + + JAVA_OPT="${JAVA_OPT} -XX:+HeapDumpOnOutOfMemoryError -XX:HeapDumpPath=${BASE_DIR}/oom/heapdump.hprof " } choose_gc_log_directory @@ -84,7 +84,6 @@ JAVA_OPT="${JAVA_OPT} -XX:-OmitStackTraceInFastThrow" JAVA_OPT="${JAVA_OPT} -XX:+AlwaysPreTouch" JAVA_OPT="${JAVA_OPT} -XX:MaxDirectMemorySize=15g" JAVA_OPT="${JAVA_OPT} -XX:-UseLargePages -XX:-UseBiasedLocking" -#JAVA_OPT="${JAVA_OPT} -Xdebug -Xrunjdwp:transport=dt_socket,address=9555,server=y,suspend=n" JAVA_OPT="${JAVA_OPT} ${JAVA_OPT_EXT}" set -e @@ -93,7 +92,7 @@ getpid() { pid=$(cat ./bin/broker.pid) fi if [[ -n ${pid} ]]; then - count=$(ps -ef | grep $pid | grep -v "grep" | wc -l) + count=$(ps -ef | grep $pid |grep 'org.fedai.osx' | grep -v "grep" | wc -l) if [[ ${count} -eq 0 ]]; then rm ./bin/broker.pid unset pid @@ -111,57 +110,54 @@ mklogsdir() { start() { echo "try to start $1" module=broker - main_class=com.osx.broker.Bootstrap + main_class=org.fedai.osx.broker.Bootstrap getpid $module if [[ ! -n ${pid} ]]; then JAVA_OPT="${JAVA_OPT} " mklogsdir -# if [[ -e "${module}.jar" ]]; then -# rm ${module}.jar -# fi -# ln -s ${module}-${module_version}.jar ${module}.jar - JAVA_OPT="${JAVA_OPT} -cp conf/broker/:lib/*" -# if [ ${module} = "transfer" ]; then -# echo "transfer" -# elif [ ${module} = "cluster-manager" ] || [ ${module} = "dashboard" ]; then -# JAVA_OPT="${JAVA_OPT} -Dspring.config.location=${configpath}/cluster-manager.properties" -# JAVA_OPT="${JAVA_OPT} -cp conf/:lib/*:${module}.jar" -# else -# echo "usage: ${module} {transfer|cluster-manager|dashboard}" -# fi - + JAVA_OPT="${JAVA_OPT} -cp conf/broker/:lib/*:extension/*:${BASE_DIR}/${project_name}-${module}-${module_version}.jar" JAVA_OPT="${JAVA_OPT} ${main_class}" + JAVA_OPT="${JAVA_OPT} -c ${configpath} " + echo $JAVA ${JAVA_OPT} + nohup $JAVA ${JAVA_OPT} >/dev/null 2>&1 & + inspect_pid 5 $! + if [[ "$exist" = 1 ]]; then + echo $! >./bin/${module}.pid + getpid ${module} + echo "service start sucessfully. pid: ${pid}" + else + echo "service start failed, " + fi + else + echo "service already started. pid: ${pid}" + fi +} - JAVA_OPT="${JAVA_OPT} -c ${configpath}/broker/broker.properties" -# if [ ${module} = "broker" -o ${module} = "cli" ]; then -# JAVA_OPT="${JAVA_OPT} -c ${configpath}/broker/broker.properties" -# elif [ ${module} = "cluster-manager" ]; then -# JAVA_OPT="${JAVA_OPT} -c ${configpath}/cluster-manager/cluster-manager.properties" -# -# elif [ ${module} = "dashboard" ]; then -# JAVA_OPT="-jar ${libpath}/dashboard-1.0.0.jar -spring.config.location=${configpath}/dashboard/application.properties" -# fi - - if [ ${module} = "cli" ]; then - java ${JAVA_OPT} - else - nohup $JAVA ${JAVA_OPT} >/dev/null 2>&1 & - #sleep 5 - #id=$(ps -p $! | awk '{print $1}' | sed -n '2p') - inspect_pid 5 $! - - if [[ "$exist" = 1 ]]; then - echo $! >./bin/${module}.pid - getpid ${module} - echo "service start sucessfully. pid: ${pid}" - else - echo "service start failed" - fi - fi +debug() { + echo "try to start $1" + module=broker + main_class=org.fedai.osx.broker.Bootstrap + getpid $module + if [[ ! -n ${pid} ]]; then JAVA_OPT="${JAVA_OPT} " + mklogsdir + JAVA_OPT="${JAVA_OPT} -Xdebug -Xrunjdwp:transport=dt_socket,server=y,suspend=n,address=7007 -cp conf/broker/:lib/*:extension/*:${BASE_DIR}/${project_name}-${module}-${module_version}.jar" + JAVA_OPT="${JAVA_OPT} ${main_class}" + JAVA_OPT="${JAVA_OPT} -c ${configpath} " + echo $JAVA ${JAVA_OPT} + nohup $JAVA ${JAVA_OPT} >/dev/null 2>&1 & + inspect_pid 5 $! + if [[ "$exist" = 1 ]]; then + echo $! >./bin/${module}.pid + getpid ${module} + echo "service start sucessfully. pid: ${pid}" + else + echo "service start failed, " + fi else echo "service already started. pid: ${pid}" fi } + status() { getpid $1 if [[ -n ${pid} ]]; then @@ -195,11 +191,9 @@ stop() { fi } - inspect_pid() { total=0 exist=0 - #echo "inspect pid: $2,periods: $1" if [[ -n $2 ]]; then while [[ $total -le $1 ]] do @@ -214,4 +208,4 @@ inspect_pid() { fi done fi -} \ No newline at end of file +} diff --git a/java/osx/bin/service.sh b/java/osx/bin/service.sh index 07f15ba4b4..ed507da4ac 100644 --- a/java/osx/bin/service.sh +++ b/java/osx/bin/service.sh @@ -1,5 +1,4 @@ #!/bin/bash - # # Copyright 2019 The FATE Authors. All Rights Reserved. # @@ -15,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -set -e +set -e source ./bin/common.sh #export JAVA_HOME=/data/projects/fate/common/jdk/jdk-8u192 #export PATH=$PATH:$JAVA_HOME/bin @@ -25,8 +24,8 @@ configpath=$(cd $basepath/conf;pwd) libpath=$(cd $basepath/lib;pwd) #module=transfer #main_class=com.firework.transfer.Bootstrap -#module_version=1.0.0 - +module_version=1.0.0-beta +project_name=osx @@ -35,6 +34,10 @@ case "$1" in start $2 status $2 ;; + debug) + debug $2 + status $2 + ;; stop) stop $2 ;; @@ -47,7 +50,13 @@ case "$1" in start $2 status $2 ;; + rebudeg) + stop $2 + sleep 0.5 + debug $2 + status $2 + ;; *) echo "usage: $0 {start|stop|status|restart}" exit 1 -esac \ No newline at end of file +esac diff --git a/java/osx/broker/src/main/java/com/osx/broker/Bootstrap.java b/java/osx/broker/src/main/java/com/osx/broker/Bootstrap.java deleted file mode 100644 index 6c008efd23..0000000000 --- a/java/osx/broker/src/main/java/com/osx/broker/Bootstrap.java +++ /dev/null @@ -1,165 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.osx.broker; -import com.google.common.collect.Lists; -import com.osx.core.config.MetaInfo; -import com.osx.core.constant.Dict; -import com.osx.core.constant.StreamLimitMode; -import com.osx.core.jvm.JvmInfoCounter; -import com.osx.core.utils.JsonUtil; -import com.osx.core.utils.NetUtils; -import com.osx.core.utils.ServerUtil; -import org.apache.commons.cli.CommandLine; -import org.apache.commons.cli.Option; -import org.apache.commons.cli.Options; -import org.apache.commons.cli.PosixParser; -import org.apache.commons.lang3.StringUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.*; -import java.util.Properties; -public class Bootstrap { - static Logger logger = LoggerFactory.getLogger(Bootstrap.class); - static CommandLine commandLine; - public static void main(String[] args) { - try { - Options options = ServerUtil.buildCommandlineOptions(new Options()); - commandLine = ServerUtil.parseCmdLine("osx", args, buildCommandlineOptions(options), - new PosixParser()); - String filePath = commandLine.getOptionValue('c'); - logger.info("try to parse config file {}", filePath); - if (StringUtils.isEmpty(filePath)) { - System.err.println("config file is not set ,please use -c to set the config file path"); - System.exit(-1); - } - parseConfig(filePath); - Bootstrap bootstrap = new Bootstrap(); - bootstrap.start(args); - Thread shutDownThread = new Thread(() -> bootstrap.stop()); - Runtime.getRuntime().addShutdownHook(shutDownThread); - } catch (Exception ex) { - System.exit(1); - } - } - - private static Options buildCommandlineOptions(final Options options) { - Option opt = new Option("c", "configFile", true, "config properties file"); - opt.setRequired(false); - options.addOption(opt); - return options; - } - - public static void parseConfig(String configFilePath) { - try { - File file = new File(configFilePath); - Properties environment = new Properties(); - try (InputStream inputStream = new BufferedInputStream(new FileInputStream(file))) { - environment.load(inputStream); - } catch (FileNotFoundException e) { - logger.error("profile broker.properties not found"); - throw e; - } catch (IOException e) { - logger.error("parse config error, {}", e.getMessage()); - throw e; - } - MetaInfo.PROPERTY_FATE_TECH_PROVIDER = environment.getProperty(Dict.PROPERTY_FATE_TECH_PROVIDER,"FATE"); - MetaInfo.PROPERTY_ROOT_PATH = new File("").getCanonicalPath(); - MetaInfo.PROPERTY_ROUTE_TABLE = environment.getProperty(Dict.PROPERTY_ROUTE_TABLE); - MetaInfo.PROPERTY_SERVER_CERTCHAIN_FILE = environment.getProperty(Dict.PROPERTY_SERVER_CERTCHAIN_FILE); - MetaInfo.PROPERTY_SERVER_PRIVATEKEY_FILE = environment.getProperty(Dict.PROPERTY_SERVER_PRIVATEKEY_FILE); - MetaInfo.PROPERTY_SERVER_CA_FILE = environment.getProperty(Dict.PROPERTY_SERVER_CA_FILE); - MetaInfo.PROPERTY_GRPC_TLS_PORT = Integer.valueOf(environment.getProperty(Dict.PROPERTY_GRPC_TLS_PORT, "9883")); - MetaInfo.PROPERTY_GRPC_PORT = Integer.valueOf(environment.getProperty(Dict.PROPERTY_GRPC_PORT, "9889")); - MetaInfo.PROPERTY_HTTP_PORT = Integer.valueOf(environment.getProperty(Dict.HTTP_PORT,"8762")); - MetaInfo.PROPERTY_PRINT_INPUT_DATA = Boolean.valueOf(environment.getProperty(Dict.PROPERTY_PRINT_INPUT_DATA, "false")); - MetaInfo.PROPERTY_PRINT_OUTPUT_DATA = Boolean.valueOf(environment.getProperty(Dict.PROPERTY_PRINT_OUTPUT_DATA, "false")); - MetaInfo.PROPERTY_USER_HOME = System.getProperty("user.home"); - MetaInfo.PROPERTY_NEGOTIATIONTYPE = environment.getProperty(Dict.PROPERTY_NEGOTIATIONTYPE, "PLAINTEXT"); - MetaInfo.PROPERTY_TRANSFER_FILE_PATH_PRE = environment.getProperty(Dict.PROPERTY_TRANSFER_FILE_PATH, MetaInfo.PROPERTY_USER_HOME + "/.fate/transfer_file"); - MetaInfo.PROPERTY_TRANSFER_FILE_CACHE_SIZE = environment.getProperty(Dict.PROPERTY_TRANSFER_FILE_CACHE_SIZE) != null ? Integer.parseInt(environment.getProperty(Dict.PROPERTY_TRANSFER_FILE_CACHE_SIZE)) : 1 << 27; - MetaInfo.PROPERTY_USE_DIRECT_CACHE = Boolean.parseBoolean(environment.getProperty(Dict.PROPERTY_USE_DIRECT_CACHE, "false")); - MetaInfo.PROPERTY_MAX_TRANSFER_CACHE_SIZE = environment.getProperty(Dict.PROPERTY_MAX_TRANSFER_CACHE_SIZE) != null ? Integer.parseInt(environment.getProperty(Dict.PROPERTY_MAX_TRANSFER_CACHE_SIZE)) : 1 << 30; - MetaInfo.PROPERTY_GRPC_ONCOMPLETED_WAIT_TIMEOUT = Integer.parseInt(environment.getProperty(Dict.PROPERTY_GRPC_ONCOMPLETED_WAIT_TIMEOUT, "600")); - // MetaInfo.PROPERTY_USE_QUEUE_MODEL = Boolean.valueOf(environment.getProperty(Dict.PROPERTY_USE_QUEUE_MODEL, "false")); - MetaInfo.PROPERTY_STREAM_LIMIT_MODE = environment.getProperty(Dict.PROPERTY_STREAM_LIMIT_MODE, StreamLimitMode.LOCAL.name()); - MetaInfo.PROPERTY_STREAM_LIMIT_MAX_TRY_TIME = Integer.parseInt(environment.getProperty(Dict.PROPERTY_STREAM_LIMIT_MAX_TRY_TIME, "10")); - MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONCURRENT_CALL_PER_CONNECTION = Integer.parseInt(environment.getProperty(Dict.PROPERTY_GRPC_SERVER_MAX_CONCURRENT_CALL_PER_CONNECTION, "1000")); - MetaInfo.PROPERTY_GRPC_SERVER_MAX_INBOUND_MESSAGE_SIZE = environment.getProperty(Dict.PROPERTY_GRPC_SERVER_MAX_INBOUND_MESSAGE_SIZE) != null ? Integer.parseInt(environment.getProperty(Dict.PROPERTY_GRPC_SERVER_MAX_INBOUND_MESSAGE_SIZE)) : (2 << 30) - 1; - MetaInfo.PROPERTY_GRPC_SERVER_MAX_INBOUND_METADATA_SIZE = environment.getProperty(Dict.PROPERTY_GRPC_SERVER_MAX_INBOUND_METADATA_SIZE) != null ? Integer.parseInt(environment.getProperty(Dict.PROPERTY_GRPC_SERVER_MAX_INBOUND_METADATA_SIZE)) : 128 << 20; - MetaInfo.PROPERTY_GRPC_SERVER_FLOW_CONTROL_WINDOW = environment.getProperty(Dict.PROPERTY_GRPC_SERVER_FLOW_CONTROL_WINDOW) != null ? Integer.parseInt(environment.getProperty(Dict.PROPERTY_GRPC_SERVER_FLOW_CONTROL_WINDOW)) : 128 << 20; - MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC = Integer.parseInt(environment.getProperty(Dict.PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC, "7200")); - MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC = Integer.parseInt(environment.getProperty(Dict.PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC, "3600")); - MetaInfo.PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC = Integer.parseInt(environment.getProperty(Dict.PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC, "120")); - MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED = Boolean.parseBoolean(environment.getProperty(Dict.PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED, "false")); - MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC = Integer.parseInt(environment.getProperty(Dict.PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC, "86400")); - MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC = Integer.parseInt(environment.getProperty(Dict.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC, "86400")); - MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC = Integer.parseInt(environment.getProperty(Dict.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC, "86400")); - MetaInfo.TRANSFER_FATECLOUD_AHTHENTICATION_ENABLED = Boolean.valueOf(environment.getProperty(Dict.TRANSFER_FATECLOUD_AHTHENTICATION_ENABLED, "false")); - MetaInfo.TRANSFER_FATECLOUD_AUTHENTICATION_USE_CONFIG = Boolean.valueOf(environment.getProperty(Dict.TRANSFER_FATECLOUD_AUTHENTICATION_USE_CONFIG, "false")); - MetaInfo.TRANSFER_FATECLOUD_AUTHENTICATION_ROLE = environment.getProperty(Dict.TRANSFER_FATECLOUD_AUTHENTICATION_ROLE, "guest"); - MetaInfo.TRANSFER_FATECLOUD_AUTHENTICATION_URI = environment.getProperty(Dict.TRANSFER_FATECLOUD_AUTHENTICATION_URI, "/cloud-manager/api/site/rollsite/checkPartyId"); - MetaInfo.TRANSFER_FATECLOUD_AUTHENTICATION_APPKEY = environment.getProperty(Dict.TRANSFER_FATECLOUD_AUTHENTICATION_APPKEY, ""); - MetaInfo.TRANSFER_FATECLOUD_AUTHENTICATION_APPSERCRET = environment.getProperty(Dict.TRANSFER_FATECLOUD_AUTHENTICATION_APPSERCRET, ""); - MetaInfo.TRANSFER_FATECLOUD_SECRET_INFO_URL = environment.getProperty(Dict.TRANSFER_FATECLOUD_SECRET_INFO_URL, "http://localhost:9091/fate-manager/api/site/secretinfo"); - MetaInfo.TRANSFER_FATECLOUD_AUTHENTICATION_URL = environment.getProperty(Dict.TRANSFER_FATECLOUD_AUTHENTICATION_URL, "http://localhost:8999/cloud-manager/api/site/rollsite/checkPartyId"); - MetaInfo.PROPERTY_SELF_PARTY.addAll(Lists.newArrayList(environment.getProperty(Dict.PROPERTY_SELF_PARTY, "").split(","))); - - MetaInfo.HTTP_CLIENT_CONFIG_CONN_REQ_TIME_OUT = Integer.valueOf(environment.getProperty(Dict.HTTP_CLIENT_CONFIG_CONN_REQ_TIME_OUT,"500")); - MetaInfo.HTTP_CLIENT_CONFIG_CONN_TIME_OUT = Integer.valueOf(environment.getProperty(Dict.HTTP_CLIENT_CONFIG_CONN_TIME_OUT,"2000")); - MetaInfo.HTTP_CLIENT_CONFIG_SOCK_TIME_OUT = Integer.valueOf(environment.getProperty(Dict.HTTP_CLIENT_CONFIG_SOCK_TIME_OUT,"3000")); - MetaInfo.HTTP_CLIENT_INIT_POOL_MAX_TOTAL = Integer.valueOf(environment.getProperty(Dict.HTTP_CLIENT_INIT_POOL_MAX_TOTAL,"500")); - MetaInfo.HTTP_CLIENT_INIT_POOL_DEF_MAX_PER_ROUTE = Integer.valueOf(environment.getProperty(Dict.HTTP_CLIENT_INIT_POOL_DEF_MAX_PER_ROUTE,"200")); - MetaInfo.HTTP_CLIENT_INIT_POOL_SOCK_TIME_OUT = Integer.valueOf(environment.getProperty(Dict.HTTP_CLIENT_INIT_POOL_SOCK_TIME_OUT,"10000")); - MetaInfo.HTTP_CLIENT_INIT_POOL_CONN_TIME_OUT = Integer.valueOf(environment.getProperty(Dict.HTTP_CLIENT_INIT_POOL_CONN_TIME_OUT,"10000")); - MetaInfo.HTTP_CLIENT_INIT_POOL_CONN_REQ_TIME_OUT = Integer.valueOf(environment.getProperty(Dict.HTTP_CLIENT_INIT_POOL_CONN_REQ_TIME_OUT,"10000")); - MetaInfo.HTTP_CLIENT_TRAN_CONN_REQ_TIME_OUT = Integer.valueOf(environment.getProperty(Dict.HTTP_CLIENT_TRAN_CONN_REQ_TIME_OUT,"60000")); - MetaInfo.HTTP_CLIENT_TRAN_CONN_TIME_OUT = Integer.valueOf(environment.getProperty(Dict.HTTP_CLIENT_TRAN_CONN_TIME_OUT,"60000")); - MetaInfo.HTTP_CLIENT_TRAN_SOCK_TIME_OUT = Integer.valueOf(environment.getProperty(Dict.HTTP_CLIENT_TRAN_SOCK_TIME_OUT,"60000")); - - MetaInfo.PRPPERTY_QUEUE_MAX_FREE_TIME = Integer.parseInt(environment.getProperty(Dict.PRPPERTY_QUEUE_MAX_FREE_TIME, "60000000")); - MetaInfo.INSTANCE_ID = NetUtils.getLocalHost() + ":" + MetaInfo.PROPERTY_GRPC_PORT; - MetaInfo.PROPERTY_DEPLOY_MODE = environment.getProperty(Dict.PROPERTY_DEPLOY_MODE); - MetaInfo.PROPERTY_CLUSTER_MANAGER_ADDRESS = environment.getProperty(Dict.PROPERTY_CLUSTER_MANAGER_ADDRESS); - MetaInfo.PROPERTY_EGGROLL_CLUSTER_MANANGER_IP = environment.getProperty(Dict.PROPERTY_EGGROLL_CLUSTER_MANANGER_IP); - MetaInfo.PROPERTY_EGGROLL_CLUSTER_MANANGER_PORT = Integer.parseInt(environment.getProperty(Dict.PROPERTY_EGGROLL_CLUSTER_MANANGER_PORT)); - MetaInfo.PROPERTY_ZK_URL = environment.getProperty(Dict.PROPERTY_ZK_URL); - MetaInfo.PROPERTY_OPEN_HTTP_SERVER = Boolean.valueOf(environment.getProperty(Dict.PROPERTY_OPEN_HTTP_SERVER, "false")); - MetaInfo.PROPERTY_OPEN_GRPC_TLS_SERVER = Boolean.valueOf(environment.getProperty(Dict.PROPERTY_OPEN_GRPC_TLS_SERVER, "false")); -// public static Boolean PROPERTY_OPEN_HTTP_SERVER = false; -// public static Boolean PROPERTY_OPEN_GRPC_TLS_SERVER = false; - MetaInfo.PROPERTY_DEFAULT_CLIENT_VERSION = environment.getProperty(Dict.PROPERTY_DEFAULT_CLIENT_VERSION,"2.X.X"); - - } catch (Exception e) { - logger.error("init MetaInfo error", e); - System.exit(1); - } - logger.info("Meta Info {}", JsonUtil.formatJson(JsonUtil.object2Json(MetaInfo.toMap()))); - } - - public void start(String[] args) { - ServiceContainer.init(); - JvmInfoCounter.start(); - } - - public void stop() { - logger.info("try to shutdown server ..."); - if (ServiceContainer.transferQueueManager != null) { - ServiceContainer.transferQueueManager.destroyAll(); - } - } - -} \ No newline at end of file diff --git a/java/osx/broker/src/main/java/com/osx/broker/ServiceContainer.java b/java/osx/broker/src/main/java/com/osx/broker/ServiceContainer.java deleted file mode 100644 index 0ec60a5397..0000000000 --- a/java/osx/broker/src/main/java/com/osx/broker/ServiceContainer.java +++ /dev/null @@ -1,197 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.osx.broker; - - -import com.osx.broker.consumer.ConsumerManager; -import com.osx.broker.grpc.PcpGrpcService; -import com.osx.broker.grpc.ProxyGrpcService; -import com.osx.broker.interceptor.RequestHandleInterceptor; -import com.osx.broker.interceptor.RouterInterceptor; -import com.osx.broker.message.AllocateMappedFileService; -import com.osx.broker.queue.TransferQueueManager; -import com.osx.broker.router.DefaultFateRouterServiceImpl; -import com.osx.broker.router.FateRouterService; -import com.osx.broker.server.OsxServer; -import com.osx.broker.service.PushService; -import com.osx.broker.service.TokenApplyService; -import com.osx.broker.service.UnaryCallService; -import com.osx.broker.store.MessageStore; -import com.osx.broker.token.DefaultTokenService; -import com.osx.broker.zk.CuratorZookeeperClient; -import com.osx.broker.zk.ZkConfig; -import com.osx.core.config.MetaInfo; -import com.osx.core.flow.ClusterFlowRuleManager; -import com.osx.core.flow.FlowCounterManager; -import com.osx.core.service.AbstractServiceAdaptor; -import com.osx.tech.provider.TechProviderRegister; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.File; -import java.util.HashMap; -import java.util.Map; - -public class ServiceContainer { - static public ConsumerManager consumerManager; - static public PcpGrpcService pcpGrpcService; - static public TransferQueueManager transferQueueManager; - static public AllocateMappedFileService allocateMappedFileService; - static public FlowCounterManager flowCounterManager; - static public OsxServer transferServer; - static public ProxyGrpcService proxyGrpcService; - static public FateRouterService fateRouterService; - static public Map serviceAdaptorMap = new HashMap(); - static public TokenApplyService tokenApplyService; - static public PushService pushService; - static public UnaryCallService unaryCallService; - static public RequestHandleInterceptor requestHandleInterceptor; - static public MessageStore messageStore; - static public RouterInterceptor routerInterceptor; - static public ClusterFlowRuleManager clusterFlowRuleManager; - static public DefaultTokenService defaultTokenService; - static public CuratorZookeeperClient zkClient; - static public TechProviderRegister techProviderRegister; - - static Logger logger = LoggerFactory.getLogger(ServiceContainer.class); - - public static void init() { - flowCounterManager = createFlowCounterManager(); - clusterFlowRuleManager = createClusterFlowRuleManager(); - allocateMappedFileService = createAllocateMappedFileService(); - messageStore = createMessageStore(allocateMappedFileService); - zkClient = createCuratorZookeeperClient(); - transferQueueManager = createTransferQueueManager(); - consumerManager = createTransferQueueConsumerManager(); - fateRouterService = createFateRouterService(); - tokenApplyService = createTokenApplyService(); - pushService = createPushService(); - requestHandleInterceptor = createDefaulRequestInterceptor(); - routerInterceptor = createDefaultRouterInterceptor(fateRouterService); - unaryCallService = createUnaryCallService(requestHandleInterceptor,routerInterceptor); - proxyGrpcService = new ProxyGrpcService(pushService, unaryCallService); - transferServer = new OsxServer(); - defaultTokenService = createDefaultTokenService(); - tokenApplyService = createTokenApplyService(); - - - pcpGrpcService = createPcpGrpcService(); - techProviderRegister = createTechProviderRegister(); - if (!transferServer.start()) { - System.exit(-1); - } else { - - } - ; - - - } - - public static TechProviderRegister createTechProviderRegister() { - TechProviderRegister techProviderRegister = new TechProviderRegister(); - techProviderRegister.init(); - return techProviderRegister; - } - - public static PcpGrpcService createPcpGrpcService() { - return new PcpGrpcService(); - } - - public static CuratorZookeeperClient createCuratorZookeeperClient() { - if (MetaInfo.isCluster()) { - ZkConfig zkConfig = new ZkConfig(MetaInfo.PROPERTY_ZK_URL, 5000); - return new CuratorZookeeperClient(zkConfig); - } - return null; - } - - public static TokenApplyService createTokenApplyService() { - TokenApplyService tokenApplyService = new TokenApplyService(); - tokenApplyService.start(); - return tokenApplyService; - } - - public static DefaultTokenService createDefaultTokenService() { - return new DefaultTokenService(); - } - - public static ClusterFlowRuleManager createClusterFlowRuleManager() { - return new ClusterFlowRuleManager(); - } - - public static MessageStore createMessageStore( - AllocateMappedFileService allocateMappedFileService) { - // TransferQueueManager transferQueueManager ,AllocateMappedFileService allocateMappedFileService,String path){ - MessageStore messageStore = new MessageStore(allocateMappedFileService - , MetaInfo.PROPERTY_TRANSFER_FILE_PATH_PRE + File.separator + MetaInfo.INSTANCE_ID + File.separator + "message-store"); - messageStore.start(); - return messageStore; - - } - - - public static RequestHandleInterceptor createDefaulRequestInterceptor() { - RequestHandleInterceptor requestHandleInterceptor = new RequestHandleInterceptor(); - return requestHandleInterceptor; - } - public static RouterInterceptor createDefaultRouterInterceptor(FateRouterService fateRouterService){ - RouterInterceptor routerInterceptor = new RouterInterceptor(fateRouterService); - return routerInterceptor; - } - - - static FlowCounterManager createFlowCounterManager() { - FlowCounterManager flowCounterManager = new FlowCounterManager("transfer"); - flowCounterManager.startReport(); - return flowCounterManager; - } - - static UnaryCallService createUnaryCallService(RequestHandleInterceptor requestHandleInterceptor,RouterInterceptor routerInterceptor) { - UnaryCallService unaryCallService = new UnaryCallService(); - unaryCallService.addPreProcessor(requestHandleInterceptor); - unaryCallService.addPreProcessor(routerInterceptor); - return unaryCallService; - } - - static PushService createPushService() { - PushService pushService = new PushService(); - return pushService; - } - - static ConsumerManager createTransferQueueConsumerManager() { - ConsumerManager consumerManager = new ConsumerManager(); - return consumerManager; - } - - static FateRouterService createFateRouterService() { - DefaultFateRouterServiceImpl fateRouterService = new DefaultFateRouterServiceImpl(); - fateRouterService.start(); - return fateRouterService; - } - - static TransferQueueManager createTransferQueueManager() { - TransferQueueManager transferQueueManager = new TransferQueueManager(); - return transferQueueManager; - } - - static AllocateMappedFileService createAllocateMappedFileService() { - AllocateMappedFileService allocateMappedFileService = new AllocateMappedFileService(); - allocateMappedFileService.start(); - return allocateMappedFileService; - } - - -} diff --git a/java/osx/broker/src/main/java/com/osx/broker/grpc/PushRequestDataWrap.java b/java/osx/broker/src/main/java/com/osx/broker/grpc/PushRequestDataWrap.java deleted file mode 100644 index a03c4510ae..0000000000 --- a/java/osx/broker/src/main/java/com/osx/broker/grpc/PushRequestDataWrap.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.osx.broker.grpc; - -import com.webank.ai.eggroll.api.networking.proxy.Proxy; -import io.grpc.stub.StreamObserver; - -public class PushRequestDataWrap { - Proxy.Packet packet; - StreamObserver streamObserver; - - public Proxy.Packet getPacket() { - return packet; - } - - public void setPacket(Proxy.Packet packet) { - this.packet = packet; - } - - public StreamObserver getStreamObserver() { - return streamObserver; - } - - public void setStreamObserver(StreamObserver streamObserver) { - this.streamObserver = streamObserver; - } -} diff --git a/java/osx/broker/src/main/java/com/osx/broker/interceptor/RequestHandleInterceptor.java b/java/osx/broker/src/main/java/com/osx/broker/interceptor/RequestHandleInterceptor.java deleted file mode 100644 index 907375f4a4..0000000000 --- a/java/osx/broker/src/main/java/com/osx/broker/interceptor/RequestHandleInterceptor.java +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.osx.broker.interceptor; - -import com.google.protobuf.InvalidProtocolBufferException; -import com.osx.broker.grpc.PushRequestDataWrap; -import com.osx.broker.router.FateRouterService; -import com.osx.core.context.Context; -import com.osx.core.exceptions.NoRouterInfoException; -import com.osx.core.exceptions.ParameterException; -import com.osx.core.router.RouterInfo; -import com.osx.core.service.InboundPackage; -import com.osx.core.service.Interceptor; -import com.webank.ai.eggroll.api.networking.proxy.Proxy; -import com.webank.eggroll.core.transfer.Transfer; -import org.apache.commons.lang3.StringUtils; -import org.ppc.ptp.Osx; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.Map; - -public class RequestHandleInterceptor implements Interceptor { - Logger logger = LoggerFactory.getLogger(RequestHandleInterceptor.class); - - public void doPreProcess(Context context, InboundPackage inboundPackage) throws Exception { - Object body = inboundPackage.getBody(); - - if (body instanceof Osx.Inbound) { - Osx.Inbound request = (Osx.Inbound) body; - Map metaDataMap = request.getMetadataMap(); - String version = metaDataMap.get(Osx.Header.Version.name()); - String techProviderCode = metaDataMap.get(Osx.Header.TechProviderCode.name()); - String traceId = metaDataMap.get(Osx.Header.TraceID.name()); - String token = metaDataMap.get(Osx.Header.Token.name()); - String sourceNodeId = metaDataMap.get(Osx.Header.SourceNodeID.name()); - String targetNodeId = metaDataMap.get(Osx.Header.TargetNodeID.name()); - String sourceInstId = metaDataMap.get(Osx.Header.SourceInstID.name()); - String targetInstId = metaDataMap.get(Osx.Header.TargetInstID.name()); - String sessionId = metaDataMap.get(Osx.Header.SessionID.name()); - String targetMethod = metaDataMap.get(Osx.Metadata.TargetMethod.name()); - String targetComponentName = metaDataMap.get(Osx.Metadata.TargetComponentName.name()); - String sourceComponentName = metaDataMap.get(Osx.Metadata.SourceComponentName.name()); - String sourcePartyId = StringUtils.isEmpty(sourceInstId) ? sourceNodeId : sourceInstId + "." + sourceNodeId; - String targetPartyId = StringUtils.isEmpty(targetInstId) ? targetNodeId : targetInstId + "." + targetNodeId; - String topic = metaDataMap.get(Osx.Metadata.MessageTopic.name()); - String offsetString = metaDataMap.get(Osx.Metadata.MessageOffSet.name()); - Long offset = StringUtils.isNotEmpty(offsetString) ? Long.parseLong(offsetString) : null; - context.setDesPartyId(targetPartyId); - context.setSrcPartyId(sourcePartyId); - context.setTopic(topic); - context.setRequestMsgIndex(offset); - context.setSessionId(sessionId); - context.setDesComponent(targetComponentName); - context.setSrcComponent(sourceComponentName); - return; - } - else if (body instanceof PushRequestDataWrap) { - PushRequestDataWrap pushRequestDataWrap = (PushRequestDataWrap) body; - Proxy.Packet packet = pushRequestDataWrap.getPacket(); - handleProxyPacket(context ,packet); - return ; - }else if (body instanceof Proxy.Packet) { - handleProxyPacket(context ,(Proxy.Packet) body); - } else { - throw new ParameterException("invalid inbound type"); - } - - } - - private void handleProxyPacket(Context context ,Proxy.Packet packet){ - Proxy.Metadata metadata = packet.getHeader(); - Transfer.RollSiteHeader rollSiteHeader = null; - try { - rollSiteHeader = Transfer.RollSiteHeader.parseFrom(metadata.getExt()); - } catch (InvalidProtocolBufferException e) { - throw new ParameterException("invalid rollSiteHeader"); - } - String dstPartyId = rollSiteHeader.getDstPartyId(); - if (StringUtils.isEmpty(dstPartyId)) { - dstPartyId = metadata.getDst().getPartyId(); - } - - String desRole = metadata.getDst().getRole(); - String srcRole = metadata.getSrc().getRole(); - String srcPartyId = metadata.getSrc().getPartyId(); - context.setSrcPartyId(srcPartyId); - context.setDesPartyId(dstPartyId); - context.setSrcComponent(srcRole); - context.setDesComponent(desRole); - } - -} diff --git a/java/osx/broker/src/main/java/com/osx/broker/message/MessageDecoder.java b/java/osx/broker/src/main/java/com/osx/broker/message/MessageDecoder.java deleted file mode 100644 index 5303e63303..0000000000 --- a/java/osx/broker/src/main/java/com/osx/broker/message/MessageDecoder.java +++ /dev/null @@ -1,403 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.osx.broker.message; - -import com.osx.broker.constants.MessageFlag; -import com.osx.broker.util.MessageId; -import com.osx.broker.util.UtilAll; - -import java.net.*; -import java.nio.ByteBuffer; -import java.nio.charset.Charset; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -public class MessageDecoder { -// public final static int MSG_ID_LENGTH = 8 + 8; - - public final static Charset CHARSET_UTF8 = Charset.forName("UTF-8"); - public final static int MESSAGE_MAGIC_CODE_POSTION = 4; - public final static int MESSAGE_FLAG_POSTION = 16; - public final static int MESSAGE_PHYSIC_OFFSET_POSTION = 28; - // public final static int MESSAGE_STORE_TIMESTAMP_POSTION = 56; - public final static int MESSAGE_MAGIC_CODE = -626843481; - public static final char NAME_VALUE_SEPARATOR = 1; - public static final char PROPERTY_SEPARATOR = 2; - public static final int PHY_POS_POSITION = 4 + 4 + 4 + 4 + 4 + 8; - public static final int QUEUE_OFFSET_POSITION = 4 + 4 + 4 + 4 + 4; - public static final int SYSFLAG_POSITION = 4 + 4 + 4 + 4 + 4 + 8 + 8; - - - public static String createMessageId(final ByteBuffer input, final ByteBuffer addr, final long offset) { - input.flip(); - int msgIDLength = addr.limit() == 8 ? 16 : 28; - input.limit(msgIDLength); - - input.put(addr); - input.putLong(offset); - - return UtilAll.bytes2string(input.array()); - } - - public static MessageExtBrokerInner buildMessageExtBrokerInner(String topic, byte[] body, - int queueId, MessageFlag flag, String srcPartyId, String desPartyId) { - MessageExtBrokerInner messageExtBrokerInner = new MessageExtBrokerInner(); - messageExtBrokerInner.setQueueId(queueId); - messageExtBrokerInner.setBody(body); - messageExtBrokerInner.setTopic(topic); - messageExtBrokerInner.setFlag(flag.getFlag()); - messageExtBrokerInner.setBornTimestamp(System.currentTimeMillis()); - messageExtBrokerInner.setDesPartyId(srcPartyId); - messageExtBrokerInner.setSrcPartyId(desPartyId); - return messageExtBrokerInner; - } - - public static String createMessageId(SocketAddress socketAddress, long transactionIdhashCode) { - InetSocketAddress inetSocketAddress = (InetSocketAddress) socketAddress; - int msgIDLength = inetSocketAddress.getAddress() instanceof Inet4Address ? 16 : 28; - ByteBuffer byteBuffer = ByteBuffer.allocate(msgIDLength); - byteBuffer.put(inetSocketAddress.getAddress().getAddress()); - byteBuffer.putInt(inetSocketAddress.getPort()); - byteBuffer.putLong(transactionIdhashCode); - byteBuffer.flip(); - return UtilAll.bytes2string(byteBuffer.array()); - } - - public static MessageId decodeMessageId(final String msgId) throws UnknownHostException { - SocketAddress address; - long offset; - int ipLength = msgId.length() == 32 ? 4 * 2 : 16 * 2; - - byte[] ip = UtilAll.string2bytes(msgId.substring(0, ipLength)); - byte[] port = UtilAll.string2bytes(msgId.substring(ipLength, ipLength + 8)); - ByteBuffer bb = ByteBuffer.wrap(port); - int portInt = bb.getInt(0); - address = new InetSocketAddress(InetAddress.getByAddress(ip), portInt); - - // offset - byte[] data = UtilAll.string2bytes(msgId.substring(ipLength + 8, ipLength + 8 + 16)); - bb = ByteBuffer.wrap(data); - offset = bb.getLong(0); - - return new MessageId(address, offset); - } - - /** - * Just decode properties from msg buffer. - * - * @param byteBuffer msg commit log buffer. - */ - public static Map decodeProperties(ByteBuffer byteBuffer) { - int sysFlag = byteBuffer.getInt(SYSFLAG_POSITION); - int bornhostLength = (sysFlag & MessageSysFlag.BORNHOST_V6_FLAG) == 0 ? 8 : 20; - int storehostAddressLength = (sysFlag & MessageSysFlag.STOREHOSTADDRESS_V6_FLAG) == 0 ? 8 : 20; - int bodySizePosition = 4 // 1 TOTALSIZE - + 4 // 2 MAGICCODE - + 4 // 3 BODYCRC - + 4 // 4 QUEUEID - + 4 // 5 FLAG - + 8 // 6 QUEUEOFFSET - + 8 // 7 PHYSICALOFFSET - + 4 // 8 SYSFLAG - + 8 // 9 BORNTIMESTAMP - + bornhostLength // 10 BORNHOST - + 8 // 11 STORETIMESTAMP - + storehostAddressLength // 12 STOREHOSTADDRESS - + 4 // 13 RECONSUMETIMES - + 8; // 14 Prepared Transaction Offset - int topicLengthPosition = bodySizePosition + 4 + byteBuffer.getInt(bodySizePosition); - - byte topicLength = byteBuffer.get(topicLengthPosition); - - short propertiesLength = byteBuffer.getShort(topicLengthPosition + 1 + topicLength); - - byteBuffer.position(topicLengthPosition + 1 + topicLength + 2); - - if (propertiesLength > 0) { - byte[] properties = new byte[propertiesLength]; - byteBuffer.get(properties); - String propertiesString = new String(properties, CHARSET_UTF8); - Map map = string2messageProperties(propertiesString); - return map; - } - return null; - } - - public static MessageExt decode(ByteBuffer byteBuffer) { - return decode(byteBuffer, true, true, false); - } - - public static MessageExt clientDecode(ByteBuffer byteBuffer, final boolean readBody) { - return decode(byteBuffer, readBody, true, true); - } - - public static MessageExt decode(ByteBuffer byteBuffer, final boolean readBody) { - return decode(byteBuffer, readBody, true, false); - } - - public static MessageExt decode( - ByteBuffer byteBuffer, final boolean readBody, final boolean deCompressBody) { - return decode(byteBuffer, readBody, deCompressBody, false); - } - - public static MessageExt decode( - ByteBuffer byteBuffer, final boolean readBody, final boolean deCompressBody, final boolean isClient) { - try { - - MessageExt msgExt= new MessageExt(); - // 1 TOTALSIZE - int storeSize = byteBuffer.getInt(); - msgExt.setStoreSize(storeSize); - - // 2 MAGICCODE - byteBuffer.getInt(); - - // 3 BODYCRC - int bodyCRC = byteBuffer.getInt(); - msgExt.setBodyCRC(bodyCRC); - - // 4 QUEUEID - int queueId = byteBuffer.getInt(); - msgExt.setQueueId(queueId); - - // 5 FLAG - int flag = byteBuffer.getInt(); - msgExt.setFlag(flag); - - // 6 QUEUEOFFSET - int srcPartyIdLength = byteBuffer.get(); - if (srcPartyIdLength > 0) { - byte[] srcPartyBytes = new byte[srcPartyIdLength]; - byteBuffer.get(srcPartyBytes); - String srcPartyId = new String(srcPartyBytes); - msgExt.setSrcPartyId(srcPartyId); - } - -// long queueOffset = byteBuffer.getLong(); -// msgExt.setQueueOffset(queueOffset); - - // 7 PHYSICALOFFSET -// long physicOffset = byteBuffer.getLong(); -// msgExt.setCommitLogOffset(physicOffset); - - - int desPartyIdLength = byteBuffer.get(); - if (desPartyIdLength > 0) { - byte[] desPartyIdBytes = new byte[desPartyIdLength]; - byteBuffer.get(desPartyIdBytes); - String desPartyId = new String(desPartyIdBytes); - msgExt.setDesPartyId(desPartyId); - } - - - // 8 SYSFLAG - int sysFlag = byteBuffer.getInt(); - msgExt.setSysFlag(sysFlag); - - // 9 BORNTIMESTAMP - long bornTimeStamp = byteBuffer.getLong(); - msgExt.setBornTimestamp(bornTimeStamp); - - - // 15 BODY - int bodyLen = byteBuffer.getInt(); - if (bodyLen > 0) { - if (readBody) { - byte[] body = new byte[bodyLen]; - byteBuffer.get(body); - msgExt.setBody(body); - } else { - byteBuffer.position(byteBuffer.position() + bodyLen); - } - } - - // 16 TOPIC - short topicLen = byteBuffer.getShort(); - byte[] topic = new byte[(int) topicLen]; - byteBuffer.get(topic); - msgExt.setTopic(new String(topic, CHARSET_UTF8)); - - // 17 properties - short propertiesLength = byteBuffer.getShort(); - if (propertiesLength > 0) { - byte[] properties = new byte[propertiesLength]; - byteBuffer.get(properties); - String propertiesString = new String(properties, CHARSET_UTF8); - Map map = string2messageProperties(propertiesString); - msgExt.setProperties(map); - } - - return msgExt; - } catch (Exception e) { - e.printStackTrace(); - byteBuffer.position(byteBuffer.limit()); - } - - return null; - } - - public static List decodes(ByteBuffer byteBuffer) { - return decodes(byteBuffer, true); - } - - public static List decodes(ByteBuffer byteBuffer, final boolean readBody) { - List msgExts = new ArrayList(); - while (byteBuffer.hasRemaining()) { - MessageExt msgExt = clientDecode(byteBuffer, readBody); - if (null != msgExt) { - msgExts.add(msgExt); - } else { - break; - } - } - return msgExts; - } - - public static String messageProperties2String(Map properties) { - StringBuilder sb = new StringBuilder(); - if (properties != null) { - for (final Map.Entry entry : properties.entrySet()) { - final String name = entry.getKey(); - final String value = entry.getValue(); - - if (value == null) { - continue; - } - sb.append(name); - sb.append(NAME_VALUE_SEPARATOR); - sb.append(value); - sb.append(PROPERTY_SEPARATOR); - } - } - return sb.toString(); - } - - public static Map string2messageProperties(final String properties) { - Map map = new HashMap(); - if (properties != null) { - String[] items = properties.split(String.valueOf(PROPERTY_SEPARATOR)); - for (String i : items) { - String[] nv = i.split(String.valueOf(NAME_VALUE_SEPARATOR)); - if (2 == nv.length) { - map.put(nv[0], nv[1]); - } - } - } - - return map; - } - - public static byte[] encodeMessage(Message message) { - //only need flag, body, properties - byte[] body = message.getBody(); - int bodyLen = body.length; - String properties = messageProperties2String(message.getProperties()); - byte[] propertiesBytes = properties.getBytes(CHARSET_UTF8); - //note properties length must not more than Short.MAX - short propertiesLength = (short) propertiesBytes.length; - int sysFlag = message.getFlag(); - int storeSize = 4 // 1 TOTALSIZE - + 4 // 2 MAGICCOD - + 4 // 3 BODYCRC - + 4 // 4 FLAG - + 4 + bodyLen // 4 BODY - + 2 + propertiesLength; - ByteBuffer byteBuffer = ByteBuffer.allocate(storeSize); - // 1 TOTALSIZE - byteBuffer.putInt(storeSize); - - // 2 MAGICCODE - byteBuffer.putInt(0); - - // 3 BODYCRC - byteBuffer.putInt(0); - - // 4 FLAG - int flag = message.getFlag(); - byteBuffer.putInt(flag); - - // 5 BODY - byteBuffer.putInt(bodyLen); - byteBuffer.put(body); - - // 6 properties - byteBuffer.putShort(propertiesLength); - byteBuffer.put(propertiesBytes); - - return byteBuffer.array(); - } - - public static Message decodeMessage(ByteBuffer byteBuffer) throws Exception { - Message message = new Message(); - - // 1 TOTALSIZE - byteBuffer.getInt(); - - // 2 MAGICCODE - byteBuffer.getInt(); - - // 3 BODYCRC - byteBuffer.getInt(); - - // 4 FLAG - int flag = byteBuffer.getInt(); - message.setFlag(flag); - - // 5 BODY - int bodyLen = byteBuffer.getInt(); - byte[] body = new byte[bodyLen]; - byteBuffer.get(body); - message.setBody(body); - - // 6 properties - short propertiesLen = byteBuffer.getShort(); - byte[] propertiesBytes = new byte[propertiesLen]; - byteBuffer.get(propertiesBytes); - message.setProperties(string2messageProperties(new String(propertiesBytes, CHARSET_UTF8))); - - return message; - } - - public static byte[] encodeMessages(List messages) { - //TO DO refactor, accumulate in one buffer, avoid copies - List encodedMessages = new ArrayList(messages.size()); - int allSize = 0; - for (Message message : messages) { - byte[] tmp = encodeMessage(message); - encodedMessages.add(tmp); - allSize += tmp.length; - } - byte[] allBytes = new byte[allSize]; - int pos = 0; - for (byte[] bytes : encodedMessages) { - System.arraycopy(bytes, 0, allBytes, pos, bytes.length); - pos += bytes.length; - } - return allBytes; - } - - public static List decodeMessages(ByteBuffer byteBuffer) throws Exception { - //TO DO add a callback for processing, avoid creating lists - List msgs = new ArrayList(); - while (byteBuffer.hasRemaining()) { - Message msg = decodeMessage(byteBuffer); - msgs.add(msg); - } - return msgs; - } -} diff --git a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpClusterTokenApplyService.java b/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpClusterTokenApplyService.java deleted file mode 100644 index 72f3704d11..0000000000 --- a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpClusterTokenApplyService.java +++ /dev/null @@ -1,40 +0,0 @@ -package com.osx.broker.ptp; - -import com.google.protobuf.ByteString; -import com.osx.broker.ServiceContainer; -import com.osx.core.config.MetaInfo; -import com.osx.core.constant.ActionType; -import com.osx.core.context.Context; -import com.osx.core.exceptions.RemoteRpcException; -import com.osx.core.flow.*; -import com.osx.core.frame.GrpcConnectionFactory; -import com.osx.core.router.RouterInfo; -import com.osx.core.service.InboundPackage; -import com.osx.core.token.TokenRequest; -import com.osx.core.token.TokenResult; -import com.osx.core.token.TokenResultStatus; -import com.osx.core.utils.JsonUtil; -import io.grpc.ManagedChannel; -import org.apache.commons.lang3.StringUtils; -import org.ppc.ptp.Osx; -import org.ppc.ptp.PrivateTransferProtocolGrpc; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.nio.charset.StandardCharsets; - -public class PtpClusterTokenApplyService extends AbstractPtpServiceAdaptor { - - Logger logger = LoggerFactory.getLogger(PtpClusterTokenApplyService.class); - @Override - protected Osx.Outbound doService(Context context, InboundPackage data) { - context.setActionType(ActionType.CLUSTER_TOKEN_APPLY.getAlias()); - Osx.Inbound inbound = data.getBody(); - byte[] temp = inbound.getPayload().toByteArray(); - TokenRequest tokenRequest = JsonUtil.json2Object(temp, TokenRequest.class); - TokenResult tokenResult = ServiceContainer.defaultTokenService.requestToken(tokenRequest.getResource(),tokenRequest.getAcquireCount(),tokenRequest.isPrioritized()); - Osx.Outbound.Builder resultBuilder = Osx.Outbound.newBuilder(); - resultBuilder.setPayload(ByteString.copyFrom(JsonUtil.object2Json(tokenResult).getBytes(StandardCharsets.UTF_8))); - return resultBuilder.build(); - } -} diff --git a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpClusterTopicApplyService.java b/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpClusterTopicApplyService.java deleted file mode 100644 index 4a3183b05a..0000000000 --- a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpClusterTopicApplyService.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.osx.broker.ptp; - -import com.osx.broker.ServiceContainer; -import com.osx.core.constant.ActionType; -import com.osx.core.context.Context; -import com.osx.core.exceptions.ParameterException; -import com.osx.core.service.InboundPackage; -import org.apache.commons.lang3.StringUtils; -import org.ppc.ptp.Osx; - - -public class PtpClusterTopicApplyService extends AbstractPtpServiceAdaptor { - @Override - protected Osx.Outbound doService(Context context, InboundPackage data) { - context.setActionType(ActionType.TOPIC_APPLY.getAlias()); - Osx.Inbound inbound = data.getBody(); - String topic = inbound.getMetadataMap().get(Osx.Metadata.MessageTopic.name()); - String instanceId = inbound.getMetadataMap().get(Osx.Metadata.InstanceId.name()); - String sessionId = inbound.getMetadataMap().get(Osx.Header.SessionID.name()); - if(StringUtils.isEmpty(topic)) - { - throw new ParameterException("topic is null"); - } - if(StringUtils.isEmpty(instanceId)) - { - throw new ParameterException("instanceId is null"); - } - if(StringUtils.isEmpty(sessionId)) - { - throw new ParameterException("sessionId is null"); - } - context.setTopic(topic); - context.setSessionId(sessionId); - Osx.Outbound outbound = ServiceContainer.transferQueueManager.applyFromMaster( topic,sessionId,instanceId); - return outbound; - } - -} diff --git a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpConsumeService.java b/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpConsumeService.java deleted file mode 100644 index 7f67d66035..0000000000 --- a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpConsumeService.java +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.osx.broker.ptp; - -import com.google.common.base.Preconditions; -import com.osx.broker.ServiceContainer; -import com.osx.broker.consumer.UnaryConsumer; -import com.osx.broker.queue.CreateQueueResult; -import com.osx.broker.queue.TransferQueue; -import com.osx.broker.queue.TransferQueueApplyInfo; -import com.osx.broker.util.TransferUtil; -import com.osx.core.config.MetaInfo; -import com.osx.core.constant.ActionType; -import com.osx.core.constant.Dict; -import com.osx.core.constant.StatusCode; -import com.osx.core.context.Context; -import com.osx.core.exceptions.ParameterException; -import com.osx.core.exceptions.TransferQueueNotExistException; -import com.osx.core.frame.GrpcConnectionFactory; -import com.osx.core.router.RouterInfo; -import com.osx.core.service.InboundPackage; -import io.grpc.ManagedChannel; -import io.grpc.stub.StreamObserver; -import org.ppc.ptp.Osx; -import org.ppc.ptp.PrivateTransferProtocolGrpc; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class PtpConsumeService extends AbstractPtpServiceAdaptor { - - Logger logger = LoggerFactory.getLogger(PtpConsumeService.class); - - public PtpConsumeService() { - this.setServiceName("consume-unary"); - } - - @Override - protected Osx.Outbound doService(Context context, InboundPackage data) { - context.setActionType(ActionType.DEFUALT_CONSUME.getAlias()); - Osx.Inbound inbound = data.getBody(); - String topic = context.getTopic(); - TransferQueue transferQueue = ServiceContainer.transferQueueManager.getQueue(topic); - if (transferQueue == null) { - - if(MetaInfo.isCluster()) { - TransferQueueApplyInfo transferQueueApplyInfo = ServiceContainer.transferQueueManager.queryGlobleQueue(topic); - if (transferQueueApplyInfo == null) { - throw new TransferQueueNotExistException(); - } else { - String[] args = transferQueueApplyInfo.getInstanceId().split(":"); - String ip = args[0]; - int port = Integer.parseInt(args[1]); - RouterInfo routerInfo = new RouterInfo(); - routerInfo.setHost(ip); - routerInfo.setPort(port); - context.setRouterInfo(routerInfo); - return redirect(context, inbound); - } - }else{ - /** - * 单机版直接创建队列 - */ - logger.warn("create topic {} by consume request ",topic); - CreateQueueResult createQueueResult = ServiceContainer.transferQueueManager.createNewQueue( topic, context.getSessionId(), true); - if(createQueueResult.getTransferQueue()==null){ - throw new TransferQueueNotExistException(); - } - } - } - StreamObserver streamObserver = (StreamObserver) context.getData(Dict.RESPONSE_STREAM_OBSERVER); - Long offset = context.getRequestMsgIndex(); - Preconditions.checkArgument(offset != null); - if(offset==null){ - throw new ParameterException("offset is null"); - } - if (offset > 0) { - context.setActionType(ActionType.CUSTOMER_CONSUME.getAlias()); - } - UnaryConsumer consumer = ServiceContainer.consumerManager.getOrCreateUnaryConsumer(topic); - TransferQueue.TransferQueueConsumeResult transferQueueConsumeResult = consumer.consume(context, offset); - context.setReturnCode(transferQueueConsumeResult.getCode()); - if (transferQueueConsumeResult.getCode().equals(StatusCode.CONSUME_NO_MESSAGE)) { - /* - * 由其他扫描线程应答 - */ - if (offset < 0) { - UnaryConsumer.LongPullingHold longPullingHold = new UnaryConsumer.LongPullingHold(); - longPullingHold.setNeedOffset(offset); - longPullingHold.setStreamObserver(streamObserver); - longPullingHold.setContext(context.subContext()); - consumer.addLongPullingQueue(longPullingHold); - return null; - } - } - Osx.Outbound consumeResponse = TransferUtil.buildResponse(transferQueueConsumeResult.getCode(), "", transferQueueConsumeResult); - return consumeResponse; - - } - - private Osx.Outbound redirect(Context context, Osx.Inbound inbound) { - ManagedChannel managedChannel = GrpcConnectionFactory.createManagedChannel(context.getRouterInfo(),true); - context.setActionType(ActionType.REDIRECT_CONSUME.getAlias()); - PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub stub = PrivateTransferProtocolGrpc.newBlockingStub(managedChannel); - return stub.invoke(inbound); - } - - -} diff --git a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpProduceService.java b/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpProduceService.java deleted file mode 100644 index d832176a2c..0000000000 --- a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpProduceService.java +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.osx.broker.ptp; - -import com.osx.broker.ServiceContainer; -import com.osx.broker.constants.MessageFlag; -import com.osx.broker.message.MessageDecoder; -import com.osx.broker.message.MessageExtBrokerInner; -import com.osx.broker.queue.CreateQueueResult; -import com.osx.broker.queue.PutMessageResult; -import com.osx.broker.queue.PutMessageStatus; -import com.osx.broker.queue.TransferQueue; -import com.osx.broker.util.TransferUtil; -import com.osx.core.config.MetaInfo; -import com.osx.core.constant.ActionType; -import com.osx.core.constant.DeployMode; -import com.osx.core.constant.Dict; -import com.osx.core.constant.StatusCode; -import com.osx.core.context.Context; -import com.osx.core.exceptions.*; -import com.osx.core.router.RouterInfo; -import com.osx.core.service.InboundPackage; -import org.apache.commons.lang3.StringUtils; -import org.ppc.ptp.Osx; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import static com.osx.broker.util.TransferUtil.redirect; - -public class PtpProduceService extends AbstractPtpServiceAdaptor { - - Logger logger = LoggerFactory.getLogger(PtpProduceService.class); - - @Override - protected Osx.Outbound doService(Context context, InboundPackage data) { - - String topic = context.getTopic(); - boolean isDst = false; - RouterInfo routerInfo = context.getRouterInfo(); - String srcPartyId = context.getSrcPartyId(); - String sessionId = context.getSessionId(); - Osx.Inbound produceRequest = data.getBody(); - if (MetaInfo.PROPERTY_SELF_PARTY.contains(context.getDesPartyId())) { - isDst = true; - } - if (!isDst) { - /** - * 向外转发 - */ - return redirect(context, produceRequest, routerInfo, false); - } else { - /** - * 本地处理 - */ - if (StringUtils.isEmpty(topic)) { - throw new ParameterException(StatusCode.PARAM_ERROR, "topic is null"); - } - if (StringUtils.isEmpty(sessionId)) { - throw new ParameterException(StatusCode.PARAM_ERROR, "sessionId is null"); - } - context.setActionType(ActionType.MSG_DOWNLOAD.getAlias()); - context.setRouterInfo(null); - TransferQueue transferQueue = ServiceContainer.transferQueueManager.getQueue(topic); - CreateQueueResult createQueueResult = null; - if( transferQueue==null) { - createQueueResult = ServiceContainer.transferQueueManager.createNewQueue(topic, sessionId, false); - if (createQueueResult == null) { - throw new CreateTopicErrorException("create topic " + topic + " error"); - } - transferQueue = createQueueResult.getTransferQueue(); - } - String resource = TransferUtil.buildResource(produceRequest); - int dataSize = produceRequest.getSerializedSize(); - ServiceContainer.tokenApplyService.applyToken(context,resource,dataSize); - ServiceContainer.flowCounterManager.pass(resource,dataSize); - if (transferQueue != null) { - byte[] msgBytes = produceRequest.getPayload().toByteArray(); - MessageExtBrokerInner messageExtBrokerInner = MessageDecoder.buildMessageExtBrokerInner(topic, msgBytes, 0, MessageFlag.MSG, context.getSrcPartyId(), - context.getDesPartyId()); - PutMessageResult putMessageResult = transferQueue.putMessage(messageExtBrokerInner); - if (putMessageResult.getPutMessageStatus() != PutMessageStatus.PUT_OK) { - throw new PutMessageException("put status " + putMessageResult.getPutMessageStatus()); - } - long logicOffset = putMessageResult.getMsgLogicOffset(); - context.setCurrentMsgIndex(logicOffset); - Osx.Outbound.Builder outBoundBuilder = Osx.Outbound.newBuilder(); - outBoundBuilder.setCode(StatusCode.SUCCESS); - outBoundBuilder.setMessage(Dict.SUCCESS); - return outBoundBuilder.build(); - } else { - /** - * 集群内转发 - */ - if (MetaInfo.PROPERTY_DEPLOY_MODE.equals(DeployMode.cluster.name())) { - RouterInfo redirectRouterInfo = new RouterInfo(); - String redirectIp = createQueueResult.getRedirectIp(); - int redirectPort = createQueueResult.getPort(); - if (StringUtils.isEmpty(redirectIp) || redirectPort == 0) { - logger.error("invalid redirect info {}:{}", redirectIp, redirectPort); - throw new InvalidRedirectInfoException(); - } - redirectRouterInfo.setHost(redirectIp); - redirectRouterInfo.setPort(redirectPort); - context.setRouterInfo(redirectRouterInfo); - context.setActionType(ActionType.INNER_REDIRECT.getAlias()); - return redirect(context, produceRequest, redirectRouterInfo, true); - } else { - logger.error("create topic {} error", topic); - throw new ProduceMsgExcption(); - } - } - } - } - - -} diff --git a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpUnaryCallService.java b/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpUnaryCallService.java deleted file mode 100644 index 60136ab9c3..0000000000 --- a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpUnaryCallService.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.osx.broker.ptp; - -import com.osx.core.constant.ActionType; -import com.osx.core.context.Context; -import com.osx.core.exceptions.RemoteRpcException; -import com.osx.core.frame.GrpcConnectionFactory; -import com.osx.core.router.RouterInfo; -import com.osx.core.service.InboundPackage; -import io.grpc.ManagedChannel; -import org.ppc.ptp.Osx; -import org.ppc.ptp.PrivateTransferProtocolGrpc; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - - -public class PtpUnaryCallService extends AbstractPtpServiceAdaptor { - - Logger logger = LoggerFactory.getLogger(PtpUnaryCallService.class); - @Override - protected Osx.Outbound doService(Context context, InboundPackage data) { - context.setActionType(ActionType.UNARY_CALL_NEW.getAlias()); - RouterInfo routerInfo = context.getRouterInfo(); - Osx.Inbound inbound = data.getBody(); - String host = routerInfo.getHost(); - Integer port = routerInfo.getPort(); - ManagedChannel managedChannel=GrpcConnectionFactory.createManagedChannel(routerInfo,true); - PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub blockingStub = PrivateTransferProtocolGrpc.newBlockingStub(managedChannel); - Osx.Outbound outbound= null; - try { - outbound = blockingStub.invoke(inbound); - }catch(io.grpc.StatusRuntimeException e){ - logger.error("remote rpc error :router info {}",routerInfo); - throw new RemoteRpcException("remote rpc error"); - } - return outbound; - } - -} diff --git a/java/osx/broker/src/main/java/com/osx/broker/router/DefaultFateRouterServiceImpl.java b/java/osx/broker/src/main/java/com/osx/broker/router/DefaultFateRouterServiceImpl.java deleted file mode 100644 index e09f9ea1ec..0000000000 --- a/java/osx/broker/src/main/java/com/osx/broker/router/DefaultFateRouterServiceImpl.java +++ /dev/null @@ -1,260 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.osx.broker.router; -import com.google.common.base.Preconditions; -import com.google.common.collect.Maps; -import com.google.gson.JsonArray; -import com.google.gson.JsonElement; -import com.google.gson.JsonObject; -import com.google.gson.JsonParser; -import com.google.protobuf.InvalidProtocolBufferException; -import com.osx.core.constant.NegotiationType; -import com.osx.core.datasource.FileRefreshableDataSource; -import com.osx.core.flow.PropertyListener; -import com.osx.core.router.RouterInfo; -import com.osx.core.utils.JsonUtil; -import com.webank.ai.eggroll.api.networking.proxy.Proxy; -import com.webank.eggroll.core.transfer.Transfer; -import org.apache.commons.lang3.StringUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.File; -import java.io.FileNotFoundException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; - -public class DefaultFateRouterServiceImpl implements FateRouterService { - - private static final String IP = "ip"; - private static final String PORT = "port"; - private static final String URL = "url"; - private static final String USE_SSL = "useSSL"; - private static final String HOSTNAME = "hostname"; - private static final String negotiationType = "negotiationType"; - private static final String certChainFile = "certChainFile"; - private static final String privateKeyFile = "privateKeyFile"; - private static final String caFile = "caFile"; - private static final String DEFAULT = "default"; - private static final String VERSION = "version"; - Logger logger = LoggerFactory.getLogger(DefaultFateRouterServiceImpl.class); - Map> routerInfoMap = new ConcurrentHashMap>(); - Map>> endPointMap = new ConcurrentHashMap<>(); - FileRefreshableDataSource fileRefreshableDataSource; - - @Override - public RouterInfo route(Proxy.Packet packet) { - Preconditions.checkArgument(packet != null); - RouterInfo routerInfo = null; - Proxy.Metadata metadata = packet.getHeader(); - Transfer.RollSiteHeader rollSiteHeader = null; - try { - rollSiteHeader = Transfer.RollSiteHeader.parseFrom(metadata.getExt()); - } catch (InvalidProtocolBufferException e) { - e.printStackTrace(); - } - String dstPartyId = rollSiteHeader.getDstPartyId(); - - if (StringUtils.isEmpty(dstPartyId)) { - dstPartyId = metadata.getDst().getPartyId(); - } - dstPartyId = metadata.getDst().getPartyId(); - String desRole = metadata.getDst().getRole(); - String srcRole = metadata.getSrc().getRole(); - String srcPartyId = metadata.getSrc().getPartyId(); - routerInfo = this.route(srcPartyId, srcRole, dstPartyId, desRole); - //logger.info("query router info {} to {} {} return {}", srcPartyId, dstPartyId, desRole, routerInfo); - return routerInfo; - } - - - public RouterInfo route(String srcPartyId, String srcRole, String dstPartyId, String desRole) { - RouterInfo routerInfo = null; - Map> partyIdMap = this.endPointMap.get(dstPartyId); - if (partyIdMap != null) { - - if (StringUtils.isNotEmpty(desRole)&&partyIdMap.get(desRole) != null) { - List ips = partyIdMap.getOrDefault(desRole, null); - if (ips != null && ips.size() > 0) { - Map endpoint = ips.get((int) (System.currentTimeMillis() % ips.size())); - routerInfo = new RouterInfo(); - routerInfo.setHost(endpoint.get(IP).toString()); - routerInfo.setPort(((Number) endpoint.get(PORT)).intValue()); - routerInfo.setDesPartyId(dstPartyId); - routerInfo.setSourcePartyId(srcPartyId); - routerInfo.setVersion(endpoint.get(VERSION) != null ? endpoint.get(VERSION).toString() : null); - routerInfo.setNegotiationType(endpoint.get(negotiationType)!=null?endpoint.get(negotiationType).toString():""); - } - } else { - - List ips = partyIdMap.getOrDefault(DEFAULT, null); - if (ips != null && ips.size() > 0) { - Map endpoint = ips.get((int) (System.currentTimeMillis() % ips.size())); - routerInfo = new RouterInfo(); - routerInfo.setHost(endpoint.get(IP).toString()); - routerInfo.setPort(((Number) endpoint.get(PORT)).intValue()); - routerInfo.setDesPartyId(dstPartyId); - routerInfo.setSourcePartyId(srcPartyId); - routerInfo.setVersion(endpoint.get(VERSION) != null ? endpoint.get(VERSION).toString() : null); - routerInfo.setNegotiationType(endpoint.get(negotiationType)!=null?endpoint.get(negotiationType).toString():""); - } - if(StringUtils.isNotEmpty(desRole)){ - logger.warn("role {} is not found,return default router info ",desRole); - } - } - } - return routerInfo; - } - - - - Map>> initRouteTable(Map confJson) { - // BasicMeta.Endpoint.Builder endpointBuilder = BasicMeta.Endpoint.newBuilder(); - Map>> newRouteTable = new ConcurrentHashMap<>(); - // loop through coordinator - - confJson.forEach((k,v)->{ - String coordinatorKey = k.toString(); - Map coordinatorValue = (Map)v; - - Map> serviceTable = newRouteTable.get(coordinatorKey); - if (serviceTable == null) { - serviceTable = new ConcurrentHashMap<>(4); - newRouteTable.put(coordinatorKey, serviceTable); - } - // loop through role in coordinator - for (Object roleEntryObject : coordinatorValue.entrySet()) { - Map.Entry roleEntry = (Map.Entry)roleEntryObject; - String roleKey = roleEntry.getKey().toString(); - if (roleKey.equals("createTime") || roleKey.equals("updateTime")) { - continue; - } - List roleValue = (List)roleEntry.getValue(); - - List endpoints = serviceTable.get(roleKey); - if (endpoints == null) { - endpoints = new ArrayList<>(); - serviceTable.put(roleKey, endpoints); - } - - // loop through endpoints - for (Object endpointElement : roleValue) { - - Map element = Maps.newHashMap(); - - Map endpointJson = (Map)endpointElement; - - if (endpointJson.get(IP)!=null) { - String targetIp = endpointJson.get(IP).toString(); - element.put(IP, targetIp); - } - - if (endpointJson.get(PORT)!=null) { - int targetPort = Integer.parseInt(endpointJson.get(PORT).toString()); - element.put(PORT, targetPort); - } -// if(endpointJson.has(URL)){ -// String url = endpointJson.get(URL).getAsString(); -// endpointBuilder.setUrl(url); -// } - - if (endpointJson.get(USE_SSL)!=null) { - boolean targetUseSSL = Boolean.getBoolean(endpointJson.get(USE_SSL).toString()); - element.put(USE_SSL, targetUseSSL); - } - - if (endpointJson.get(HOSTNAME)!=null) { - String targetHostname = endpointJson.get(HOSTNAME).toString(); - element.put(HOSTNAME, targetHostname); - } - - if (endpointJson.get(negotiationType)!=null) { - String targetNegotiationType = endpointJson.get(negotiationType).toString(); - element.put(negotiationType, targetNegotiationType); - }else{ - element.put(negotiationType, NegotiationType.PLAINTEXT); - } - - if (endpointJson.get(certChainFile)!=null) { - String targetCertChainFile = endpointJson.get(certChainFile).toString(); - element.put(certChainFile, targetCertChainFile); - } - - if (endpointJson.get(privateKeyFile)!=null) { - String targetPrivateKeyFile = endpointJson.get(privateKeyFile).toString(); - element.put(privateKeyFile, targetPrivateKeyFile); - } - - if (endpointJson.get(caFile)!=null) { - String targetCaFile = endpointJson.get(caFile).toString(); - element.put(caFile, targetCaFile); - } - if (endpointJson.get(VERSION)!=null) { - String targetVersion = endpointJson.get(VERSION).toString(); - element.put(VERSION, targetVersion); - } - - //BasicMeta.Endpoint endpoint = endpointBuilder.build(); - endpoints.add(element); - } - } - - }); - - return newRouteTable; - } - - public void start() { - String currentPath = Thread.currentThread().getContextClassLoader().getResource("route_table.json").getPath(); - logger.info("load router file {}", currentPath); - File confFile = new File(currentPath); - FileRefreshableDataSource fileRefreshableDataSource = null; - try { - fileRefreshableDataSource = new FileRefreshableDataSource(confFile, (source) -> { - logger.info("read route_table {}", source); - return source; - }); - fileRefreshableDataSource.getProperty().addListener(new RouterTableListener()); - - } catch (FileNotFoundException e) { - logger.error("router file {} is not found", currentPath); - } - } - - private class RouterTableListener implements PropertyListener { - - @Override - public void configUpdate(String value) { - // logger.info("fire router table update {}",value); - Map confJson = JsonUtil.json2Object(value,Map.class); - // JsonObject confJson = JsonParser.parseString(value).getAsJsonObject(); - Map content =(Map) confJson.get("route_table"); - endPointMap = initRouteTable(content); - } - - @Override - public void configLoad(String value) { - Map confJson = JsonUtil.json2Object(value,Map.class); - Map content =(Map) confJson.get("route_table"); - endPointMap = initRouteTable(content); - logger.info("load router config {}", JsonUtil.formatJson(JsonUtil.object2Json(endPointMap))); - } - } - - -} diff --git a/java/osx/broker/src/main/java/com/osx/broker/server/OsxServer.java b/java/osx/broker/src/main/java/com/osx/broker/server/OsxServer.java deleted file mode 100644 index b78a399b6f..0000000000 --- a/java/osx/broker/src/main/java/com/osx/broker/server/OsxServer.java +++ /dev/null @@ -1,193 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.osx.broker.server; -import com.osx.broker.ServiceContainer; -import com.osx.broker.grpc.ContextPrepareInterceptor; -import com.osx.broker.grpc.ServiceExceptionHandler; -import com.osx.broker.http.DispatchServlet; -import com.osx.core.config.MetaInfo; -import io.grpc.ServerBuilder; -import io.grpc.ServerInterceptors; -import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts; -import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder; -import io.grpc.netty.shaded.io.netty.handler.ssl.ClientAuth; -import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder; -import io.grpc.netty.shaded.io.netty.handler.ssl.SslProvider; -import org.apache.commons.lang3.StringUtils; -import org.eclipse.jetty.server.*; -import org.eclipse.jetty.servlet.ServletContextHandler; -import org.eclipse.jetty.servlet.ServletHolder; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import javax.net.ssl.SSLException; -import java.io.File; -import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; - -import static com.osx.core.config.MetaInfo.PROPERTY_OPEN_GRPC_TLS_SERVER; - -/** - * http1.X + grpc - */ -public class OsxServer { - - - Logger logger = LoggerFactory.getLogger(OsxServer.class); - io.grpc.Server server; - io.grpc.Server tlsServer; - org.eclipse.jetty.server.Server httpServer; - - private void init() { - server = buildServer(); - if(MetaInfo.PROPERTY_OPEN_HTTP_SERVER) { - httpServer = buildHttpServer(); - } - - // tlsServer = buildTlsServer(); - } - - public Server buildHttpServer(){ - Server server = new Server(); - try { - int acceptors = 1; - int selectors = 1; - ServerConnector connector = new ServerConnector(server, acceptors, selectors, new HttpConnectionFactory()); - // logger.info("http server try to start listen port {}", MetaInfo.PROPERTY_HTTP_PORT); - connector.setPort(MetaInfo.PROPERTY_HTTP_PORT); - connector.setHost("127.0.0.1"); - connector.setAcceptQueueSize(128); - server.addConnector(connector); - server.setHandler(buildServlet()); - return server; - } catch (Exception e) { - logger.error("build http server error",e); - } - return null; - } - - ServletContextHandler buildServlet(){ - ServletContextHandler context = new ServletContextHandler(); - context.setContextPath(MetaInfo.PROPERTY_HTTP_CONTEXT_PATH); - ServletHolder servletHolder = context.addServlet(DispatchServlet.class, MetaInfo.PROPERTY_HTTP_SERVLET_PATH);//"/*" - return context; - } - - - - public boolean start() { - - init(); - try { - server.start(); - logger.info("listen grpc port {} success", MetaInfo.PROPERTY_GRPC_PORT); - } catch (Exception e) { - if (e instanceof java.net.BindException||e.getCause() instanceof java.net.BindException) { - logger.error("port {} already in use, please try to choose another one !!!!", MetaInfo.PROPERTY_GRPC_PORT); - } - return false; - } - try{ - if(httpServer!=null){ - - httpServer.start(); - logger.info("listen http port {} success", MetaInfo.PROPERTY_HTTP_PORT); - } - } - catch (Exception e) { - if (e instanceof java.net.BindException||e.getCause() instanceof java.net.BindException) { - logger.error("port {} already in use, please try to choose another one !!!!", MetaInfo.PROPERTY_GRPC_PORT); - } - return false; - } - try{ - if (tlsServer != null) { - logger.info("grpc tls server try to start, listen port {}", MetaInfo.PROPERTY_GRPC_TLS_PORT); - tlsServer.start(); - logger.info("listen grpc tls port {} success", MetaInfo.PROPERTY_GRPC_TLS_PORT); - } - } catch (Exception e) { - if (e instanceof java.net.BindException||e.getCause() instanceof java.net.BindException) { - logger.error("port {} already in use, please try to choose another one !!!!", MetaInfo.PROPERTY_GRPC_TLS_PORT); - } - return false; - } - return true; - } - - private io.grpc.Server buildTlsServer(){ - String certChainFilePath = MetaInfo.PROPERTY_SERVER_CERTCHAIN_FILE; - String privateKeyFilePath = MetaInfo.PROPERTY_SERVER_PRIVATEKEY_FILE; - String trustCertCollectionFilePath = MetaInfo.PROPERTY_SERVER_CA_FILE; - - if(PROPERTY_OPEN_GRPC_TLS_SERVER && StringUtils.isNotBlank(certChainFilePath) - && StringUtils.isNotBlank(privateKeyFilePath) && StringUtils.isNotBlank(trustCertCollectionFilePath)) { - try { - int port = MetaInfo.PROPERTY_GRPC_TLS_PORT; - NettyServerBuilder serverBuilder = (NettyServerBuilder) ServerBuilder.forPort(port); - SslContextBuilder sslContextBuilder = GrpcSslContexts.forServer(new File(certChainFilePath), new File(privateKeyFilePath)) - .trustManager(new File(trustCertCollectionFilePath)) - .clientAuth(ClientAuth.REQUIRE) - .sessionTimeout(3600 << 4) - .sessionCacheSize(65536); - GrpcSslContexts.configure(sslContextBuilder, SslProvider.OPENSSL); - serverBuilder.sslContext(sslContextBuilder.build()); - logger.info("running in secure mode. server crt path: {}, server key path: {}, ca crt path: {}.", - certChainFilePath, privateKeyFilePath, trustCertCollectionFilePath); - //serverBuilder.executor(executor); - serverBuilder.addService(ServerInterceptors.intercept(ServiceContainer.proxyGrpcService, new ServiceExceptionHandler(), new ContextPrepareInterceptor())); - serverBuilder.addService(ServerInterceptors.intercept(ServiceContainer.pcpGrpcService, new ServiceExceptionHandler(), new ContextPrepareInterceptor())); - return serverBuilder.build(); - } catch (SSLException e) { - throw new SecurityException(e); - } - - - } - return null; - } - - - private io.grpc.Server buildServer() { - NettyServerBuilder nettyServerBuilder = (NettyServerBuilder) ServerBuilder.forPort(MetaInfo.PROPERTY_GRPC_PORT); - nettyServerBuilder.addService(ServerInterceptors.intercept(ServiceContainer.proxyGrpcService, new ServiceExceptionHandler(), new ContextPrepareInterceptor())); - nettyServerBuilder.addService(ServerInterceptors.intercept(ServiceContainer.pcpGrpcService, new ServiceExceptionHandler(), new ContextPrepareInterceptor())); - nettyServerBuilder - .executor(Executors.newCachedThreadPool()) - .maxConcurrentCallsPerConnection(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONCURRENT_CALL_PER_CONNECTION) - .maxInboundMessageSize(MetaInfo.PROPERTY_GRPC_SERVER_MAX_INBOUND_MESSAGE_SIZE) - .maxInboundMetadataSize(MetaInfo.PROPERTY_GRPC_SERVER_MAX_INBOUND_METADATA_SIZE) - .flowControlWindow(MetaInfo.PROPERTY_GRPC_SERVER_FLOW_CONTROL_WINDOW); - - if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC > 0) - nettyServerBuilder.keepAliveTime(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC, TimeUnit.SECONDS); - if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC > 0) - nettyServerBuilder.keepAliveTimeout(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC, TimeUnit.SECONDS); - if (MetaInfo.PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC > 0) { - - nettyServerBuilder.permitKeepAliveTime(MetaInfo.PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC, TimeUnit.SECONDS); - } - if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED) - nettyServerBuilder.permitKeepAliveWithoutCalls(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED); - if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC > 0) - nettyServerBuilder.maxConnectionIdle(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC, TimeUnit.SECONDS); - if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC > 0) - nettyServerBuilder.maxConnectionAge(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC, TimeUnit.SECONDS); - if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC > 0) - nettyServerBuilder.maxConnectionAgeGrace(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC, TimeUnit.SECONDS); - return nettyServerBuilder.build(); - } -} diff --git a/java/osx/broker/src/main/java/com/osx/broker/service/UnaryCallService.java b/java/osx/broker/src/main/java/com/osx/broker/service/UnaryCallService.java deleted file mode 100644 index 3c650bd18b..0000000000 --- a/java/osx/broker/src/main/java/com/osx/broker/service/UnaryCallService.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.osx.broker.service; -import com.osx.core.constant.ActionType; -import com.osx.core.context.Context; -import com.osx.core.exceptions.ExceptionInfo; -import com.osx.core.frame.GrpcConnectionFactory; -import com.osx.core.service.AbstractServiceAdaptor; -import com.osx.core.service.InboundPackage; -import com.webank.ai.eggroll.api.networking.proxy.DataTransferServiceGrpc; -import com.webank.ai.eggroll.api.networking.proxy.Proxy; -import io.grpc.Deadline; -import io.grpc.ManagedChannel; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * 用于兼容旧版FATE - */ -public class UnaryCallService extends AbstractServiceAdaptor { - - Logger logger = LoggerFactory.getLogger(UnaryCallService.class); - - - public UnaryCallService() { - - } - - @Override - protected Proxy.Packet doService(Context context, InboundPackage data) { - context.setActionType(ActionType.UNARY_CALL.getAlias()); - Proxy.Packet req = (Proxy.Packet) data.getBody(); - Proxy.Packet resp = unaryCall(context, req); - //logger.info("uncary req {} resp {}", req, resp); - return resp; - } - - - protected Proxy.Packet transformExceptionInfo(Context context, ExceptionInfo exceptionInfo) { - return null; - } - - /** - * 非流式传输 - * - * @param context - * @param - */ - public Proxy.Packet unaryCall(Context context, Proxy.Packet req) { - Deadline endDeadline = null; - boolean isPolling = false; - - ManagedChannel managedChannel = GrpcConnectionFactory.createManagedChannel(context.getRouterInfo(),true); - DataTransferServiceGrpc.DataTransferServiceBlockingStub stub = DataTransferServiceGrpc.newBlockingStub(managedChannel); - Proxy.Packet result = null; - result = stub.unaryCall(req); - return result; - } - - -} diff --git a/java/osx/broker/src/main/java/com/osx/broker/util/TransferUtil.java b/java/osx/broker/src/main/java/com/osx/broker/util/TransferUtil.java deleted file mode 100644 index c757c40f64..0000000000 --- a/java/osx/broker/src/main/java/com/osx/broker/util/TransferUtil.java +++ /dev/null @@ -1,295 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.osx.broker.util; - - -import com.google.common.collect.Maps; -import com.google.protobuf.ByteString; -import com.google.protobuf.InvalidProtocolBufferException; -import com.osx.broker.eggroll.ErRollSiteHeader; -import com.osx.broker.http.HttpClientPool; -import com.osx.broker.http.PtpHttpResponse; -import com.osx.broker.queue.TransferQueue; -import com.osx.core.config.MetaInfo; -import com.osx.core.constant.Dict; -import com.osx.core.constant.Protocol; -import com.osx.core.constant.PtpHttpHeader; -import com.osx.core.constant.StatusCode; -import com.osx.core.context.Context; -import com.osx.core.exceptions.ConfigErrorException; -import com.osx.core.exceptions.NoRouterInfoException; -import com.osx.core.exceptions.RemoteRpcException; -import com.osx.core.frame.GrpcConnectionFactory; -import com.osx.core.router.RouterInfo; -import com.webank.ai.eggroll.api.networking.proxy.Proxy; -import com.webank.eggroll.core.transfer.Transfer; -import io.grpc.ManagedChannel; -import io.grpc.StatusRuntimeException; -import org.apache.commons.lang3.StringUtils; -import org.ppc.ptp.Osx; -import org.ppc.ptp.PrivateTransferProtocolGrpc; - -import javax.servlet.http.HttpServletRequest; -import java.util.Map; - -public class TransferUtil { - - /** - * 2.0之前版本 - * - * @param version - * @return - */ - public static boolean isOldVersionFate(String version) { - - try{ - if (StringUtils.isEmpty(version)) - version= MetaInfo.PROPERTY_DEFAULT_CLIENT_VERSION; - String firstVersion = version.substring(0,1); - if (Integer.parseInt(firstVersion) >= 2) { - return false; - } else { - return true; - } - }catch(NumberFormatException e){ - throw new ConfigErrorException("remote version config error : "+version); - } - - } - - - public static String buildResource(Osx.Inbound inbound){ - String sourceNodeId = inbound.getMetadataMap().get(Osx.Header.SourceNodeID.name()); - String targetNodeId = inbound.getMetadataMap().get(Osx.Header.TargetNodeID.name()); - String sourceInstId = inbound.getMetadataMap().get(Osx.Header.SourceInstID.name()); - if(sourceInstId==null){ - sourceInstId=""; - } - String targetInstId = inbound.getMetadataMap().get(Osx.Header.TargetInstID.name()); - if(targetInstId==null){ - targetInstId=""; - } - StringBuffer sb = new StringBuffer(); - sb.append(sourceInstId).append(sourceNodeId).append("_").append(targetInstId).append(targetNodeId); - return sb.toString(); - } - - public static Proxy.Metadata buildProxyMetadataFromOutbound(Osx.Outbound outbound) { - try { - return Proxy.Metadata.parseFrom(outbound.getPayload()); - } catch (InvalidProtocolBufferException e) { - - } - return null; - } - public static Osx.Outbound buildOutboundFromProxyMetadata(Proxy.Metadata metadata) { - return Osx.Outbound.newBuilder().setPayload(metadata.toByteString()).build(); - - } - - public static Proxy.Packet parsePacketFromInbound(Osx.Inbound inbound){ - try { - return Proxy.Packet.parseFrom(inbound.getPayload()); - } catch (InvalidProtocolBufferException e) { - return null; - } - } - - public static Osx.Inbound buildInboundFromPushingPacket(Proxy.Packet packet, String targetMethod) { - Osx.Inbound.Builder inboundBuilder = Osx.Inbound.newBuilder(); - Proxy.Topic srcTopic = packet.getHeader().getSrc(); - String srcPartyId = srcTopic.getPartyId(); - Proxy.Metadata metadata = packet.getHeader(); - ByteString encodedRollSiteHeader = metadata.getExt(); - ErRollSiteHeader rsHeader = null; - try { - rsHeader = ErRollSiteHeader.parseFromPb(Transfer.RollSiteHeader.parseFrom(encodedRollSiteHeader)); - } catch (InvalidProtocolBufferException e) { - e.printStackTrace(); - } - - String sessionId = ""; - if (rsHeader != null) { - sessionId = String.join("_", rsHeader.getRollSiteSessionId(), rsHeader.getDstRole(), rsHeader.getDstPartyId()); - } - Proxy.Topic desTopic = packet.getHeader().getDst(); - String desPartyId = desTopic.getPartyId(); - String desRole = desTopic.getRole(); - inboundBuilder.setPayload(packet.toByteString()); - inboundBuilder.putMetadata(Osx.Header.Version.name(), Long.toString(MetaInfo.CURRENT_VERSION)); - inboundBuilder.putMetadata(Osx.Header.TechProviderCode.name(), MetaInfo.PROPERTY_FATE_TECH_PROVIDER); - inboundBuilder.putMetadata(Osx.Header.Token.name(), ""); - inboundBuilder.putMetadata(Osx.Header.SourceNodeID.name(), srcPartyId); - inboundBuilder.putMetadata(Osx.Header.TargetNodeID.name(), desPartyId); - inboundBuilder.putMetadata(Osx.Header.SourceInstID.name(), ""); - inboundBuilder.putMetadata(Osx.Header.TargetInstID.name(), ""); - inboundBuilder.putMetadata(Osx.Header.SessionID.name(), sessionId); - inboundBuilder.putMetadata(Osx.Metadata.TargetMethod.name(), targetMethod); - inboundBuilder.putMetadata(Osx.Metadata.TargetComponentName.name(), desRole); - inboundBuilder.putMetadata(Osx.Metadata.SourceComponentName.name(), ""); - return inboundBuilder.build(); - }; - - static public void buildHttpFromPb(Osx.Inbound inbound){ - - - - - } - - - static public Osx.Inbound.Builder buildPbFromHttpRequest(HttpServletRequest request){ - - Osx.Inbound.Builder inboundBuilder = Osx.Inbound.newBuilder(); - String Version = request.getHeader(PtpHttpHeader.Version); - String TechProviderCode = request.getHeader(PtpHttpHeader.TechProviderCode); - String TraceID = request.getHeader(PtpHttpHeader.TraceID); - String Token = request.getHeader(PtpHttpHeader.Token); - String SourceNodeID = request.getHeader(PtpHttpHeader.SourceNodeID); - String TargetNodeID = request.getHeader(PtpHttpHeader.TargetNodeID); - String SourceInstID = request.getHeader(PtpHttpHeader.SourceInstID); - String TargetInstID = request.getHeader(PtpHttpHeader.TargetInstID); - String SessionID = request.getHeader(PtpHttpHeader.SessionID); - String MessageTopic = request.getHeader(PtpHttpHeader.MessageTopic); - String MessageCode = request.getHeader(PtpHttpHeader.MessageCode); - String SourceComponentName = request.getHeader(PtpHttpHeader.SourceComponentName); - String TargetComponentName = request.getHeader(PtpHttpHeader.TargetComponentName); - String TargetMethod = request.getHeader(PtpHttpHeader.TargetMethod); - String MessageOffSet = request.getHeader(PtpHttpHeader.MessageOffSet); - String InstanceId = request.getHeader(PtpHttpHeader.InstanceId); - String Timestamp = request.getHeader(PtpHttpHeader.Timestamp); - - inboundBuilder.putMetadata(Osx.Header.Version.name(), Version != null ? Version : ""); - inboundBuilder.putMetadata(Osx.Header.TechProviderCode.name(), TechProviderCode != null ? TechProviderCode : ""); - inboundBuilder.putMetadata(Osx.Header.Token.name(), Token != null ? Token : ""); - inboundBuilder.putMetadata(Osx.Header.SourceNodeID.name(), SourceNodeID != null ? SourceNodeID : ""); - inboundBuilder.putMetadata(Osx.Header.TargetNodeID.name(), TargetNodeID != null ? TargetNodeID : ""); - inboundBuilder.putMetadata(Osx.Header.SourceInstID.name(), SourceInstID != null ? SourceInstID : ""); - inboundBuilder.putMetadata(Osx.Header.TargetInstID.name(), TargetInstID != null ? TargetInstID : ""); - inboundBuilder.putMetadata(Osx.Header.SessionID.name(), SessionID != null ? SessionID : ""); - inboundBuilder.putMetadata(Osx.Metadata.TargetMethod.name(), TargetMethod != null ? TargetMethod : ""); - inboundBuilder.putMetadata(Osx.Metadata.TargetComponentName.name(), TargetComponentName != null ? TargetComponentName : ""); - inboundBuilder.putMetadata(Osx.Metadata.SourceComponentName.name(), SourceComponentName != null ? SourceComponentName : ""); - inboundBuilder.putMetadata(Osx.Metadata.MessageTopic.name(), MessageTopic != null ? MessageTopic : ""); - inboundBuilder.putMetadata(Osx.Metadata.MessageOffSet.name(), MessageOffSet != null ? MessageOffSet : ""); - inboundBuilder.putMetadata(Osx.Metadata.InstanceId.name(), InstanceId != null ? InstanceId : ""); - inboundBuilder.putMetadata(Osx.Metadata.Timestamp.name(), Timestamp != null ? Timestamp : ""); - return inboundBuilder; - - - } - - - - static public Osx.Outbound redirect(Context context, Osx.Inbound - produceRequest, RouterInfo routerInfo, boolean forceSend) { - Osx.Outbound result = null; - // context.setActionType("redirect"); - // 目的端协议为grpc - if (routerInfo == null) { - throw new NoRouterInfoException("can not find router info"); - } - if (routerInfo.getProtocol() == null || routerInfo.getProtocol().equals(Protocol.GRPC)) { - ManagedChannel managedChannel = GrpcConnectionFactory.createManagedChannel(routerInfo,true); - PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub stub = PrivateTransferProtocolGrpc.newBlockingStub(managedChannel); - try { - result = stub.invoke(produceRequest); - } catch (StatusRuntimeException e) { - throw new RemoteRpcException(StatusCode.NET_ERROR, "send to " + routerInfo.toKey() + " error"); - } - // ServiceContainer.tokenApplyService.applyToken(context,routerInfo.getResource(),produceRequest.getSerializedSize()); - }else{ - if(routerInfo.getProtocol().equals(Protocol.HTTP)){ - String url = routerInfo.getUrl(); - - Map metaDataMap = produceRequest.getMetadataMap(); - - String version = metaDataMap.get(Osx.Header.Version.name()); - String techProviderCode = metaDataMap.get(Osx.Header.TechProviderCode.name()); - String traceId = metaDataMap.get(Osx.Header.TraceID.name()); - String token = metaDataMap.get(Osx.Header.Token.name()); - String sourceNodeId = metaDataMap.get(Osx.Header.SourceNodeID.name()); - String targetNodeId = metaDataMap.get(Osx.Header.TargetNodeID.name()); - String sourceInstId = metaDataMap.get(Osx.Header.SourceInstID.name()); - String targetInstId = metaDataMap.get(Osx.Header.TargetInstID.name()); - String sessionId = metaDataMap.get(Osx.Header.SessionID.name()); - String targetMethod = metaDataMap.get(Osx.Metadata.TargetMethod.name()); - String targetComponentName = metaDataMap.get(Osx.Metadata.TargetComponentName.name()); - String sourceComponentName = metaDataMap.get(Osx.Metadata.SourceComponentName.name()); - String sourcePartyId = StringUtils.isEmpty(sourceInstId) ? sourceNodeId : sourceInstId + "." + sourceNodeId; - String targetPartyId = StringUtils.isEmpty(targetInstId) ? targetNodeId : targetInstId + "." + targetNodeId; - String topic = metaDataMap.get(Osx.Metadata.MessageTopic.name()); - String offsetString = metaDataMap.get(Osx.Metadata.MessageOffSet.name()); - String InstanceId = metaDataMap.get(Osx.Metadata.InstanceId.name()); - String timestamp = metaDataMap.get(Osx.Metadata.Timestamp.name()); - String messageCode = metaDataMap.get(Osx.Metadata.MessageCode.name()); - Map header = Maps.newHashMap(); - header.put(PtpHttpHeader.Version,version!=null?version:""); - header.put(PtpHttpHeader.TechProviderCode,techProviderCode!=null?techProviderCode:""); - header.put(PtpHttpHeader.TraceID,traceId!=null?traceId:""); - header.put(PtpHttpHeader.Token,token!=null?token:""); - header.put(PtpHttpHeader.SourceNodeID,sourceNodeId!=null?sourceNodeId:""); - header.put(PtpHttpHeader.TargetNodeID,targetNodeId!=null?targetNodeId:""); - header.put(PtpHttpHeader.SourceInstID,sourceInstId!=null?sourceInstId:""); - header.put(PtpHttpHeader.TargetInstID,targetInstId!=null?targetInstId:""); - header.put(PtpHttpHeader.SessionID,sessionId!=null?sessionId:""); - header.put(PtpHttpHeader.MessageTopic,topic!=null?topic:""); - header.put(PtpHttpHeader.MessageCode,messageCode); - header.put(PtpHttpHeader.SourceComponentName,sourceComponentName!=null?sourceComponentName:""); - header.put(PtpHttpHeader.TargetComponentName,targetComponentName!=null?targetComponentName:""); - header.put(PtpHttpHeader.TargetMethod,targetMethod!=null?targetMethod:""); - header.put(PtpHttpHeader.MessageOffSet,offsetString!=null?offsetString:""); - header.put(PtpHttpHeader.InstanceId,InstanceId!=null?InstanceId:""); - header.put(PtpHttpHeader.Timestamp,timestamp!=null?timestamp:""); - result = HttpClientPool.sendPtpPost(url,produceRequest.getPayload().toByteArray(),header); - } - } - - return result; - - } - - - public static Osx.Outbound buildResponse(String code, String msgReturn, TransferQueue.TransferQueueConsumeResult messageWraper) { - // FireworkTransfer.ConsumeResponse.Builder consumeResponseBuilder = FireworkTransfer.ConsumeResponse.newBuilder(); - Osx.Outbound.Builder builder = Osx.Outbound.newBuilder(); - - builder.setCode(code); - builder.setMessage(msgReturn); - if (messageWraper != null) { - Osx.Message message = null; - try { - message = Osx.Message.parseFrom(messageWraper.getMessage().getBody()); - } catch (InvalidProtocolBufferException e) { - e.printStackTrace(); - } - builder.setPayload(message.toByteString()); - builder.putMetadata(Osx.Metadata.MessageOffSet.name(), Long.toString(messageWraper.getRequestIndex())); -// FireworkTransfer.Message msg = produceRequest.getMessage(); -// consumeResponseBuilder.setTransferId(produceRequest.getTransferId()); -// consumeResponseBuilder.setMessage(msg); -// consumeResponseBuilder.setStartOffset(messageWraper.getRequestIndex()); -// consumeResponseBuilder.setTotalOffset(messageWraper.getLogicIndexTotal()); - } - - return builder.build(); - } - - - public static void main(String[] args){ - System.err.println(isOldVersionFate(null)); - } -} diff --git a/java/osx/broker/src/main/java/com/osx/tech/provider/FateTechProvider.java b/java/osx/broker/src/main/java/com/osx/tech/provider/FateTechProvider.java deleted file mode 100644 index d63286586f..0000000000 --- a/java/osx/broker/src/main/java/com/osx/tech/provider/FateTechProvider.java +++ /dev/null @@ -1,238 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.osx.tech.provider; - - -import com.google.common.base.Preconditions; -import com.google.common.collect.Sets; -import com.google.protobuf.ByteString; -import com.osx.broker.ServiceContainer; -import com.osx.broker.util.ContextUtil; -import com.osx.broker.interceptor.RequestHandleInterceptor; -import com.osx.broker.interceptor.RouterInterceptor; -import com.osx.broker.ptp.*; -import com.osx.broker.util.TransferUtil; -import com.osx.core.config.MetaInfo; -import com.osx.core.constant.Dict; -import com.osx.core.constant.PtpHttpHeader; -import com.osx.core.context.Context; -import com.osx.core.exceptions.ErrorMessageUtil; -import com.osx.core.exceptions.ExceptionInfo; -import com.osx.core.exceptions.ParameterException; -import com.osx.core.frame.Lifecycle; -import com.osx.core.provider.TechProvider; -import com.osx.core.ptp.TargetMethod; -import com.osx.core.service.InboundPackage; -import com.osx.core.service.OutboundPackage; -import com.osx.core.service.ServiceAdaptor; -import io.grpc.stub.StreamObserver; -import org.apache.commons.io.IOUtils; -import org.ppc.ptp.Osx; - - -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.io.*; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; - -/** - * FATE 相关实现 - */ - -public class FateTechProvider implements TechProvider, Lifecycle { - - ConcurrentMap serviceAdaptorConcurrentMap = new ConcurrentHashMap<>(); - - RequestHandleInterceptor requestHandleInterceptor; - RouterInterceptor routerInterceptor; - - private Set httpAllowedMethod= Sets.newHashSet(TargetMethod.PRODUCE_MSG.name(),TargetMethod.UNARY_CALL.name()); - - private void checkHttpAllowedMethod(String targetMethod){ - - if(!httpAllowedMethod.contains(targetMethod)){ - throw new ParameterException("target method :"+targetMethod+"is not allowed"); - } - - } - - @Override - public void processHttpInvoke(HttpServletRequest request, HttpServletResponse response) { - Context context = ContextUtil.buildContext(); - Osx.Inbound.Builder inboundBuilder ; - ServiceAdaptor serviceAdaptor=null; - try { - String Version = request.getHeader(PtpHttpHeader.Version); - String TechProviderCode = request.getHeader(PtpHttpHeader.TechProviderCode); - String TraceID = request.getHeader(PtpHttpHeader.TraceID); - String Token = request.getHeader(PtpHttpHeader.Token); - String SourceNodeID = request.getHeader(PtpHttpHeader.SourceNodeID); - String TargetNodeID = request.getHeader(PtpHttpHeader.TargetNodeID); - String SourceInstID = request.getHeader(PtpHttpHeader.SourceInstID); - String TargetInstID = request.getHeader(PtpHttpHeader.TargetInstID); - String SessionID = request.getHeader(PtpHttpHeader.SessionID); - String MessageTopic = request.getHeader(PtpHttpHeader.MessageTopic); - String MessageCode = request.getHeader(PtpHttpHeader.MessageCode); - String SourceComponentName = request.getHeader(PtpHttpHeader.SourceComponentName); - String TargetComponentName = request.getHeader(PtpHttpHeader.TargetComponentName); - String TargetMethod = request.getHeader(PtpHttpHeader.TargetMethod); - String MessageOffSet = request.getHeader(PtpHttpHeader.MessageOffSet); - String InstanceId = request.getHeader(PtpHttpHeader.InstanceId); - String Timestamp = request.getHeader(PtpHttpHeader.Timestamp); - context.setSrcPartyId(SourceNodeID); - context.setDesPartyId(TargetNodeID); - context.setSessionId(SessionID); - context.setTopic(MessageTopic); - context.setActionType(TargetMethod); - inboundBuilder = TransferUtil.buildPbFromHttpRequest(request); - String targetMethod = inboundBuilder.getMetadataMap().get(Osx.Metadata.TargetMethod.name()); - checkHttpAllowedMethod(TargetMethod); - serviceAdaptor = this.getServiceAdaptor(TargetMethod); - byte[] buffer = new byte[MetaInfo.PROPERTY_HTTP_REQUEST_BODY_MAX_SIZE]; - int length = IOUtils.read(request.getInputStream(), buffer); - byte[] data = new byte[length]; - System.arraycopy(buffer, 0, data, 0, length); - inboundBuilder.setPayload(ByteString.copyFrom(data)); - }catch(Exception e){ - ExceptionInfo exceptionInfo = ErrorMessageUtil.handleExceptionExceptionInfo(context,e); - this.writeHttpRespose(response, exceptionInfo.getCode(),exceptionInfo.getMessage(),null); - context.setReturnCode(exceptionInfo.getCode()); - context.setReturnMsg(exceptionInfo.getMessage()); - context.printFlowLog(); - return ; - } - InboundPackage inboundPackage = new InboundPackage(); - inboundPackage.setBody(inboundBuilder.build()); - OutboundPackage outboundPackage = serviceAdaptor.service(context, inboundPackage); - Osx.Outbound outbound = outboundPackage.getData(); - response.setContentType(Dict.CONTENT_TYPE_JSON_UTF8); - this.writeHttpRespose(response,outbound.getCode(),outbound.getMessage(),outbound.getPayload().toByteArray() ); - } - - private void writeHttpRespose(HttpServletResponse response,String code, - String msg, - byte[] content){ - try { - response.setHeader(PtpHttpHeader.ReturnCode,code); - response.setHeader(PtpHttpHeader.MessageCode,msg); - OutputStream outputStream = response.getOutputStream(); - if(content!=null) { - outputStream.write(content); - } - outputStream.flush(); - } catch (IOException e) { - e.printStackTrace(); - } - } - - - @Override - public void processGrpcInvoke(Osx.Inbound request, StreamObserver responseObserver) { - Context context = ContextUtil.buildContext(); - context.putData(Dict.RESPONSE_STREAM_OBSERVER,responseObserver); - Osx.Outbound result = null; - try { - Map metaDataMap = request.getMetadataMap(); - String targetMethod = metaDataMap.get(Osx.Metadata.TargetMethod.name()); - ServiceAdaptor serviceAdaptor = this.getServiceAdaptor(targetMethod); - if (serviceAdaptor == null) { - throw new ParameterException("invalid target method " + targetMethod); - } - InboundPackage inboundPackage = new InboundPackage(); - inboundPackage.setBody(request); - OutboundPackage outboundPackage = serviceAdaptor.service(context, inboundPackage); - if (outboundPackage.getData() != null) { - result = outboundPackage.getData(); - } - }catch (Exception e){ - ExceptionInfo exceptionInfo = ErrorMessageUtil.handleExceptionExceptionInfo(context,e); - //this.writeHttpRespose(response, exceptionInfo.getCode(),exceptionInfo.getMessage(),null); - context.setReturnCode(exceptionInfo.getCode()); - context.setReturnMsg(exceptionInfo.getMessage()); - context.printFlowLog(); - result = Osx.Outbound.newBuilder().setCode(exceptionInfo.getCode()).setMessage(exceptionInfo.getMessage()).build(); - } - if(result!=null) { - responseObserver.onNext(result); - responseObserver.onCompleted(); - } - - } - - @Override - public String getProviderId() { - return MetaInfo.PROPERTY_FATE_TECH_PROVIDER; - } - - - @Override - public StreamObserver processGrpcTransport(Osx.Inbound fristPackage, StreamObserver responseObserver) { - Map metaDataMap = fristPackage.getMetadataMap(); - String targetMethod = metaDataMap.get(Osx.Metadata.TargetMethod.name()); - ServiceAdaptor serviceAdaptor = this.getServiceAdaptor(targetMethod); - if(serviceAdaptor==null){ - throw new ParameterException("invalid target method "+targetMethod); - } - Context context = ContextUtil.buildContext(); - InboundPackage inboundPackage = new InboundPackage(); - inboundPackage.setBody(responseObserver); - OutboundPackage> outboundPackage = serviceAdaptor.service( context, inboundPackage); - if(outboundPackage!=null&&outboundPackage.getData()!=null){ - return (StreamObserver)outboundPackage.getData(); - }else{ - return null; - } - - - } - - @Override - public void init() { - Preconditions.checkArgument(ServiceContainer.fateRouterService != null); - requestHandleInterceptor = new RequestHandleInterceptor(); - routerInterceptor =ServiceContainer.routerInterceptor; - registerServiceAdaptor(); - } - - @Override - public void start() { - - } - - @Override - public void destroy() { - - } - public ServiceAdaptor getServiceAdaptor(String name) { - return this.serviceAdaptorConcurrentMap.get(name); - } - private void registerServiceAdaptor() { - this.serviceAdaptorConcurrentMap.put(TargetMethod.UNARY_CALL.name(), new PtpUnaryCallService().addPreProcessor(requestHandleInterceptor) - .addPreProcessor(routerInterceptor)); - this.serviceAdaptorConcurrentMap.put(TargetMethod.PRODUCE_MSG.name(), new PtpProduceService().addPreProcessor(requestHandleInterceptor) - .addPreProcessor(routerInterceptor)); - this.serviceAdaptorConcurrentMap.put(TargetMethod.ACK_MSG.name(), new PtpAckService().addPreProcessor(requestHandleInterceptor)); - this.serviceAdaptorConcurrentMap.put(TargetMethod.CONSUME_MSG.name(), new PtpConsumeService().addPreProcessor(requestHandleInterceptor)); - this.serviceAdaptorConcurrentMap.put(TargetMethod.QUERY_TOPIC.name(), new PtpQueryTransferQueueService().addPreProcessor(requestHandleInterceptor)); - this.serviceAdaptorConcurrentMap.put(TargetMethod.CANCEL_TOPIC.name(), new PtpCancelTransferService().addPreProcessor(requestHandleInterceptor)); - this.serviceAdaptorConcurrentMap.put(TargetMethod.PUSH.name(), new PtpPushService()); - this.serviceAdaptorConcurrentMap.put(TargetMethod.APPLY_TOKEN.name(), new PtpClusterTokenApplyService()); - this.serviceAdaptorConcurrentMap.put(TargetMethod.APPLY_TOPIC.name(),new PtpClusterTopicApplyService()); - } -} diff --git a/java/osx/broker/src/main/java/com/osx/tech/provider/TechProviderRegister.java b/java/osx/broker/src/main/java/com/osx/tech/provider/TechProviderRegister.java deleted file mode 100644 index bd124e277d..0000000000 --- a/java/osx/broker/src/main/java/com/osx/tech/provider/TechProviderRegister.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.osx.tech.provider; - -import com.google.common.base.Preconditions; -import com.osx.core.frame.Lifecycle; -import com.osx.core.provider.TechProvider; - -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; - -/** - * 厂商选择 - */ -public class TechProviderRegister implements Lifecycle { - - ConcurrentMap registerMap = new ConcurrentHashMap<>(); - - public TechProvider select(String techProviderCode ) { - Preconditions.checkArgument(techProviderCode != null); - return this.registerMap.get(techProviderCode); - } - public void init() { - FateTechProvider fateTechProvider = new FateTechProvider(); - fateTechProvider.init(); - this.registerMap.put(fateTechProvider.getProviderId(), fateTechProvider); - } - @Override - public void start() { - } - @Override - public void destroy() { - } - -} - - - diff --git a/java/osx/broker/src/main/resources/broker.properties b/java/osx/broker/src/main/resources/broker.properties deleted file mode 100644 index c288094306..0000000000 --- a/java/osx/broker/src/main/resources/broker.properties +++ /dev/null @@ -1,23 +0,0 @@ -#grpc?? -grpc.port= 9370 -#????http server -open.http.server=false -# http?? -http.port=8080 -# ????grpc+TLS?? -open.grpc.tls.server=false -#grpc+TLS???????? -grpc.tls.port=9883 -#??partyId,???????????? -self.party=10000 -#???? standalone/cluster?standalone?????? cluster?????? -deploy.model=standalone -#?????????zookeeper,???zookeeper?? -zk.url=localhost:2181 -#????eggroll???????????eggroll cluster-manager???ip??? -eggroll.cluster.manager.ip = localhost -eggroll.cluster.manager.port = 4670 - - - - diff --git a/java/osx/broker/src/test/java/com/osx/broker/test/grpc/EggrollTest.java b/java/osx/broker/src/test/java/com/osx/broker/test/grpc/EggrollTest.java deleted file mode 100644 index 7f5254d3a1..0000000000 --- a/java/osx/broker/src/test/java/com/osx/broker/test/grpc/EggrollTest.java +++ /dev/null @@ -1,4 +0,0 @@ -package com.osx.broker.test.grpc; - -public class EggrollTest { -} diff --git a/java/osx/broker/src/test/java/com/osx/broker/test/grpc/NewFateTest.java b/java/osx/broker/src/test/java/com/osx/broker/test/grpc/NewFateTest.java deleted file mode 100644 index 9209eed62e..0000000000 --- a/java/osx/broker/src/test/java/com/osx/broker/test/grpc/NewFateTest.java +++ /dev/null @@ -1,83 +0,0 @@ -package com.osx.broker.test.grpc; - -import com.osx.core.config.MetaInfo; -import com.osx.core.constant.Dict; -import com.osx.core.ptp.TargetMethod; -import io.grpc.ManagedChannel; -import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts; -import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder; -import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder; -import org.junit.Before; -import org.junit.Test; -import org.ppc.ptp.Osx; -import org.ppc.ptp.PrivateTransferProtocolGrpc; - -import java.io.File; -import java.util.concurrent.TimeUnit; - -public class NewFateTest { - - String ip = "localhost"; - //int port = 8250;//nginx - int port = 9370;//nginx - String desPartyId = "10000"; - String desRole = ""; - String srcPartyId = "9999"; - String srcRole = ""; - String transferId = "testTransferId"; - String sessionId = "testSessionId"; - PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub blockingStub; - - @Before - public void init() { - ManagedChannel managedChannel = createManagedChannel(ip, port); - // stub = PrivateTransferProtocolGrpc.newBlockingStub(); - // ManagedChannel managedChannel2 = createManagedChannel(ip, port); - blockingStub = PrivateTransferProtocolGrpc.newBlockingStub(managedChannel); - } - - public static ManagedChannel createManagedChannel(String ip, int port) { - try { - NettyChannelBuilder channelBuilder = NettyChannelBuilder - .forAddress(ip, port) - .keepAliveTime(60, TimeUnit.SECONDS) - .keepAliveTimeout(60, TimeUnit.SECONDS) - .keepAliveWithoutCalls(true) - .idleTimeout(60, TimeUnit.SECONDS) - .perRpcBufferLimit(128 << 20) - .flowControlWindow(32 << 20) - .maxInboundMessageSize(32 << 20) - .enableRetry() - .retryBufferSize(16 << 20) - .maxRetryAttempts(20); - channelBuilder.usePlaintext(); - - return channelBuilder.build(); - } catch (Exception e) { - e.printStackTrace(); - // logger.error("create channel error : " ,e); - //e.printStackTrace(); - } - return null; - } - - @Test - public void testUnaryCall(){ - Osx.Inbound.Builder inboundBuilder = Osx.Inbound.newBuilder(); - inboundBuilder.putMetadata(Osx.Header.Version.name(), "123"); - inboundBuilder.putMetadata(Osx.Header.TechProviderCode.name(), MetaInfo.PROPERTY_FATE_TECH_PROVIDER); - inboundBuilder.putMetadata(Osx.Header.Token.name(), "testToken"); - inboundBuilder.putMetadata(Osx.Header.SourceNodeID.name(), "9999"); - inboundBuilder.putMetadata(Osx.Header.TargetNodeID.name(), "10000"); - inboundBuilder.putMetadata(Osx.Header.SourceInstID.name(), ""); - inboundBuilder.putMetadata(Osx.Header.TargetInstID.name(), ""); - inboundBuilder.putMetadata(Osx.Header.SessionID.name(), "testSessionID"); - inboundBuilder.putMetadata(Osx.Metadata.TargetMethod.name(), TargetMethod.UNARY_CALL.name()); - inboundBuilder.putMetadata(Osx.Metadata.TargetComponentName.name(), "fateflow"); - inboundBuilder.putMetadata(Osx.Metadata.SourceComponentName.name(), ""); - // inboundBuilder.putMetadata(Osx.Metadata.MessageTopic.name(), transferId); - Osx.Outbound outbound = blockingStub.invoke(inboundBuilder.build()); - System.err.println("response : "+outbound); - } - -} diff --git a/java/osx/core/src/main/java/com/osx/core/config/MetaInfo.java b/java/osx/core/src/main/java/com/osx/core/config/MetaInfo.java deleted file mode 100644 index 1ff1d88779..0000000000 --- a/java/osx/core/src/main/java/com/osx/core/config/MetaInfo.java +++ /dev/null @@ -1,230 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.osx.core.config; - -import com.google.common.collect.Maps; -import com.google.common.collect.Sets; -import com.osx.core.constant.DeployMode; -import com.osx.core.constant.Dict; -import com.osx.core.constant.StreamLimitMode; - -import java.lang.reflect.Field; -import java.util.Map; -import java.util.Set; - -public class MetaInfo { - public static final long CURRENT_VERSION = 100; - public static String PROPERTY_FATE_TECH_PROVIDER = "FATE"; - public static String PROPERTY_DEFAULT_CLIENT_VERSION="2.X.X"; - public static volatile MasterInfo masterInfo; - public static int PROPERTY_GRPC_SERVER_MAX_CONCURRENT_CALL_PER_CONNECTION = 1000; - public static int PROPERTY_GRPC_SERVER_MAX_INBOUND_METADATA_SIZE = 128 << 20; - public static int PROPERTY_GRPC_SERVER_MAX_INBOUND_MESSAGE_SIZE = (2 << 30) - 1; - public static int PROPERTY_GRPC_SERVER_FLOW_CONTROL_WINDOW = 128 << 20; - public static int PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC = 7200; - public static int PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC = 3600; - public static int PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC = 10; - public static boolean PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED = true; - public static int PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC = 86400; - public static int PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC = 86400; - public static int PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC = 86400; - public static int PROPERTY_GRPC_ONCOMPLETED_WAIT_TIMEOUT = 600; - - - - - public static int PROPERTY_GRPC_CLIENT_MAX_CONCURRENT_CALL_PER_CONNECTION = 1000; - public static int PROPERTY_GRPC_CLIENT_MAX_INBOUND_METADATA_SIZE = 128 << 20; - public static int PROPERTY_GRPC_CLIENT_MAX_INBOUND_MESSAGE_SIZE = (2 << 30) - 1; - public static int PROPERTY_GRPC_CLIENT_FLOW_CONTROL_WINDOW = 128 << 20; - public static int PROPERTY_GRPC_CLIENT_KEEPALIVE_TIME_SEC = 7200; - public static int PROPERTY_GRPC_CLIENT_KEEPALIVE_TIMEOUT_SEC = 3600; - public static int PROPERTY_GRPC_CLIENT_PERMIT_KEEPALIVE_TIME_SEC = 10; - public static boolean PROPERTY_GRPC_CLIENT_KEEPALIVE_WITHOUT_CALLS_ENABLED = true; - public static int PROPERTY_GRPC_CLIENT_MAX_CONNECTION_IDLE_SEC = 86400; - public static int PROPERTY_GRPC_CLIENT_MAX_CONNECTION_AGE_SEC = 86400; - public static int PROPERTY_GRPC_CLIENT_MAX_CONNECTION_AGE_GRACE_SEC = 86400; - public static int PROPERTY_GRPC_CLIENT_PER_RPC_BUFFER_LIMIT=86400; - - public static int PROPERTY_GRPC_CLIENT_RETRY_BUFFER_SIZE = 86400; - - - - public static boolean PROPERTY_USE_DIRECT_CACHE = false; - public static int PROPERTY_TRANSFER_FILE_CACHE_SIZE = 1 << 27; - public static int PROPERTY_TRANSFER_RETRY_COUNT = 1; - public static int MAP_FILE_SIZE = 1 << 25; - public static int PROPERTY_INDEX_MAP_FILE_SIZE = 1 << 21; - public static Boolean TRANSFER_FATECLOUD_AHTHENTICATION_ENABLED; - public static Boolean TRANSFER_FATECLOUD_AUTHENTICATION_USE_CONFIG; - public static String TRANSFER_FATECLOUD_AUTHENTICATION_URI; - public static String TRANSFER_FATECLOUD_AUTHENTICATION_APPKEY; - public static String TRANSFER_FATECLOUD_AUTHENTICATION_APPSERCRET; - public static String TRANSFER_FATECLOUD_AUTHENTICATION_ROLE; - public static String TRANSFER_FATECLOUD_SECRET_INFO_URL; - public static String TRANSFER_FATECLOUD_AUTHENTICATION_URL; - public static String PROPERTY_SERVER_CERTCHAIN_FILE; - public static String PROPERTY_SERVER_PRIVATEKEY_FILE; - public static String PROPERTY_SERVER_CA_FILE; - public static int ROLLSITE_PARTY_ID; -// public static Integer PROPERTY_PORT; - public static Integer PROPERTY_GRPC_PORT; - public static Integer PROPERTY_HTTP_PORT; - public static Boolean PROPERTY_OPEN_HTTP_SERVER = false; - public static Boolean PROPERTY_OPEN_GRPC_TLS_SERVER = false; - public static int PROPERTY_HTTP_REQUEST_BODY_MAX_SIZE=4096; - public static String PROPERTY_HTTP_CONTEXT_PATH="/osx"; - public static String PROPERTY_HTTP_SERVLET_PATH="/*"; - public static Integer PROPERTY_GRPC_TLS_PORT; - public static String PROPERTY_ZK_URL; - public static Boolean PROPERTY_USE_DISRUPTOR = true; - public static int PROPERTY_STREAM_LIMIT_MAX_TRY_TIME = 3; - - public static String PROPERTY_USER_HOME = ""; - - public static Integer PROPERTY_SAMPLE_COUNT = 10; - public static Integer PROPERTY_INTERVAL_MS = 1000; - //public static Boolean PROPERTY_USE_QUEUE_MODEL = false; - public static String PROPERTY_STREAM_LIMIT_MODE = StreamLimitMode.NOLIMIT.name(); - - public static Integer PROPERTY_CONSUMER_TIMEOUT = 30000; - public static Integer PROPERTY_QUEUE_MAX_FREE_TIME; - public static Integer PROPERTY_MAPPED_FILE_EXPIRE_TIME = 3600 * 1000 * 36; - public static Integer PROPERTY_MAX_CONSUME_EMPTY_TRY_COUNT = 30; - - public static Integer PROPERTY_MAX_TRANSFER_CACHE_SIZE = 1 << 30; - public static String PROPERTY_TRANSFER_FILE_PATH_PRE; - public static String PROPERTY_DEPLOY_MODE = "standalone"; - public static String PROPERTY_TRANSFER_APPLY_CACHE = "/tmp/cachetest"; - - public static Set PROPERTY_SELF_PARTY = Sets.newHashSet();// - - public static Integer PROPERTY_APPLY_EXPIRE_TIME = 3000; - public static Integer PROPERTY_COORDINATOR; - public static Integer PROPERTY_SERVER_PORT; - public static String PROPERTY_INFERENCE_SERVICE_NAME; - public static String PROPERTY_ROUTE_TYPE; - public static String PROPERTY_ROUTE_TABLE; - - public static String PROPERTY_FLOW_RULE_TABLE; - public static String PROPERTY_AUTH_FILE; - public static Boolean PROPERTY_ACL_ENABLE = false; - public static String PROPERTY_ACL_USERNAME; - public static String PROPERTY_ACL_PASSWORD; - public static String PROPERTY_ROOT_PATH; - public static Boolean PROPERTY_PRINT_INPUT_DATA; - public static Boolean PROPERTY_PRINT_OUTPUT_DATA; - - public static Boolean PROPERTY_AUTH_OPEN; - public static String PROPERTY_NEGOTIATIONTYPE; - public static String PROPERTY_PROXY_GRPC_INTER_CA_FILE; - public static String PROPERTY_PROXY_GRPC_INTER_CLIENT_CERTCHAIN_FILE; - public static String PROPERTY_PROXY_GRPC_INTER_CLIENT_PRIVATEKEY_FILE; - public static String PROPERTY_PROXY_GRPC_INTER_SERVER_CERTCHAIN_FILE; - public static String PROPERTY_PROXY_GRPC_INTER_SERVER_PRIVATEKEY_FILE; - public static Integer PROPERTY_ADMIN_HEALTH_CHECK_TIME; - public static Integer PRPPERTY_QUEUE_MAX_FREE_TIME; - public static String ROLLSITE_ROUTE_TABLE_KEY; - public static String ROLLSITE_ROUTE_TABLE_WHITE_LIST; - public static String ROLLSITE_ROUTE_TABLE_PARTY_ID; - public static String INSTANCE_ID; - - public static String PROPERTY_EGGROLL_CLUSTER_MANANGER_IP; - public static Integer PROPERTY_EGGROLL_CLUSTER_MANANGER_PORT; - - - public static Integer PROPERTY_CONSUME_SPIN_TIME = 500; - - public static String PROPERTY_CLUSTER_MANAGER_ADDRESS; - public static Integer PROPERTY_NETTY_CLIENT_TIMEOUT = 3000; - - public static Integer PROPERTY_HEARTBEAT_INTERVAL = 10000; - - public static String PROPERTY_CLUSTER_MANAGER_HOST; - public static Integer PROPERTY_CLUSTER_MANAGER_PORT; - - public static Boolean PROPERTY_USE_ZOOKEEPER = true; - - /** - * 从连接池中申请连接的超时时间 - */ - public static Integer HTTP_CLIENT_CONFIG_CONN_REQ_TIME_OUT; - /** - * 建立连接的超时时间 - */ - public static Integer HTTP_CLIENT_CONFIG_CONN_TIME_OUT; - /** - * 等待数据 - */ - public static Integer HTTP_CLIENT_CONFIG_SOCK_TIME_OUT; - public static Integer HTTP_CLIENT_INIT_POOL_MAX_TOTAL; - public static Integer HTTP_CLIENT_INIT_POOL_DEF_MAX_PER_ROUTE; - public static Integer HTTP_CLIENT_INIT_POOL_SOCK_TIME_OUT; - public static Integer HTTP_CLIENT_INIT_POOL_CONN_TIME_OUT; - public static Integer HTTP_CLIENT_INIT_POOL_CONN_REQ_TIME_OUT; - public static Integer HTTP_CLIENT_TRAN_CONN_REQ_TIME_OUT; - public static Integer HTTP_CLIENT_TRAN_CONN_TIME_OUT; - public static Integer HTTP_CLIENT_TRAN_SOCK_TIME_OUT; - - - - - - - public static String getClusterManagerHost() { - if (PROPERTY_CLUSTER_MANAGER_HOST != null) { - return PROPERTY_CLUSTER_MANAGER_HOST; - } else { - PROPERTY_CLUSTER_MANAGER_HOST = PROPERTY_CLUSTER_MANAGER_ADDRESS.split(":")[0]; - PROPERTY_CLUSTER_MANAGER_PORT = Integer.parseInt(PROPERTY_CLUSTER_MANAGER_ADDRESS.split(":")[1]); - return PROPERTY_CLUSTER_MANAGER_HOST; - } - } - - public static Integer getClusterManagerPort() { - if (PROPERTY_CLUSTER_MANAGER_PORT != null) { - return PROPERTY_CLUSTER_MANAGER_PORT; - } else { - PROPERTY_CLUSTER_MANAGER_HOST = PROPERTY_CLUSTER_MANAGER_ADDRESS.split(":")[0]; - PROPERTY_CLUSTER_MANAGER_PORT = Integer.parseInt(PROPERTY_CLUSTER_MANAGER_ADDRESS.split(":")[1]); - return PROPERTY_CLUSTER_MANAGER_PORT; - } - } - - - public static boolean isCluster() { - return PROPERTY_DEPLOY_MODE.equals(DeployMode.cluster.name()); - } - - public static Map toMap() { - Map result = Maps.newHashMap(); - Field[] fields = MetaInfo.class.getFields(); - - for (Field field : fields) { - try { - if (field.get(MetaInfo.class) != null) { - String key = Dict.class.getField(field.getName()) != null ? String.valueOf(Dict.class.getField(field.getName()).get(Dict.class)) : field.getName(); - result.put(key, field.get(MetaInfo.class)); - } - } catch (IllegalAccessException | NoSuchFieldException e) { - - } - } - return result; - } - -} diff --git a/java/osx/core/src/main/java/com/osx/core/router/RouterInfo.java b/java/osx/core/src/main/java/com/osx/core/router/RouterInfo.java deleted file mode 100644 index 0ca0a0b1e3..0000000000 --- a/java/osx/core/src/main/java/com/osx/core/router/RouterInfo.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright 2019 The FATE Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.osx.core.router; - -import com.osx.core.constant.Protocol; -import com.osx.core.utils.JsonUtil; -import lombok.Data; - -@Data -public class RouterInfo { - private Protocol protocol; - private String sourcePartyId; - private String desPartyId; - private String desMode; - private String url; - private String host; - private Integer port; - private boolean useSSL = false; - private String negotiationType; - private String certChainFile; - private String privateKeyFile; - private String trustCertCollectionFile; - private String caFile; - private String version; - - public String toKey() { - StringBuffer sb = new StringBuffer(); - sb.append(host).append("_").append(port); - if(negotiationType!=null) - sb.append("_").append(negotiationType); - return sb.toString(); - } - - @Override - public String toString() { - return JsonUtil.object2Json(this); - } - - public String getResource() { - StringBuilder sb = new StringBuilder(); - sb.append(sourcePartyId).append("-").append(desPartyId); - return sb.toString(); - } - - -} \ No newline at end of file diff --git a/java/osx/core/src/main/java/com/osx/core/utils/FlowLogUtil.java b/java/osx/core/src/main/java/com/osx/core/utils/FlowLogUtil.java deleted file mode 100644 index 241b85544a..0000000000 --- a/java/osx/core/src/main/java/com/osx/core/utils/FlowLogUtil.java +++ /dev/null @@ -1,54 +0,0 @@ -package com.osx.core.utils; - -import com.osx.core.context.Context; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class FlowLogUtil { - static Logger logger = LoggerFactory.getLogger("flow"); - static final String SPLIT= "|"; - public static void printFlowLog(Context context) { - StringBuffer stringBuffer = new StringBuffer(); - if(context.getActionType()!=null){ - stringBuffer.append(context.getActionType()).append(SPLIT); - } - if(context.getSessionId()!=null){ - stringBuffer.append("session:").append(context.getSessionId()).append(SPLIT); - } - if(context.getTopic()!=null){ - stringBuffer.append("topic:").append(context.getTopic()).append(SPLIT); - } - if(context.getRequestMsgIndex()!=null){ - stringBuffer.append("req-offset:").append(context.getRequestMsgIndex()).append(SPLIT); - } - if(context.getCurrentMsgIndex()!=null){ - stringBuffer.append("offset-in-queue:").append(context.getCurrentMsgIndex()).append(SPLIT); - } - if(context.getSrcPartyId()!=null){ - stringBuffer.append("src:").append(context.getSrcPartyId()).append(SPLIT); - } - if(context.getDesPartyId()!=null){ - stringBuffer.append("des:").append(context.getDesPartyId()).append(SPLIT); - } - if(context.getReturnCode()!=null){ - stringBuffer.append("code:").append(context.getReturnCode()).append(SPLIT); - } - stringBuffer.append("cost:").append(System.currentTimeMillis() - context.getTimeStamp()).append(SPLIT); - if(context.getRouterInfo()!=null){ - stringBuffer.append("router_info:").append(context.getRouterInfo().getHost() + ":" + context.getRouterInfo().getPort()).append(SPLIT); - } - if(context.getDataSize()!=null){ - stringBuffer.append("size:").append(context.getDataSize()).append(SPLIT); - } - if(context.getReturnMsg()!=null){ - stringBuffer.append("msg:").append(context.getReturnMsg()); - } - logger.info(stringBuffer.toString()); - - } - - - - - -} diff --git a/java/osx/deploy/auto-package.sh b/java/osx/deploy/auto-package.sh index 02dc1fc4ad..17bc921574 100755 --- a/java/osx/deploy/auto-package.sh +++ b/java/osx/deploy/auto-package.sh @@ -9,21 +9,21 @@ fi mkdir osx/bin mkdir osx/lib mkdir osx/conf +mkdir osx/extension mkdir osx/conf/broker -#mkdir osx/conf/cluster-manager +mkdir osx/conf/components cd .. mvn clean package -DskipTests - if [[ ! -d "lib" ]]; then mkdir lib fi - -cp -r broker/target/*.jar deploy/osx/lib -cp -r broker/target/lib/* deploy/osx/lib -cp broker/src/main/resources/* deploy/osx/conf/broker +cp -r osx-broker/target/*.jar deploy/osx/lib +cp -r osx-broker/target/lib/* deploy/osx/lib +cp osx-broker/src/main/resources/broker/* deploy/osx/conf/broker +cp -r osx-broker/src/main/resources/components/* deploy/osx/conf/components cp bin/service.sh deploy/osx/ cp bin/common.sh deploy/osx/bin cd deploy diff --git a/java/osx/osx-api/pom.xml b/java/osx/osx-api/pom.xml new file mode 100644 index 0000000000..4dfb72cd89 --- /dev/null +++ b/java/osx/osx-api/pom.xml @@ -0,0 +1,59 @@ + + + + osx + osx + ${osx.version} + + 4.0.0 + + osx-api + + + 8 + 8 + + + + io.grpc + grpc-netty-shaded + + + io.grpc + grpc-protobuf + + + io.grpc + grpc-stub + + + org.eclipse.jetty + jetty-server + + + + org.eclipse.jetty + jetty-servlet + + + + + + + + + + org.xolstice.maven.plugins + protobuf-maven-plugin + + + + + + + + + + \ No newline at end of file diff --git a/java/osx/core/src/main/java/com/osx/core/constant/Protocol.java b/java/osx/osx-api/src/main/java/org/fedai/osx/api/constants/Protocol.java similarity index 92% rename from java/osx/core/src/main/java/com/osx/core/constant/Protocol.java rename to java/osx/osx-api/src/main/java/org/fedai/osx/api/constants/Protocol.java index 189b447630..a203d728d1 100644 --- a/java/osx/core/src/main/java/com/osx/core/constant/Protocol.java +++ b/java/osx/osx-api/src/main/java/org/fedai/osx/api/constants/Protocol.java @@ -13,8 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.constant; +package org.fedai.osx.api.constants; public enum Protocol { - GRPC, HTTP + grpc, + http } diff --git a/java/osx/osx-api/src/main/java/org/fedai/osx/api/context/Context.java b/java/osx/osx-api/src/main/java/org/fedai/osx/api/context/Context.java new file mode 100644 index 0000000000..055ed4cfae --- /dev/null +++ b/java/osx/osx-api/src/main/java/org/fedai/osx/api/context/Context.java @@ -0,0 +1,45 @@ +package org.fedai.osx.api.context; + +import org.fedai.osx.api.constants.Protocol; +import org.fedai.osx.api.router.RouterInfo; + +public interface Context { + public String getTechProviderCode() ; + public void setTechProviderCode(String techProviderCode) ; + public String getTraceId() ; + public void setTraceId(String traceId); + public void setJobId(String jobId); + public String getToken() ; + public void setToken(String token) ; + public String getTopic(); + public void setTopic(String topic); + public Protocol getProtocol(); + public void setProtocol(Protocol protocol); + public String getSessionId() ; + public void setSessionId(String sessionId); + public Object getData(Object key); + public void putData(Object key, Object data); + public String getSrcPartyId() ; + public void setSrcPartyId(String guestAppId) ; + public String getDesPartyId() ; + public void setDesPartyId(String hostAppid) ; + public void setSrcComponent(String srcComponent); + public String getSrcComponent(); + public void setDesComponent(String desComponent); + public String getDesComponent(); + public String getReturnCode() ; + public void setReturnCode(String returnCode); + public String getReturnMsg() ; + public void setReturnMsg(String returnMsg); + public String getServiceName(); + public void setServiceName(String serviceName) ; + public String getSelfPartyId(); + public void setSelfPartyId(String selfPartyId); + public void setActionType(String actionType); + public String getActionType(); + public void setRouterInfo(RouterInfo routerInfo); + public RouterInfo getRouterInfo(); + public Context subContext(); + + +} diff --git a/java/osx/osx-api/src/main/java/org/fedai/osx/api/router/RouterInfo.java b/java/osx/osx-api/src/main/java/org/fedai/osx/api/router/RouterInfo.java new file mode 100644 index 0000000000..882e3ef5aa --- /dev/null +++ b/java/osx/osx-api/src/main/java/org/fedai/osx/api/router/RouterInfo.java @@ -0,0 +1,192 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.fedai.osx.api.router; +import org.fedai.osx.api.constants.Protocol; + + +public class RouterInfo { + private Protocol protocol; + private String sourcePartyId; + private String desPartyId; + private String desRole; + private String sourceRole; + private String url; + private String host; + private Integer port; + private boolean useSSL = false; + private String negotiationType; + private String certChainFile; + private String privateKeyFile; + private String trustCertCollectionFile; + private String caFile; + private String version; + + public Protocol getProtocol() { + return protocol; + } + + public void setProtocol(Protocol protocol) { + this.protocol = protocol; + } + + public String getSourcePartyId() { + return sourcePartyId; + } + + public void setSourcePartyId(String sourcePartyId) { + this.sourcePartyId = sourcePartyId; + } + + public String getDesPartyId() { + return desPartyId; + } + + public void setDesPartyId(String desPartyId) { + this.desPartyId = desPartyId; + } + + public String getDesRole() { + return desRole; + } + + public void setDesRole(String desRole) { + this.desRole = desRole; + } + + public String getSourceRole() { + return sourceRole; + } + + public void setSourceRole(String sourceRole) { + this.sourceRole = sourceRole; + } + + public String getUrl() { + return url; + } + + public void setUrl(String url) { + this.url = url; + } + + public String getHost() { + return host; + } + + public void setHost(String host) { + this.host = host; + } + + public Integer getPort() { + return port; + } + + public void setPort(Integer port) { + this.port = port; + } + + public boolean isUseSSL() { + return useSSL; + } + + public void setUseSSL(boolean useSSL) { + this.useSSL = useSSL; + } + + public String getNegotiationType() { + return negotiationType; + } + + public void setNegotiationType(String negotiationType) { + this.negotiationType = negotiationType; + } + + public String getCertChainFile() { + return certChainFile; + } + + public void setCertChainFile(String certChainFile) { + this.certChainFile = certChainFile; + } + + public String getPrivateKeyFile() { + return privateKeyFile; + } + + public void setPrivateKeyFile(String privateKeyFile) { + this.privateKeyFile = privateKeyFile; + } + + public String getTrustCertCollectionFile() { + return trustCertCollectionFile; + } + + public void setTrustCertCollectionFile(String trustCertCollectionFile) { + this.trustCertCollectionFile = trustCertCollectionFile; + } + + public String getCaFile() { + return caFile; + } + + public void setCaFile(String caFile) { + this.caFile = caFile; + } + + public String getVersion() { + return version; + } + + public void setVersion(String version) { + this.version = version; + } + + public boolean isCycle() { + return isCycle; + } + + public void setCycle(boolean cycle) { + isCycle = cycle; + } + + private boolean isCycle; + + public String toKey() { + StringBuffer sb = new StringBuffer(); + if(Protocol.grpc.equals(protocol)) { + sb.append(host).append("_").append(port); + if (negotiationType != null) + sb.append("_").append(negotiationType); + }else { + sb.append(url); + } + return sb.toString(); + } + + @Override + public String toString() { + return toKey(); + } + + public String getResource() { + StringBuilder sb = new StringBuilder(); + sb.append(sourcePartyId).append("-").append(desPartyId); + return sb.toString(); + } + + +} \ No newline at end of file diff --git a/java/osx/core/src/main/java/com/osx/core/provider/TechProvider.java b/java/osx/osx-api/src/main/java/org/fedai/osx/api/tech/provider/TechProvider.java similarity index 84% rename from java/osx/core/src/main/java/com/osx/core/provider/TechProvider.java rename to java/osx/osx-api/src/main/java/org/fedai/osx/api/tech/provider/TechProvider.java index 20ad62b8c4..fe16a19acb 100644 --- a/java/osx/core/src/main/java/com/osx/core/provider/TechProvider.java +++ b/java/osx/osx-api/src/main/java/org/fedai/osx/api/tech/provider/TechProvider.java @@ -13,24 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.provider; - +package org.fedai.osx.api.tech.provider; import io.grpc.stub.StreamObserver; import org.ppc.ptp.Osx; - - import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; public interface TechProvider { - + //用于处理http1.X请求 void processHttpInvoke(HttpServletRequest httpServletRequest,HttpServletResponse httpServletResponse); - + //用于处理grpc非流式请求 void processGrpcInvoke(Osx.Inbound request, io.grpc.stub.StreamObserver responseObserver); - - String getProviderId(); - - public StreamObserver processGrpcTransport(Osx.Inbound inbound, io.grpc.stub.StreamObserver responseObserver); + //用于处理grpc流式请求 + public StreamObserver processGrpcTransport(Osx.Inbound inbound, StreamObserver responseObserver); } diff --git a/java/osx/osx-api/src/main/java/org/fedai/osx/api/translator/Translator.java b/java/osx/osx-api/src/main/java/org/fedai/osx/api/translator/Translator.java new file mode 100644 index 0000000000..aa40e3fe16 --- /dev/null +++ b/java/osx/osx-api/src/main/java/org/fedai/osx/api/translator/Translator.java @@ -0,0 +1,16 @@ +package org.fedai.osx.api.translator; + + +import org.fedai.osx.api.context.Context; +import org.ppc.ptp.Osx; +//用于转换不同厂商通信时的接收和发总数据, +public interface Translator { + //服务方转化接收的数据 + Osx.Inbound translateReceiveInbound(Context context, Osx.Inbound inbound); + //请求方转化接受到的返回数据 + Osx.Outbound translateReceiveOutbound(Context context,Osx.Outbound outbound); + //请求方转化发送的数据 + Osx.Inbound translateSendInbound(Context context,Osx.Inbound inbound); + //服务方转化准备返回的数据 + Osx.Outbound translateSendOutbound(Context context,Osx.Outbound outbound); +} diff --git a/java/osx/broker/package.xml b/java/osx/osx-broker/package.xml similarity index 73% rename from java/osx/broker/package.xml rename to java/osx/osx-broker/package.xml index 35387a9715..f05abac5cb 100644 --- a/java/osx/broker/package.xml +++ b/java/osx/osx-broker/package.xml @@ -27,7 +27,7 @@ - /lib + /osx target *.jar @@ -37,7 +37,7 @@ - /lib + /osx/lib target/lib *.jar @@ -47,7 +47,7 @@ - / + /osx/ bin service.sh @@ -56,7 +56,7 @@ unix - /bin + /osx/bin bin transfer.sh @@ -65,7 +65,7 @@ unix - /bin + /osx/bin ../bin *.sh @@ -75,22 +75,28 @@ - /conf - src/main/resources + /osx/conf + ../build - transfer.properties - route_table.json + * - /conf - src/main/resources + /osx/conf/broker + ../build/broker - log4j2.xml + *.* + + /osx/conf/components + ../build/components + + *.* + + \ No newline at end of file diff --git a/java/osx/broker/pom.xml b/java/osx/osx-broker/pom.xml similarity index 84% rename from java/osx/broker/pom.xml rename to java/osx/osx-broker/pom.xml index 19e507635d..9092bbf33e 100644 --- a/java/osx/broker/pom.xml +++ b/java/osx/osx-broker/pom.xml @@ -9,27 +9,27 @@ 4.0.0 - broker + osx-broker osx - core + osx-core ${osx.version} - org.eclipse.jetty - jetty-server + org.eclipse.jetty + jetty-server com.google.guava guava - - - - + + com.lmax + disruptor + org.apache.commons commons-lang3 @@ -58,7 +58,10 @@ io.grpc grpc-stub - + + commons-net + commons-net + org.apache.curator curator-recipes @@ -106,18 +109,6 @@ - - org.junit.platform - junit-platform-launcher - 1.0.1 - test - - - org.junit.jupiter - junit-jupiter-engine - 5.0.1 - test - org.junit.vintage junit-vintage-engine @@ -134,13 +125,13 @@ - net.java.dev.jna - jna + net.java.dev.jna + jna - commons-validator - commons-validator + commons-validator + commons-validator @@ -171,5 +162,4 @@ - \ No newline at end of file diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/Bootstrap.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/Bootstrap.java new file mode 100644 index 0000000000..0849e2e641 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/Bootstrap.java @@ -0,0 +1,93 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.fedai.osx.broker; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Option; +import org.apache.commons.cli.Options; +import org.apache.commons.cli.PosixParser; +import org.apache.commons.lang3.StringUtils; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.jvm.JvmInfoCounter; +import org.fedai.osx.core.utils.PropertiesUtil; +import org.fedai.osx.core.utils.ServerUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Properties; +public class Bootstrap { + static Logger logger = LoggerFactory.getLogger(Bootstrap.class); + static CommandLine commandLine; + static Object lockObject= new Object(); + public static void main(String[] args) { + try { + Options options = ServerUtil.buildCommandlineOptions(new Options()); + commandLine = ServerUtil.parseCmdLine("osx", args, buildCommandlineOptions(options), + new PosixParser()); + String configDir = commandLine.getOptionValue('c'); + logger.info("try to parse config dir {}", configDir); + if (StringUtils.isEmpty(configDir)) { + System.err.println("config file is not set ,please use -c to set the config file dir path"); + System.exit(-1); + } + parseConfig(configDir); + Bootstrap bootstrap = new Bootstrap(); + bootstrap.start(args); + Thread shutDownThread = new Thread(bootstrap::stop); + Runtime.getRuntime().addShutdownHook(shutDownThread); + synchronized (lockObject){ + lockObject.wait(); + } + + } catch (Exception ex) { + logger.error("broker start failed ",ex); + ex.printStackTrace(); + System.exit(1); + } + } + + private static Options buildCommandlineOptions(final Options options) { + Option opt = new Option("c", "configFile", true, "config properties file"); + opt.setRequired(false); + options.addOption(opt); + return options; + } + + public static void parseConfig(String configDir) { + try { + MetaInfo.PROPERTY_CONFIG_DIR = configDir; + String configFilePath = configDir+ "/broker/broker.properties"; + Properties environment = PropertiesUtil.getProperties(configFilePath); + MetaInfo.init(environment); + } catch (Exception e) { + logger.error("init MetaInfo error", e); + System.exit(1); + } + } + + public void start(String[] args) { + ServiceContainer.init(); + JvmInfoCounter.start(); + } + + public void stop() { + logger.info("try to shutdown server ..."); + if (ServiceContainer.transferQueueManager != null) { + ServiceContainer.transferQueueManager.destroyAll(); + } + } + +} \ No newline at end of file diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ServiceContainer.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ServiceContainer.java new file mode 100644 index 0000000000..8aeaa6546b --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ServiceContainer.java @@ -0,0 +1,189 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.fedai.osx.broker; + + +import org.fedai.osx.broker.consumer.ConsumerManager; +import org.fedai.osx.broker.eggroll.EventDriverMsgManager; +import org.fedai.osx.broker.grpc.PcpGrpcService; +import org.fedai.osx.broker.http.HttpClientPool; +import org.fedai.osx.broker.queue.TransferQueueManager; +import org.fedai.osx.broker.router.DefaultFateRouterServiceImpl; +import org.fedai.osx.broker.router.FateRouterService; +import org.fedai.osx.broker.router.RouterRegister; +import org.fedai.osx.broker.security.TokenGeneratorRegister; +import org.fedai.osx.broker.security.TokenValidatorRegister; +import org.fedai.osx.broker.server.OsxServer; +import org.fedai.osx.broker.service.TokenApplyService; +import org.fedai.osx.broker.token.DefaultTokenService; +import org.fedai.osx.broker.zk.CuratorZookeeperClient; +import org.fedai.osx.broker.zk.ZkConfig; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.flow.ClusterFlowRuleManager; +import org.fedai.osx.core.flow.FlowCounterManager; +import org.fedai.osx.core.service.AbstractServiceAdaptor; +import org.fedai.osx.tech.provider.TechProviderRegister; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.Map; + +public class ServiceContainer { + static public ConsumerManager consumerManager; + static public TransferQueueManager transferQueueManager; + static public FlowCounterManager flowCounterManager; + static public OsxServer transferServer; + static public Map serviceAdaptorMap = new HashMap(); + static public TokenApplyService tokenApplyService; + static public ClusterFlowRuleManager clusterFlowRuleManager; + static public DefaultTokenService defaultTokenService; + static public CuratorZookeeperClient zkClient; + //厂商注册 + static public TechProviderRegister techProviderRegister; + static public EventDriverMsgManager eventDriverMsgManager; + //Token校验器,用于双方token校验 + static public TokenValidatorRegister tokenValidatorRegister; + //Token生成器注册,用于双方token校验 + static public TokenGeneratorRegister tokenGeneratorRegister; + + static public RouterRegister routerRegister; + + + static Logger logger = LoggerFactory.getLogger(ServiceContainer.class); + + public static void init() { + flowCounterManager = createFlowCounterManager(); + clusterFlowRuleManager = createClusterFlowRuleManager(); + zkClient = createCuratorZookeeperClient(); + transferQueueManager = createTransferQueueManager(); + consumerManager = createTransferQueueConsumerManager(); + tokenApplyService = createTokenApplyService(); + transferServer = new OsxServer(); + defaultTokenService = createDefaultTokenService(); + tokenApplyService = createTokenApplyService(); + eventDriverMsgManager = createEventDriverMsgManager( consumerManager, transferQueueManager); + techProviderRegister = createTechProviderRegister(); + tokenValidatorRegister = createTokenValidatorRegister(); + tokenGeneratorRegister = createTokenGeneratorRegister(); + routerRegister = createRouterRegister(); + HttpClientPool.initPool(); + if (!transferServer.start()) { + logger.error("server start failed"); + System.err.println("server start failed"); + System.exit(-1); + } else { + + }; + + + } + + + private static RouterRegister createRouterRegister(){ + RouterRegister routerRegister = new RouterRegister(); + routerRegister.init(); + routerRegister.start(); + return routerRegister; + } + + private static TokenValidatorRegister createTokenValidatorRegister(){ + TokenValidatorRegister tokenValidatorRegister = new TokenValidatorRegister(); + tokenValidatorRegister.init(); + tokenValidatorRegister.start(); + return tokenValidatorRegister; + } + + private static TokenGeneratorRegister createTokenGeneratorRegister(){ + TokenGeneratorRegister tokenGeneratorRegister = new TokenGeneratorRegister(); + tokenGeneratorRegister.init(); + tokenGeneratorRegister.start(); + return tokenGeneratorRegister; + } + + + private static EventDriverMsgManager createEventDriverMsgManager(ConsumerManager consumerManager,TransferQueueManager transferQueueManager){ + EventDriverMsgManager eventDriverMsgManager = new EventDriverMsgManager(consumerManager,transferQueueManager); + eventDriverMsgManager.init(); + eventDriverMsgManager.start(); + return eventDriverMsgManager; + } + + + public static TechProviderRegister createTechProviderRegister() { + try { + TechProviderRegister techProviderRegister = new TechProviderRegister(); + techProviderRegister.start(); + return techProviderRegister; + }catch(Exception e){ + logger.error("tech provider create error",e); + } + return null; + + } + + public static PcpGrpcService createPcpGrpcService() { + return new PcpGrpcService(); + } + + public static CuratorZookeeperClient createCuratorZookeeperClient() { + if (MetaInfo.isCluster()) { + ZkConfig zkConfig = new ZkConfig(MetaInfo.PROPERTY_ZK_URL, 5000); + return new CuratorZookeeperClient(zkConfig); + } + return null; + } + + public static TokenApplyService createTokenApplyService() { + TokenApplyService tokenApplyService = new TokenApplyService(); + tokenApplyService.start(); + return tokenApplyService; + } + + public static DefaultTokenService createDefaultTokenService() { + return new DefaultTokenService(); + } + + public static ClusterFlowRuleManager createClusterFlowRuleManager() { + return new ClusterFlowRuleManager(); + } + + static FlowCounterManager createFlowCounterManager() { + FlowCounterManager flowCounterManager = new FlowCounterManager("transfer"); + flowCounterManager.startReport(); + return flowCounterManager; + } + + static ConsumerManager createTransferQueueConsumerManager() { + ConsumerManager consumerManager = new ConsumerManager(); + return consumerManager; + } + + static FateRouterService createFateRouterService() { + DefaultFateRouterServiceImpl fateRouterService = new DefaultFateRouterServiceImpl(); + fateRouterService.start(); + return fateRouterService; + } + + static TransferQueueManager createTransferQueueManager() { + TransferQueueManager transferQueueManager = new TransferQueueManager(); + return transferQueueManager; + } + + + + +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/buffer/BufferStatus.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/buffer/BufferStatus.java similarity index 95% rename from java/osx/broker/src/main/java/com/osx/broker/buffer/BufferStatus.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/buffer/BufferStatus.java index 2920501f57..baba6ece95 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/buffer/BufferStatus.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/buffer/BufferStatus.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.buffer; +package org.fedai.osx.broker.buffer; public enum BufferStatus { FREE, diff --git a/java/osx/broker/src/main/java/com/osx/broker/buffer/ReadResult.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/buffer/ReadResult.java similarity index 97% rename from java/osx/broker/src/main/java/com/osx/broker/buffer/ReadResult.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/buffer/ReadResult.java index 71571cd93b..b867f80d3b 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/buffer/ReadResult.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/buffer/ReadResult.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.buffer; +package org.fedai.osx.broker.buffer; public class ReadResult { ReadStatus status; diff --git a/java/osx/broker/src/main/java/com/osx/broker/buffer/ReadStatus.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/buffer/ReadStatus.java similarity index 94% rename from java/osx/broker/src/main/java/com/osx/broker/buffer/ReadStatus.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/buffer/ReadStatus.java index 50962086f2..80442b6306 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/buffer/ReadStatus.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/buffer/ReadStatus.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.buffer; +package org.fedai.osx.broker.buffer; public enum ReadStatus { OK, ERROR, DISCARD; diff --git a/java/osx/broker/src/main/java/com/osx/broker/buffer/TransferBufferUtil.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/buffer/TransferBufferUtil.java similarity index 96% rename from java/osx/broker/src/main/java/com/osx/broker/buffer/TransferBufferUtil.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/buffer/TransferBufferUtil.java index 3c0e7ca4fc..62cccd8a2c 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/buffer/TransferBufferUtil.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/buffer/TransferBufferUtil.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.buffer; +package org.fedai.osx.broker.buffer; import java.nio.ByteBuffer; diff --git a/java/osx/broker/src/main/java/com/osx/broker/buffer/WriteResult.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/buffer/WriteResult.java similarity index 94% rename from java/osx/broker/src/main/java/com/osx/broker/buffer/WriteResult.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/buffer/WriteResult.java index 5534a4dc36..d3c6748c7e 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/buffer/WriteResult.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/buffer/WriteResult.java @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.buffer; +package org.fedai.osx.broker.buffer; -import com.osx.core.utils.JsonUtil; +import org.fedai.osx.core.utils.JsonUtil; public class WriteResult { WriteStatus status; diff --git a/java/osx/broker/src/main/java/com/osx/broker/buffer/WriteStatus.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/buffer/WriteStatus.java similarity index 94% rename from java/osx/broker/src/main/java/com/osx/broker/buffer/WriteStatus.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/buffer/WriteStatus.java index f1d66faf71..1512006794 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/buffer/WriteStatus.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/buffer/WriteStatus.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.buffer; +package org.fedai.osx.broker.buffer; public enum WriteStatus { OK, FULL, ERROR diff --git a/java/osx/broker/src/main/java/com/osx/broker/callback/CompleteCallback.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/callback/CompleteCallback.java similarity index 94% rename from java/osx/broker/src/main/java/com/osx/broker/callback/CompleteCallback.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/callback/CompleteCallback.java index d72e5ee4de..2d56caddfb 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/callback/CompleteCallback.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/callback/CompleteCallback.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.callback; +package org.fedai.osx.broker.callback; @FunctionalInterface public interface CompleteCallback { diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/callback/CreateUserCallback.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/callback/CreateUserCallback.java new file mode 100644 index 0000000000..efab9d9f8b --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/callback/CreateUserCallback.java @@ -0,0 +1,35 @@ +package org.fedai.osx.broker.callback; + +import org.fedai.osx.broker.ServiceContainer; +import org.fedai.osx.broker.consumer.GrpcEventHandler; +import org.fedai.osx.broker.message.MessageExt; +import org.fedai.osx.broker.queue.TransferQueue; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class CreateUserCallback implements MsgEventCallback{ + + Logger logger = LoggerFactory.getLogger(CreateUserCallback.class); + public CreateUserCallback(Class eventHandlerClass){ + this.grpcEventHandlerClass = eventHandlerClass; + + } + Class grpcEventHandlerClass ; + + @Override + public void callback(TransferQueue queue , MessageExt message) { + String topic = queue.getTransferId(); + if(ServiceContainer.consumerManager.getEventDrivenConsumer(topic)==null){ + GrpcEventHandler grpcEventHandler = null; + try { + grpcEventHandler = (GrpcEventHandler)grpcEventHandlerClass.newInstance(); + } catch (InstantiationException e) { + throw new RuntimeException(e); + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } + ServiceContainer.consumerManager.createEventDrivenConsumer(topic,grpcEventHandler); + }; + } + +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/callback/DestoryCallback.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/callback/DestoryCallback.java similarity index 94% rename from java/osx/broker/src/main/java/com/osx/broker/callback/DestoryCallback.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/callback/DestoryCallback.java index 8b2fde4874..66b73b6930 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/callback/DestoryCallback.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/callback/DestoryCallback.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.callback; +package org.fedai.osx.broker.callback; @FunctionalInterface public interface DestoryCallback { diff --git a/java/osx/broker/src/main/java/com/osx/broker/callback/ErrorCallback.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/callback/ErrorCallback.java similarity index 94% rename from java/osx/broker/src/main/java/com/osx/broker/callback/ErrorCallback.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/callback/ErrorCallback.java index 504667eee4..de50ac727e 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/callback/ErrorCallback.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/callback/ErrorCallback.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.callback; +package org.fedai.osx.broker.callback; public interface ErrorCallback { public void callback(Throwable e); diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/callback/MockDesGrpcEventHandler.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/callback/MockDesGrpcEventHandler.java new file mode 100644 index 0000000000..a5c2ca6784 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/callback/MockDesGrpcEventHandler.java @@ -0,0 +1,78 @@ +//package org.fedai.osx.broker.callback; +// +//import com.google.protobuf.ByteString; +//import com.google.protobuf.InvalidProtocolBufferException; +//import org.fedai.osx.broker.ServiceContainer; +//import org.fedai.osx.broker.constants.MessageFlag; +//import org.fedai.osx.broker.consumer.GrpcEventHandler; +//import org.fedai.osx.broker.consumer.MessageEvent; +//import org.fedai.osx.broker.message.MessageExt; +//import org.fedai.osx.broker.util.TransferUtil; +//import org.fedai.osx.core.constant.Dict; +//import org.fedai.osx.core.constant.TransferStatus; +//import org.fedai.osx.core.frame.GrpcConnectionFactory; +//import org.fedai.osx.core.ptp.TargetMethod; +//import org.fedai.osx.core.router.RouterInfo; +//import com.webank.ai.eggroll.api.networking.proxy.DataTransferServiceGrpc; +//import io.grpc.ManagedChannel; +//import org.ppc.ptp.Osx; +//import org.ppc.ptp.PrivateTransferProtocolGrpc; +//import org.slf4j.Logger; +//import org.slf4j.LoggerFactory; +// +//import java.nio.charset.StandardCharsets; +// +//public class MockDesGrpcEventHandler extends GrpcEventHandler { +// +// +// +// Logger logger = LoggerFactory.getLogger(MockDesGrpcEventHandler.class); +// +// PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub blockingStub; +// @Override +// protected void handleMessage(MessageExt message) { +// +// String topic = message.getTopic(); +// String srcPartyId = message.getSrcPartyId(); +// String desPartyId = message .getDesPartyId(); +// try { +// Osx.Inbound inbound = Osx.Inbound.parseFrom(message.getBody()); +// logger.info("receive message topic {} srcPartyId {} desPartyId {} msg {}",topic,srcPartyId,desPartyId,new String(inbound.getPayload().toByteArray())); +// } catch (InvalidProtocolBufferException e) { +// e.printStackTrace(); +// } +// +// } +// +// @Override +// protected void handleError(MessageExt message) { +// logger.info("handle error : {}",new String(message.getBody())); +// } +// +// @Override +// protected void handleComplete(MessageExt message) { +// logger.info("receive complete"); +// +// } +// +// @Override +// protected void handleInit(MessageEvent event) { +// +// logger.info("init================= {} {} {} {} {}",topic, backTopic,srcPartyId,desPartyId,sessionId); +// new Thread(new Runnable() { +// @Override +// public void run() { +// for(int i=0;i<10;i++){ +// +// Osx.Outbound outBound = Osx.Outbound.newBuilder().setPayload(ByteString.copyFrom("my name is god".getBytes(StandardCharsets.UTF_8))).build(); +// sendBackMsg(outBound.toByteArray()); +// if(i==9){ +// sendBackCompleted(); +// } +// } +// } +// }).start(); +// } +// +// +//} diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/callback/MsgEventCallback.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/callback/MsgEventCallback.java new file mode 100644 index 0000000000..7ef5595bdc --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/callback/MsgEventCallback.java @@ -0,0 +1,9 @@ +package org.fedai.osx.broker.callback; + +import org.fedai.osx.broker.message.MessageExt; +import org.fedai.osx.broker.queue.TransferQueue; + +@FunctionalInterface +public interface MsgEventCallback { + void callback(TransferQueue transferQueue , MessageExt message) throws Exception; +} diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/callback/MsgEventDispatchCallback.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/callback/MsgEventDispatchCallback.java new file mode 100644 index 0000000000..56e189710d --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/callback/MsgEventDispatchCallback.java @@ -0,0 +1,35 @@ +package org.fedai.osx.broker.callback; + +import org.fedai.osx.broker.ServiceContainer; +import org.fedai.osx.broker.consumer.EventDrivenConsumer; +import org.fedai.osx.broker.message.MessageExt; +import org.fedai.osx.broker.queue.TransferQueue; + +public class MsgEventDispatchCallback implements MsgEventCallback{ + + + + @Override + public void callback(TransferQueue transferQueue, MessageExt message) throws Exception { + + String topic = transferQueue.getTransferId(); + EventDrivenConsumer eventDrivenConsumer = ServiceContainer.consumerManager.getEventDrivenConsumer(topic); + if(eventDrivenConsumer!=null){ + if(!transferQueue.isHasEventMsgDestoryCallback()) { + transferQueue.registerDestoryCallback(() -> { + ServiceContainer.consumerManager.onComplete(topic); + }); + transferQueue.setHasEventMsgDestoryCallback(true); + } +// MessageEvent messageEvent = new MessageEvent(); +// messageEvent.setTopic(topic); +// +// messageEvent.setDesComponent(message.getProperty(Dict.DES_COMPONENT)); +// messageEvent.setSrcComponent(message.getProperty(Dict.SOURCE_COMPONENT)); +// messageEvent.setSrcPartyId(message.getSrcPartyId()); +// messageEvent.setDesPartyId(message.getDesPartyId()); +// messageEvent.setSessionId(message.getProperty(Dict.SESSION_ID)); + eventDrivenConsumer.fireEvent(message); + } + } +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/constants/Direction.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/constants/Direction.java similarity index 73% rename from java/osx/broker/src/main/java/com/osx/broker/constants/Direction.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/constants/Direction.java index 3b0e0bb7f7..6062936c1b 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/constants/Direction.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/constants/Direction.java @@ -1,4 +1,4 @@ -package com.osx.broker.constants; +package org.fedai.osx.broker.constants; public enum Direction { // RECEIVE, diff --git a/java/osx/broker/src/main/java/com/osx/broker/constants/MessageFlag.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/constants/MessageFlag.java similarity index 85% rename from java/osx/broker/src/main/java/com/osx/broker/constants/MessageFlag.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/constants/MessageFlag.java index ccd8c7fc52..3bbb807969 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/constants/MessageFlag.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/constants/MessageFlag.java @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.constants; +package org.fedai.osx.broker.constants; public enum MessageFlag { - MSG(0), ERROR(1), COMPELETED(2); + SENDMSG(0), ERROR(1), COMPELETED(2),BACKMSG(3); private int flag; @@ -28,11 +28,13 @@ private MessageFlag(int flag) { static public MessageFlag getMessageFlag(int flag) { switch (flag) { case 0: - return MSG; + return SENDMSG; case 1: return ERROR; case 2: return COMPELETED; + case 3: + return BACKMSG; default: return null; } diff --git a/java/osx/broker/src/main/java/com/osx/broker/consumer/ConsumerManager.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/ConsumerManager.java similarity index 62% rename from java/osx/broker/src/main/java/com/osx/broker/consumer/ConsumerManager.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/ConsumerManager.java index 9c8c58231e..d471553ea4 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/consumer/ConsumerManager.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/ConsumerManager.java @@ -13,34 +13,48 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.consumer; +package org.fedai.osx.broker.consumer; + import com.google.common.collect.Maps; -import com.osx.core.frame.ServiceThread; +import org.fedai.osx.core.frame.Lifecycle; +import org.fedai.osx.core.frame.ServiceThread; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; -public class ConsumerManager { +public class ConsumerManager implements Lifecycle { Logger logger = LoggerFactory.getLogger(ConsumerManager.class); - ScheduledExecutorService scheduledExecutorService = new ScheduledThreadPoolExecutor(1); ConcurrentHashMap unaryConsumerMap = new ConcurrentHashMap<>(); - ConcurrentHashMap streamConsumerMap = new ConcurrentHashMap<>(); - ConcurrentHashMap redirectConsumerMap = new ConcurrentHashMap<>(); + ConcurrentHashMap eventDrivenConsumerMap = new ConcurrentHashMap<>(); AtomicLong consumerIdIndex = new AtomicLong(0); + ServiceThread monitorThread = new ServiceThread() { + @Override + public String getServiceName() { + return "monitor"; + } + + @Override + public void run() { + while (true) { + try { + report(); + } catch (Exception igore) { + } + this.waitForRunning(60000); + } + } + }; ServiceThread longPullingThread = new ServiceThread() { @Override public String getServiceName() { return "longPullingThread"; } - @Override public void run() { int interval = 200; @@ -51,13 +65,21 @@ public void run() { longPullingWaitingSize.set(0); answerCount.set(0); unaryConsumerMap.forEach((transferId, unaryConsumer) -> { + try { + //TODO 当transferId 对应的grpc连接断开之后从unaryConsumerMap中移除该transferId +// if(context.getGprcContext().isCancelled()){ +// unaryConsumerMap.remove(transferId); +// return; +// } + answerCount.addAndGet(unaryConsumer.answerLongPulling()); longPullingWaitingSize.addAndGet(unaryConsumer.getLongPullingQueueSize()); } catch (Exception igore) { - + igore.printStackTrace(); } }); + if (longPullingWaitingSize.get() > 0) { interval = 500; } else { @@ -66,15 +88,16 @@ public void run() { } catch (Exception igore) { } + this.waitForRunning(interval); } } }; - public ConsumerManager() { longPullingThread.start(); + monitorThread.start(); } public Map getUnaryConsumerMap() { @@ -95,6 +118,24 @@ public UnaryConsumer getUnaryConsumer(String transferId) { return unaryConsumerMap.get(transferId); } + public EventDrivenConsumer getEventDrivenConsumer(String topic){ + + return this.eventDrivenConsumerMap.get(topic); + + } + + public EventDrivenConsumer createEventDrivenConsumer(String topic, GrpcEventHandler eventHandler){ + logger.info("create event driven consumer , {}",topic); + if (eventDrivenConsumerMap.get(topic) == null) { + EventDrivenConsumer eventDrivenConsumer = + new EventDrivenConsumer(consumerIdIndex.get(), topic,eventHandler); + eventDrivenConsumerMap.putIfAbsent(topic, eventDrivenConsumer); + return eventDrivenConsumerMap.get(topic); + } else { + return eventDrivenConsumerMap.get(topic); + } + } + public UnaryConsumer getOrCreateUnaryConsumer(String transferId) { if (unaryConsumerMap.get(transferId) == null) { UnaryConsumer unaryConsumer = @@ -106,50 +147,39 @@ public UnaryConsumer getOrCreateUnaryConsumer(String transferId) { } } - public StreamConsumer getOrCreateStreamConsumer(String transferId) { - - if (streamConsumerMap.get(transferId) == null) { - StreamConsumer streamConsumer = new StreamConsumer(consumerIdIndex.get(), transferId); - streamConsumerMap.putIfAbsent(transferId, streamConsumer); - return streamConsumerMap.get(transferId); - } else { - return streamConsumerMap.get(transferId); + public void onComplete(String transferId) { + if(this.unaryConsumerMap.contains(transferId)) { + this.unaryConsumerMap.get(transferId).destroy(); + this.unaryConsumerMap.remove(transferId); + } + if(this.eventDrivenConsumerMap.contains(transferId)){ + this.eventDrivenConsumerMap.get(transferId).destroy(); + // this.eventDrivenConsumerMap.remove(transferId); } + + logger.info("remove consumer {}", transferId); } - public synchronized RedirectConsumer getOrCreateRedirectConsumer(String resource) { - logger.info("getOrCreateRedirectConsumer {}", resource); - if (unaryConsumerMap.get(resource) == null) { - RedirectConsumer redirectConsumer = - new RedirectConsumer(consumerIdIndex.get(), resource); - unaryConsumerMap.putIfAbsent(resource, redirectConsumer); - return (RedirectConsumer) unaryConsumerMap.get(resource); - } else { - return (RedirectConsumer) unaryConsumerMap.get(resource); - } + private void checkAndClean() { + } + @Override + public void init() { + + -// public synchronized PushConsumer getOrCreatePushConsumer(String transferId){ -// if (pushConsumerMap.get(transferId) == null) { -// PushConsumer pushConsumer = -// new PushConsumer(consumerIdIndex.get(), transferId); -// pushConsumerMap.putIfAbsent(transferId,pushConsumer); -// return pushConsumerMap.get(transferId); -// } else { -// return pushConsumerMap.get(transferId); -// } -// } - public void onComplete(String transferId) { - this.unaryConsumerMap.remove(transferId); - logger.info("remove consumer {}", transferId); } - /** - * - */ - private void checkAndClean() { + @Override + public void start() { + + } + + @Override + public void destroy() { + } public static class ReportData { diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/EventDrivenConsumer.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/EventDrivenConsumer.java new file mode 100644 index 0000000000..cf93b0e87a --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/EventDrivenConsumer.java @@ -0,0 +1,58 @@ +package org.fedai.osx.broker.consumer; + +import org.fedai.osx.broker.message.MessageExt; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +public class EventDrivenConsumer extends LocalQueueConsumer { + + Logger logger = LoggerFactory.getLogger(EventDrivenConsumer.class); + GrpcEventHandler eventHandler; + // Disruptor disruptor; + + public EventDrivenConsumer(long consumerId, String topic,GrpcEventHandler eventHandler){ + + super(consumerId,topic); + this.eventHandler = eventHandler; +// disruptor = new Disruptor(() -> new MessageEvent(), +// 16, DaemonThreadFactory.INSTANCE, +// ProducerType.SINGLE, new BlockingWaitStrategy()); +// disruptor.handleEventsWith(eventHandler); +// disruptor.start(); + + logger.info("new EventDrivenConsumer {}",topic); + + } +// public static final EventTranslatorOneArg TRANSLATOR = +// (event, sequence, arg) -> { +// event.setTopic(arg.getTopic()); +// event.setDesPartyId(arg.getDesPartyId()); +// event.setSrcComponent(arg.getSrcComponent()); +// event.setSrcPartyId(arg.getSrcPartyId()); +// event.setDesComponent(arg.getDesComponent()); +// event.setSessionId(arg.getSessionId()); +// }; + + public void fireEvent(MessageExt msg) throws Exception { + //disruptor.publishEvent((EventTranslatorOneArg) TRANSLATOR,event); + eventHandler.onEvent(msg); + } + + + @Override + public void destroy() { + + + // this.disruptor.shutdown(); + } + + + public static void main(String[] args){ +// MessageEvent messageEvent = new MessageEvent(); +// EventDrivenConsumer eventDrivenConsumer = new EventDrivenConsumer(0,"test",new MockDesGrpcEventHandler()); +// eventDrivenConsumer.fireEvent(messageEvent); + + } + +} diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/EventDriverRule.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/EventDriverRule.java new file mode 100644 index 0000000000..27aaeb8802 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/EventDriverRule.java @@ -0,0 +1,8 @@ +package org.fedai.osx.broker.consumer; + +import org.fedai.osx.broker.queue.TransferQueue; + +@FunctionalInterface +public interface EventDriverRule { + boolean isMatch(TransferQueue queue); +} diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/GrpcEventHandler.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/GrpcEventHandler.java new file mode 100644 index 0000000000..5b432a87f3 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/GrpcEventHandler.java @@ -0,0 +1,157 @@ +package org.fedai.osx.broker.consumer; + + +import org.fedai.osx.api.router.RouterInfo; +import org.fedai.osx.broker.ServiceContainer; +import org.fedai.osx.broker.constants.MessageFlag; +import org.fedai.osx.broker.message.MessageExt; +import org.fedai.osx.broker.util.TransferUtil; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.constant.Dict; +import org.fedai.osx.core.constant.TransferStatus; +import org.fedai.osx.core.context.FateContext; +import org.fedai.osx.core.exceptions.ExceptionInfo; +import org.fedai.osx.core.ptp.TargetMethod; +import org.ppc.ptp.Osx; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.charset.StandardCharsets; + +public abstract class GrpcEventHandler { + + Logger logger = LoggerFactory.getLogger(GrpcEventHandler.class); + public GrpcEventHandler(String provider){ + this.provider = provider; + } + protected TransferStatus transferStatus = TransferStatus.INIT; + protected String provider; + protected String srcPartyId; + protected String desPartyId; + protected String sessionId; + protected String srcComponent; + protected String desComponent; + protected String topic; + protected String backTopic; + protected RouterInfo backRouterInfo; + protected FateContext context; + + public void sendBackException(ExceptionInfo e){ + if(transferStatus==TransferStatus.TRANSFERING) { + Osx.Inbound.Builder inboundBuilder = TransferUtil.buildInbound(provider,desPartyId, srcPartyId, + TargetMethod.PRODUCE_MSG.name(), backTopic, MessageFlag.COMPELETED, sessionId,e.toString().getBytes(StandardCharsets.UTF_8) ); + TransferUtil.redirect(context,inboundBuilder.build(),backRouterInfo,true); + + }else{ + logger.error("!!!!!!!!!transferStatus is {}",transferStatus); + } + }; + + public void sendBackCompleted(){ + if(transferStatus== TransferStatus.TRANSFERING) { + Osx.Inbound.Builder inboundBuilder = TransferUtil.buildInbound(provider,desPartyId, srcPartyId, + TargetMethod.PRODUCE_MSG.name(), backTopic, MessageFlag.COMPELETED, sessionId, "completed".getBytes(StandardCharsets.UTF_8)); + TransferUtil.redirect(context,inboundBuilder.build(),backRouterInfo,true); + }else{ + logger.error("!!!!!!!!!transferStatus is {}",transferStatus); + } + } + + public void sendBackMsg(byte[] data){ + if(transferStatus== TransferStatus.TRANSFERING) { + Osx.Inbound.Builder inboundBuilder = TransferUtil.buildInbound(provider,desPartyId, srcPartyId, + TargetMethod.PRODUCE_MSG.name(), backTopic, MessageFlag.SENDMSG, sessionId, data); + TransferUtil.redirect(context,inboundBuilder.build(),backRouterInfo,true); + }else{ + logger.error("!!!!!!!!!transferStatus is {}",transferStatus); + } + } + + protected void init(MessageExt message){ + + if(transferStatus==TransferStatus.INIT){ + try { + +// messageEvent.setDesComponent(message.getProperty(Dict.DES_COMPONENT)); +// messageEvent.setSrcComponent(message.getProperty(Dict.SOURCE_COMPONENT)); +// messageEvent.setSrcPartyId(message.getSrcPartyId()); +// messageEvent.setDesPartyId(message.getDesPartyId()); +// messageEvent.setSessionId(message.getProperty(Dict.SESSION_ID)); + context = new FateContext(); + topic = message.getTopic(); + desComponent = message.getProperty(Dict.DES_COMPONENT); + srcComponent = message.getProperty(Dict.SOURCE_COMPONENT); + srcPartyId = message.getSrcPartyId(); + desPartyId = message.getDesPartyId(); + sessionId = message.getProperty(Dict.SESSION_ID); + if (topic.startsWith(Dict.STREAM_SEND_TOPIC_PREFIX)) { + backTopic = topic.replaceAll(Dict.STREAM_SEND_TOPIC_PREFIX, Dict.STREAM_BACK_TOPIC_PREFIX); + } else if (topic.startsWith(Dict.STREAM_BACK_TOPIC_PREFIX)) { + backTopic = topic.replaceAll(Dict.STREAM_BACK_TOPIC_PREFIX, Dict.STREAM_SEND_TOPIC_PREFIX); + } + backRouterInfo = ServiceContainer.routerRegister.getRouterService(MetaInfo.PROPERTY_FATE_TECH_PROVIDER).route(desPartyId,"",srcPartyId,""); + handleInit(message); + transferStatus = TransferStatus.TRANSFERING; + }catch(Throwable e){ + logger.error("grpc event handler init error",e); + transferStatus = TransferStatus.ERROR; + } + } + + + } + + + + public void onEvent(MessageExt messageExt) throws Exception { + + // String topic = event.getTopic(); + +// messageEvent.setDesComponent(message.getProperty(Dict.DES_COMPONENT)); +// messageEvent.setSrcComponent(message.getProperty(Dict.SOURCE_COMPONENT)); +// messageEvent.setSrcPartyId(message.getSrcPartyId()); +// messageEvent.setDesPartyId(message.getDesPartyId()); +// messageEvent.setSessionId(message.getProperty(Dict.SESSION_ID)); + + // logger.info("======event {}",event); + init(messageExt); + if(transferStatus==TransferStatus.TRANSFERING) { +// EventDrivenConsumer consumer = ServiceContainer.consumerManager.getEventDrivenConsumer(topic); +// TransferQueue.TransferQueueConsumeResult transferQueueConsumeResult = consumer.consume(new FateContext(), -1); +// +// if (transferQueueConsumeResult.getCode().equals(StatusCode.SUCCESS)) { +// long index = transferQueueConsumeResult.getRequestIndex(); +// //ack 的位置需要调整 +// consumer.ack(index); +// MessageExt messageExt = transferQueueConsumeResult.getMessage(); +// + int flag = messageExt.getFlag(); + // logger.info("message flag {}", flag); + switch (flag) { + //msg + case 0: + handleMessage(messageExt); + break; + //error + case 1: + handleError(messageExt); + break; + //completed + case 2: + handleComplete(messageExt); + break; + default: + ; + } +// } else { +// // logger.warn("consume error {}", transferQueueConsumeResult); +// } + } + } + + protected abstract void handleMessage(MessageExt message); + protected abstract void handleError(MessageExt message); + protected abstract void handleComplete(MessageExt message); + protected abstract void handleInit(MessageExt message); + +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/consumer/LocalQueueConsumer.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/LocalQueueConsumer.java similarity index 84% rename from java/osx/broker/src/main/java/com/osx/broker/consumer/LocalQueueConsumer.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/LocalQueueConsumer.java index 4521f4a909..27321e7558 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/consumer/LocalQueueConsumer.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/LocalQueueConsumer.java @@ -13,17 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.consumer; +package org.fedai.osx.broker.consumer; -import com.osx.broker.ServiceContainer; -import com.osx.broker.message.SelectMappedBufferResult; -import com.osx.broker.queue.Consumer; -import com.osx.broker.queue.TransferQueue; -import com.osx.core.constant.StatusCode; -import com.osx.core.constant.TransferStatus; -import com.osx.core.context.Context; -import com.osx.core.exceptions.AckIndexException; +import org.fedai.osx.api.context.Context; +import org.fedai.osx.broker.ServiceContainer; +import org.fedai.osx.broker.message.SelectMappedBufferResult; +import org.fedai.osx.broker.queue.Consumer; +import org.fedai.osx.broker.queue.TransferQueue; +import org.fedai.osx.core.constant.StatusCode; +import org.fedai.osx.core.constant.TransferStatus; +import org.fedai.osx.core.exceptions.AckIndexException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -63,6 +63,7 @@ public boolean checkMsgIsArrive(long consumeOffset) { TransferQueue transferQueue = ServiceContainer.transferQueueManager.getQueue(transferId); if (transferQueue != null) { long indexFileOffset = transferQueue.getIndexQueue().getLogicOffset().get(); + logger.info("topic {} need consume {} , {} inqueue",transferId,consumeOffset, indexFileOffset); return consumeOffset <= indexFileOffset; } return false; @@ -124,4 +125,18 @@ public synchronized TransferQueue.TransferQueueConsumeResult consume(Context con } + @Override + public void init() { + + } + + @Override + public void start() { + + } + + @Override + public void destroy() { + + } } diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/MessageEvent.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/MessageEvent.java new file mode 100644 index 0000000000..580937e5af --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/MessageEvent.java @@ -0,0 +1,13 @@ +package org.fedai.osx.broker.consumer; + +import lombok.Data; + +@Data +public class MessageEvent { + String srcPartyId; + String desPartyId; + String srcComponent; + String desComponent; + String topic; + String sessionId ; +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/consumer/RedirectConsumer.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/RedirectConsumer.java similarity index 92% rename from java/osx/broker/src/main/java/com/osx/broker/consumer/RedirectConsumer.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/RedirectConsumer.java index 76e85dcd01..d38277b290 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/consumer/RedirectConsumer.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/RedirectConsumer.java @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.consumer; +package org.fedai.osx.broker.consumer; -import com.osx.core.constant.TransferStatus; -import com.osx.core.router.RouterInfo; +import org.fedai.osx.api.router.RouterInfo; +import org.fedai.osx.core.constant.TransferStatus; import java.util.concurrent.atomic.AtomicBoolean; diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/SourceGrpcEventHandler.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/SourceGrpcEventHandler.java new file mode 100644 index 0000000000..3838914ef1 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/SourceGrpcEventHandler.java @@ -0,0 +1,65 @@ +package org.fedai.osx.broker.consumer; + +import com.google.protobuf.InvalidProtocolBufferException; +import io.grpc.stub.StreamObserver; +import org.fedai.osx.broker.ServiceContainer; +import org.fedai.osx.broker.message.MessageExt; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.exceptions.ExceptionInfo; +import org.fedai.osx.core.utils.JsonUtil; + +/** + * 放在源头,用于接听远端返回 + */ +public class SourceGrpcEventHandler extends GrpcEventHandler{ + + com.google.protobuf.Parser parser; + StreamObserver respStreamObserver; + + + public SourceGrpcEventHandler(StreamObserver respStreamObserver, + com.google.protobuf.Parser parser){ + super(MetaInfo.PROPERTY_FATE_TECH_PROVIDER); + this.parser=parser; + this.respStreamObserver = respStreamObserver; + } + + @Override + protected void handleMessage(MessageExt message) { + + try { + Object data = parser.parseFrom(message.getBody()); + respStreamObserver.onNext(data); + } catch (InvalidProtocolBufferException e) { + logger.error(""); + } + } + + @Override + protected void handleError(MessageExt message) { + try { + ExceptionInfo exceptionInfo = JsonUtil.json2Object(message.getBody(), ExceptionInfo.class); + respStreamObserver.onError(new Throwable(exceptionInfo.getMessage())); + }finally { + String topic =message.getTopic(); + ServiceContainer.transferQueueManager.onCompleted(topic); + } + + } + + @Override + protected void handleComplete(MessageExt message) { + try { + respStreamObserver.onCompleted(); + }finally { + String topic =message.getTopic(); + ServiceContainer.transferQueueManager.onCompleted(topic); + } + + } + + @Override + protected void handleInit(MessageExt message) { + + } +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/consumer/StreamConsumer.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/StreamConsumer.java similarity index 95% rename from java/osx/broker/src/main/java/com/osx/broker/consumer/StreamConsumer.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/StreamConsumer.java index e9cf06491b..d2705c6bea 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/consumer/StreamConsumer.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/StreamConsumer.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.consumer; +package org.fedai.osx.broker.consumer; public class StreamConsumer extends LocalQueueConsumer { diff --git a/java/osx/broker/src/main/java/com/osx/broker/consumer/UnaryConsumer.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/UnaryConsumer.java similarity index 53% rename from java/osx/broker/src/main/java/com/osx/broker/consumer/UnaryConsumer.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/UnaryConsumer.java index e549831697..f273b77bd2 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/consumer/UnaryConsumer.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/consumer/UnaryConsumer.java @@ -13,26 +13,28 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.consumer; - -import com.osx.broker.ServiceContainer; -import com.osx.broker.queue.TransferQueue; -import com.osx.core.constant.ActionType; -import com.osx.core.constant.StatusCode; -import com.osx.core.context.Context; -import com.osx.core.utils.FlowLogUtil; +package org.fedai.osx.broker.consumer; + import io.grpc.stub.StreamObserver; import lombok.Data; +import org.fedai.osx.broker.ServiceContainer; +import org.fedai.osx.broker.queue.TransferQueue; +import org.fedai.osx.broker.util.TransferUtil; +import org.fedai.osx.core.constant.ActionType; +import org.fedai.osx.core.constant.StatusCode; +import org.fedai.osx.core.context.FateContext; +import org.fedai.osx.core.exceptions.ErrorMessageUtil; +import org.fedai.osx.core.exceptions.TransferQueueNotExistException; +import org.fedai.osx.core.utils.FlowLogUtil; import org.ppc.ptp.Osx; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.servlet.http.HttpServletResponse; import java.util.ArrayList; import java.util.List; import java.util.concurrent.ConcurrentLinkedQueue; -import static com.osx.broker.util.TransferUtil.buildResponse; - public class UnaryConsumer extends LocalQueueConsumer { Logger logger = LoggerFactory.getLogger(UnaryConsumer.class); @@ -42,11 +44,11 @@ public UnaryConsumer(long consumerId, String transferId) { super(consumerId, transferId); TransferQueue transferQueue = ServiceContainer.transferQueueManager.getQueue(transferId); if (transferQueue != null) { - transferQueue.registeDestoryCallback(() -> { + transferQueue.registerDestoryCallback(() -> { ServiceContainer.consumerManager.onComplete(transferId); }); } - longPullingQueue = new ConcurrentLinkedQueue(); + longPullingQueue = new ConcurrentLinkedQueue<>(); } public int getLongPullingQueueSize() { @@ -70,12 +72,29 @@ public synchronized int answerLongPulling() { TransferQueue transferQueue = ServiceContainer.transferQueueManager.getQueue(transferId); List reputList = null; while (this.longPullingQueue.size() > 0) { + LongPullingHold longPullingHold = this.longPullingQueue.poll(); try { - long indexFileOffset = transferQueue.getIndexQueue().getLogicOffset().get(); - LongPullingHold longPullingHold = this.longPullingQueue.poll(); - //StreamObserver streamObserver = longPullingHold.getStreamObserver(); + io.grpc.Context grpcContext = longPullingHold.getGrpcContext(); + if(grpcContext!=null){ + if(grpcContext.isCancelled()){ + logger.error("topic {} consumer grpc context is cancelled",transferId); + continue; + } + } + long current= System.currentTimeMillis(); long needOffset = longPullingHold.getNeedOffset(); - Context context = longPullingHold.getContext(); + if(transferQueue==null){ + // TODO: 2023/7/24 这里需要通知阻塞的客户端,最好是由队列清理时主动通知客户端 + longPullingHold.throwException(new TransferQueueNotExistException()); + continue; + } + + if( longPullingHold.getExpireTimestamp()>0&¤t>longPullingHold.getExpireTimestamp()){ + handleExpire(longPullingHold); + continue; + } + + FateContext context = longPullingHold.getContext(); context.setActionType(ActionType.LONG_PULLING_ANSWER.getAlias()); TransferQueue.TransferQueueConsumeResult consumeResult = null; if (needOffset <= 0) { @@ -92,17 +111,15 @@ public synchronized int answerLongPulling() { * client 传入的offset 小于等于index,可以消费 */ consumeResult = this.consume(context, needOffset); - } } if (consumeResult != null) { if (consumeResult.getMessage() != null && consumeResult.getMessage().getBody() != null) context.setDataSize(consumeResult.getMessage().getBody().length); - Osx.Outbound consumeResponse = buildResponse(StatusCode.SUCCESS, "success", consumeResult); + Osx.Outbound consumeResponse = TransferUtil.buildResponse(StatusCode.SUCCESS, "success", consumeResult); answerCount++; - longPullingHold.getStreamObserver().onNext(consumeResponse); - longPullingHold.getStreamObserver().onCompleted(); + longPullingHold.answer(consumeResponse); context.setTopic(transferQueue.getTransferId()); context.setReturnCode(StatusCode.SUCCESS); context.setRequestMsgIndex(consumeResult.getRequestIndex()); @@ -115,10 +132,10 @@ public synchronized int answerLongPulling() { if (reputList == null) reputList = new ArrayList<>(); reputList.add(longPullingHold); - } - } catch (Exception igore) { - + } catch (Exception e) { + logger.error("topic {} answer long pulling error ",transferId,e); + longPullingHold.throwException(e); } } if (reputList != null) { @@ -127,11 +144,49 @@ public synchronized int answerLongPulling() { return answerCount; } + private void handleExpire(LongPullingHold longPullingHold){ + Osx.Outbound consumeResponse = TransferUtil.buildResponse(StatusCode.CONSUME_MSG_TIMEOUT, "CONSUME_MSG_TIMEOUT", null); + longPullingHold.answer(consumeResponse); + } + @Data public static class LongPullingHold { - Context context; + Logger logger = LoggerFactory.getLogger(LongPullingHold.class); + FateContext context; + io.grpc.Context grpcContext; StreamObserver streamObserver; + HttpServletResponse httpServletResponse; + long expireTimestamp; long needOffset; + + public void answer(Osx.Outbound consumeResponse){ + + + if(streamObserver!=null) { + + streamObserver.onNext(consumeResponse); + streamObserver.onCompleted(); + }else if(httpServletResponse!=null){ + TransferUtil.writeHttpRespose(httpServletResponse,consumeResponse.getCode(),consumeResponse.getMessage(),consumeResponse.getPayload()!=null?consumeResponse.getPayload().toByteArray():null); + } + } + public void throwException(Throwable throwable){ + logger.info("============ answer throw exception========"); + try { + if (streamObserver != null) { + streamObserver.onError(ErrorMessageUtil.toGrpcRuntimeException(throwable)); + streamObserver.onCompleted(); + } else if (httpServletResponse != null) { + + // TODO: 2023/7/24 http 处理未添加 + // TransferUtil.writeHttpRespose(httpServletResponse,consumeResponse.getCode(),consumeResponse.getMessage(),consumeResponse.getPayload()!=null?consumeResponse.getPayload().toByteArray():null); + } + }catch(Exception e){ + logger.error("send error back to consumer , occury error",e); + } + } + + } } diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/BaseProto.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/BaseProto.java similarity index 90% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/BaseProto.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/BaseProto.java index 2759e49047..de3a5a9dc3 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/eggroll/BaseProto.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/BaseProto.java @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.eggroll; +package org.fedai.osx.broker.eggroll; -import com.osx.core.utils.JsonUtil; +import org.fedai.osx.core.utils.JsonUtil; public abstract class BaseProto { diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ClusterManagerClient.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ClusterManagerClient.java similarity index 93% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/ClusterManagerClient.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ClusterManagerClient.java index 7c0c772d4b..ab30d75125 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ClusterManagerClient.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ClusterManagerClient.java @@ -13,15 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.eggroll; +package org.fedai.osx.broker.eggroll; import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; -import com.osx.core.exceptions.ParameterException; -import com.osx.core.exceptions.RemoteRpcException; -import com.osx.core.exceptions.SysException; import com.webank.eggroll.core.command.Command; import com.webank.eggroll.core.meta.Meta; +import org.fedai.osx.core.exceptions.RemoteRpcException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -92,8 +90,8 @@ public ErStore getOrCreateStore(ErStore input) { try { Meta.Store oriStore = Meta.Store.parseFrom(result.get(0)); resultErStore = ErStore.parseFromPb(oriStore); - } catch (InvalidProtocolBufferException e) { - e.printStackTrace(); + } catch (InvalidProtocolBufferException igore) { + } } return resultErStore; diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/CommandClient.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/CommandClient.java similarity index 97% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/CommandClient.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/CommandClient.java index fb98d14754..059a828b04 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/eggroll/CommandClient.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/CommandClient.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.eggroll; +package org.fedai.osx.broker.eggroll; import com.google.protobuf.AbstractMessageLite; import com.webank.eggroll.core.command.Command; @@ -65,7 +65,7 @@ public Command.CommandResponse call(CommandURI commandUri, BaseProto... baseProt .addAllArgs(Arrays.stream(baseProtos). map((element) -> ((AbstractMessageLite) element.toProto()).toByteString()).collect(Collectors.toList())) .build(); - logger.info("===call {} {} id {}", erEndpoint.host, erEndpoint.port, id); + ManagedChannel managedChannel = buildManagedChannel(erEndpoint.host, erEndpoint.port); CommandServiceGrpc.CommandServiceBlockingStub stub = CommandServiceGrpc.newBlockingStub(managedChannel); diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/CommandURI.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/CommandURI.java similarity index 97% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/CommandURI.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/CommandURI.java index f60c41e2d0..e7a0e94317 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/eggroll/CommandURI.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/CommandURI.java @@ -13,9 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.eggroll; -import com.osx.core.constant.Dict; +package org.fedai.osx.broker.eggroll; + import org.apache.commons.lang3.StringUtils; +import org.fedai.osx.core.constant.Dict; import java.net.URI; import java.net.URLDecoder; diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErEndpoint.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErEndpoint.java similarity index 97% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/ErEndpoint.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErEndpoint.java index 2e5bc670db..ad7e1fd6fb 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErEndpoint.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErEndpoint.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.eggroll; +package org.fedai.osx.broker.eggroll; import com.webank.eggroll.core.meta.Meta; public class ErEndpoint extends BaseProto { diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErFunctor.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErFunctor.java similarity index 98% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/ErFunctor.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErFunctor.java index 886611f486..356fc6f27b 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErFunctor.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErFunctor.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.eggroll; +package org.fedai.osx.broker.eggroll; import com.google.protobuf.ByteString; import com.webank.eggroll.core.meta.Meta; diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErJob.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErJob.java similarity index 99% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/ErJob.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErJob.java index 28d2913327..6058d27cd2 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErJob.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErJob.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.eggroll; +package org.fedai.osx.broker.eggroll; import com.google.common.collect.Lists; import com.webank.eggroll.core.meta.Meta; diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErPartition.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErPartition.java similarity index 98% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/ErPartition.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErPartition.java index f4281bd3c0..c010631080 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErPartition.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErPartition.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.eggroll; +package org.fedai.osx.broker.eggroll; import com.webank.eggroll.core.meta.Meta; diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErProcessor.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErProcessor.java similarity index 98% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/ErProcessor.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErProcessor.java index b0a1054b9d..23e2711ae5 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErProcessor.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErProcessor.java @@ -13,9 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.eggroll; -import com.osx.core.constant.Dict; +package org.fedai.osx.broker.eggroll; + import com.webank.eggroll.core.meta.Meta; +import org.fedai.osx.core.constant.Dict; import java.util.concurrent.ConcurrentHashMap; diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErRollSiteHeader.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErRollSiteHeader.java similarity index 99% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/ErRollSiteHeader.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErRollSiteHeader.java index ebb02948cc..15dbfe0cf7 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErRollSiteHeader.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErRollSiteHeader.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.eggroll; +package org.fedai.osx.broker.eggroll; import com.google.common.collect.Lists; import com.webank.eggroll.core.transfer.Transfer; diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErSession.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErSession.java similarity index 88% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/ErSession.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErSession.java index 23b4260d98..e0845e1a21 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErSession.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErSession.java @@ -13,19 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.eggroll; +package org.fedai.osx.broker.eggroll; import com.google.common.collect.Lists; import com.google.common.collect.Maps; +import org.apache.commons.lang3.StringUtils; +import org.fedai.osx.core.config.MetaInfo; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.List; import java.util.Map; -import static com.osx.core.config.MetaInfo.PROPERTY_EGGROLL_CLUSTER_MANANGER_IP; -import static com.osx.core.config.MetaInfo.PROPERTY_EGGROLL_CLUSTER_MANANGER_PORT; - public class ErSession { Logger logger = LoggerFactory.getLogger(ErSession.class); @@ -45,44 +44,28 @@ public ErSession(String sessionId, boolean createIfNotExists) { this.sessionId = sessionId; this.createIfNotExists = createIfNotExists; - clusterManagerClient = new ClusterManagerClient(new CommandClient(new ErEndpoint(PROPERTY_EGGROLL_CLUSTER_MANANGER_IP, PROPERTY_EGGROLL_CLUSTER_MANANGER_PORT.intValue()))); + clusterManagerClient = new ClusterManagerClient(new CommandClient(new ErEndpoint(MetaInfo.PROPERTY_EGGROLL_CLUSTER_MANANGER_IP, MetaInfo.PROPERTY_EGGROLL_CLUSTER_MANANGER_PORT.intValue()))); ErSessionMeta erSessionMetaArgs = new ErSessionMeta(); - erSessionMetaArgs.setId(sessionId); erSessionMetaArgs.setName(name); erSessionMetaArgs.setStatus(status.name()); erSessionMetaArgs.setTag(tag); erSessionMetaArgs.setProcessors(this.processors); erSessionMetaArgs.setOptions(options); - logger.info("create ErSession ============{}", erSessionMetaArgs); + if (createIfNotExists) { if (processors.isEmpty()) { erSessionMeta = clusterManagerClient.getOrCreateSession(erSessionMetaArgs); } else { - - erSessionMeta = clusterManagerClient.registerSession(erSessionMetaArgs); } } else { erSessionMeta = clusterManagerClient.getSession(erSessionMetaArgs); - } - - logger.info("===============dddddd=============={} ", erSessionMeta); - processors = erSessionMeta.getProcessors(); - status = SessionStatus.valueOf(erSessionMeta.getStatus()); - // processors.foreach(p => { -// val processorType = p.processorType -// if (processorType.toLowerCase().startsWith("egg_")) { -// eggs_buffer.getOrElseUpdate(p.serverNodeId, ArrayBuffer[ErProcessor]()) += p -// } else if (processorType.toLowerCase().startsWith("roll_")) { -// rolls_buffer += p -// } else { -// throw new IllegalArgumentException(s"processor type ${processorType} not supported in roll pair") -// } -// }) - + if(StringUtils.isNotEmpty(erSessionMeta.getStatus())) { + status = SessionStatus.valueOf(erSessionMeta.getStatus()); + } processors.forEach((processor -> { if (processor.getProcessorType().toLowerCase().startsWith("egg_")) { diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErSessionMeta.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErSessionMeta.java similarity index 98% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/ErSessionMeta.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErSessionMeta.java index c76665fa6d..78e8684ce2 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErSessionMeta.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErSessionMeta.java @@ -13,9 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.eggroll; -import com.osx.core.utils.JsonUtil; +package org.fedai.osx.broker.eggroll; + import com.webank.eggroll.core.meta.Meta; +import org.fedai.osx.core.utils.JsonUtil; import java.util.List; import java.util.Map; diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErStore.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErStore.java similarity index 98% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/ErStore.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErStore.java index 9a402300ea..105b99dd98 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErStore.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErStore.java @@ -13,11 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.eggroll; +package org.fedai.osx.broker.eggroll; + import com.google.common.collect.Lists; import com.google.common.collect.Maps; -import com.osx.core.utils.JsonUtil; import com.webank.eggroll.core.meta.Meta; +import org.fedai.osx.core.utils.JsonUtil; import java.util.List; import java.util.Map; diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErStoreLocator.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErStoreLocator.java similarity index 97% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/ErStoreLocator.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErStoreLocator.java index 7fbf16f34a..4ee3975727 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErStoreLocator.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErStoreLocator.java @@ -13,10 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.eggroll; -import com.osx.core.utils.JsonUtil; +package org.fedai.osx.broker.eggroll; + import com.webank.eggroll.core.meta.Meta; import org.apache.commons.lang3.StringUtils; +import org.fedai.osx.core.utils.JsonUtil; public class ErStoreLocator extends BaseProto { diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErTask.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErTask.java similarity index 97% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/ErTask.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErTask.java index cf42a4930f..3c957e99a1 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/eggroll/ErTask.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErTask.java @@ -13,9 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.eggroll; -import com.osx.core.constant.Dict; +package org.fedai.osx.broker.eggroll; + import com.webank.eggroll.core.meta.Meta; +import org.fedai.osx.core.constant.Dict; import java.util.ArrayList; import java.util.List; diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/EventDriverMsgManager.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/EventDriverMsgManager.java new file mode 100644 index 0000000000..274c19395d --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/EventDriverMsgManager.java @@ -0,0 +1,59 @@ +package org.fedai.osx.broker.eggroll; + +import com.google.common.collect.Lists; +import org.fedai.osx.broker.ServiceContainer; +import org.fedai.osx.broker.callback.CreateUserCallback; +import org.fedai.osx.broker.callback.MsgEventDispatchCallback; +import org.fedai.osx.broker.consumer.ConsumerManager; +import org.fedai.osx.broker.queue.TransferQueueManager; +import org.fedai.osx.core.constant.Dict; +import org.fedai.osx.core.frame.Lifecycle; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class EventDriverMsgManager implements Lifecycle { + + Logger logger = LoggerFactory.getLogger(EventDriverMsgManager.class); + ConsumerManager consumerManager=null; + TransferQueueManager transferQueueManager=null; + public EventDriverMsgManager(ConsumerManager consumerManager,TransferQueueManager transferQueueManager){ + this.consumerManager = consumerManager; + this.transferQueueManager = transferQueueManager; + } + + + + + @Override + public void init() { + MsgEventDispatchCallback dispatchCallback = new MsgEventDispatchCallback(); + ServiceContainer.transferQueueManager.addMsgCallBackRule((queue -> { + if(queue.getTransferId().startsWith(Dict.STREAM_SEND_TOPIC_PREFIX)){ + return true; + } + return false; + }), Lists.newArrayList(new CreateUserCallback(PushEventHandler.class),dispatchCallback)); + ServiceContainer.transferQueueManager.addMsgCallBackRule((queue -> { + if(queue.getTransferId().startsWith(Dict.STREAM_BACK_TOPIC_PREFIX)){ + return true; + } + return false; + }), Lists.newArrayList(dispatchCallback)); + + } + + + + + + @Override + public void start() { + + + } + + @Override + public void destroy() { + + } +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/IdUtils.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/IdUtils.java similarity index 95% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/IdUtils.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/IdUtils.java index 431cd008db..ee2e6e6040 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/eggroll/IdUtils.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/IdUtils.java @@ -13,9 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.eggroll; -import com.osx.broker.util.TimeUtils; +package org.fedai.osx.broker.eggroll; + import org.apache.commons.lang3.StringUtils; +import org.fedai.osx.broker.util.TimeUtils; public class IdUtils { private static String job = "job"; diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/MetaCommnads.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/MetaCommnads.java similarity index 95% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/MetaCommnads.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/MetaCommnads.java index 4adba482dd..8151e65b4c 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/eggroll/MetaCommnads.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/MetaCommnads.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.eggroll; +package org.fedai.osx.broker.eggroll; public class MetaCommnads { diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/PartitionerTypes.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/PartitionerTypes.java similarity index 94% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/PartitionerTypes.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/PartitionerTypes.java index 0001956e0a..8baa78d177 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/eggroll/PartitionerTypes.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/PartitionerTypes.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.eggroll; +package org.fedai.osx.broker.eggroll; public enum PartitionerTypes { diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/PushEventHandler.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/PushEventHandler.java new file mode 100644 index 0000000000..8761026ee3 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/PushEventHandler.java @@ -0,0 +1,292 @@ +package org.fedai.osx.broker.eggroll; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; +import com.webank.ai.eggroll.api.networking.proxy.Proxy; +import com.webank.eggroll.core.command.Command; +import com.webank.eggroll.core.meta.Meta; +import com.webank.eggroll.core.transfer.Transfer; +import com.webank.eggroll.core.transfer.TransferServiceGrpc; +import io.grpc.ManagedChannel; +import io.grpc.stub.StreamObserver; +import org.apache.commons.lang3.StringUtils; +import org.fedai.osx.api.constants.Protocol; +import org.fedai.osx.api.router.RouterInfo; +import org.fedai.osx.broker.ServiceContainer; +import org.fedai.osx.broker.constants.MessageFlag; +import org.fedai.osx.broker.consumer.GrpcEventHandler; +import org.fedai.osx.broker.message.MessageExt; +import org.fedai.osx.broker.util.TransferUtil; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.constant.ActionType; +import org.fedai.osx.core.constant.Dict; +import org.fedai.osx.core.constant.TransferStatus; +import org.fedai.osx.core.context.FateContext; +import org.fedai.osx.core.exceptions.*; +import org.fedai.osx.core.frame.GrpcConnectionFactory; +import org.fedai.osx.core.ptp.TargetMethod; +import org.fedai.osx.core.utils.ToStringUtils; +import org.ppc.ptp.Osx; +import org.ppc.ptp.PrivateTransferProtocolGrpc; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; + +public class PushEventHandler extends GrpcEventHandler { + Logger logger = LoggerFactory.getLogger(PushEventHandler.class); + public PushEventHandler(){ + super(MetaInfo.PROPERTY_FATE_TECH_PROVIDER); + } + TransferStatus transferStatus= TransferStatus.INIT; + FateContext context = new FateContext(); + RouterInfo routerInfo ; + Proxy.Metadata metadata; + String brokerTag; + ErRollSiteHeader rsHeader = null; + CountDownLatch finishLatch; + StreamObserver putBatchSinkPushReqSO; + String topic = null; + String backTopic = null; + PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub backBlockingStub; + String desRole = null; + String srcRole = null; + String sessionId = null; + RouterInfo revertRouterInfo; + + protected void handleError(MessageExt messageExt){ + //todo + // 需要构建新异常 + try { + + if (putBatchSinkPushReqSO != null) { + putBatchSinkPushReqSO.onError(new Exception()); + } + }finally { + String topic = messageExt.getTopic(); + ServiceContainer.transferQueueManager.onCompleted(topic); + } + } + + protected void handleComplete(MessageExt messageExt){ + try { + if (putBatchSinkPushReqSO != null) { + putBatchSinkPushReqSO.onCompleted(); + } + }finally { + String topic = messageExt.getTopic(); + ServiceContainer.transferQueueManager.onCompleted(topic); + } + + + } + + @Override + protected void handleInit(MessageExt event) { + + } + + protected void handleMessage(MessageExt messageExt){ + try { + Proxy.Packet packet=null; + try { + packet = Proxy.Packet.parseFrom(messageExt.getBody()); + }catch (Exception e){ + logger.error("parse packet error {}",new String(messageExt.getBody())); + } + if (transferStatus.equals(TransferStatus.INIT)) { + //初始化 + try { + initEggroll(packet,messageExt); + }catch(Exception e){ + logger.error("init eggroll error",e); + transferStatus=TransferStatus.ERROR; + } + } + if (!transferStatus.equals(TransferStatus.TRANSFERING)) { + throw new RemoteRpcException("eggroll init error"); + } + + Transfer.TransferHeader.Builder transferHeaderBuilder = Transfer.TransferHeader.newBuilder(); + Transfer.TransferHeader tbHeader = transferHeaderBuilder.setId((int) metadata.getSeq()) + .setTag(brokerTag) + .setExt(packet.getHeader().getExt()).build(); + Transfer.TransferBatch.Builder transferBatchBuilder = Transfer.TransferBatch.newBuilder(); + Transfer.TransferBatch tbBatch = transferBatchBuilder.setHeader(tbHeader) + .setData(packet.getBody().getValue()) + .build(); + putBatchSinkPushReqSO.onNext(tbBatch); + + }catch(Exception e){ + logger.error("handle msg error : "+ messageExt.getTopic(),e); + if(backBlockingStub!=null) { + Osx.Inbound.Builder inboundBuilder = TransferUtil.buildInbound(provider,desPartyId, srcPartyId, TargetMethod.PRODUCE_MSG.name(), + backTopic, MessageFlag.ERROR, sessionId, ErrorMessageUtil.buildRemoteRpcErrorMsg(1343,"kkkkk").getBytes()); + Osx.Outbound outbound = backBlockingStub.invoke(inboundBuilder.build()); + }else{ + logger.error("back stub is null"); + } + } + } + + private void initEggroll(Proxy.Packet firstRequest,MessageExt messageExt) throws Exception { + if (StringUtils.isEmpty(MetaInfo.PROPERTY_EGGROLL_CLUSTER_MANANGER_IP)) { + throw new SysException("eggroll cluter manager ip is not found"); + } + + topic = messageExt.getTopic(); + backTopic= buildBackTopic(topic); + metadata = firstRequest.getHeader(); + ByteString encodedRollSiteHeader = metadata.getExt(); + rsHeader = ErRollSiteHeader.parseFromPb(Transfer.RollSiteHeader.parseFrom(encodedRollSiteHeader)); + Integer partitionId = rsHeader.getPartitionId(); + brokerTag = "putBatch-" + rsHeader.getRsKey("#", "__rsk") + "-" + partitionId; + String oneLineStringMetadata = ToStringUtils.toOneLineString(metadata); + context.setActionType(ActionType.PUSH_EGGROLL.getAlias()); + String rsKey = rsHeader.getRsKey("#", "__rsk"); + sessionId = String.join("_", rsHeader.getRollSiteSessionId(), rsHeader.getDstRole(), rsHeader.getDstPartyId()); + context.setSessionId(sessionId); + desPartyId = metadata.getDst().getPartyId(); + desRole = metadata.getDst().getRole(); + srcRole = metadata.getSrc().getRole(); + srcPartyId = metadata.getSrc().getPartyId(); + //String srcPartyId, String srcRole, String dstPartyId, String desRole + revertRouterInfo = ServiceContainer.routerRegister.getRouterService(MetaInfo.PROPERTY_FATE_TECH_PROVIDER).route(desPartyId,desRole,srcPartyId,srcRole); + if(revertRouterInfo==null){ + throw new NoRouterInfoException(srcPartyId+" can not found route info"); + } + if(Protocol.grpc.equals(revertRouterInfo.getProtocol())) { + ManagedChannel backChannel = GrpcConnectionFactory.createManagedChannel(revertRouterInfo, true); + backBlockingStub = PrivateTransferProtocolGrpc.newBlockingStub(backChannel); + context.putData(Dict.BLOCKING_STUB,backBlockingStub); + } + + + ErSession session = null; + try { + session = PutBatchSinkUtil.sessionCache.get(sessionId); + } catch (ExecutionException e) { + logger.error("get session error ", e); + } + if (!SessionStatus.ACTIVE.name().equals(session.getErSessionMeta().getStatus())) { + logger.error(""); + IllegalStateException error = new IllegalStateException("eggroll session "+sessionId+" status is "+session.getErSessionMeta().getStatus()); + // onError(error); + throw error; + } + + String namespace = rsHeader.getRollSiteSessionId(); + String name = rsKey; + RollPairContext ctx = new RollPairContext(session); + Map rpOptions = Maps.newHashMap(); + rpOptions.putAll(rsHeader.getOptions()); + rpOptions.put(Dict.TOTAL_PARTITIONS_SNAKECASE, rsHeader.getTotalPartitions().toString()); + + if (rsHeader.getDataType().equals("object")) { + rpOptions.put(Dict.SERDES, SerdesTypes.EMPTY.name()); + } else { + rpOptions.put(Dict.SERDES, rsHeader.getOptions().getOrDefault("serdes", SerdesTypes.PICKLE.name())); + } + + // table creates here + RollPair rp = ctx.load(namespace, name, rpOptions); + ErPartition partition = rp.getStore().getPartition(partitionId); + ErProcessor egg = ctx.getErSession().routeToEgg(partition); + String jobId = IdUtils.generateJobId(ctx.getErSession().getSessionId(), brokerTag, "-"); + Map jobOptions = new HashMap<>(); + + jobOptions.putAll(rsHeader.getOptions()); + jobOptions.put(SessionConfKeys.CONFKEY_SESSION_ID, ctx.getErSession().getSessionId()); + ErJob job = new ErJob( + jobId, + RollPair.PUT_BATCH, + Lists.newArrayList(rp.getStore()), + Lists.newArrayList(rp.getStore()), + Lists.newArrayList(), + jobOptions); + + ErTask task = new ErTask(brokerTag, + RollPair.PUT_BATCH, + Lists.newArrayList(partition), + Lists.newArrayList(partition), + job); + + Future commandFuture = RollPairContext.executor.submit(() -> { + CommandClient commandClient = new CommandClient(egg.getCommandEndpoint()); + Command.CommandResponse commandResponse = commandClient.call(RollPair.EGG_RUN_TASK_COMMAND, task); + long begin = System.currentTimeMillis(); + try { + Meta.Task taskMeta = Meta.Task.parseFrom(commandResponse.getResultsList().get(0)); + ErTask erTask = ErTask.parseFromPb(taskMeta); + long now = System.currentTimeMillis(); + return erTask; + } catch (InvalidProtocolBufferException igore) { + + } + return null; + }); + routerInfo = new RouterInfo(); + context.setRouterInfo(routerInfo); + routerInfo.setHost(egg.getTransferEndpoint().getHost()); + routerInfo.setPort(egg.getTransferEndpoint().getPort()); + context.setSrcPartyId(routerInfo.getSourcePartyId()); + context.setDesPartyId(routerInfo.getDesPartyId()); + ManagedChannel eggChannel = GrpcConnectionFactory.createManagedChannel(routerInfo,false); + TransferServiceGrpc.TransferServiceStub stub = TransferServiceGrpc.newStub(eggChannel); + StreamObserver eggSiteServicerPushRespSO; + putBatchSinkPushReqSO = stub.send(new PutBatchSinkPushRespSO(metadata, commandFuture, new StreamObserver(){ + + TransferStatus transferStatus = TransferStatus.INIT; + + private void init(){ + transferStatus= TransferStatus.TRANSFERING; + } + + @Override + public void onNext(Proxy.Metadata metadata) { + //将其对调后再查路由 + Osx.Inbound.Builder inboundBuilder = TransferUtil.buildInbound(provider,desPartyId,srcPartyId,TargetMethod.PRODUCE_MSG.name(), + backTopic,MessageFlag.SENDMSG,sessionId, metadata.toByteString().toByteArray()); + TransferUtil.redirect(context,inboundBuilder.build(),revertRouterInfo,true); + } + + @Override + public void onError(Throwable throwable) { + ExceptionInfo exceptionInfo = new ExceptionInfo(); + exceptionInfo.setMessage(throwable.getMessage()); + String message = throwable.getMessage(); + Osx.Inbound.Builder inboundBuilder = TransferUtil.buildInbound(provider,desPartyId, srcPartyId, TargetMethod.PRODUCE_MSG.name(), + backTopic, MessageFlag.SENDMSG, sessionId, exceptionInfo.toString().getBytes(StandardCharsets.UTF_8)); + TransferUtil.redirect(context,inboundBuilder.build(),revertRouterInfo,true); + + } + + @Override + public void onCompleted() { + /** + * 完成回调 + */ + try { + Osx.Inbound.Builder inboundBuilder = TransferUtil.buildInbound(provider,desPartyId, srcPartyId, TargetMethod.PRODUCE_MSG.name(), + backTopic, MessageFlag.COMPELETED, sessionId, "completed".getBytes(StandardCharsets.UTF_8)); + Osx.Outbound result =TransferUtil.redirect(context, inboundBuilder.build(), revertRouterInfo,true); + }catch (Exception e){ + logger.error("receive completed error",e); + } + } + }, finishLatch)); + transferStatus= TransferStatus.TRANSFERING; + } + + private String buildBackTopic(String oriTopic){ + int length = Dict.STREAM_SEND_TOPIC_PREFIX.length(); + return Dict.STREAM_BACK_TOPIC_PREFIX+oriTopic.substring(length); + } +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/PutBatchSinkPushRespSO.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/PutBatchSinkPushRespSO.java similarity index 92% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/PutBatchSinkPushRespSO.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/PutBatchSinkPushRespSO.java index b521b96523..e19ce3a4e7 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/eggroll/PutBatchSinkPushRespSO.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/PutBatchSinkPushRespSO.java @@ -13,10 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.eggroll; -import com.google.protobuf.InvalidProtocolBufferException; -import com.osx.core.exceptions.ParameterException; -import com.osx.core.utils.ToStringUtils; +package org.fedai.osx.broker.eggroll; import com.webank.ai.eggroll.api.networking.proxy.Proxy; import com.webank.eggroll.core.transfer.Transfer; import io.grpc.stub.StreamObserver; diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/PutBatchSinkUtil.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/PutBatchSinkUtil.java similarity index 71% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/PutBatchSinkUtil.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/PutBatchSinkUtil.java index 5ca1fc72a5..f66d7c1f8c 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/eggroll/PutBatchSinkUtil.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/PutBatchSinkUtil.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.eggroll; +package org.fedai.osx.broker.eggroll; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; @@ -40,18 +40,4 @@ public ErSession load(String sessionId) throws Exception { } ); - -// object PutBatchSinkUtils { -// val sessionCache: LoadingCache[String, ErSession] = CacheBuilder.newBuilder -// .maximumSize(2000) -// .expireAfterWrite(10, TimeUnit.MINUTES) -// .concurrencyLevel(100) -// .recordStats -// .softValues -// .build(new CacheLoader[String, ErSession]() { -// override def load(key: String): ErSession = { -// new ErSession(sessionId = key, createIfNotExists = false) -// } -// }) -// } } diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/RollPair.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/RollPair.java similarity index 99% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/RollPair.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/RollPair.java index 0f0b8a31aa..bb2f25a37d 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/eggroll/RollPair.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/RollPair.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.eggroll; +package org.fedai.osx.broker.eggroll; import java.util.Map; diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/RollPairContext.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/RollPairContext.java similarity index 98% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/RollPairContext.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/RollPairContext.java index 7100eee593..e598437dc7 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/eggroll/RollPairContext.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/RollPairContext.java @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.eggroll; +package org.fedai.osx.broker.eggroll; import com.google.common.collect.Lists; import com.google.common.collect.Maps; -import com.osx.core.constant.Dict; +import org.fedai.osx.core.constant.Dict; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/SerdesTypes.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/SerdesTypes.java similarity index 94% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/SerdesTypes.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/SerdesTypes.java index 3a2b0d2f73..661f1ce97c 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/eggroll/SerdesTypes.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/SerdesTypes.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.eggroll; +package org.fedai.osx.broker.eggroll; public enum SerdesTypes { diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/SessionCommands.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/SessionCommands.java similarity index 97% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/SessionCommands.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/SessionCommands.java index b79138fea4..9bc09dadcd 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/eggroll/SessionCommands.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/SessionCommands.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.eggroll; +package org.fedai.osx.broker.eggroll; public class SessionCommands { static String prefix = "v1/cluster-manager/session"; diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/SessionConfKeys.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/SessionConfKeys.java similarity index 95% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/SessionConfKeys.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/SessionConfKeys.java index ee840094ff..ed3f0a3766 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/eggroll/SessionConfKeys.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/SessionConfKeys.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.eggroll; +package org.fedai.osx.broker.eggroll; public class SessionConfKeys { diff --git a/java/osx/broker/src/main/java/com/osx/broker/eggroll/SessionStatus.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/SessionStatus.java similarity index 94% rename from java/osx/broker/src/main/java/com/osx/broker/eggroll/SessionStatus.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/SessionStatus.java index b09188b610..f1e003447e 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/eggroll/SessionStatus.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/SessionStatus.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.eggroll; +package org.fedai.osx.broker.eggroll; public enum SessionStatus { NEW, ACTIVE, CLOSED, KILLED, ERROR diff --git a/java/osx/broker/src/main/java/com/osx/broker/flow/ClusterMetricStatistics.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/flow/ClusterMetricStatistics.java similarity index 92% rename from java/osx/broker/src/main/java/com/osx/broker/flow/ClusterMetricStatistics.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/flow/ClusterMetricStatistics.java index 0bcb3920ab..dca634e6e2 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/flow/ClusterMetricStatistics.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/flow/ClusterMetricStatistics.java @@ -1,8 +1,8 @@ -package com.osx.broker.flow; +package org.fedai.osx.broker.flow; -import com.osx.core.flow.ClusterMetric; -import com.osx.core.utils.AssertUtil; +import org.fedai.osx.core.flow.ClusterMetric; +import org.fedai.osx.core.utils.AssertUtil; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; diff --git a/java/osx/broker/src/main/java/com/osx/broker/flow/ClusterRuleUtil.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/flow/ClusterRuleUtil.java similarity index 82% rename from java/osx/broker/src/main/java/com/osx/broker/flow/ClusterRuleUtil.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/flow/ClusterRuleUtil.java index 5b89cceb66..f7303f5f2f 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/flow/ClusterRuleUtil.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/flow/ClusterRuleUtil.java @@ -1,4 +1,4 @@ -package com.osx.broker.flow; +package org.fedai.osx.broker.flow; public final class ClusterRuleUtil { diff --git a/java/osx/broker/src/main/java/com/osx/broker/grpc/ContextPrepareInterceptor.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/grpc/ContextPrepareInterceptor.java similarity index 95% rename from java/osx/broker/src/main/java/com/osx/broker/grpc/ContextPrepareInterceptor.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/grpc/ContextPrepareInterceptor.java index d475a053de..5dab96dc3f 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/grpc/ContextPrepareInterceptor.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/grpc/ContextPrepareInterceptor.java @@ -1,4 +1,4 @@ -package com.osx.broker.grpc; +package org.fedai.osx.broker.grpc; import io.grpc.*; @@ -12,6 +12,7 @@ public ServerCall.Listener interceptCall(ServerCall transport( public void invoke(Osx.Inbound request, io.grpc.stub.StreamObserver responseObserver) { - + DebugUtil.printGrpcParams(request); Map metaDataMap = request.getMetadataMap(); String techProviderCode = metaDataMap.get(Osx.Header.TechProviderCode.name()); TechProvider techProvider = ServiceContainer.techProviderRegister.select(techProviderCode); @@ -76,6 +78,7 @@ private void init(Osx.Inbound inbound) { String techProviderCode = metaDataMap.get(Osx.Header.TechProviderCode.name()); techProvider = ServiceContainer.techProviderRegister.select(techProviderCode); if (techProvider != null) { + DebugUtil.printGrpcParams(inbound); requestObserver = techProvider.processGrpcTransport(inbound, responseObserver); } else { //抛出异常 diff --git a/java/osx/broker/src/main/java/com/osx/broker/grpc/ProxyGrpcService.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/grpc/ProxyGrpcService.java similarity index 70% rename from java/osx/broker/src/main/java/com/osx/broker/grpc/ProxyGrpcService.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/grpc/ProxyGrpcService.java index 0ebb3f1b0c..d2011eb052 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/grpc/ProxyGrpcService.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/grpc/ProxyGrpcService.java @@ -13,17 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.grpc; +package org.fedai.osx.broker.grpc; -import com.osx.broker.service.PushService; -import com.osx.broker.service.UnaryCallService; -import com.osx.broker.util.ContextUtil; -import com.osx.core.context.Context; -import com.osx.core.service.InboundPackage; -import com.osx.core.service.OutboundPackage; import com.webank.ai.eggroll.api.networking.proxy.DataTransferServiceGrpc; import com.webank.ai.eggroll.api.networking.proxy.Proxy; import io.grpc.stub.StreamObserver; +import org.fedai.osx.api.constants.Protocol; +import org.fedai.osx.broker.interceptor.RouterInterceptor; +import org.fedai.osx.broker.interceptor.UnaryCallHandleInterceptor; +import org.fedai.osx.broker.service.PushService; +import org.fedai.osx.broker.service.UnaryCallService; +import org.fedai.osx.broker.util.ContextUtil; +import org.fedai.osx.core.context.FateContext; +import org.fedai.osx.core.service.InboundPackage; +import org.fedai.osx.core.service.OutboundPackage; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -32,21 +35,22 @@ public class ProxyGrpcService extends DataTransferServiceGrpc.DataTransferServic Logger logger = LoggerFactory.getLogger(ProxyGrpcService.class); UnaryCallService unaryCallService; PushService pushService; - public ProxyGrpcService(PushService pushService, - UnaryCallService unaryCallService + public ProxyGrpcService( ) { - this.pushService = pushService; - this.unaryCallService = unaryCallService; + this.pushService = new PushService(); + this.unaryCallService =new UnaryCallService(); + unaryCallService .addPreProcessor(new UnaryCallHandleInterceptor()). + addPreProcessor(new RouterInterceptor()); + } public io.grpc.stub.StreamObserver push( io.grpc.stub.StreamObserver responseObserver) { try { - Context context = ContextUtil.buildContext(); - InboundPackage data = new InboundPackage<>(); - PushRequestDataWrap pushRequestDataWrap = new PushRequestDataWrap(); - pushRequestDataWrap.setStreamObserver(responseObserver); - data.setBody(pushRequestDataWrap); + FateContext context = ContextUtil.buildFateContext(Protocol.grpc); + context.setNeedPrintFlowLog(false); + InboundPackage data = new InboundPackage<>(); + data.setBody(responseObserver); OutboundPackage outboundPackage = pushService.service(context, data); return outboundPackage.getData(); } catch (Exception e) { @@ -58,7 +62,7 @@ public io.grpc.stub.StreamObserver responseObserver) { - Context context = ContextUtil.buildContext(); + FateContext context = ContextUtil.buildFateContext(Protocol.grpc); InboundPackage data = new InboundPackage<>(); data.setBody(request); context.setDataSize(request.getSerializedSize()); diff --git a/java/osx/broker/src/main/java/com/osx/broker/grpc/PullRequestDataWrap.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/grpc/PullRequestDataWrap.java similarity index 97% rename from java/osx/broker/src/main/java/com/osx/broker/grpc/PullRequestDataWrap.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/grpc/PullRequestDataWrap.java index 52c1ee38b4..c49801f0c4 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/grpc/PullRequestDataWrap.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/grpc/PullRequestDataWrap.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.grpc; +package org.fedai.osx.broker.grpc; import com.webank.ai.eggroll.api.networking.proxy.Proxy; import io.grpc.stub.StreamObserver; diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/grpc/PushRequestDataWrap.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/grpc/PushRequestDataWrap.java new file mode 100644 index 0000000000..88e38bc3ec --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/grpc/PushRequestDataWrap.java @@ -0,0 +1,40 @@ +///* +// * Copyright 2019 The FATE Authors. All Rights Reserved. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"); +// * you may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +//package org.fedai.osx.broker.grpc; +// +//import com.webank.ai.eggroll.api.networking.proxy.Proxy; +//import io.grpc.stub.StreamObserver; +// +//public class PushRequestDataWrap { +// Proxy.Packet packet; +// StreamObserver streamObserver; +// +// public Proxy.Packet getPacket() { +// return packet; +// } +// +// public void setPacket(Proxy.Packet packet) { +// this.packet = packet; +// } +// +// public StreamObserver getStreamObserver() { +// return streamObserver; +// } +// +// public void setStreamObserver(StreamObserver streamObserver) { +// this.streamObserver = streamObserver; +// } +//} diff --git a/java/osx/broker/src/main/java/com/osx/broker/grpc/QueuePushReqStreamObserver.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/grpc/QueuePushReqStreamObserver.java similarity index 68% rename from java/osx/broker/src/main/java/com/osx/broker/grpc/QueuePushReqStreamObserver.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/grpc/QueuePushReqStreamObserver.java index 9c84a3587e..42ac28b8e4 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/grpc/QueuePushReqStreamObserver.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/grpc/QueuePushReqStreamObserver.java @@ -13,27 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.grpc; +package org.fedai.osx.broker.grpc; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; -import com.osx.broker.ServiceContainer; -import com.osx.broker.eggroll.*; -import com.osx.broker.ptp.PtpForwardPushRespSO; -import com.osx.broker.util.TransferUtil; -import com.osx.core.config.MetaInfo; -import com.osx.core.constant.ActionType; -import com.osx.core.constant.Dict; -import com.osx.core.constant.TransferStatus; -import com.osx.core.context.Context; -import com.osx.core.exceptions.*; -import com.osx.core.frame.GrpcConnectionFactory; -import com.osx.core.ptp.TargetMethod; -import com.osx.core.router.RouterInfo; -import com.osx.core.utils.FlowLogUtil; -import com.osx.core.utils.ToStringUtils; import com.webank.ai.eggroll.api.networking.proxy.DataTransferServiceGrpc; import com.webank.ai.eggroll.api.networking.proxy.Proxy; import com.webank.eggroll.core.command.Command; @@ -43,6 +28,24 @@ import io.grpc.ManagedChannel; import io.grpc.stub.StreamObserver; import org.apache.commons.lang3.StringUtils; +import org.fedai.osx.api.constants.Protocol; +import org.fedai.osx.api.context.Context; +import org.fedai.osx.api.router.RouterInfo; +import org.fedai.osx.broker.eggroll.*; +import org.fedai.osx.broker.ptp.PtpForwardPushRespSO; +import org.fedai.osx.broker.router.RouterService; +import org.fedai.osx.broker.util.TransferUtil; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.constant.ActionType; +import org.fedai.osx.core.constant.Dict; +import org.fedai.osx.core.constant.TransferStatus; +import org.fedai.osx.core.context.FateContext; +import org.fedai.osx.core.exceptions.*; +import org.fedai.osx.core.frame.GrpcConnectionFactory; +import org.fedai.osx.core.ptp.SourceMethod; +import org.fedai.osx.core.ptp.TargetMethod; +import org.fedai.osx.core.utils.FlowLogUtil; +import org.fedai.osx.core.utils.ToStringUtils; import org.ppc.ptp.Osx; import org.ppc.ptp.PrivateTransferProtocolGrpc; import org.slf4j.Logger; @@ -58,7 +61,7 @@ public class QueuePushReqStreamObserver implements StreamObserver static public ConcurrentHashMap queueIdMap = new ConcurrentHashMap<>(); static AtomicInteger seq = new AtomicInteger(0); Logger logger = LoggerFactory.getLogger(QueuePushReqStreamObserver.class); - Context context; + FateContext context; ErRollSiteHeader rsHeader = null; TransferStatus transferStatus = TransferStatus.INIT; CountDownLatch finishLatch = new CountDownLatch(1); @@ -73,16 +76,20 @@ public class QueuePushReqStreamObserver implements StreamObserver private Class backRespSOClass; private String transferId; private Integer queueId; + private RouterService routerService; - public QueuePushReqStreamObserver(Context context, StreamObserver backRespSO, + public QueuePushReqStreamObserver(Context context,RouterService routerService, StreamObserver backRespSO, Class backRespSOClass ) { + this.context =(FateContext) context; + this.routerService = routerService; this.backRespSOClass = backRespSOClass; this.backRespSO = backRespSO; - this.context = context.subContext(); - this.context.setNeedPrintFlowLog(true); + //this.context = context.subContext(); + //this.context.setNeedPrintFlowLog(true); this.context.setServiceName("pushTransfer"); + } public StreamObserver getForwardPushReqSO() { @@ -95,12 +102,12 @@ public void setForwardPushReqSO(StreamObserver forwardPushReqSO) { public void init(Proxy.Packet packet) throws Exception { + TransferUtil.assableContextFromProxyPacket(context,packet); Proxy.Metadata metadata = packet.getHeader(); - String desPartyId = metadata.getDst().getPartyId(); - String srcPartyId = metadata.getSrc().getPartyId(); + String desPartyId = context.getDesPartyId(); + String srcPartyId = context.getSrcPartyId(); ByteString encodedRollSiteHeader = metadata.getExt(); rsHeader = ErRollSiteHeader.parseFromPb(Transfer.RollSiteHeader.parseFrom(encodedRollSiteHeader)); - Integer partitionId = rsHeader.getPartitionId(); brokerTag = "putBatch-" + rsHeader.getRsKey("#", "__rsk") + "-" + partitionId; context.setSessionId(rsHeader.getRollSiteSessionId()); @@ -113,10 +120,9 @@ public void init(Proxy.Packet packet) throws Exception { * 检查目的地是否为自己 */ if (!isDst) { - routerInfo = ServiceContainer.fateRouterService.route(packet); + routerInfo =routerService.route(context.getSrcPartyId(),context.getSrcComponent(),context.getDesPartyId(),context.getDesComponent()); if (routerInfo != null) { this.transferId = routerInfo.getResource(); - } else { throw new NoRouterInfoException("no router"); } @@ -129,48 +135,51 @@ public void init(Proxy.Packet packet) throws Exception { context.setRouterInfo(routerInfo); context.setSrcPartyId(routerInfo.getSourcePartyId()); context.setDesPartyId(routerInfo.getDesPartyId()); - ManagedChannel managedChannel = GrpcConnectionFactory.createManagedChannel(context.getRouterInfo(),true); - if (TransferUtil.isOldVersionFate(routerInfo.getVersion())) { - DataTransferServiceGrpc.DataTransferServiceStub stub = DataTransferServiceGrpc.newStub(managedChannel); - ForwardPushRespSO forwardPushRespSO = new ForwardPushRespSO(context, backRespSO,backRespSOClass, () -> { - finishLatch.countDown(); - }, (t) -> { - finishLatch.countDown(); - }); - forwardPushReqSO = stub.push(forwardPushRespSO); - } else { - PtpForwardPushRespSO ptpForwardPushRespSO = new PtpForwardPushRespSO(context, backRespSO, backRespSOClass, () -> { - finishLatch.countDown(); - }, (t) -> { - finishLatch.countDown(); - }); - - PrivateTransferProtocolGrpc.PrivateTransferProtocolStub stub = PrivateTransferProtocolGrpc.newStub(managedChannel); - - StreamObserver ptpForwardPushReqSO = stub.transport(ptpForwardPushRespSO); - - forwardPushReqSO = new StreamObserver() { - @Override - public void onNext(Proxy.Packet packet) { - Osx.Inbound inbound = TransferUtil.buildInboundFromPushingPacket(packet, TargetMethod.PUSH.name()); - ptpForwardPushReqSO.onNext(inbound); - } - @Override - public void onError(Throwable throwable) { - ptpForwardPushReqSO.onError(throwable); - } - @Override - public void onCompleted() { - ptpForwardPushReqSO.onCompleted(); - } - }; + if (routerInfo.getProtocol().equals(Protocol.http)) { + //由本方发起的传输且使用队列替代流式传输,需要在本地建立接受应答的队列, + forwardPushReqSO = QueueStreamBuilder.createStreamFromOrigin(context, backRespSO, Proxy.Packet.parser(), + routerInfo, srcPartyId, desPartyId, rsHeader.getRollSiteSessionId(),finishLatch); + } else { + ManagedChannel managedChannel = GrpcConnectionFactory.createManagedChannel(context.getRouterInfo(), true); + if (TransferUtil.isOldVersionFate(routerInfo.getVersion())) { + DataTransferServiceGrpc.DataTransferServiceStub stub = DataTransferServiceGrpc.newStub(managedChannel); + ForwardPushRespSO forwardPushRespSO = new ForwardPushRespSO(context, backRespSO, backRespSOClass, () -> { + finishLatch.countDown(); + }, (t) -> { + finishLatch.countDown(); + }); + forwardPushReqSO = stub.push(forwardPushRespSO); + } else { + PtpForwardPushRespSO ptpForwardPushRespSO = new PtpForwardPushRespSO(context, backRespSO, backRespSOClass, () -> { + finishLatch.countDown(); + }, (t) -> { + finishLatch.countDown(); + }); + PrivateTransferProtocolGrpc.PrivateTransferProtocolStub stub = PrivateTransferProtocolGrpc.newStub(managedChannel); + StreamObserver ptpForwardPushReqSO = stub.transport(ptpForwardPushRespSO); + forwardPushReqSO = new StreamObserver() { + @Override + public void onNext(Proxy.Packet packet) { + Osx.Inbound inbound = TransferUtil.buildInboundFromPushingPacket(packet, MetaInfo.PROPERTY_FATE_TECH_PROVIDER, TargetMethod.PUSH.name(), SourceMethod.PUSH.name()).build(); + ptpForwardPushReqSO.onNext(inbound); + } + + @Override + public void onError(Throwable throwable) { + ptpForwardPushReqSO.onError(throwable); + } + + @Override + public void onCompleted() { + ptpForwardPushReqSO.onCompleted(); + } + }; + } } } transferStatus = TransferStatus.TRANSFERING; - - } private void initEggroll(Proxy.Packet firstRequest) { @@ -191,7 +200,7 @@ private void initEggroll(Proxy.Packet firstRequest) { logger.error("get session error ", e); } if (!SessionStatus.ACTIVE.name().equals(session.getErSessionMeta().getStatus())) { - IllegalStateException error = new IllegalStateException("session=${sessionId} with illegal status. expected=${SessionStatus.ACTIVE}, actual=${session.sessionMeta.status}"); + SessionInitException error = new SessionInitException("eggroll session "+sessionId+" invalid status : "+session.getErSessionMeta().getStatus()); onError(error); throw error; } @@ -242,8 +251,8 @@ private void initEggroll(Proxy.Packet firstRequest) { ErTask erTask = ErTask.parseFromPb(taskMeta); long now = System.currentTimeMillis(); return erTask; - } catch (InvalidProtocolBufferException e) { - e.printStackTrace(); + } catch (InvalidProtocolBufferException igore) { + } return null; }); @@ -283,14 +292,13 @@ public void onNext(Proxy.Packet value) { forwardPushReqSO.onNext(value); } } - - } catch (Exception e) { - logger.error("push error", e); + logger.error("push error1", e); ExceptionInfo exceptionInfo = ErrorMessageUtil.handleExceptionExceptionInfo(context, e); context.setException(e); context.setReturnCode(exceptionInfo.getCode()); - throw new BaseException(exceptionInfo.getCode(), exceptionInfo.getMessage()); + throw ErrorMessageUtil.toGrpcRuntimeException(e); + } finally { FlowLogUtil.printFlowLog(context); } @@ -310,28 +318,11 @@ public void onError(Throwable t) { * 2.销毁队列 */ if (isDst) { - //transferQueue.onError(t); - - putBatchSinkPushReqSO.onError(t); } else { - -// if(MetaInfo.PROPERTY_USE_QUEUE_MODEL){ -// if(transferQueue!=null){ -// AbstractServiceAdaptor.ExceptionInfo exceptionInfo = new AbstractServiceAdaptor.ExceptionInfo(); -// exceptionInfo.setMessage(t.getMessage()); -// exceptionInfo.setThrowable(t); -// MessageExtBrokerInner messageExtBrokerInner = MessageDecoder.buildMessageExtBrokerInner(transferId,exceptionInfo.toString().getBytes(StandardCharsets.UTF_8), -// queueId,MessageFlag.ERROR,routerInfo.getSourcePartyId(),routerInfo.getDesPartyId()); -// transferQueue.putMessage(messageExtBrokerInner); -// } -// }else - - { if (forwardPushReqSO != null) { forwardPushReqSO.onError(t); } - } } @@ -342,26 +333,11 @@ public void onCompleted() { logger.info("transferId {} receive completed", transferId); if (isDst) { -// if(transferQueue!=null) { -// transferQueue.setWriteOver(true); -// } if (putBatchSinkPushReqSO != null) { putBatchSinkPushReqSO.onCompleted(); } } else { - if (forwardPushReqSO != null) { - -// if(MetaInfo.PROPERTY_USE_QUEUE_MODEL){ -// /** -// * 由pushConsumer去通知,因为要保证顺序,保证之前的数据传递完,所以只能放在队列最后串行执行 -// */ -// MessageExtBrokerInner messageExtBrokerInner = MessageDecoder.buildMessageExtBrokerInner(transferId,null,queueId,MessageFlag.COMPELETED, -// routerInfo.getSourcePartyId(),routerInfo.getDesPartyId()); -// PutMessageResult putMessageResult = transferQueue.putMessage(messageExtBrokerInner); -// }else - - { forwardPushReqSO.onCompleted(); try { if (!finishLatch.await(MetaInfo.PROPERTY_GRPC_ONCOMPLETED_WAIT_TIMEOUT, TimeUnit.SECONDS)) { @@ -373,12 +349,7 @@ public void onCompleted() { needPrintFlow = false; } } - } -// if(needPrintFlow){ -// context.setActionType("push"); -// context.printFlowLog(); -// } - logger.info("receive completed !!!!"); + } } diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/grpc/QueueStreamBuilder.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/grpc/QueueStreamBuilder.java new file mode 100644 index 0000000000..626d2d7d33 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/grpc/QueueStreamBuilder.java @@ -0,0 +1,130 @@ +package org.fedai.osx.broker.grpc; + +import com.google.protobuf.AbstractMessage; +import com.google.protobuf.Parser; +import io.grpc.stub.StreamObserver; +import org.fedai.osx.api.router.RouterInfo; +import org.fedai.osx.broker.ServiceContainer; +import org.fedai.osx.broker.constants.MessageFlag; +import org.fedai.osx.broker.consumer.SourceGrpcEventHandler; +import org.fedai.osx.broker.queue.CreateQueueResult; +import org.fedai.osx.broker.queue.TransferQueue; +import org.fedai.osx.broker.util.TransferUtil; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.constant.ActionType; +import org.fedai.osx.core.constant.Dict; +import org.fedai.osx.core.context.FateContext; +import org.fedai.osx.core.exceptions.ErrorMessageUtil; +import org.fedai.osx.core.exceptions.ExceptionInfo; +import org.fedai.osx.core.exceptions.RemoteRpcException; +import org.fedai.osx.core.ptp.TargetMethod; +import org.fedai.osx.core.utils.JsonUtil; +import org.ppc.ptp.Osx; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.charset.StandardCharsets; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicInteger; + + +public class QueueStreamBuilder { + + + ConcurrentHashMap backRegister = new ConcurrentHashMap() ; + + + + /** + * 在流的开端调用 + * @param respStreamObserver + * @param parser + * @param srcPartyId + * @param desPartyId + * @param sessionId + * @return + */ + + private static AtomicInteger count= new AtomicInteger(0); + + private static Logger logger = LoggerFactory.getLogger(QueueStreamBuilder.class); + public static StreamObserver createStreamFromOrigin(FateContext context , + StreamObserver respStreamObserver, + Parser parser, + RouterInfo routerInfo, + String srcPartyId, + String desPartyId, + String sessionId, + CountDownLatch countDownLatch + ){ + + //String uuid = UUID.randomUUID().toString(); + int temp = count.addAndGet(1); + long now = System.currentTimeMillis(); + //srcPartyId+"_"+desPartyId + String backTopic = Dict.STREAM_BACK_TOPIC_PREFIX +now+ "_"+sessionId+"_"+temp; + String sendTopic = Dict.STREAM_SEND_TOPIC_PREFIX +now+"_"+sessionId+"_"+temp; + context.setTopic(sendTopic); + context.setActionType(ActionType.MSG_REDIRECT.getAlias()); + CreateQueueResult createQueueResult = ServiceContainer.transferQueueManager.createNewQueue(backTopic, sessionId, true); + if (createQueueResult.getTransferQueue() == null) { + throw new RemoteRpcException("create queue error"); + } + TransferQueue answerQueue = createQueueResult.getTransferQueue(); + ServiceContainer.consumerManager.createEventDrivenConsumer(backTopic,new SourceGrpcEventHandler(respStreamObserver,parser)); + StreamObserver forwardPushReqSO = new StreamObserver() { + + @Override + public void onNext(AbstractMessage message) { + try { + context.setMessageFlag(MessageFlag.SENDMSG.name()); + Osx.Inbound.Builder inboundBuilder = TransferUtil.buildInbound(MetaInfo.PROPERTY_FATE_TECH_PROVIDER, srcPartyId, desPartyId, TargetMethod.PRODUCE_MSG.name(), sendTopic, MessageFlag.SENDMSG, sessionId, message.toByteArray()); + Osx.Outbound outbound = TransferUtil.redirect(context, inboundBuilder.build(), routerInfo, true); + TransferUtil.checkResponse(outbound); + }catch(Exception e){ + throw ErrorMessageUtil.toGrpcRuntimeException(e); + } + } + + @Override + public void onError(Throwable throwable) { + try { + context.setMessageFlag(MessageFlag.ERROR.name()); + ExceptionInfo exceptionInfo = new ExceptionInfo(); + exceptionInfo.setMessage(throwable.getMessage()); + String errorData = JsonUtil.object2Json(exceptionInfo); + Osx.Inbound.Builder inboundBuilder = TransferUtil.buildInbound(MetaInfo.PROPERTY_FATE_TECH_PROVIDER, srcPartyId, desPartyId, TargetMethod.PRODUCE_MSG.name(), + sendTopic, MessageFlag.ERROR, sessionId, errorData.getBytes(StandardCharsets.UTF_8)) + .putMetadata(Osx.Metadata.MessageFlag.name(), MessageFlag.ERROR.name()); + Osx.Outbound outbound = TransferUtil.redirect(context, inboundBuilder.build(), routerInfo, true); + TransferUtil.checkResponse(outbound); + countDownLatch.countDown(); + }catch (Exception e){ + throw ErrorMessageUtil.toGrpcRuntimeException(e); + } + } + + @Override + public void onCompleted() { + try { + context.setMessageFlag(MessageFlag.COMPELETED.name()); + Osx.Inbound.Builder inboundBuilder = TransferUtil.buildInbound(MetaInfo.PROPERTY_FATE_TECH_PROVIDER, srcPartyId, desPartyId, TargetMethod.PRODUCE_MSG.name(), + sendTopic, MessageFlag.COMPELETED, sessionId, "completed".getBytes(StandardCharsets.UTF_8)) + .putMetadata(Osx.Metadata.MessageFlag.name(), MessageFlag.COMPELETED.name()); + Osx.Outbound outbound = TransferUtil.redirect(context, inboundBuilder.build(), routerInfo, true); + + TransferUtil.checkResponse(outbound); + countDownLatch.countDown(); + + } catch (Exception e) { + throw ErrorMessageUtil.toGrpcRuntimeException(e); + } + } + }; + return forwardPushReqSO; + + }; + + +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/grpc/ServiceExceptionHandler.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/grpc/ServiceExceptionHandler.java similarity index 98% rename from java/osx/broker/src/main/java/com/osx/broker/grpc/ServiceExceptionHandler.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/grpc/ServiceExceptionHandler.java index a0e0d378b5..37cbcbd642 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/grpc/ServiceExceptionHandler.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/grpc/ServiceExceptionHandler.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.grpc; +package org.fedai.osx.broker.grpc; import io.grpc.*; import org.slf4j.Logger; diff --git a/java/osx/broker/src/main/java/com/osx/broker/http/DispatchServlet.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/http/DispatchServlet.java similarity index 62% rename from java/osx/broker/src/main/java/com/osx/broker/http/DispatchServlet.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/http/DispatchServlet.java index 2e42728670..a03954c1de 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/http/DispatchServlet.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/http/DispatchServlet.java @@ -13,14 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.http; +package org.fedai.osx.broker.http; -import com.osx.broker.ServiceContainer; -import com.osx.core.constant.PtpHttpHeader; -import com.osx.core.provider.TechProvider; -import com.osx.tech.provider.TechProviderRegister; import org.apache.commons.lang3.StringUtils; -import org.eclipse.jetty.http.HttpHeader; +import org.fedai.osx.broker.ServiceContainer; +import org.fedai.osx.broker.util.DebugUtil; +import org.fedai.osx.core.constant.PtpHttpHeader; +import org.fedai.osx.core.provider.TechProvider; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -33,52 +32,51 @@ public class DispatchServlet extends HttpServlet { Logger logger = LoggerFactory.getLogger(DispatchServlet.class); + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + //处理get请求 + DebugUtil.printHttpParams(req); String protocol = req.getProtocol(); if (!protocol.endsWith("1.1")) { resp.sendError(405, "http.method_get_not_supported"); } - String techProviderCode =req.getHeader(PtpHttpHeader.TechProviderCode); - if(StringUtils.isNotEmpty(techProviderCode)){ + String techProviderCode = req.getHeader(PtpHttpHeader.TechProviderCode); + if (StringUtils.isNotEmpty(techProviderCode)) { TechProvider techProvider = ServiceContainer.techProviderRegister.select(techProviderCode); - if(techProvider!=null) { + if (techProvider != null) { techProvider.processHttpInvoke(req, resp); - }else{ - resp.sendError(404,"tech-provider-code invalid"); + } else { + resp.sendError(404, "tech-provider-code invalid"); } - }else{ - resp.sendError(404,"tech-provider-code invalid"); + } else { + resp.sendError(404, "tech-provider-code invalid"); } - String requestUri =req.getRequestURI(); - logger.info("receive request uri {}",requestUri); + String requestUri = req.getRequestURI(); + logger.info("receive request uri {}", requestUri); } protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { - String requestUri =req.getRequestURI(); - logger.info("receive request uri {}",requestUri); + //处理post请求 + DebugUtil.printHttpParams(req); + String requestUri = req.getRequestURI(); + //logger.info("receive request uri {}",requestUri); String protocol = req.getProtocol(); if (!protocol.endsWith("1.1")) { resp.sendError(405, "http.method_get_not_supported"); } - String techProviderCode =req.getHeader(PtpHttpHeader.TechProviderCode); - if(StringUtils.isNotEmpty(techProviderCode)){ + String techProviderCode = req.getHeader(PtpHttpHeader.TechProviderCode); + if (StringUtils.isNotEmpty(techProviderCode)) { TechProvider techProvider = ServiceContainer.techProviderRegister.select(techProviderCode); - if(techProvider!=null) { + if (techProvider != null) { techProvider.processHttpInvoke(req, resp); - }else{ - resp.sendError(404,"tech-provider-code invalid"); + } else { + resp.sendError(404, "tech-provider-code invalid"); } - }else{ - resp.sendError(404,"tech-provider-code invalid"); + } else { + resp.sendError(404, "tech-provider-code invalid"); } - - - - - } - } diff --git a/java/osx/broker/src/main/java/com/osx/broker/http/HttpClientPool.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/http/HttpClientPool.java similarity index 74% rename from java/osx/broker/src/main/java/com/osx/broker/http/HttpClientPool.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/http/HttpClientPool.java index 3dc5f24bfe..711bcb942c 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/http/HttpClientPool.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/http/HttpClientPool.java @@ -14,14 +14,11 @@ * limitations under the License. */ -package com.osx.broker.http; +package org.fedai.osx.broker.http; import com.google.common.collect.Maps; import com.google.protobuf.ByteString; -import com.osx.core.config.MetaInfo; -import com.osx.core.constant.Dict; -import com.osx.core.constant.PtpHttpHeader; -import com.osx.core.utils.JsonUtil; +import org.apache.commons.lang3.ObjectUtils; import org.apache.http.Header; import org.apache.http.HttpEntity; import org.apache.http.client.config.RequestConfig; @@ -44,6 +41,10 @@ import org.apache.http.impl.conn.PoolingHttpClientConnectionManager; import org.apache.http.ssl.SSLContextBuilder; import org.apache.http.util.EntityUtils; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.constant.Dict; +import org.fedai.osx.core.constant.PtpHttpHeader; +import org.fedai.osx.core.utils.JsonUtil; import org.ppc.ptp.Osx; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -58,14 +59,27 @@ public class HttpClientPool { private static final Logger logger = LoggerFactory.getLogger(HttpClientPool.class); private static PoolingHttpClientConnectionManager poolConnManager; - private static RequestConfig requestConfig; private static CloseableHttpClient httpClient; - private static void config(HttpRequestBase httpRequestBase, Map headers) { + static void config(HttpRequestBase httpRequestBase, Map headers) { + Integer reqTimeout = null; + Integer connectionTimeout = null; + Integer socketTimeout = null; + + if (MetaInfo.PROPERTY_HTTP_CLIENT_METHOD_CONFIG_MAP != null) { + Map methodConfig = MetaInfo.PROPERTY_HTTP_CLIENT_METHOD_CONFIG_MAP.get(headers.get(PtpHttpHeader.SourceMethod)); + if (methodConfig != null) { + reqTimeout = methodConfig.get(Dict.METHOD_CONFIG_REQ_TIMEOUT); + connectionTimeout = methodConfig.get(Dict.METHOD_CONFIG_CONNECTION_TIMEOUT); + socketTimeout = methodConfig.get(Dict.METHOD_CONFIG_SOCKET_TIMEOUT); + + } + } + RequestConfig requestConfig = RequestConfig.custom() - .setConnectionRequestTimeout(MetaInfo.HTTP_CLIENT_CONFIG_CONN_REQ_TIME_OUT) - .setConnectTimeout(MetaInfo.HTTP_CLIENT_CONFIG_CONN_TIME_OUT) - .setSocketTimeout(MetaInfo.HTTP_CLIENT_CONFIG_SOCK_TIME_OUT).build(); + .setConnectionRequestTimeout(ObjectUtils.firstNonNull(reqTimeout, MetaInfo.PROPERTY_HTTP_CLIENT_CONFIG_CONN_REQ_TIME_OUT)) + .setConnectTimeout(ObjectUtils.firstNonNull(connectionTimeout, MetaInfo.PROPERTY_HTTP_CLIENT_CONFIG_CONN_TIME_OUT)) + .setSocketTimeout(ObjectUtils.firstNonNull(socketTimeout, MetaInfo.PROPERTY_HTTP_CLIENT_CONFIG_SOCK_TIME_OUT)).build(); httpRequestBase.addHeader(Dict.CONTENT_TYPE, Dict.CONTENT_TYPE_JSON_UTF8); if (headers != null) { headers.forEach((key, value) -> { @@ -85,14 +99,8 @@ public static void initPool() { Dict.HTTPS, sslsf).build(); poolConnManager = new PoolingHttpClientConnectionManager( socketFactoryRegistry); - poolConnManager.setMaxTotal(MetaInfo.HTTP_CLIENT_INIT_POOL_MAX_TOTAL); - poolConnManager.setDefaultMaxPerRoute(MetaInfo.HTTP_CLIENT_INIT_POOL_DEF_MAX_PER_ROUTE); - int socketTimeout = MetaInfo.HTTP_CLIENT_INIT_POOL_SOCK_TIME_OUT; - int connectTimeout = MetaInfo.HTTP_CLIENT_INIT_POOL_CONN_TIME_OUT; - int connectionRequestTimeout = MetaInfo.HTTP_CLIENT_INIT_POOL_CONN_REQ_TIME_OUT; - requestConfig = RequestConfig.custom().setConnectionRequestTimeout( - connectionRequestTimeout).setSocketTimeout(socketTimeout).setConnectTimeout( - connectTimeout).build(); + poolConnManager.setMaxTotal(MetaInfo.PROPERTY_HTTP_CLIENT_INIT_POOL_MAX_TOTAL); + poolConnManager.setDefaultMaxPerRoute(MetaInfo.PROPERTY_HTTP_CLIENT_INIT_POOL_DEF_MAX_PER_ROUTE); httpClient = createConnection(); } catch (NoSuchAlgorithmException | KeyStoreException | KeyManagementException ex) { logger.error("init http client pool failed:", ex); @@ -103,16 +111,21 @@ public static CloseableHttpClient getConnection() { } public static CloseableHttpClient createConnection() { + RequestConfig requestConfig = RequestConfig.custom() + .setConnectionRequestTimeout(MetaInfo.PROPERTY_HTTP_CLIENT_CONFIG_CONN_REQ_TIME_OUT) + .setConnectTimeout(MetaInfo.PROPERTY_HTTP_CLIENT_CONFIG_CONN_TIME_OUT) + .setSocketTimeout(MetaInfo.PROPERTY_HTTP_CLIENT_CONFIG_SOCK_TIME_OUT).build(); CloseableHttpClient httpClient = HttpClients.custom() .setConnectionManager(poolConnManager) .setDefaultRequestConfig(requestConfig) .evictExpiredConnections() - .evictIdleConnections(5, TimeUnit.SECONDS) + .evictIdleConnections(MetaInfo.PROPERTY_HTTP_CLIENT_MAX_IDLE_TIME, TimeUnit.SECONDS) .setRetryHandler(new DefaultHttpRequestRetryHandler(0, false)) .build(); return httpClient; } public static Osx.Outbound sendPtpPost(String url, byte[] body, Map headers) { + HttpPost httpPost = new HttpPost(url); config(httpPost, headers); if(body!=null) { @@ -145,10 +158,8 @@ public static String sendGet(String url, Map headers) { private static String getResponse(HttpRequestBase request) { CloseableHttpResponse response = null; try { - response = httpClient.execute(request, - HttpClientContext.create()); + response = httpClient.execute(request, HttpClientContext.create()); HttpEntity entity = response.getEntity(); - String result = EntityUtils.toString(entity, Dict.CHARSET_UTF8); EntityUtils.consume(entity); return result; @@ -166,15 +177,12 @@ private static String getResponse(HttpRequestBase request) { } } - - private static Osx.Outbound getPtpHttpResponse(HttpRequestBase request) { Osx.Outbound.Builder outboundBuilder = Osx.Outbound.newBuilder(); CloseableHttpResponse response = null; try { - response = httpClient.execute(request, - HttpClientContext.create()); + response = httpClient.execute(request, HttpClientContext.create()); HttpEntity entity = response.getEntity(); byte[] payload = EntityUtils.toByteArray(entity); Header[] headers = response.getAllHeaders(); @@ -187,8 +195,11 @@ private static Osx.Outbound getPtpHttpResponse(HttpRequestBase request) { } if(payload!=null) outboundBuilder.setPayload(ByteString.copyFrom(payload)); - if(headMap.get(PtpHttpHeader.ReturnCode)!=null) + if(headMap.get(PtpHttpHeader.ReturnCode)!=null){ outboundBuilder.setCode(headMap.get(PtpHttpHeader.ReturnCode)); + }else{ + logger.error("========kaideng test ,http respose has no return code {}",headers); + }; if(headMap.get(PtpHttpHeader.ReturnMessage)!=null) outboundBuilder.setMessage(headMap.get(PtpHttpHeader.ReturnMessage)); @@ -196,6 +207,7 @@ private static Osx.Outbound getPtpHttpResponse(HttpRequestBase request) { return outboundBuilder.build(); } catch (IOException ex) { logger.error("get http response failed:", ex); + ex.printStackTrace(); return null; } finally { try { @@ -211,9 +223,9 @@ private static Osx.Outbound getPtpHttpResponse(HttpRequestBase request) { public static String transferPost(String url, Map requestData) { HttpPost httpPost = new HttpPost(url); RequestConfig requestConfig = RequestConfig.custom() - .setConnectionRequestTimeout(MetaInfo.HTTP_CLIENT_TRAN_CONN_REQ_TIME_OUT) - .setConnectTimeout(MetaInfo.HTTP_CLIENT_TRAN_CONN_TIME_OUT) - .setSocketTimeout(MetaInfo.HTTP_CLIENT_TRAN_SOCK_TIME_OUT).build(); + .setConnectionRequestTimeout(MetaInfo.PROPERTY_HTTP_CLIENT_CONFIG_CONN_REQ_TIME_OUT) + .setConnectTimeout(MetaInfo.PROPERTY_HTTP_CLIENT_CONFIG_CONN_TIME_OUT) + .setSocketTimeout(MetaInfo.PROPERTY_HTTP_CLIENT_CONFIG_SOCK_TIME_OUT).build(); httpPost.addHeader(Dict.CONTENT_TYPE, Dict.CONTENT_TYPE_JSON_UTF8); httpPost.setConfig(requestConfig); StringEntity stringEntity = new StringEntity(JsonUtil.object2Json(requestData), Dict.CHARSET_UTF8); diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/http/HttpsClientPool.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/http/HttpsClientPool.java new file mode 100644 index 0000000000..5bff4beee5 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/http/HttpsClientPool.java @@ -0,0 +1,208 @@ +package org.fedai.osx.broker.http; + +import com.google.common.collect.Maps; +import com.google.protobuf.ByteString; +import org.apache.http.Header; +import org.apache.http.HttpEntity; +import org.apache.http.client.config.RequestConfig; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpRequestBase; +import org.apache.http.client.protocol.HttpClientContext; +import org.apache.http.config.Registry; +import org.apache.http.config.RegistryBuilder; +import org.apache.http.conn.socket.ConnectionSocketFactory; +import org.apache.http.conn.socket.PlainConnectionSocketFactory; +import org.apache.http.conn.ssl.SSLConnectionSocketFactory; +import org.apache.http.conn.ssl.TrustSelfSignedStrategy; +import org.apache.http.entity.ByteArrayEntity; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.DefaultHttpRequestRetryHandler; +import org.apache.http.impl.client.HttpClients; +import org.apache.http.impl.conn.PoolingHttpClientConnectionManager; +import org.apache.http.ssl.SSLContextBuilder; +import org.apache.http.util.EntityUtils; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.constant.Dict; +import org.fedai.osx.core.constant.PtpHttpHeader; +import org.fedai.osx.core.utils.OSXCertUtils; +import org.fedai.osx.core.utils.OsxX509TrustManager; +import org.ppc.ptp.Osx; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManager; +import java.io.IOException; +import java.security.*; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +public class HttpsClientPool { + private static final Logger logger = LoggerFactory.getLogger(HttpsClientPool.class); + private static final Map httpsClientPool = new HashMap<>(); + + public static CloseableHttpClient getConnection(String caPath, String clientCertPath, String clientKeyPath) throws Exception { + String certKey = buildCertKey(caPath, clientCertPath, clientKeyPath); + CloseableHttpClient httpClient = httpsClientPool.get(certKey); + if (httpClient == null) { + httpClient = createConnection(caPath, clientCertPath, clientKeyPath); + httpsClientPool.put(certKey, httpClient); + } + return httpClient; + } + + private static String buildCertKey(String caPath, String clientCertPath, String clientKeyPath) { + return caPath + "_" + clientCertPath + "_" + clientKeyPath; + } + + public static CloseableHttpClient createConnection(String caPath, String clientCertPath, String clientKeyPath) throws Exception { + RequestConfig requestConfig = RequestConfig.custom() + .setConnectionRequestTimeout(MetaInfo.PROPERTY_HTTP_CLIENT_CONFIG_CONN_REQ_TIME_OUT) + .setConnectTimeout(MetaInfo.PROPERTY_HTTP_CLIENT_CONFIG_CONN_TIME_OUT) + .setSocketTimeout(MetaInfo.PROPERTY_HTTP_CLIENT_CONFIG_SOCK_TIME_OUT).build(); + CloseableHttpClient httpClient = null; + try { + SSLContextBuilder builder = new SSLContextBuilder(); + builder.loadTrustMaterial(null, new TrustSelfSignedStrategy()); + SSLConnectionSocketFactory sslsf; + if (MetaInfo.PROPERTY_HTTP_SSL_HOSTNAME_VERIFY) { + sslsf = new SSLConnectionSocketFactory(OSXCertUtils.getSSLContext(caPath, clientCertPath, clientKeyPath)); + } else { + sslsf = new SSLConnectionSocketFactory(OSXCertUtils.getSSLContext(caPath, clientCertPath, clientKeyPath), OsxX509TrustManager.HostnameVerifier2.getInstance()); + } + Registry socketFactoryRegistry = RegistryBuilder.create().register( + Dict.HTTP, PlainConnectionSocketFactory.getSocketFactory()).register( + Dict.HTTPS, sslsf).build(); + PoolingHttpClientConnectionManager poolConnManager = new PoolingHttpClientConnectionManager( + socketFactoryRegistry); + poolConnManager.setMaxTotal(MetaInfo.PROPERTY_HTTP_CLIENT_INIT_POOL_MAX_TOTAL); + poolConnManager.setDefaultMaxPerRoute(MetaInfo.PROPERTY_HTTP_CLIENT_INIT_POOL_DEF_MAX_PER_ROUTE); + httpClient = HttpClients.custom() + .setSSLSocketFactory(sslsf) + .setConnectionManager(poolConnManager) + .setDefaultRequestConfig(requestConfig) + .evictExpiredConnections() + .evictIdleConnections(MetaInfo.PROPERTY_HTTP_CLIENT_MAX_IDLE_TIME, TimeUnit.SECONDS) + .setRetryHandler(new DefaultHttpRequestRetryHandler(0, false)) + .build(); + } catch (NoSuchAlgorithmException | KeyStoreException | KeyManagementException ex) { + logger.error("init https client pool failed:", ex); + } + return httpClient; + } + + public static Osx.Outbound sendPtpPost(String url, byte[] body, Map headers, String caPath, String clientCertPath, String clientKeyPath) throws Exception { + + HttpPost httpPost = new HttpPost(url); + HttpClientPool.config(httpPost, headers); + if (body != null) { + ByteArrayEntity byteArrayEntity = new ByteArrayEntity(body); + httpPost.setEntity(byteArrayEntity); + } + return getPtpHttpsResponse(httpPost, caPath, clientCertPath, clientKeyPath); + } + + @SuppressWarnings("unused") + public static String sendPost(String url, byte[] body, Map headers, String caPath, String clientCertPath, String clientKeyPath) { + HttpPost httpPost = new HttpPost(url); + HttpClientPool.config(httpPost, headers); + ByteArrayEntity byteArrayEntity = new ByteArrayEntity(body); + httpPost.setEntity(byteArrayEntity); + return getResponse(httpPost, caPath, clientCertPath, clientKeyPath); + } + + public static String get(String url, Map headers, String caPath, String clientCertPath, String clientKeyPath) { + return sendGet(url, headers, caPath, clientCertPath, clientKeyPath); + } + + public static String get(String url, String caPath, String clientCertPath, String clientKeyPath) { + return sendGet(url, null, caPath, clientCertPath, clientKeyPath); + } + + public static String sendGet(String url, Map headers, String caPath, String clientCertPath, String clientKeyPath) { + HttpGet httpGet = new HttpGet(url); + HttpClientPool.config(httpGet, headers); + return getResponse(httpGet, caPath, clientCertPath, clientKeyPath); + } + + private static String getResponse(HttpRequestBase request, String caPath, String clientCertPath, String clientKeyPath) { + CloseableHttpResponse response = null; + try { + response = getConnection(caPath, clientCertPath, clientKeyPath).execute(request, HttpClientContext.create()); + HttpEntity entity = response.getEntity(); + String result = EntityUtils.toString(entity, Dict.CHARSET_UTF8); + EntityUtils.consume(entity); + return result; + } catch (Exception ex) { + logger.error("get https response failed:", ex); + return null; + } finally { + try { + if (response != null) { + response.close(); + } + } catch (IOException ex) { + logger.error("get https response failed:", ex); + } + } + } + + private static Osx.Outbound getPtpHttpsResponse(HttpRequestBase request, String caPath, String clientCertPath, String clientKeyPath) throws Exception { + Osx.Outbound.Builder outboundBuilder = Osx.Outbound.newBuilder(); + CloseableHttpResponse response = null; + try { + response = getConnection(caPath, clientCertPath, clientKeyPath).execute(request, HttpClientContext.create()); + HttpEntity entity = response.getEntity(); + byte[] payload = EntityUtils.toByteArray(entity); + Header[] headers = response.getAllHeaders(); + Map headMap = Maps.newHashMap(); + if (headers != null) { + for (Header temp : headers) { + headMap.put(temp.getName(), temp.getValue()); + } + } + if (payload != null) + outboundBuilder.setPayload(ByteString.copyFrom(payload)); + if (headMap.get(PtpHttpHeader.ReturnCode) != null) + outboundBuilder.setCode(headMap.get(PtpHttpHeader.ReturnCode)); + if (headMap.get(PtpHttpHeader.ReturnMessage) != null) + outboundBuilder.setMessage(headMap.get(PtpHttpHeader.ReturnMessage)); + + EntityUtils.consume(entity); + return outboundBuilder.build(); + } catch (IOException ex) { + logger.error("get https response failed:", ex); + ex.printStackTrace(); + throw ex; + } finally { + try { + if (response != null) { + response.close(); + } + } catch (IOException ex) { + logger.error("get https response failed:", ex); + } + } + } + + @SuppressWarnings("unused") + private static SSLSocketFactory getSslFactory(String caPath, String clientCertPath, String clientKeyPath) throws Exception { + KeyStore keyStore = OSXCertUtils.getKeyStore(caPath, clientCertPath, clientKeyPath); + // Initialize the ssl context object + SSLContext sslContext = SSLContext.getInstance("SSL"); + TrustManager[] tm = {OsxX509TrustManager.getInstance(keyStore)}; + // Load client certificate + KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); + kmf.init(keyStore, MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD.toCharArray()); + sslContext.init(kmf.getKeyManagers(), tm, new SecureRandom()); + // Initialize the factory + return sslContext.getSocketFactory(); + } + + +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/http/PtpHttpResponse.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/http/PtpHttpResponse.java similarity index 95% rename from java/osx/broker/src/main/java/com/osx/broker/http/PtpHttpResponse.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/http/PtpHttpResponse.java index 70b07fc46a..0398c17983 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/http/PtpHttpResponse.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/http/PtpHttpResponse.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.http; +package org.fedai.osx.broker.http; import lombok.Data; diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/interceptor/PcpHandleInterceptor.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/interceptor/PcpHandleInterceptor.java new file mode 100644 index 0000000000..34cd215c9a --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/interceptor/PcpHandleInterceptor.java @@ -0,0 +1,35 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.fedai.osx.broker.interceptor; + +import org.fedai.osx.api.context.Context; +import org.fedai.osx.broker.util.TransferUtil; +import org.fedai.osx.core.service.InboundPackage; +import org.fedai.osx.core.service.Interceptor; +import org.fedai.osx.core.service.OutboundPackage; +import org.ppc.ptp.Osx; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class PcpHandleInterceptor implements Interceptor { + Logger logger = LoggerFactory.getLogger(PcpHandleInterceptor.class); + + @Override + public void doProcess(Context context, InboundPackage inboundPackage, OutboundPackage outboundPackage) { + Osx.Inbound inbound = inboundPackage.getBody(); + TransferUtil.assableContextFromInbound(context,inbound); + } +} diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/interceptor/PushHandleInterceptor.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/interceptor/PushHandleInterceptor.java new file mode 100644 index 0000000000..6b5005175c --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/interceptor/PushHandleInterceptor.java @@ -0,0 +1,19 @@ +//package org.fedai.osx.broker.interceptor; +// +//import org.fedai.osx.broker.grpc.PushRequestDataWrap; +//import org.fedai.osx.core.context.Context; +//import org.fedai.osx.core.service.InboundPackage; +//import org.fedai.osx.core.service.Interceptor; +//import com.webank.ai.eggroll.api.networking.proxy.Proxy; +// +//import static org.fedai.osx.broker.util.TransferUtil.assableContextFromProxyPacket; +// +//public class PushHandleInterceptor implements Interceptor { +// +// public void doPreProcess(Context context, InboundPackage inboundPackage) throws Exception { +// PushRequestDataWrap pushRequestDataWrap =inboundPackage.getBody(); +// Proxy.Packet packet = pushRequestDataWrap.getPacket(); +//// assableContextFromProxyPacket(context ,packet); +// } +// +//} diff --git a/java/osx/broker/src/main/java/com/osx/broker/interceptor/RouterInterceptor.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/interceptor/RouterInterceptor.java similarity index 53% rename from java/osx/broker/src/main/java/com/osx/broker/interceptor/RouterInterceptor.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/interceptor/RouterInterceptor.java index 038f508dc8..4f24575f0c 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/interceptor/RouterInterceptor.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/interceptor/RouterInterceptor.java @@ -13,40 +13,39 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.interceptor; -import com.osx.broker.router.FateRouterService; -import com.osx.core.context.Context; -import com.osx.core.router.RouterInfo; -import com.osx.core.service.InboundPackage; -import com.osx.core.service.Interceptor; -import com.osx.core.service.OutboundPackage; -import org.ppc.ptp.Osx; +package org.fedai.osx.broker.interceptor; + +import org.fedai.osx.api.context.Context; +import org.fedai.osx.api.router.RouterInfo; +import org.fedai.osx.broker.ServiceContainer; +import org.fedai.osx.broker.router.FateRouterService; +import org.fedai.osx.broker.router.RouterService; +import org.fedai.osx.core.service.InboundPackage; +import org.fedai.osx.core.service.Interceptor; +import org.fedai.osx.core.service.OutboundPackage; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class RouterInterceptor implements Interceptor { Logger logger = LoggerFactory.getLogger(RouterInterceptor.class); - - public RouterInterceptor(FateRouterService fateRouterService){ + public RouterInterceptor(){ this.fateRouterService = fateRouterService; } FateRouterService fateRouterService; - - @Override - public void doPreProcess(Context context, InboundPackage inboundPackage) throws Exception { - + public void doProcess(Context context, InboundPackage inboundPackage, OutboundPackage outboundPackage) throws Exception { + String routerKey = buildRouterKey(context); + RouterService routerService = ServiceContainer.routerRegister.getRouterService(routerKey); String sourcePartyId = context.getSrcPartyId(); String desPartyId = context.getDesPartyId(); String sourceComponentName = context.getSrcComponent(); String desComponentName = context.getDesComponent(); - RouterInfo routerInfo = fateRouterService.route(sourcePartyId,sourceComponentName,desPartyId,desComponentName); - logger.info("============== {} {} {} {} ============",sourcePartyId,sourceComponentName,desPartyId,desComponentName); - if(logger.isDebugEnabled()) { - logger.debug("RouterInterceptor return {}", routerInfo); - } + RouterInfo routerInfo = routerService.route(sourcePartyId,sourceComponentName,desPartyId,desComponentName); +// logger.info("router===================={} =============={}",routerService,routerInfo); context.setRouterInfo(routerInfo); - + } + private String buildRouterKey (Context context){ + return context.getTechProviderCode(); } } diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/interceptor/TokenValidatorInterceptor.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/interceptor/TokenValidatorInterceptor.java new file mode 100644 index 0000000000..d2701754c2 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/interceptor/TokenValidatorInterceptor.java @@ -0,0 +1,28 @@ +package org.fedai.osx.broker.interceptor; + +import org.fedai.osx.api.context.Context; +import org.fedai.osx.broker.ServiceContainer; +import org.fedai.osx.broker.security.TokenValidator; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.constant.Dict; +import org.fedai.osx.core.service.InboundPackage; +import org.fedai.osx.core.service.Interceptor; +import org.fedai.osx.core.service.OutboundPackage; + +public class TokenValidatorInterceptor implements Interceptor { + + @Override + public void doProcess(Context context, InboundPackage inboundPackage, OutboundPackage outboundPackage) throws Exception { + if (MetaInfo.PROPERTY_OPEN_TOKEN_VALIDATOR) { + TokenValidator tokenValidator = ServiceContainer.tokenValidatorRegister.getTokenValidator(getValidatorKey(context), Dict.DEFAULT); + if (tokenValidator != null) { + tokenValidator.validate(context, context.getToken()); + } + } + } + + private String getValidatorKey(Context context) { + String srcPartyId = context.getSrcPartyId(); + return srcPartyId; + } +} diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/interceptor/UnaryCallHandleInterceptor.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/interceptor/UnaryCallHandleInterceptor.java new file mode 100644 index 0000000000..87c338eb13 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/interceptor/UnaryCallHandleInterceptor.java @@ -0,0 +1,18 @@ +package org.fedai.osx.broker.interceptor; + + +import com.webank.ai.eggroll.api.networking.proxy.Proxy; +import org.fedai.osx.api.context.Context; +import org.fedai.osx.broker.util.TransferUtil; +import org.fedai.osx.core.service.InboundPackage; +import org.fedai.osx.core.service.Interceptor; +import org.fedai.osx.core.service.OutboundPackage; + +public class UnaryCallHandleInterceptor implements Interceptor { + + @Override + public void doProcess(Context context, InboundPackage inboundPackage, OutboundPackage outboundPackage) throws Exception { + Proxy.Packet packet = inboundPackage.getBody(); + TransferUtil.assableContextFromProxyPacket(context, packet); + } +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/message/AllocateMappedFileService.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/AllocateMappedFileService.java similarity index 98% rename from java/osx/broker/src/main/java/com/osx/broker/message/AllocateMappedFileService.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/AllocateMappedFileService.java index 5f930d163e..34f309d6c9 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/message/AllocateMappedFileService.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/AllocateMappedFileService.java @@ -13,12 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.message; +package org.fedai.osx.broker.message; -import com.osx.broker.queue.MappedFile; -import com.osx.broker.util.UtilAll; -import com.osx.core.frame.ServiceThread; +import org.fedai.osx.broker.queue.MappedFile; +import org.fedai.osx.broker.util.UtilAll; +import org.fedai.osx.core.frame.ServiceThread; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/java/osx/broker/src/main/java/com/osx/broker/message/AppendMessageHandler.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/AppendMessageHandler.java similarity index 95% rename from java/osx/broker/src/main/java/com/osx/broker/message/AppendMessageHandler.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/AppendMessageHandler.java index 25364d7316..52c55bebf1 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/message/AppendMessageHandler.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/AppendMessageHandler.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.message; +package org.fedai.osx.broker.message; import java.nio.ByteBuffer; diff --git a/java/osx/broker/src/main/java/com/osx/broker/message/AppendMessageResult.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/AppendMessageResult.java similarity index 98% rename from java/osx/broker/src/main/java/com/osx/broker/message/AppendMessageResult.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/AppendMessageResult.java index 9312afc352..09f8fcfd78 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/message/AppendMessageResult.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/AppendMessageResult.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.message; +package org.fedai.osx.broker.message; /** * When write a message to the commit log, returns results diff --git a/java/osx/broker/src/main/java/com/osx/broker/message/AppendMessageStatus.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/AppendMessageStatus.java similarity index 95% rename from java/osx/broker/src/main/java/com/osx/broker/message/AppendMessageStatus.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/AppendMessageStatus.java index cad4c35bf8..2b40337d4f 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/message/AppendMessageStatus.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/AppendMessageStatus.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.message; +package org.fedai.osx.broker.message; public enum AppendMessageStatus { PUT_OK, diff --git a/java/osx/broker/src/main/java/com/osx/broker/message/DefaultAppendMessageHandler.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/DefaultAppendMessageHandler.java similarity index 65% rename from java/osx/broker/src/main/java/com/osx/broker/message/DefaultAppendMessageHandler.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/DefaultAppendMessageHandler.java index ec97d5724d..870a5bb418 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/message/DefaultAppendMessageHandler.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/DefaultAppendMessageHandler.java @@ -13,55 +13,36 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.message; +package org.fedai.osx.broker.message; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; public class DefaultAppendMessageHandler implements AppendMessageHandler { // File at the end of the minimum fixed length empty private static final int END_FILE_MIN_BLANK_LENGTH = 4 + 4; - private final ByteBuffer msgIdMemory; - private final ByteBuffer msgIdV6Memory; // Store the message content private final ByteBuffer msgStoreItemMemory; // The maximum length of the message private final int maxMessageSize; - // Build Message Key - private final StringBuilder keyBuilder = new StringBuilder(); - private final StringBuilder msgIdBuilder = new StringBuilder(); Logger log = LoggerFactory.getLogger(DefaultAppendMessageHandler.class); public DefaultAppendMessageHandler(final int size) { - this.msgIdMemory = ByteBuffer.allocate(4 + 4 + 8); - this.msgIdV6Memory = ByteBuffer.allocate(16 + 4 + 8); + this.msgStoreItemMemory = ByteBuffer.allocate(size + END_FILE_MIN_BLANK_LENGTH); this.maxMessageSize = size; } - protected static int calMsgLength(int sysFlag, int srcPartyIdLength, int desPartyIdLength, int bodyLength, int topicLength, int propertiesLength) { - int bornhostLength = (sysFlag & MessageSysFlag.BORNHOST_V6_FLAG) == 0 ? 8 : 20; - int storehostAddressLength = (sysFlag & MessageSysFlag.STOREHOSTADDRESS_V6_FLAG) == 0 ? 8 : 20; final int msgLen = 4 //TOTALSIZE - + 4 //MAGICCODE - + 4 //BODYCRC - + 4 //QUEUEID + 4 //FLAG - //+ 8 //QUEUEOFFSET + 1 + (srcPartyIdLength > 0 ? srcPartyIdLength : 0) - - // + 8 //PHYSICALOFFSET + 1 + (desPartyIdLength > 0 ? desPartyIdLength : 0) + 4 //SYSFLAG + 8 //BORNTIMESTAMP - // + bornhostLength //BORNHOST - // + 8 //STORETIMESTAMP - // + storehostAddressLength //STOREHOSTADDRESS - // + 4 //RECONSUMETIMES - // + 8 //Prepared Transaction Offset + 4 + (bodyLength > 0 ? bodyLength : 0) //BODY + 2 + topicLength //TOPIC + 2 + (propertiesLength > 0 ? propertiesLength : 0) //propertiesLength @@ -81,94 +62,53 @@ public AppendMessageResult doAppend(final long fileFromOffset, final ByteBuffer String msgId = Long.toString(wroteOffset); Long queueOffset = new Long(0); final byte[] propertiesData = - msgInner.getPropertiesString() == null ? null : msgInner.getPropertiesString().getBytes(MessageDecoder.CHARSET_UTF8); - - final int propertiesLength = propertiesData == null ? 0 : propertiesData.length; + msgInner.getProperties()==null ? null : MessageDecoder.messageProperties2String (msgInner.getProperties()).getBytes(StandardCharsets.UTF_8); + final int propertiesLength = propertiesData==null? 0 : propertiesData.length; if (propertiesLength > Short.MAX_VALUE) { log.warn("putMessage message properties length too long. length={}", propertiesData.length); return new AppendMessageResult(AppendMessageStatus.PROPERTIES_SIZE_EXCEEDED); } - final byte[] topicData = msgInner.getTopic().getBytes(MessageDecoder.CHARSET_UTF8); - - final byte[] srcPartyId = - msgInner.getSrcPartyId() == null ? null : msgInner.getSrcPartyId().getBytes(MessageDecoder.CHARSET_UTF8); + final byte[] srcPartyId = msgInner.getSrcPartyId() == null ? null : msgInner.getSrcPartyId().getBytes(MessageDecoder.CHARSET_UTF8); final int srcPartyIdLength = srcPartyId != null ? srcPartyId.length : 0; - - final byte[] desPartyId = - msgInner.getDesPartyId() == null ? null : msgInner.getDesPartyId().getBytes(MessageDecoder.CHARSET_UTF8); + final byte[] desPartyId = msgInner.getDesPartyId() == null ? null : msgInner.getDesPartyId().getBytes(MessageDecoder.CHARSET_UTF8); final int desPartyIdLength = desPartyId != null ? desPartyId.length : 0; - - final int topicLength = topicData.length; - final int bodyLength = msgInner.getBody() == null ? 0 : msgInner.getBody().length; - final int msgLen = calMsgLength(msgInner.getSysFlag(), srcPartyIdLength, desPartyIdLength, bodyLength, topicLength, propertiesLength); - // Exceeds the maximum message if (msgLen > this.maxMessageSize) { + log.error("msg length {} bigger than {}",msgLen,this.maxMessageSize); return new AppendMessageResult(AppendMessageStatus.MESSAGE_SIZE_EXCEEDED); } - // Determines whether there is sufficient free space if ((msgLen + END_FILE_MIN_BLANK_LENGTH) > maxBlank) { this.resetByteBuffer(this.msgStoreItemMemory, maxBlank); // 1 TOTALSIZE this.msgStoreItemMemory.putInt(maxBlank); -// // 2 MAGICCODE -// this.msgStoreItemMemory.putInt(1111); - // 3 The remaining space may be any value - // Here the length of the specially set maxBlank final long beginTimeMills = System.currentTimeMillis(); byteBuffer.put(this.msgStoreItemMemory.array(), 0, maxBlank); return new AppendMessageResult(AppendMessageStatus.END_OF_FILE, wroteOffset, maxBlank, msgId, msgInner.getStoreTimestamp(), queueOffset, System.currentTimeMillis() - beginTimeMills); } - // Initialization of storage space this.resetByteBuffer(msgStoreItemMemory, msgLen); - // 1 TOTALSIZE this.msgStoreItemMemory.putInt(msgLen); - // log.info("msgLen {}",msgLen); - // 2 MAGICCODE - this.msgStoreItemMemory.putInt(1000); - // 3 BODYCRC - this.msgStoreItemMemory.putInt(msgInner.getBodyCRC()); - // 4 QUEUEID - this.msgStoreItemMemory.putInt(msgInner.getQueueId()); // 5 FLAG this.msgStoreItemMemory.putInt(msgInner.getFlag()); // 6 QUEUEOFFSET - this.msgStoreItemMemory.put((byte) srcPartyIdLength); if (srcPartyId != null) this.msgStoreItemMemory.put(srcPartyId); - this.msgStoreItemMemory.put((byte) desPartyIdLength); if (desPartyId != null) this.msgStoreItemMemory.put(desPartyId); - - // this.msgStoreItemMemory.putLong(fileFromOffset + byteBuffer.position()); // 8 SYSFLAG this.msgStoreItemMemory.putInt(msgInner.getSysFlag()); // 9 BORNTIMESTAMP this.msgStoreItemMemory.putLong(msgInner.getBornTimestamp()); -// // 10 BORNHOST -// this.resetByteBuffer(bornHostHolder, bornHostLength); -// this.msgStoreItemMemory.put(msgInner.getBornHostBytes(bornHostHolder)); -// // 11 STORETIMESTAMP -// this.msgStoreItemMemory.putLong(msgInner.getStoreTimestamp()); -// // 12 STOREHOSTADDRESS -// this.resetByteBuffer(storeHostHolder, storeHostLength); -// this.msgStoreItemMemory.put(msgInner.getStoreHostBytes(storeHostHolder)); -// // 13 RECONSUMETIMES -// this.msgStoreItemMemory.putInt(msgInner.getReconsumeTimes()); -// // 14 Prepared Transaction Offset -// this.msgStoreItemMemory.putLong(msgInner.getPreparedTransactionOffset()); - // 15 BODY this.msgStoreItemMemory.putInt(bodyLength); if (bodyLength > 0) this.msgStoreItemMemory.put(msgInner.getBody()); @@ -177,13 +117,10 @@ public AppendMessageResult doAppend(final long fileFromOffset, final ByteBuffer this.msgStoreItemMemory.put(topicData); // 17 PROPERTIES this.msgStoreItemMemory.putShort((short) propertiesLength); - if (propertiesLength > 0) + if (propertiesLength > 0) { this.msgStoreItemMemory.put(propertiesData); - - final long beginTimeMills = System.currentTimeMillis(); - // Write messages to the queue buffer + } byteBuffer.put(this.msgStoreItemMemory.array(), 0, msgLen); - AppendMessageResult result = new AppendMessageResult(AppendMessageStatus.PUT_OK, wroteOffset, msgLen, msgId, msgInner.getStoreTimestamp(), queueOffset, 0); return result; @@ -193,6 +130,4 @@ private void resetByteBuffer(final ByteBuffer byteBuffer, final int limit) { byteBuffer.flip(); byteBuffer.limit(limit); } - - } \ No newline at end of file diff --git a/java/osx/broker/src/main/java/com/osx/broker/message/Message.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/Message.java similarity index 98% rename from java/osx/broker/src/main/java/com/osx/broker/message/Message.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/Message.java index 58d8d22ece..84ac360cb4 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/message/Message.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/Message.java @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.message; +package org.fedai.osx.broker.message; -import com.osx.broker.util.MessageConst; +import org.fedai.osx.broker.util.MessageConst; import java.io.Serializable; import java.util.Arrays; diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/MessageDecoder.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/MessageDecoder.java new file mode 100644 index 0000000000..0ef023f3c2 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/MessageDecoder.java @@ -0,0 +1,406 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.fedai.osx.broker.message; + +import com.google.common.collect.Maps; +import org.fedai.osx.broker.constants.MessageFlag; +import org.fedai.osx.broker.util.UtilAll; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; +import java.nio.charset.Charset; +import java.util.HashMap; +import java.util.Map; + +public class MessageDecoder { + + + static Logger logger = LoggerFactory.getLogger(MessageDecoder.class); + + public final static Charset CHARSET_UTF8 = Charset.forName("UTF-8"); + public final static int MESSAGE_MAGIC_CODE_POSTION = 4; + public final static int MESSAGE_FLAG_POSTION = 16; + public final static int MESSAGE_PHYSIC_OFFSET_POSTION = 28; + // public final static int MESSAGE_STORE_TIMESTAMP_POSTION = 56; + public final static int MESSAGE_MAGIC_CODE = -626843481; + public static final char NAME_VALUE_SEPARATOR = 1; + public static final char PROPERTY_SEPARATOR = 2; + public static final int PHY_POS_POSITION = 4 + 4 + 4 + 4 + 4 + 8; + public static final int QUEUE_OFFSET_POSITION = 4 + 4 + 4 + 4 + 4; + public static final int SYSFLAG_POSITION = 4 + 4 + 4 + 4 + 4 + 8 + 8; + + + public static String createMessageId(final ByteBuffer input, final ByteBuffer addr, final long offset) { + input.flip(); + int msgIDLength = addr.limit() == 8 ? 16 : 28; + input.limit(msgIDLength); + + input.put(addr); + input.putLong(offset); + + return UtilAll.bytes2string(input.array()); + } + + public static MessageExtBrokerInner buildMessageExtBrokerInner(String topic, byte[] body, + String msgCode, MessageFlag flag, String srcPartyId, String desPartyId) { + MessageExtBrokerInner messageExtBrokerInner = new MessageExtBrokerInner(); + messageExtBrokerInner.setBody(body); + messageExtBrokerInner.setTopic(topic); + messageExtBrokerInner.setFlag(flag.getFlag()); + messageExtBrokerInner.setBornTimestamp(System.currentTimeMillis()); + messageExtBrokerInner.setDesPartyId(desPartyId); + messageExtBrokerInner.setSrcPartyId(srcPartyId); + messageExtBrokerInner.setProperties(Maps.newHashMap()); + messageExtBrokerInner.setMsgId(msgCode); + return messageExtBrokerInner; + } + +// public static String createMessageId(SocketAddress socketAddress, long transactionIdhashCode) { +// InetSocketAddress inetSocketAddress = (InetSocketAddress) socketAddress; +// int msgIDLength = inetSocketAddress.getAddress() instanceof Inet4Address ? 16 : 28; +// ByteBuffer byteBuffer = ByteBuffer.allocate(msgIDLength); +// byteBuffer.put(inetSocketAddress.getAddress().getAddress()); +// byteBuffer.putInt(inetSocketAddress.getPort()); +// byteBuffer.putLong(transactionIdhashCode); +// byteBuffer.flip(); +// return UtilAll.bytes2string(byteBuffer.array()); +// } + +// public static MessageId decodeMessageId(final String msgId) throws UnknownHostException { +// SocketAddress address; +// long offset; +// int ipLength = msgId.length() == 32 ? 4 * 2 : 16 * 2; +// +// byte[] ip = UtilAll.string2bytes(msgId.substring(0, ipLength)); +// byte[] port = UtilAll.string2bytes(msgId.substring(ipLength, ipLength + 8)); +// ByteBuffer bb = ByteBuffer.wrap(port); +// int portInt = bb.getInt(0); +// address = new InetSocketAddress(InetAddress.getByAddress(ip), portInt); +// +// // offset +// byte[] data = UtilAll.string2bytes(msgId.substring(ipLength + 8, ipLength + 8 + 16)); +// bb = ByteBuffer.wrap(data); +// offset = bb.getLong(0); +// +// return new MessageId(address, offset); +// } + + /** + * Just decode properties from msg buffer. + * + * @param byteBuffer msg commit log buffer. + */ +// public static Map decodeProperties(ByteBuffer byteBuffer) { +// int sysFlag = byteBuffer.getInt(SYSFLAG_POSITION); +// int bornhostLength = (sysFlag & MessageSysFlag.BORNHOST_V6_FLAG) == 0 ? 8 : 20; +// int storehostAddressLength = (sysFlag & MessageSysFlag.STOREHOSTADDRESS_V6_FLAG) == 0 ? 8 : 20; +// int bodySizePosition = 4 // 1 TOTALSIZE +// + 4 // 2 MAGICCODE +// + 4 // 3 BODYCRC +// + 4 // 4 QUEUEID +// + 4 // 5 FLAG +// + 8 // 6 QUEUEOFFSET +// + 8 // 7 PHYSICALOFFSET +// + 4 // 8 SYSFLAG +// + 8 // 9 BORNTIMESTAMP +// + bornhostLength // 10 BORNHOST +// + 8 // 11 STORETIMESTAMP +// + storehostAddressLength // 12 STOREHOSTADDRESS +// + 4 // 13 RECONSUMETIMES +// + 8; // 14 Prepared Transaction Offset +// int topicLengthPosition = bodySizePosition + 4 + byteBuffer.getInt(bodySizePosition); +// +// byte topicLength = byteBuffer.get(topicLengthPosition); +// +// short propertiesLength = byteBuffer.getShort(topicLengthPosition + 1 + topicLength); +// +// byteBuffer.position(topicLengthPosition + 1 + topicLength + 2); +// +// if (propertiesLength > 0) { +// byte[] properties = new byte[propertiesLength]; +// byteBuffer.get(properties); +// String propertiesString = new String(properties, CHARSET_UTF8); +// Map map = string2messageProperties(propertiesString); +// return map; +// } +// return null; +// } + + public static MessageExt decode(ByteBuffer byteBuffer) { + return decode(byteBuffer, true, true, false); + } + +// public static MessageExt clientDecode(ByteBuffer byteBuffer, final boolean readBody) { +// return decode(byteBuffer, readBody, true, true); +// } + + public static MessageExt decode(ByteBuffer byteBuffer, final boolean readBody) { + return decode(byteBuffer, readBody, true, false); + } + + public static MessageExt decode( + ByteBuffer byteBuffer, final boolean readBody, final boolean deCompressBody) { + return decode(byteBuffer, readBody, deCompressBody, false); + } + + public static MessageExt decode( + ByteBuffer byteBuffer, final boolean readBody, final boolean deCompressBody, final boolean isClient) { + try { + + MessageExt msgExt= new MessageExt(); + // 1 TOTALSIZE + int storeSize = byteBuffer.getInt(); + msgExt.setStoreSize(storeSize); + +// // 2 MAGICCODE +// byteBuffer.getInt(); +// +// // 3 BODYCRC +// int bodyCRC = byteBuffer.getInt(); +// msgExt.setBodyCRC(bodyCRC); +// +// // 4 QUEUEID +// int queueId = byteBuffer.getInt(); +// msgExt.setQueueId(queueId); + + // 5 FLAG + int flag = byteBuffer.getInt(); + msgExt.setFlag(flag); + + // 6 QUEUEOFFSET + int srcPartyIdLength = byteBuffer.get(); + if (srcPartyIdLength > 0) { + byte[] srcPartyBytes = new byte[srcPartyIdLength]; + byteBuffer.get(srcPartyBytes); + String srcPartyId = new String(srcPartyBytes); + msgExt.setSrcPartyId(srcPartyId); + } + +// long queueOffset = byteBuffer.getLong(); +// msgExt.setQueueOffset(queueOffset); + + // 7 PHYSICALOFFSET +// long physicOffset = byteBuffer.getLong(); +// msgExt.setCommitLogOffset(physicOffset); + + + int desPartyIdLength = byteBuffer.get(); + if (desPartyIdLength > 0) { + byte[] desPartyIdBytes = new byte[desPartyIdLength]; + byteBuffer.get(desPartyIdBytes); + String desPartyId = new String(desPartyIdBytes); + msgExt.setDesPartyId(desPartyId); + } + + + // 8 SYSFLAG + int sysFlag = byteBuffer.getInt(); + msgExt.setSysFlag(sysFlag); + + // 9 BORNTIMESTAMP + long bornTimeStamp = byteBuffer.getLong(); + msgExt.setBornTimestamp(bornTimeStamp); + + + // 15 BODY + int bodyLen = byteBuffer.getInt(); + if (bodyLen > 0) { + if (readBody) { + byte[] body = new byte[bodyLen]; + byteBuffer.get(body); + msgExt.setBody(body); + } else { + byteBuffer.position(byteBuffer.position() + bodyLen); + } + } + + // 16 TOPIC + short topicLen = byteBuffer.getShort(); + byte[] topic = new byte[(int) topicLen]; + byteBuffer.get(topic); + msgExt.setTopic(new String(topic, CHARSET_UTF8)); + + // 17 properties + short propertiesLength = byteBuffer.getShort(); + + if (propertiesLength > 0) { + byte[] properties = new byte[propertiesLength]; + byteBuffer.get(properties); + String propertiesString = new String(properties, CHARSET_UTF8); + Map map = string2messageProperties(propertiesString); + msgExt.setProperties(map); + + } + return msgExt; + } catch (Exception e) { + e.printStackTrace(); + byteBuffer.position(byteBuffer.limit()); + } + + return null; + } + +// public static List decodes(ByteBuffer byteBuffer) { +// return decodes(byteBuffer, true); +// } + +// public static List decodes(ByteBuffer byteBuffer, final boolean readBody) { +// List msgExts = new ArrayList(); +// while (byteBuffer.hasRemaining()) { +// MessageExt msgExt = clientDecode(byteBuffer, readBody); +// if (null != msgExt) { +// msgExts.add(msgExt); +// } else { +// break; +// } +// } +// return msgExts; +// } + + public static String messageProperties2String(Map properties) { + StringBuilder sb = new StringBuilder(); + if (properties != null) { + for (final Map.Entry entry : properties.entrySet()) { + final String name = entry.getKey(); + final String value = entry.getValue(); + + if (value == null) { + continue; + } + sb.append(name); + sb.append(NAME_VALUE_SEPARATOR); + sb.append(value); + sb.append(PROPERTY_SEPARATOR); + } + } + return sb.toString(); + } + + public static Map string2messageProperties(final String properties) { + Map map = new HashMap(); + if (properties != null) { + String[] items = properties.split(String.valueOf(PROPERTY_SEPARATOR)); + for (String i : items) { + String[] nv = i.split(String.valueOf(NAME_VALUE_SEPARATOR)); + if (2 == nv.length) { + map.put(nv[0], nv[1]); + } + } + } + + return map; + } + +// public static byte[] encodeMessage(Message message) { +// //only need flag, body, properties +// byte[] body = message.getBody(); +// int bodyLen = body.length; +// String properties = messageProperties2String(message.getProperties()); +// byte[] propertiesBytes = properties.getBytes(CHARSET_UTF8); +// //note properties length must not more than Short.MAX +// short propertiesLength = (short) propertiesBytes.length; +// int sysFlag = message.getFlag(); +// int storeSize = 4 // 1 TOTALSIZE +// + 4 // 2 MAGICCOD +// + 4 // 3 BODYCRC +// + 4 // 4 FLAG +// + 4 + bodyLen // 4 BODY +// + 2 + propertiesLength; +// ByteBuffer byteBuffer = ByteBuffer.allocate(storeSize); +// // 1 TOTALSIZE +// byteBuffer.putInt(storeSize); +// +// // 2 MAGICCODE +// byteBuffer.putInt(0); +// +// // 3 BODYCRC +// byteBuffer.putInt(0); +// +// // 4 FLAG +// int flag = message.getFlag(); +// byteBuffer.putInt(flag); +// +// // 5 BODY +// byteBuffer.putInt(bodyLen); +// byteBuffer.put(body); +// +// // 6 properties +// byteBuffer.putShort(propertiesLength); +// byteBuffer.put(propertiesBytes); +// +// return byteBuffer.array(); +// } + +// public static Message decodeMessage(ByteBuffer byteBuffer) throws Exception { +// Message message = new Message(); +// +// // 1 TOTALSIZE +// byteBuffer.getInt(); +// +// // 2 MAGICCODE +// byteBuffer.getInt(); +// +// // 3 BODYCRC +// byteBuffer.getInt(); +// +// // 4 FLAG +// int flag = byteBuffer.getInt(); +// message.setFlag(flag); +// +// // 5 BODY +// int bodyLen = byteBuffer.getInt(); +// byte[] body = new byte[bodyLen]; +// byteBuffer.get(body); +// message.setBody(body); +// +// // 6 properties +// short propertiesLen = byteBuffer.getShort(); +// byte[] propertiesBytes = new byte[propertiesLen]; +// byteBuffer.get(propertiesBytes); +// message.setProperties(string2messageProperties(new String(propertiesBytes, CHARSET_UTF8))); +// +// return message; +// } + +// public static byte[] encodeMessages(List messages) { +// //TO DO refactor, accumulate in one buffer, avoid copies +// List encodedMessages = new ArrayList(messages.size()); +// int allSize = 0; +// for (Message message : messages) { +// byte[] tmp = encodeMessage(message); +// encodedMessages.add(tmp); +// allSize += tmp.length; +// } +// byte[] allBytes = new byte[allSize]; +// int pos = 0; +// for (byte[] bytes : encodedMessages) { +// System.arraycopy(bytes, 0, allBytes, pos, bytes.length); +// pos += bytes.length; +// } +// return allBytes; +// } + +// public static List decodeMessages(ByteBuffer byteBuffer) throws Exception { +// //TO DO add a callback for processing, avoid creating lists +// List msgs = new ArrayList(); +// while (byteBuffer.hasRemaining()) { +// Message msg = decodeMessage(byteBuffer); +// msgs.add(msg); +// } +// return msgs; +// } +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/message/MessageExt.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/MessageExt.java similarity index 99% rename from java/osx/broker/src/main/java/com/osx/broker/message/MessageExt.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/MessageExt.java index 7bd15a4502..23879522b5 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/message/MessageExt.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/MessageExt.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.message; +package org.fedai.osx.broker.message; import java.net.Inet4Address; import java.net.InetAddress; @@ -24,16 +24,9 @@ public class MessageExt extends Message { private static final long serialVersionUID = 5720810158625748049L; - private String brokerName; - private int queueId; - private int storeSize; - - // private long queueOffset; - - private int sysFlag; private long bornTimestamp; private SocketAddress bornHost; diff --git a/java/osx/broker/src/main/java/com/osx/broker/message/MessageExtBrokerInner.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/MessageExtBrokerInner.java similarity index 82% rename from java/osx/broker/src/main/java/com/osx/broker/message/MessageExtBrokerInner.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/MessageExtBrokerInner.java index 5f42c56990..34fff0d178 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/message/MessageExtBrokerInner.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/MessageExtBrokerInner.java @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.message; +package org.fedai.osx.broker.message; public class MessageExtBrokerInner extends MessageExt { private static final long serialVersionUID = 7256001576878700634L; private String propertiesString; - private long tagsCode; + + + public String getPropertiesString() { return propertiesString; @@ -27,11 +29,4 @@ public void setPropertiesString(String propertiesString) { this.propertiesString = propertiesString; } - public long getTagsCode() { - return tagsCode; - } - - public void setTagsCode(long tagsCode) { - this.tagsCode = tagsCode; - } } diff --git a/java/osx/broker/src/main/java/com/osx/broker/message/MessageStoreConfig.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/MessageStoreConfig.java similarity index 99% rename from java/osx/broker/src/main/java/com/osx/broker/message/MessageStoreConfig.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/MessageStoreConfig.java index 783449fe8a..e79d0d9d5a 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/message/MessageStoreConfig.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/MessageStoreConfig.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.message; +package org.fedai.osx.broker.message; import java.io.File; diff --git a/java/osx/broker/src/main/java/com/osx/broker/message/MessageSysFlag.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/MessageSysFlag.java similarity index 97% rename from java/osx/broker/src/main/java/com/osx/broker/message/MessageSysFlag.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/MessageSysFlag.java index 41c88cd18e..9e04ab21c4 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/message/MessageSysFlag.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/MessageSysFlag.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.message; +package org.fedai.osx.broker.message; public class MessageSysFlag { public final static int COMPRESSED_FLAG = 0x1; diff --git a/java/osx/broker/src/main/java/com/osx/broker/message/MessageWraper.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/MessageWraper.java similarity index 97% rename from java/osx/broker/src/main/java/com/osx/broker/message/MessageWraper.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/MessageWraper.java index a401338699..b7d37e5d16 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/message/MessageWraper.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/MessageWraper.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.message; +package org.fedai.osx.broker.message; public class MessageWraper { diff --git a/java/osx/broker/src/main/java/com/osx/broker/message/SelectMappedBufferResult.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/SelectMappedBufferResult.java similarity index 95% rename from java/osx/broker/src/main/java/com/osx/broker/message/SelectMappedBufferResult.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/SelectMappedBufferResult.java index b25f0a0c0e..fc1fa33fb5 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/message/SelectMappedBufferResult.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/message/SelectMappedBufferResult.java @@ -14,9 +14,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.message; +package org.fedai.osx.broker.message; -import com.osx.broker.queue.MappedFile; +import org.fedai.osx.broker.queue.MappedFile; import java.nio.ByteBuffer; diff --git a/java/osx/broker/src/main/java/com/osx/broker/metric/ClusterMetricLeapArray.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/metric/ClusterMetricLeapArray.java similarity index 93% rename from java/osx/broker/src/main/java/com/osx/broker/metric/ClusterMetricLeapArray.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/metric/ClusterMetricLeapArray.java index 130d212ae5..f1996ce170 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/metric/ClusterMetricLeapArray.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/metric/ClusterMetricLeapArray.java @@ -13,13 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.metric; +package org.fedai.osx.broker.metric; -import com.osx.core.flow.ClusterFlowEvent; -import com.osx.core.flow.ClusterMetricBucket; -import com.osx.core.flow.LeapArray; -import com.osx.core.flow.WindowWrap; +import org.fedai.osx.core.flow.ClusterFlowEvent; +import org.fedai.osx.core.flow.ClusterMetricBucket; +import org.fedai.osx.core.flow.LeapArray; +import org.fedai.osx.core.flow.WindowWrap; import java.util.concurrent.atomic.LongAdder; diff --git a/java/osx/broker/src/main/java/com/osx/broker/ptp/AbstractPtpServiceAdaptor.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/AbstractPtpServiceAdaptor.java similarity index 72% rename from java/osx/broker/src/main/java/com/osx/broker/ptp/AbstractPtpServiceAdaptor.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/AbstractPtpServiceAdaptor.java index 1389b63760..9356e11a5f 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/ptp/AbstractPtpServiceAdaptor.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/AbstractPtpServiceAdaptor.java @@ -13,17 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.ptp; +package org.fedai.osx.broker.ptp; -import com.osx.core.context.Context; -import com.osx.core.exceptions.ExceptionInfo; -import com.osx.core.service.AbstractServiceAdaptor; + +import org.fedai.osx.core.context.FateContext; +import org.fedai.osx.core.exceptions.ExceptionInfo; +import org.fedai.osx.core.service.AbstractServiceAdaptor; import org.ppc.ptp.Osx; -public abstract class AbstractPtpServiceAdaptor extends AbstractServiceAdaptor { +public abstract class AbstractPtpServiceAdaptor extends AbstractServiceAdaptor { @Override - protected Osx.Outbound transformExceptionInfo(Context context, ExceptionInfo exceptionInfo) { + protected Osx.Outbound transformExceptionInfo(FateContext context, ExceptionInfo exceptionInfo) { + Osx.Outbound.Builder builder = Osx.Outbound.newBuilder(); builder.setCode(exceptionInfo.getCode()); builder.setMessage(exceptionInfo.getMessage()); diff --git a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpAckService.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpAckService.java similarity index 72% rename from java/osx/broker/src/main/java/com/osx/broker/ptp/PtpAckService.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpAckService.java index e18dae12de..89e6771b04 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpAckService.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpAckService.java @@ -13,43 +13,41 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.ptp; +package org.fedai.osx.broker.ptp; -import com.osx.broker.ServiceContainer; -import com.osx.broker.consumer.UnaryConsumer; -import com.osx.broker.queue.TransferQueue; -import com.osx.broker.queue.TransferQueueApplyInfo; -import com.osx.core.config.MetaInfo; -import com.osx.core.constant.ActionType; -import com.osx.core.constant.Dict; -import com.osx.core.constant.StatusCode; -import com.osx.core.context.Context; -import com.osx.core.exceptions.ConsumerNotExistException; -import com.osx.core.exceptions.InvalidRedirectInfoException; -import com.osx.core.exceptions.TransferQueueNotExistException; -import com.osx.core.router.RouterInfo; -import com.osx.core.service.InboundPackage; import org.apache.commons.lang3.StringUtils; - +import org.fedai.osx.api.router.RouterInfo; +import org.fedai.osx.broker.ServiceContainer; +import org.fedai.osx.broker.consumer.UnaryConsumer; +import org.fedai.osx.broker.queue.TransferQueue; +import org.fedai.osx.broker.queue.TransferQueueApplyInfo; +import org.fedai.osx.broker.util.TransferUtil; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.constant.ActionType; +import org.fedai.osx.core.constant.Dict; +import org.fedai.osx.core.constant.StatusCode; +import org.fedai.osx.core.context.FateContext; +import org.fedai.osx.core.exceptions.ConsumerNotExistException; +import org.fedai.osx.core.exceptions.InvalidRedirectInfoException; +import org.fedai.osx.core.exceptions.TransferQueueNotExistException; +import org.fedai.osx.core.service.InboundPackage; import org.ppc.ptp.Osx; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import static com.osx.broker.util.TransferUtil.redirect; - public class PtpAckService extends AbstractPtpServiceAdaptor { Logger logger = LoggerFactory.getLogger(PtpAckService.class); @Override - protected Osx.Outbound doService(Context context, InboundPackage data) { + protected Osx.Outbound doService(FateContext context, InboundPackage data) { context.setActionType(ActionType.LOCAL_ACK.getAlias()); Osx.Inbound inbound = data.getBody(); Osx.Outbound.Builder outboundBuilder = Osx.Outbound.newBuilder(); - String sessionId = context.getSessionId(); String topic = context.getTopic(); - Long offset = context.getRequestMsgIndex(); +// Long offset = context.getRequestMsgIndex(); + Long offset = (Long)context.getData(Dict.REQUEST_INDEX); TransferQueue transferQueue = ServiceContainer.transferQueueManager.getQueue(topic); - /** + /* * 若本地queue不存在,则检查是否在集群中其他节点 */ if (transferQueue == null) { @@ -72,7 +70,7 @@ protected Osx.Outbound doService(Context context, InboundPackage da redirectRouterInfo.setHost(redirectIp); redirectRouterInfo.setPort(redirectPort); //context.setRouterInfo(redirectRouterInfo); - return redirect(context, inbound, redirectRouterInfo, false); + return TransferUtil.redirect(context, inbound, redirectRouterInfo,true); } } else { throw new TransferQueueNotExistException(); @@ -80,7 +78,7 @@ protected Osx.Outbound doService(Context context, InboundPackage da } UnaryConsumer unaryConsumer = ServiceContainer.consumerManager.getUnaryConsumer(topic); if (unaryConsumer != null) { - long currentMsgIndex = unaryConsumer.ack(offset); + unaryConsumer.ack(offset); //context.setCurrentMsgIndex(currentMsgIndex); outboundBuilder.setCode(StatusCode.SUCCESS); outboundBuilder.setMessage(Dict.SUCCESS); @@ -88,8 +86,5 @@ protected Osx.Outbound doService(Context context, InboundPackage da } else { throw new ConsumerNotExistException("consumer is not exist"); } - } - - } diff --git a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpCancelTransferService.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpCancelTransferService.java similarity index 77% rename from java/osx/broker/src/main/java/com/osx/broker/ptp/PtpCancelTransferService.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpCancelTransferService.java index 687071fbfd..502e9658e6 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpCancelTransferService.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpCancelTransferService.java @@ -13,26 +13,27 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.ptp; +package org.fedai.osx.broker.ptp; -import com.osx.broker.ServiceContainer; -import com.osx.core.constant.Dict; -import com.osx.core.constant.StatusCode; -import com.osx.core.context.Context; -import com.osx.core.service.InboundPackage; +import org.fedai.osx.broker.ServiceContainer; +import org.fedai.osx.core.constant.Dict; +import org.fedai.osx.core.constant.StatusCode; +import org.fedai.osx.core.context.FateContext; +import org.fedai.osx.core.service.InboundPackage; import org.ppc.ptp.Osx; - import java.util.List; public class PtpCancelTransferService extends AbstractPtpServiceAdaptor { public PtpCancelTransferService() { - this.setServiceName("cansel-unary"); + this.setServiceName("cancel-unary"); } + + @Override - protected Osx.Outbound doService(Context context, InboundPackage data) { + protected Osx.Outbound doService(FateContext context, InboundPackage data) { String sessionId = context.getSessionId(); String topic = context.getTopic(); diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpClusterTokenApplyService.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpClusterTokenApplyService.java new file mode 100644 index 0000000000..0d1d8e1c9d --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpClusterTokenApplyService.java @@ -0,0 +1,31 @@ +package org.fedai.osx.broker.ptp; + +import com.google.protobuf.ByteString; +import org.fedai.osx.broker.ServiceContainer; +import org.fedai.osx.core.constant.ActionType; +import org.fedai.osx.core.context.FateContext; +import org.fedai.osx.core.service.InboundPackage; +import org.fedai.osx.core.token.TokenRequest; +import org.fedai.osx.core.token.TokenResult; +import org.fedai.osx.core.utils.JsonUtil; +import org.ppc.ptp.Osx; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.charset.StandardCharsets; + +public class PtpClusterTokenApplyService extends AbstractPtpServiceAdaptor { + + Logger logger = LoggerFactory.getLogger(PtpClusterTokenApplyService.class); + @Override + protected Osx.Outbound doService(FateContext context, InboundPackage data) { + context.setActionType(ActionType.CLUSTER_TOKEN_APPLY.getAlias()); + Osx.Inbound inbound = data.getBody(); + byte[] temp = inbound.getPayload().toByteArray(); + TokenRequest tokenRequest = JsonUtil.json2Object(temp, TokenRequest.class); + TokenResult tokenResult = ServiceContainer.defaultTokenService.requestToken(tokenRequest.getResource(),tokenRequest.getAcquireCount(),tokenRequest.isPrioritized()); + Osx.Outbound.Builder resultBuilder = Osx.Outbound.newBuilder(); + resultBuilder.setPayload(ByteString.copyFrom(JsonUtil.object2Json(tokenResult).getBytes(StandardCharsets.UTF_8))); + return resultBuilder.build(); + } +} diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpClusterTopicApplyService.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpClusterTopicApplyService.java new file mode 100644 index 0000000000..3aedd9e3e6 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpClusterTopicApplyService.java @@ -0,0 +1,59 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.fedai.osx.broker.ptp; + +import org.apache.commons.lang3.StringUtils; +import org.fedai.osx.broker.ServiceContainer; +import org.fedai.osx.core.constant.ActionType; +import org.fedai.osx.core.context.FateContext; +import org.fedai.osx.core.exceptions.ParameterException; +import org.fedai.osx.core.service.InboundPackage; +import org.ppc.ptp.Osx; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +public class PtpClusterTopicApplyService extends AbstractPtpServiceAdaptor { + Logger logger = LoggerFactory.getLogger(PtpClusterTopicApplyService.class); + @Override + protected Osx.Outbound doService(FateContext context, InboundPackage data) { + try { + context.setActionType(ActionType.TOPIC_APPLY.getAlias()); + Osx.Inbound inbound = data.getBody(); + String topic = inbound.getMetadataMap().get(Osx.Metadata.MessageTopic.name()); + String instanceId = inbound.getMetadataMap().get(Osx.Metadata.InstanceId.name()); + String sessionId = inbound.getMetadataMap().get(Osx.Header.SessionID.name()); + if (StringUtils.isEmpty(topic)) { + throw new ParameterException("topic is null"); + } + if (StringUtils.isEmpty(instanceId)) { + throw new ParameterException("instanceId is null"); + } + if (StringUtils.isEmpty(sessionId)) { + throw new ParameterException("sessionId is null"); + } + context.setTopic(topic); + context.setSessionId(sessionId); + Osx.Outbound outbound = ServiceContainer.transferQueueManager.applyFromMaster(topic, sessionId, instanceId); + logger.info("====================PtpClusterTopicApplyService================{}=====", outbound); + return outbound; + }catch(Exception e){ + e.printStackTrace(); + throw e; + } + } + +} diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpConsumeService.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpConsumeService.java new file mode 100644 index 0000000000..ec84c0667a --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpConsumeService.java @@ -0,0 +1,134 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.fedai.osx.broker.ptp; + +import com.google.common.base.Preconditions; +import io.grpc.ManagedChannel; +import io.grpc.stub.StreamObserver; +import org.apache.commons.lang3.StringUtils; +import org.fedai.osx.api.context.Context; +import org.fedai.osx.api.router.RouterInfo; +import org.fedai.osx.broker.ServiceContainer; +import org.fedai.osx.broker.consumer.UnaryConsumer; +import org.fedai.osx.broker.queue.CreateQueueResult; +import org.fedai.osx.broker.queue.TransferQueue; +import org.fedai.osx.broker.queue.TransferQueueApplyInfo; +import org.fedai.osx.broker.util.TransferUtil; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.constant.ActionType; +import org.fedai.osx.core.constant.Dict; +import org.fedai.osx.core.constant.StatusCode; +import org.fedai.osx.core.context.FateContext; +import org.fedai.osx.core.exceptions.ParameterException; +import org.fedai.osx.core.exceptions.TransferQueueNotExistException; +import org.fedai.osx.core.frame.GrpcConnectionFactory; +import org.fedai.osx.core.service.InboundPackage; +import org.ppc.ptp.Osx; +import org.ppc.ptp.PrivateTransferProtocolGrpc; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class PtpConsumeService extends AbstractPtpServiceAdaptor { + + Logger logger = LoggerFactory.getLogger(PtpConsumeService.class); + public PtpConsumeService() { + this.setServiceName("consume-unary"); + } + @Override + protected Osx.Outbound doService(FateContext context, InboundPackage data) { + + context.setActionType(ActionType.DEFUALT_CONSUME.getAlias()); + Osx.Inbound inbound = data.getBody(); + String topic = context.getTopic(); + TransferQueue transferQueue = ServiceContainer.transferQueueManager.getQueue(topic); + if (transferQueue == null) { + if (MetaInfo.isCluster()) { + TransferQueueApplyInfo transferQueueApplyInfo = ServiceContainer.transferQueueManager.queryGlobleQueue(topic); + if (transferQueueApplyInfo == null) { + throw new TransferQueueNotExistException("topic "+topic+" not found" ); +// CreateQueueResult createQueueResult = ServiceContainer.transferQueueManager.createNewQueue(topic, context.getSessionId(), false); +// if (createQueueResult.getTransferQueue() == null) { +// //重定向 +// Osx.TopicInfo topicInfo = Osx.TopicInfo.newBuilder() +// .setTopic(topic) +// .setCreateTimestamp(System.currentTimeMillis()) +// .setIp(createQueueResult.getRedirectIp()) +// .setPort(createQueueResult.getPort()) +// .build(); +// return TransferUtil.buildResponseInner(StatusCode.TRANSFER_QUEUE_REDIRECT,"NEED REDIRECT",topicInfo.toByteArray()).build(); +// } + } else { + String[] args = transferQueueApplyInfo.getInstanceId().split(":"); + String ip = args[0]; + int port = Integer.parseInt(args[1]); + RouterInfo routerInfo = new RouterInfo(); + routerInfo.setHost(ip); + routerInfo.setPort(port); + return redirect(context, routerInfo, inbound); + } + } else { + /** + * 单机版直接创建队列 + */ + logger.warn("create topic {} by consume request ", topic); + CreateQueueResult createQueueResult = ServiceContainer.transferQueueManager.createNewQueue(topic, context.getSessionId(), true); + if (createQueueResult.getTransferQueue() == null) { + throw new TransferQueueNotExistException(); + } + } + } + StreamObserver streamObserver = (StreamObserver) context.getData(Dict.RESPONSE_STREAM_OBSERVER); + Long offset = (Long) context.getData(Dict.REQUEST_INDEX); + Preconditions.checkArgument(offset != null); + if (offset == null) { + throw new ParameterException("offset is null"); + } + if (offset > 0) { + context.setActionType(ActionType.CUSTOMER_CONSUME.getAlias()); + } + UnaryConsumer consumer = ServiceContainer.consumerManager.getOrCreateUnaryConsumer(topic); + TransferQueue.TransferQueueConsumeResult transferQueueConsumeResult = consumer.consume(context, offset); + context.setReturnCode(transferQueueConsumeResult.getCode()); + if (transferQueueConsumeResult.getCode().equals(StatusCode.CONSUME_NO_MESSAGE)) { + // 由其他扫描线程应答 + if (offset < 0) { + + UnaryConsumer.LongPullingHold longPullingHold = new UnaryConsumer.LongPullingHold(); + + longPullingHold.setGrpcContext(io.grpc.Context.current()); + longPullingHold.setNeedOffset(offset); + longPullingHold.setStreamObserver(streamObserver); + longPullingHold.setContext(context.subContext()); + String timeOutString = inbound.getMetadataMap().get(Osx.Metadata.Timeout.name()); + if (StringUtils.isNotEmpty(timeOutString)) { + long current = System.currentTimeMillis(); + longPullingHold.setExpireTimestamp(current + Long.valueOf(timeOutString)); + } + consumer.addLongPullingQueue(longPullingHold); + return null; + } + } + Osx.Outbound consumeResponse = TransferUtil.buildResponse(transferQueueConsumeResult.getCode(), "", transferQueueConsumeResult); + return consumeResponse; + + } + private Osx.Outbound redirect(Context context, RouterInfo routerInfo, Osx.Inbound inbound) { + ManagedChannel managedChannel = GrpcConnectionFactory.createManagedChannel(routerInfo,true); + context.setActionType(ActionType.REDIRECT_CONSUME.getAlias()); + PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub stub = PrivateTransferProtocolGrpc.newBlockingStub(managedChannel); + return stub.invoke(inbound); + } +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpForwardPushRespSO.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpForwardPushRespSO.java similarity index 92% rename from java/osx/broker/src/main/java/com/osx/broker/ptp/PtpForwardPushRespSO.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpForwardPushRespSO.java index 5a66afc18f..38efd7ff7e 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpForwardPushRespSO.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpForwardPushRespSO.java @@ -13,14 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.ptp; +package org.fedai.osx.broker.ptp; -import com.osx.broker.callback.CompleteCallback; -import com.osx.broker.callback.ErrorCallback; -import com.osx.broker.util.TransferUtil; -import com.osx.core.context.Context; import com.webank.ai.eggroll.api.networking.proxy.Proxy; import io.grpc.stub.StreamObserver; +import org.fedai.osx.api.context.Context; +import org.fedai.osx.broker.callback.CompleteCallback; +import org.fedai.osx.broker.callback.ErrorCallback; +import org.fedai.osx.broker.util.TransferUtil; import org.ppc.ptp.Osx; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpProduceService.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpProduceService.java new file mode 100644 index 0000000000..ac70faced5 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpProduceService.java @@ -0,0 +1,213 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.fedai.osx.broker.ptp; + +import org.apache.commons.lang3.StringUtils; +import org.fedai.osx.api.router.RouterInfo; +import org.fedai.osx.broker.ServiceContainer; +import org.fedai.osx.broker.constants.MessageFlag; +import org.fedai.osx.broker.message.MessageDecoder; +import org.fedai.osx.broker.message.MessageExtBrokerInner; +import org.fedai.osx.broker.queue.CreateQueueResult; +import org.fedai.osx.broker.queue.PutMessageResult; +import org.fedai.osx.broker.queue.PutMessageStatus; +import org.fedai.osx.broker.queue.TransferQueue; +import org.fedai.osx.broker.util.TransferUtil; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.constant.ActionType; +import org.fedai.osx.core.constant.DeployMode; +import org.fedai.osx.core.constant.Dict; +import org.fedai.osx.core.constant.StatusCode; +import org.fedai.osx.core.context.FateContext; +import org.fedai.osx.core.exceptions.*; +import org.fedai.osx.core.service.InboundPackage; +import org.fedai.osx.core.service.Interceptor; +import org.fedai.osx.core.service.OutboundPackage; +import org.fedai.osx.core.utils.FlowLogUtil; +import org.ppc.ptp.Osx; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class PtpProduceService extends AbstractPtpServiceAdaptor { + + Logger logger = LoggerFactory.getLogger(PtpProduceService.class); + + + public PtpProduceService() { + this.addPostProcessor(new Interceptor() { + @Override + public void doProcess(FateContext context, InboundPackage inboundPackage, OutboundPackage outboundPackage) { + TransferQueue transferQueue = (TransferQueue) context.getData(Dict.TRANSFER_QUEUE); + if (transferQueue != null) { + transferQueue.cacheReceivedMsg(inboundPackage.getBody().getMetadataMap().get(Osx.Metadata.MessageCode.name()), outboundPackage); + } + } + }); + } + + @Override + protected Osx.Outbound doService(FateContext context, InboundPackage data) { + TransferQueue transferQueue ; + String topic = context.getTopic(); + RouterInfo routerInfo = context.getRouterInfo(); + String srcPartyId = context.getSrcPartyId(); + String sessionId = context.getSessionId(); + Osx.Inbound produceRequest = data.getBody(); + int dataSize = produceRequest.getSerializedSize(); + String resource = TransferUtil.buildResource(produceRequest); + ServiceContainer.tokenApplyService.applyToken(context, resource, dataSize); + ServiceContainer.flowCounterManager.pass(resource, dataSize); + context.setDataSize(dataSize); + if (!MetaInfo.PROPERTY_SELF_PARTY.contains(context.getDesPartyId())) { + + //向外转发 + Osx.Outbound response = null; + int tryTime = 0; + context.setActionType(ActionType.MSG_REDIRECT.getAlias()); + boolean usePooled = true; + while (tryTime < MetaInfo.PROPERTY_PRODUCE_MSG_MAX_TRY_TIME) { + tryTime++; + + try { + if (tryTime > 1) { + context.setRetryTime(tryTime); + produceRequest = produceRequest.toBuilder().putMetadata(Osx.Metadata.RetryCount.name(), Integer.toString(tryTime)).build(); + usePooled = false; + } + response = TransferUtil.redirect(context, produceRequest, routerInfo,usePooled); + if (response == null) { + continue; + } + + break; + } catch (RemoteRpcException e) { + logger.error("redirect retry count {}", tryTime); + if (tryTime == MetaInfo.PROPERTY_PRODUCE_MSG_MAX_TRY_TIME) { + throw e; + }else{ + FlowLogUtil.printFlowLog(context); + } + try { + Thread.sleep(MetaInfo.PROPERTY_PRODUCE_MSG_RETRY_INTERVAL); + } catch (InterruptedException ignore) { + + } + } + } + return response; + } else { + /* + * 本地处理 + */ + if (StringUtils.isEmpty(topic)) { + throw new ParameterException(StatusCode.PARAM_ERROR, "topic is null"); + } + if (StringUtils.isEmpty(sessionId)) { + throw new ParameterException(StatusCode.PARAM_ERROR, "sessionId is null"); + } + + context.setActionType(ActionType.MSG_DOWNLOAD.getAlias()); + context.setRouterInfo(null); + + transferQueue = ServiceContainer.transferQueueManager.getQueue(topic); + CreateQueueResult createQueueResult = null; + if (transferQueue == null) { + createQueueResult = ServiceContainer.transferQueueManager.createNewQueue(topic, sessionId, false); + if (createQueueResult == null) { + throw new CreateTopicErrorException("create topic " + topic + " error"); + } + transferQueue = createQueueResult.getTransferQueue(); + } + + + + if (transferQueue != null) { +// ServiceContainer.tokenApplyService.applyToken(context, resource, dataSize); +// ServiceContainer.flowCounterManager.pass(resource, dataSize); + context.putData(Dict.TRANSFER_QUEUE, transferQueue); + String msgCode = produceRequest.getMetadataMap().get(Osx.Metadata.MessageCode.name()); + String retryCountString = produceRequest.getMetadataMap().get(Osx.Metadata.RetryCount.name()); + //此处为处理重复请求 + if (StringUtils.isNotEmpty(msgCode)) { + if (transferQueue.checkMsgIdDuplicate(msgCode)) {//检查消息是不是已经存在于队列里面 + if (StringUtils.isBlank(retryCountString)) {//重复请求,非重试请求 + Osx.Outbound.Builder outBoundBuilder = Osx.Outbound.newBuilder(); + outBoundBuilder.setCode(StatusCode.SUCCESS); + outBoundBuilder.setMessage(Dict.DUP_MSG); + return outBoundBuilder.build(); + } else { + logger.info("receive retry request , topic {} msgcode {} try count {}", topic, msgCode, retryCountString); + } + OutboundPackage cacheReceivedMsg = transferQueue.getReceivedMsgCache(msgCode); + if (cacheReceivedMsg != null) {//返回上次缓存的结果 + return cacheReceivedMsg.getData(); + } else {//重试请求,但是缓存的结果已经过期 + logger.warn("The cached message has expired , msgCode = {}", msgCode); + Osx.Outbound.Builder outBoundBuilder = Osx.Outbound.newBuilder(); + outBoundBuilder.setCode(StatusCode.SUCCESS); + outBoundBuilder.setMessage(Dict.PROCESSED_MSG); + return outBoundBuilder.build(); + } + } + } + + byte[] msgBytes = produceRequest.getPayload().toByteArray(); + String flag = produceRequest.getMetadataMap().get(Osx.Metadata.MessageFlag.name()); + MessageFlag messageFlag = MessageFlag.SENDMSG; + if (StringUtils.isNotEmpty(flag)) { + messageFlag = MessageFlag.valueOf(flag); + } + context.putData(Dict.MESSAGE_FLAG, messageFlag.name()); + MessageExtBrokerInner messageExtBrokerInner = MessageDecoder.buildMessageExtBrokerInner(topic, msgBytes, msgCode, messageFlag, context.getSrcPartyId(), + context.getDesPartyId()); + messageExtBrokerInner.getProperties().put(Dict.SESSION_ID, sessionId); + messageExtBrokerInner.getProperties().put(Dict.SOURCE_COMPONENT, context.getSrcComponent() != null ? context.getSrcComponent() : ""); + messageExtBrokerInner.getProperties().put(Dict.DES_COMPONENT, context.getDesComponent() != null ? context.getDesComponent() : ""); + PutMessageResult putMessageResult = transferQueue.putMessage(messageExtBrokerInner); + if (putMessageResult.getPutMessageStatus() != PutMessageStatus.PUT_OK) { + throw new PutMessageException("put status " + putMessageResult.getPutMessageStatus()); + } + long logicOffset = putMessageResult.getMsgLogicOffset(); + context.putData(Dict.CURRENT_INDEX, transferQueue.getIndexQueue().getLogicOffset().get()); + Osx.Outbound.Builder outBoundBuilder = Osx.Outbound.newBuilder(); + outBoundBuilder.setCode(StatusCode.SUCCESS); + outBoundBuilder.setMessage(Dict.SUCCESS); + return outBoundBuilder.build(); + } else { + /* + * 集群内转发 + */ + if (MetaInfo.PROPERTY_DEPLOY_MODE.equals(DeployMode.cluster.name())) { + RouterInfo redirectRouterInfo = new RouterInfo(); + String redirectIp = createQueueResult.getRedirectIp(); + int redirectPort = createQueueResult.getPort(); + if (StringUtils.isEmpty(redirectIp) || redirectPort == 0) { + logger.error("invalid redirect info {}:{}", redirectIp, redirectPort); + throw new InvalidRedirectInfoException(); + } + redirectRouterInfo.setHost(redirectIp); + redirectRouterInfo.setPort(redirectPort); + context.putData(Dict.ROUTER_INFO, redirectRouterInfo); + context.setActionType(ActionType.INNER_REDIRECT.getAlias()); + return TransferUtil.redirect(context, produceRequest, redirectRouterInfo,true); + } else { + logger.error("create topic {} error", topic); + throw new ProduceMsgExcption(); + } + } + } + } +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpPushService.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpPushService.java similarity index 67% rename from java/osx/broker/src/main/java/com/osx/broker/ptp/PtpPushService.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpPushService.java index 4d724509b3..b21815f08c 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpPushService.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpPushService.java @@ -13,31 +13,34 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.ptp; -import com.osx.broker.ServiceContainer; -import com.osx.broker.grpc.PushRequestDataWrap; -import com.osx.broker.grpc.QueuePushReqStreamObserver; -import com.osx.broker.util.TransferUtil; -import com.osx.core.context.Context; -import com.osx.core.exceptions.ExceptionInfo; -import com.osx.core.ptp.TargetMethod; -import com.osx.core.service.AbstractServiceAdaptor; -import com.osx.core.service.InboundPackage; -import com.osx.core.token.TokenResult; +package org.fedai.osx.broker.ptp; + import com.webank.ai.eggroll.api.networking.proxy.Proxy; import io.grpc.stub.StreamObserver; +import org.fedai.osx.broker.ServiceContainer; +import org.fedai.osx.broker.grpc.QueuePushReqStreamObserver; +import org.fedai.osx.broker.util.TransferUtil; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.context.FateContext; +import org.fedai.osx.core.exceptions.ExceptionInfo; +import org.fedai.osx.core.service.AbstractServiceAdaptor; +import org.fedai.osx.core.service.InboundPackage; import org.ppc.ptp.Osx; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class PtpPushService extends AbstractServiceAdaptor { +public class PtpPushService extends AbstractServiceAdaptor { Logger logger = LoggerFactory.getLogger(PtpPushService.class); + + @Override - protected StreamObserver doService(Context context, InboundPackage data) { + protected StreamObserver doService(FateContext context, InboundPackage data) { StreamObserver responseStreamObserver = data.getBody(); + context.setNeedPrintFlowLog(false); return new StreamObserver() { Logger logger = LoggerFactory.getLogger(PtpPushService.class); - QueuePushReqStreamObserver queuePushReqStreamObserver = new QueuePushReqStreamObserver(context,responseStreamObserver,Osx.Outbound.class); + QueuePushReqStreamObserver queuePushReqStreamObserver = new QueuePushReqStreamObserver(context, ServiceContainer.routerRegister.getRouterService(MetaInfo.PROPERTY_FATE_TECH_PROVIDER), + responseStreamObserver,Osx.Outbound.class); @Override public void onNext(Osx.Inbound inbound) { int dataSize = inbound.getSerializedSize(); @@ -61,7 +64,7 @@ public void onCompleted() { } @Override - protected StreamObserver transformExceptionInfo(Context context, ExceptionInfo exceptionInfo) { + protected StreamObserver transformExceptionInfo(FateContext context, ExceptionInfo exceptionInfo) { return null; } } diff --git a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpQueryTransferQueueService.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpQueryTransferQueueService.java similarity index 93% rename from java/osx/broker/src/main/java/com/osx/broker/ptp/PtpQueryTransferQueueService.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpQueryTransferQueueService.java index 1ce8e0f25f..a6ce02f2e8 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/ptp/PtpQueryTransferQueueService.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpQueryTransferQueueService.java @@ -13,19 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.ptp; +package org.fedai.osx.broker.ptp; -import com.osx.broker.ServiceContainer; -import com.osx.broker.queue.CreateQueueResult; -import com.osx.broker.queue.TransferQueue; -import com.osx.broker.queue.TransferQueueApplyInfo; -import com.osx.core.config.MetaInfo; -import com.osx.core.constant.ActionType; -import com.osx.core.constant.StatusCode; -import com.osx.core.context.Context; -import com.osx.core.service.InboundPackage; -import com.osx.core.utils.NetUtils; +import org.fedai.osx.broker.ServiceContainer; +import org.fedai.osx.broker.queue.CreateQueueResult; +import org.fedai.osx.broker.queue.TransferQueue; +import org.fedai.osx.broker.queue.TransferQueueApplyInfo; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.constant.ActionType; +import org.fedai.osx.core.constant.StatusCode; +import org.fedai.osx.core.context.FateContext; +import org.fedai.osx.core.service.InboundPackage; +import org.fedai.osx.core.utils.NetUtils; import org.ppc.ptp.Osx; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -40,7 +40,7 @@ public PtpQueryTransferQueueService() { } @Override - protected Osx.Outbound doService(Context context, InboundPackage data) { + protected Osx.Outbound doService(FateContext context, InboundPackage data) { Osx.Inbound request = data.getBody(); Osx.Outbound.Builder outboundBuilder = Osx.Outbound.newBuilder(); diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpStreamTestService.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpStreamTestService.java new file mode 100644 index 0000000000..a8d2242bc6 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpStreamTestService.java @@ -0,0 +1,99 @@ +//package org.fedai.osx.broker.ptp; +// +//import com.google.protobuf.Parser; +//import org.fedai.osx.broker.ServiceContainer; +//import org.fedai.osx.broker.grpc.QueueStreamBuilder; +//import org.fedai.osx.broker.grpc.QueuePushReqStreamObserver; +//import org.fedai.osx.broker.util.TransferUtil; +//import org.fedai.osx.core.constant.TransferStatus; +//import org.fedai.osx.core.context.Context; +//import org.fedai.osx.core.exceptions.ExceptionInfo; +//import org.fedai.osx.core.frame.GrpcConnectionFactory; +//import org.fedai.osx.core.router.RouterInfo; +//import org.fedai.osx.core.service.AbstractServiceAdaptor; +//import org.fedai.osx.core.service.InboundPackage; +//import com.webank.ai.eggroll.api.networking.proxy.Proxy; +//import io.grpc.ManagedChannel; +//import io.grpc.stub.StreamObserver; +//import org.ppc.ptp.Osx; +//import org.ppc.ptp.PrivateTransferProtocolGrpc; +//import org.slf4j.Logger; +//import org.slf4j.LoggerFactory; +// +//public class PtpStreamTestService extends AbstractServiceAdaptor { +// +// Logger logger = LoggerFactory.getLogger(PtpStreamTestService.class); +// @Override +// protected StreamObserver doService(Context context, InboundPackage data) { +// +// return new StreamObserver() { +// TransferStatus transferStatus = TransferStatus.INIT; +// StreamObserver responseStreamObserver = data.getBody(); +// StreamObserver reqSb=null; +// boolean isDes = false; +// +//// private void initDes(Osx.Inbound first){ +//// +//// +//// reqSb = HttpStreamBuilder.buildStream(responseStreamObserver, +//// Osx.Outbound.parser(), +//// GrpcConnectionFactory.createManagedChannel(context.getRouterInfo(),true), +//// context.getSrcPartyId(),context.getDesPartyId(),context.getSessionId()); +//// transferStatus = TransferStatus.TRANSFERING; +//// } +// +// private void initNotDes(Osx.Inbound first){ +// InboundPackage inboundPackage = new InboundPackage(); +// inboundPackage.setBody(first); +// try { +// ServiceContainer.requestHandleInterceptor.doPreProcess(context, inboundPackage); +// ServiceContainer.routerInterceptor.doPreProcess(context,inboundPackage); +// logger.info("init========={}",context.getRouterInfo()); +// }catch (Exception e){ +// e.printStackTrace(); +// } +// logger.info("ppppppppppppppppppp {}",context.getRouterInfo()); +// reqSb = QueueStreamBuilder.createStreamFromOrigin(context,responseStreamObserver, +// Osx.Outbound.parser(), +// context.getRouterInfo(), +// context.getSrcPartyId(), +// context.getDesPartyId(), +// context.getSessionId(),null); +// transferStatus = TransferStatus.TRANSFERING; +// } +// +// +// @Override +// public void onNext(Osx.Inbound inbound) { +// +//// if(isDes) { +//// if (transferStatus == TransferStatus.INIT) { +//// initDes(inbound); +//// } +//// } +//// else{ +// if(transferStatus==TransferStatus.INIT) { +// initNotDes(inbound); +// } +// // } +// +// if (reqSb != null) { +// reqSb.onNext(inbound); +// } +// } +// @Override +// public void onError(Throwable throwable) { +// reqSb.onError(throwable); +// } +// @Override +// public void onCompleted() { +// logger.info("==============onCompleted=============="); +// } +// }; +// } +// +// @Override +// protected StreamObserver transformExceptionInfo(Context context, ExceptionInfo exceptionInfo) { +// return null; +// } +//} diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpUnaryCallService.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpUnaryCallService.java new file mode 100644 index 0000000000..4248bdb24a --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/ptp/PtpUnaryCallService.java @@ -0,0 +1,41 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.fedai.osx.broker.ptp; + +import org.fedai.osx.api.router.RouterInfo; +import org.fedai.osx.broker.util.TransferUtil; +import org.fedai.osx.core.constant.ActionType; +import org.fedai.osx.core.context.FateContext; +import org.fedai.osx.core.service.InboundPackage; +import org.ppc.ptp.Osx; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class PtpUnaryCallService extends AbstractPtpServiceAdaptor { + + Logger logger = LoggerFactory.getLogger(PtpUnaryCallService.class); + @Override + protected Osx.Outbound doService(FateContext context, InboundPackage data) { + + context.setActionType(ActionType.UNARY_CALL_NEW.getAlias()); + RouterInfo routerInfo = context.getRouterInfo(); + Osx.Inbound inbound = data.getBody(); + // logger.info("PtpUnaryCallService receive : {}",inbound); + Osx.Outbound outbound = TransferUtil.redirect(context,inbound,routerInfo,true); + return outbound; + } + +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/queue/Consumer.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/Consumer.java similarity index 79% rename from java/osx/broker/src/main/java/com/osx/broker/queue/Consumer.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/Consumer.java index ce5c7c9f91..007c263487 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/queue/Consumer.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/Consumer.java @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.queue; -import com.osx.core.context.Context; +package org.fedai.osx.broker.queue; -public interface Consumer { + +import org.fedai.osx.api.context.Context; +import org.fedai.osx.core.frame.Lifecycle; + +public interface Consumer extends Lifecycle { public T consume(Context context, long offset); diff --git a/java/osx/broker/src/main/java/com/osx/broker/queue/CreateQueueResult.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/CreateQueueResult.java similarity index 94% rename from java/osx/broker/src/main/java/com/osx/broker/queue/CreateQueueResult.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/CreateQueueResult.java index 9c50f615bd..5aa610057d 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/queue/CreateQueueResult.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/CreateQueueResult.java @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.queue; +package org.fedai.osx.broker.queue; -import com.osx.core.utils.JsonUtil; +import org.fedai.osx.core.utils.JsonUtil; public class CreateQueueResult { TransferQueue transferQueue; diff --git a/java/osx/broker/src/main/java/com/osx/broker/queue/MappedFile.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/MappedFile.java similarity index 99% rename from java/osx/broker/src/main/java/com/osx/broker/queue/MappedFile.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/MappedFile.java index 2ef37a6951..f27b48a545 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/queue/MappedFile.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/MappedFile.java @@ -13,13 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.queue; +package org.fedai.osx.broker.queue; -import com.osx.broker.message.*; -import com.osx.broker.util.LibC; -import com.osx.broker.util.UtilAll; import com.sun.jna.NativeLong; import com.sun.jna.Pointer; +import org.fedai.osx.broker.message.*; +import org.fedai.osx.broker.util.LibC; +import org.fedai.osx.broker.util.UtilAll; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import sun.nio.ch.DirectBuffer; diff --git a/java/osx/broker/src/main/java/com/osx/broker/queue/MappedFileQueue.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/MappedFileQueue.java similarity index 99% rename from java/osx/broker/src/main/java/com/osx/broker/queue/MappedFileQueue.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/MappedFileQueue.java index 268cb3f794..7591096f80 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/queue/MappedFileQueue.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/MappedFileQueue.java @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.queue; +package org.fedai.osx.broker.queue; -import com.osx.broker.message.AllocateMappedFileService; -import com.osx.broker.message.SelectMappedBufferResult; -import com.osx.broker.util.UtilAll; +import org.fedai.osx.broker.message.AllocateMappedFileService; +import org.fedai.osx.broker.message.SelectMappedBufferResult; +import org.fedai.osx.broker.util.UtilAll; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/java/osx/broker/src/main/java/com/osx/broker/queue/PutMessageLock.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/PutMessageLock.java similarity index 95% rename from java/osx/broker/src/main/java/com/osx/broker/queue/PutMessageLock.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/PutMessageLock.java index 331070f8ec..6f6326af5b 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/queue/PutMessageLock.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/PutMessageLock.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.queue; +package org.fedai.osx.broker.queue; public interface PutMessageLock { void lock(); diff --git a/java/osx/broker/src/main/java/com/osx/broker/queue/PutMessageReentrantLock.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/PutMessageReentrantLock.java similarity index 96% rename from java/osx/broker/src/main/java/com/osx/broker/queue/PutMessageReentrantLock.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/PutMessageReentrantLock.java index b9f9e826fc..5dbf39c71a 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/queue/PutMessageReentrantLock.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/PutMessageReentrantLock.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.queue; +package org.fedai.osx.broker.queue; import java.util.concurrent.locks.ReentrantLock; diff --git a/java/osx/broker/src/main/java/com/osx/broker/queue/PutMessageResult.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/PutMessageResult.java similarity index 95% rename from java/osx/broker/src/main/java/com/osx/broker/queue/PutMessageResult.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/PutMessageResult.java index 6de6219647..becb7321db 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/queue/PutMessageResult.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/PutMessageResult.java @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.queue; +package org.fedai.osx.broker.queue; -import com.osx.broker.message.AppendMessageResult; +import org.fedai.osx.broker.message.AppendMessageResult; public class PutMessageResult { private PutMessageStatus putMessageStatus; diff --git a/java/osx/broker/src/main/java/com/osx/broker/queue/PutMessageStatus.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/PutMessageStatus.java similarity index 96% rename from java/osx/broker/src/main/java/com/osx/broker/queue/PutMessageStatus.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/PutMessageStatus.java index ed637554cd..688af558e7 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/queue/PutMessageStatus.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/PutMessageStatus.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.queue; +package org.fedai.osx.broker.queue; public enum PutMessageStatus { PUT_OK, diff --git a/java/osx/broker/src/main/java/com/osx/broker/queue/ReferenceResource.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/ReferenceResource.java similarity index 98% rename from java/osx/broker/src/main/java/com/osx/broker/queue/ReferenceResource.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/ReferenceResource.java index 48fee6879a..cda19878c1 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/queue/ReferenceResource.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/ReferenceResource.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.queue; +package org.fedai.osx.broker.queue; import java.util.concurrent.atomic.AtomicLong; diff --git a/java/osx/broker/src/main/java/com/osx/broker/queue/TransferQueue.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/TransferQueue.java similarity index 65% rename from java/osx/broker/src/main/java/com/osx/broker/queue/TransferQueue.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/TransferQueue.java index 3bc3605dc3..a6a4513641 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/queue/TransferQueue.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/TransferQueue.java @@ -13,30 +13,43 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.queue; -import com.osx.broker.ServiceContainer; -import com.osx.broker.callback.CompleteCallback; -import com.osx.broker.callback.DestoryCallback; -import com.osx.broker.callback.ErrorCallback; -import com.osx.broker.message.MessageDecoder; -import com.osx.broker.message.MessageExt; -import com.osx.broker.message.MessageExtBrokerInner; -import com.osx.broker.message.SelectMappedBufferResult; -import com.osx.broker.store.IndexQueue; -import com.osx.core.config.MetaInfo; -import com.osx.core.constant.StatusCode; -import com.osx.core.constant.TransferStatus; -import com.osx.core.context.Context; -import com.osx.core.exceptions.TransferQueueInvalidStatusException; -import com.osx.core.queue.TranferQueueInfo; +package org.fedai.osx.broker.queue; + +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import org.apache.commons.lang3.StringUtils; +import org.fedai.osx.api.context.Context; +import org.fedai.osx.broker.callback.CompleteCallback; +import org.fedai.osx.broker.callback.DestoryCallback; +import org.fedai.osx.broker.callback.ErrorCallback; +import org.fedai.osx.broker.callback.MsgEventCallback; +import org.fedai.osx.broker.message.MessageDecoder; +import org.fedai.osx.broker.message.MessageExt; +import org.fedai.osx.broker.message.MessageExtBrokerInner; +import org.fedai.osx.broker.message.SelectMappedBufferResult; +import org.fedai.osx.broker.store.IndexQueue; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.constant.Dict; +import org.fedai.osx.core.constant.StatusCode; +import org.fedai.osx.core.constant.TransferStatus; +import org.fedai.osx.core.exceptions.PutMessageException; +import org.fedai.osx.core.exceptions.TransferQueueInvalidStatusException; +import org.fedai.osx.core.queue.TranferQueueInfo; +import org.fedai.osx.core.service.OutboundPackage; +import org.ppc.ptp.Osx; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReferenceArray; public class TransferQueue { + + AtomicReferenceArray receivedMsgIds = new AtomicReferenceArray<>(MetaInfo.PROPERTY_TRANSFER_CACHED_MSGID_SIZE); + private Cache> receivedMsgCache; protected final AtomicInteger wrotePosition = new AtomicInteger(0); Logger logger = LoggerFactory.getLogger(TransferQueue.class); String transferId; @@ -46,7 +59,8 @@ public class TransferQueue { volatile TransferStatus transferStatus = TransferStatus.INIT; List errorCallbacks = new ArrayList<>(); List completeCallbacks = new ArrayList<>(); - List destoryCallbacks = new ArrayList(); + List destoryCallbacks = new ArrayList<>(); + List msgCallbacks = new ArrayList<>(); long createTimestamp; long lastStatusChangeTimestamp; long lastWriteTimestamp; @@ -54,6 +68,17 @@ public class TransferQueue { boolean writeOver = false; IndexQueue indexQueue; TransferQueueManager transferQueueManager; + + public boolean isHasEventMsgDestoryCallback() { + return hasEventMsgDestoryCallback; + } + + public void setHasEventMsgDestoryCallback(boolean hasEventMsgDestoryCallback) { + this.hasEventMsgDestoryCallback = hasEventMsgDestoryCallback; + } + + boolean hasEventMsgDestoryCallback = false; + public TransferQueue(String transferId, TransferQueueManager transferQueueManager, String path) { this.transferId = transferId; this.transferQueueManager = transferQueueManager; @@ -61,7 +86,7 @@ public TransferQueue(String transferId, TransferQueueManager transferQueueManage this.lastStatusChangeTimestamp = this.createTimestamp; this.lastWriteTimestamp = this.createTimestamp; this.indexQueue = new IndexQueue(transferId, path, MetaInfo.PROPERTY_INDEX_MAP_FILE_SIZE); - + initReceivedMsgCache(); } public String getSessionId() { @@ -96,20 +121,48 @@ public void setIndexQueue(IndexQueue indexQueue) { this.indexQueue = indexQueue; } - public synchronized PutMessageResult putMessage(final MessageExtBrokerInner msg) { + public synchronized boolean checkMsgIdDuplicate(String msgId) { + for (int i = 0; i < receivedMsgIds.length(); i++) { + String tempMsgId = receivedMsgIds.get(i); + if (msgId.equals(tempMsgId)) { + return true; + } + } + return false; + } + + public synchronized PutMessageResult putMessage(final MessageExtBrokerInner msg) { + if (transferStatus == TransferStatus.TRANSFERING) { + String msgId = msg.getMsgId(); this.lastWriteTimestamp = System.currentTimeMillis(); - PutMessageResult putMessageResult = ServiceContainer.messageStore.putMessage(msg); + PutMessageResult putMessageResult = transferQueueManager.messageStore.putMessage(msg); if (putMessageResult.isOk()) { + + int cacheIdx = wrotePosition.addAndGet(1) % MetaInfo.PROPERTY_TRANSFER_CACHED_MSGID_SIZE; + receivedMsgIds.set(cacheIdx, msgId); long beginWriteOffset = putMessageResult.getAppendMessageResult().getWroteOffset(); int size = putMessageResult.getAppendMessageResult().getWroteBytes(); - logger.info("store begin offset {},size {}", beginWriteOffset, size); putMessageResult.setMsgLogicOffset(indexQueue.putMessagePositionInfoWrapper(beginWriteOffset, size)); + //todo 这里需要修改,用另外的队列类型来做,就不再需要持久化 + if (this.msgCallbacks.size() > 0) { + try { + for (MsgEventCallback msgCallback : this.msgCallbacks) { + msgCallback.callback(this, msg); + } + }catch(Exception e){ + e.printStackTrace(); + logger.error("topic {} callback error",msg.getTopic(),e); + throw new PutMessageException("topic " + msg.getTopic() + " callback error"); + } + } } else { - throw new RuntimeException(); + logger.info("topic {} put msg error {}",transferId,putMessageResult.getPutMessageStatus()); + throw new PutMessageException("topic " + msg.getTopic() + " put message error"); } return putMessageResult; } else { + logger.error("topic {} is not ready",transferId); throw new TransferQueueInvalidStatusException("invalid queue status : " + transferStatus); } } @@ -120,13 +173,15 @@ public TransferQueueConsumeResult consumeOneMessage(Context context, long reques if (transferStatus == TransferStatus.TRANSFERING) { this.lastReadTimestamp = System.currentTimeMillis(); long logicIndex = indexQueue.getLogicOffset().get(); - context.setRequestMsgIndex(requestIndex); - context.setCurrentMsgIndex(logicIndex); + + context.putData(Dict.REQUEST_INDEX, requestIndex); + //context.setCurrentMsgIndex(logicIndex); + context.putData(Dict.CURRENT_INDEX, logicIndex); if (requestIndex <= logicIndex) { SelectMappedBufferResult indexBufferResult = this.indexQueue.getIndexBuffer(requestIndex); if (indexBufferResult != null) { long pyOffset = indexBufferResult.getByteBuffer().getLong(); - SelectMappedBufferResult msgBufferResult = ServiceContainer.messageStore.consumeOneMessage(pyOffset); + SelectMappedBufferResult msgBufferResult = this.transferQueueManager.getMessageStore().consumeOneMessage(pyOffset); transferQueueConsumeResult = new TransferQueueConsumeResult(StatusCode.SUCCESS, msgBufferResult, requestIndex, logicIndex); MessageExt message = MessageDecoder.decode(transferQueueConsumeResult.getSelectMappedBufferResult().getByteBuffer()); transferQueueConsumeResult.setMessage(message); @@ -145,12 +200,12 @@ public TransferQueueConsumeResult consumeOneMessage(Context context, long reques public synchronized void destory() { logger.info("try to destory transfer queue {} ", transferId); this.indexQueue.destroy(); - logger.info("destroy index file"); + logger.info("topic {} destroy index file", transferId); destoryCallbacks.forEach(destoryCallback -> { try { destoryCallback.callback(); } catch (Exception e) { - logger.error("destory call back error", e); + logger.error("topic {} destory call back execute error", transferId, e); } }); } @@ -165,14 +220,13 @@ public void setCreateTimestamp(long createTimestamp) { public synchronized void onCompeleted() { if (transferStatus == TransferStatus.TRANSFERING) { - transferStatus = TransferStatus.FINISH; } completeCallbacks.forEach(completeCallback -> { try { completeCallback.callback(); } catch (Exception e) { - + logger.error("complete call back error", e); } }); } @@ -191,7 +245,7 @@ public synchronized void onError(Throwable throwable) { }); } - public synchronized void registeErrorCallback(ErrorCallback errorCallback) { + public synchronized void registerErrorCallback(ErrorCallback errorCallback) { if (transferStatus == TransferStatus.TRANSFERING) { errorCallbacks.add(errorCallback); } else { @@ -199,20 +253,24 @@ public synchronized void registeErrorCallback(ErrorCallback errorCallback) { } } - public synchronized void registeDestoryCallback(DestoryCallback destoryCallback) { + public synchronized void registerDestoryCallback(DestoryCallback destoryCallback) { if (transferStatus == TransferStatus.TRANSFERING) destoryCallbacks.add(destoryCallback); else throw new TransferQueueInvalidStatusException("status is " + transferStatus); } + public synchronized void registerMsgCallback(List msgCallbacks) { + if (transferStatus == TransferStatus.TRANSFERING) { + this.msgCallbacks.addAll(msgCallbacks); + } else + throw new TransferQueueInvalidStatusException("status is " + transferStatus); + } + public TransferStatus getTransferStatus() { return transferStatus; } - // public void setTransferStatus(TransferStatus transferStatus) { -// this.transferStatus = transferStatus; -// } public AtomicInteger getWrotePosition() { return wrotePosition; } @@ -256,6 +314,27 @@ public void setLastWriteTimestamp(long lastWriteTimestamp) { this.lastWriteTimestamp = lastWriteTimestamp; } + public void cacheReceivedMsg(String msgId, OutboundPackage outboundPackage) { + + if(StringUtils.isNotEmpty(msgId)) + receivedMsgCache.put(msgId, outboundPackage); + } + + public OutboundPackage getReceivedMsgCache(String sessionId) { + + return receivedMsgCache.getIfPresent(sessionId); + } + + private void initReceivedMsgCache() { + if (receivedMsgCache == null) { + CacheBuilder cacheBuilder = CacheBuilder.newBuilder().maximumSize(MetaInfo.PRODUCE_MSG_CACHE_MAX_SIZE); + if (MetaInfo.PRODUCE_MSG_CACHE_TIMEOUT != null && MetaInfo.PRODUCE_MSG_CACHE_TIMEOUT > 0) { + cacheBuilder.expireAfterWrite(MetaInfo.PRODUCE_MSG_CACHE_TIMEOUT, TimeUnit.MILLISECONDS); + } + receivedMsgCache = cacheBuilder.build(); + } + } + public TranferQueueInfo getTransferQueueInfo() { TranferQueueInfo transferQueueInfo = new TranferQueueInfo(); transferQueueInfo.setTransferId(transferId); diff --git a/java/osx/broker/src/main/java/com/osx/broker/queue/TransferQueueApplyInfo.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/TransferQueueApplyInfo.java similarity index 95% rename from java/osx/broker/src/main/java/com/osx/broker/queue/TransferQueueApplyInfo.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/TransferQueueApplyInfo.java index 17edc0c351..8a40d9da3d 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/queue/TransferQueueApplyInfo.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/TransferQueueApplyInfo.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.queue; -import com.osx.core.utils.JsonUtil; +package org.fedai.osx.broker.queue; +import org.fedai.osx.core.utils.JsonUtil; public class TransferQueueApplyInfo { diff --git a/java/osx/broker/src/main/java/com/osx/broker/queue/TransferQueueManager.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/TransferQueueManager.java similarity index 71% rename from java/osx/broker/src/main/java/com/osx/broker/queue/TransferQueueManager.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/TransferQueueManager.java index 0096e38300..013e88fceb 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/queue/TransferQueueManager.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/TransferQueueManager.java @@ -13,43 +13,48 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.queue; +package org.fedai.osx.broker.queue; + import com.google.common.base.Preconditions; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; -import com.google.protobuf.ByteString; -import com.osx.broker.ServiceContainer; -import com.osx.core.config.MasterInfo; -import com.osx.core.config.MetaInfo; -import com.osx.core.constant.DeployMode; -import com.osx.core.constant.Dict; -import com.osx.core.constant.StatusCode; -import com.osx.core.constant.TransferStatus; -import com.osx.core.context.Context; -import com.osx.core.exceptions.CreateTopicErrorException; -import com.osx.core.exceptions.RemoteRpcException; -import com.osx.core.frame.GrpcConnectionFactory; -import com.osx.core.frame.ServiceThread; -import com.osx.core.ptp.TargetMethod; -import com.osx.core.router.RouterInfo; -import com.osx.core.utils.JsonUtil; import io.grpc.ManagedChannel; import io.grpc.StatusRuntimeException; import org.apache.commons.lang3.StringUtils; import org.apache.zookeeper.KeeperException; +import org.fedai.osx.api.constants.Protocol; +import org.fedai.osx.api.router.RouterInfo; +import org.fedai.osx.broker.ServiceContainer; +import org.fedai.osx.broker.callback.MsgEventCallback; +import org.fedai.osx.broker.consumer.EventDriverRule; +import org.fedai.osx.broker.message.AllocateMappedFileService; +import org.fedai.osx.broker.store.MessageStore; +import org.fedai.osx.core.config.MasterInfo; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.constant.DeployMode; +import org.fedai.osx.core.constant.Dict; +import org.fedai.osx.core.constant.StatusCode; +import org.fedai.osx.core.constant.TransferStatus; +import org.fedai.osx.core.exceptions.CreateTopicErrorException; +import org.fedai.osx.core.exceptions.RemoteRpcException; +import org.fedai.osx.core.frame.GrpcConnectionFactory; +import org.fedai.osx.core.frame.ServiceThread; +import org.fedai.osx.core.ptp.TargetMethod; +import org.fedai.osx.core.utils.JsonUtil; +import org.fedai.osx.core.utils.NetUtils; import org.ppc.ptp.Osx; import org.ppc.ptp.PrivateTransferProtocolGrpc; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; -import java.nio.charset.StandardCharsets; import java.util.*; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; public class TransferQueueManager { @@ -66,13 +71,35 @@ public class TransferQueueManager { volatile Set instanceIds = new HashSet<>(); ConcurrentHashMap transferQueueMap = new ConcurrentHashMap<>(); ConcurrentHashMap> sessionQueueMap = new ConcurrentHashMap<>(); - ConcurrentHashMap transferIdLockMap = new ConcurrentHashMap(); + ConcurrentHashMap transferIdLockMap = new ConcurrentHashMap<>(); + ConcurrentHashMap> msgCallBackRuleMap = new ConcurrentHashMap<>(); + + public MessageStore getMessageStore() { + return messageStore; + } + + public void setMessageStore(MessageStore messageStore) { + this.messageStore = messageStore; + } + + MessageStore messageStore; + AllocateMappedFileService allocateMappedFileService; volatile long transferApplyInfoVersion = -1; + + public MessageStore createMessageStore( + AllocateMappedFileService allocateMappedFileService) { + MessageStore messageStore = new MessageStore(allocateMappedFileService + , MetaInfo.PROPERTY_TRANSFER_FILE_PATH_PRE + File.separator + MetaInfo.INSTANCE_ID + File.separator + "message-store"); + messageStore.start(); + return messageStore; + + } + private ServiceThread cleanTask = new ServiceThread() { @Override public void run() { while (true) { - this.waitForRunning(1000); + this.waitForRunning(MetaInfo.PROPERTY_TRANSFER_QUEUE_CHECK_INTERVAL); checkAndClean(); } } @@ -84,21 +111,25 @@ public String getServiceName() { }; + AllocateMappedFileService createAllocateMappedFileService() { + AllocateMappedFileService allocateMappedFileService = new AllocateMappedFileService(); + allocateMappedFileService.start(); + return allocateMappedFileService; + } + public TransferQueueManager() { + allocateMappedFileService = createAllocateMappedFileService(); + messageStore = createMessageStore(allocateMappedFileService); instanceIds.add(MetaInfo.INSTANCE_ID); if (MetaInfo.isCluster()) { boolean pathExists = ServiceContainer.zkClient.checkExists(ZK_QUEUE_PREFIX); if (!pathExists) { ServiceContainer.zkClient.create(ZK_QUEUE_PREFIX, false); } - List initApplyInfo = ServiceContainer.zkClient.addChildListener(ZK_QUEUE_PREFIX, (path, children) -> { - parseApplyInfo(children); - }); + List initApplyInfo = ServiceContainer.zkClient.addChildListener(ZK_QUEUE_PREFIX, (path, children) -> parseApplyInfo(children)); parseApplyInfo(initApplyInfo); ServiceContainer.zkClient.create(ZK_COMPONENTS_PREFIX + "/" + MetaInfo.INSTANCE_ID, true); - List initInstanceIds = ServiceContainer.zkClient.addChildListener(ZK_COMPONENTS_PREFIX, (path, children) -> { - handleClusterInstanceId(children); - }); + List initInstanceIds = ServiceContainer.zkClient.addChildListener(ZK_COMPONENTS_PREFIX, (path, children) -> handleClusterInstanceId(children)); ServiceContainer.zkClient.addDataListener(MASTER_PATH, (path, data, type) -> { logger.info("master event {} {}", type, data); if (data != null) { @@ -164,8 +195,6 @@ public String getServiceName() { /** * 平衡的策略暂时没有开发 * - * @param instanceId - * @return */ private String doClusterBalance(String transferId, String instanceId, @@ -180,12 +209,7 @@ private void doMasterWork() { transferQueueApplyInfoMap.forEach((k, v) -> { String instanceId = v.getInstanceId(); if (instanceIds.contains(instanceId)) { - Integer count = temp.get(instanceId); - if (count == null) { - temp.put(instanceId, 1); - } else { - temp.put(instanceId, count + 1); - } + temp.merge(instanceId, 1, Integer::sum); ; } }); @@ -195,52 +219,46 @@ private void doMasterWork() { if (transferQueueApplyInfoMap.get(k) == null) { masterQueueApplyInfoMap.remove(k); } - }; + } + ; }); } private MasterInfo parseMasterInfo(String masterContent) { - MasterInfo masterInfo = JsonUtil.json2Object(masterContent, MasterInfo.class); - return masterInfo; + return JsonUtil.json2Object(masterContent, MasterInfo.class); } private void handleClusterInstanceId(List children) { this.instanceIds.clear(); this.instanceIds.addAll(children); - if(logger.isInfoEnabled()) { + if (logger.isInfoEnabled()) { logger.info("instance change : {}", instanceIds); } } private synchronized void parseApplyInfo(List children) { - Set childSet = Sets.newHashSet(children); - Set intersecitonSet = Sets.intersection(transferQueueApplyInfoMap.keySet(), childSet); - Set needAddSet = null; - if (intersecitonSet != null) - needAddSet = Sets.difference(childSet, intersecitonSet); - Set needRemoveSet = Sets.difference(transferQueueApplyInfoMap.keySet(), intersecitonSet); - if(logger.isInfoEnabled()) { + Set childSet = Sets.newHashSet(children); + Set intersectionSet = Sets.intersection(transferQueueApplyInfoMap.keySet(), childSet); + Set needAddSet; + needAddSet = Sets.difference(childSet, intersectionSet); + Set needRemoveSet = Sets.difference(transferQueueApplyInfoMap.keySet(), intersectionSet); + if (logger.isInfoEnabled()) { logger.info("cluster apply info add {} remove {}", needAddSet, needRemoveSet); } - if (needRemoveSet != null) { - needRemoveSet.forEach(k -> { - transferQueueApplyInfoMap.remove(k); - }); - } - if (needAddSet != null) { - needAddSet.forEach(k -> { - try { - String content = ServiceContainer.zkClient.getContent(ZK_QUEUE_PREFIX + "/" + k); - TransferQueueApplyInfo transferQueueApplyInfo = JsonUtil.json2Object(content, TransferQueueApplyInfo.class); - if (transferQueueApplyInfo != null) { - transferQueueApplyInfoMap.put(k, transferQueueApplyInfo); - } - } catch (Exception e) { - logger.error("parse apply info from zk error",e); + needRemoveSet.forEach(k -> transferQueueApplyInfoMap.remove(k)); + needAddSet.forEach(k -> { + try { + String content = ServiceContainer.zkClient.getContent(buildZkPath(k)); + TransferQueueApplyInfo transferQueueApplyInfo = JsonUtil.json2Object(content, TransferQueueApplyInfo.class); + if (transferQueueApplyInfo != null) { + transferQueueApplyInfoMap.put(k, transferQueueApplyInfo); } - }); - } + } catch (Exception e) { + logger.error("parse apply info from zk error", e); + } + }); } + ; public List cleanByParam(String sessionId, String paramTransferId) { @@ -287,26 +305,25 @@ private void destroyInner(TransferQueue transferQueue) { private void checkAndClean() { long now = System.currentTimeMillis(); + logger.info("the total topic size is {}, total session size is {}", transferQueueMap.size(), sessionQueueMap.size()); transferQueueMap.forEach((transferId, transferQueue) -> { try { long lastReadTimestamp = transferQueue.getLastReadTimestamp(); long lastWriteTimestamp = transferQueue.getLastWriteTimestamp(); - long freeTime = now - (lastReadTimestamp > lastWriteTimestamp ? lastReadTimestamp : lastWriteTimestamp); + long freeTime = now - Math.max(lastReadTimestamp, lastWriteTimestamp); if (transferQueue.getTransferStatus() == TransferStatus.ERROR || transferQueue.getTransferStatus() == TransferStatus.FINISH) { destroy(transferId); } - if (freeTime > MetaInfo.PRPPERTY_QUEUE_MAX_FREE_TIME) { - if(logger.isInfoEnabled()) { - logger.info("transfer queue : {} freetime {} need to be destroy", transferId, freeTime); + if (freeTime > MetaInfo.PROPERTY_QUEUE_MAX_FREE_TIME) { + if (logger.isInfoEnabled()) { + logger.info("topic : {} freetime {} need to be destroy", transferId, freeTime); } destroy(transferId); - return; } } catch (Exception igrone) { - + logger.error("transferQueue clean error ", igrone); } }); - } @@ -321,6 +338,7 @@ public List getTransferQueues(List transferIds) { } return result; } + ConcurrentHashMap clusterApplyLockMap = new ConcurrentHashMap(); public synchronized TransferQueueApplyInfo handleClusterApply(String transferId, @@ -333,7 +351,7 @@ public synchronized TransferQueueApplyInfo handleClusterApply(String transferId, } else { long current = System.currentTimeMillis(); TransferQueueApplyInfo newTransferQueueApplyInfo = new TransferQueueApplyInfo(); - String intanceId = doClusterBalance(transferId, instanceId, sessionId); + doClusterBalance(transferId, instanceId, sessionId); newTransferQueueApplyInfo.setTransferId(transferId); newTransferQueueApplyInfo.setInstanceId(instanceId); newTransferQueueApplyInfo.setSessionId(sessionId); @@ -344,18 +362,25 @@ public synchronized TransferQueueApplyInfo handleClusterApply(String transferId, } - public CreateQueueResult createNewQueue(String transferId, String sessionId, boolean forceCreateLocal) { - Preconditions.checkArgument(StringUtils.isNotEmpty(transferId)); - CreateQueueResult createQueueResult = new CreateQueueResult(); + public ReentrantLock getLock(String transferId){ ReentrantLock transferCreateLock = transferIdLockMap.get(transferId); if (transferCreateLock == null) { transferIdLockMap.putIfAbsent(transferId, new ReentrantLock(false)); } transferCreateLock = transferIdLockMap.get(transferId); - transferCreateLock.lock(); - try { + return transferCreateLock; + } + + - boolean exist = this.transferQueueMap.get(transferId) != null ? true : false; + + public CreateQueueResult createNewQueue(String transferId, String sessionId, boolean forceCreateLocal) { + Preconditions.checkArgument(StringUtils.isNotEmpty(transferId)); + CreateQueueResult createQueueResult = new CreateQueueResult(); + ReentrantLock transferCreateLock= getLock(transferId); + try { + transferCreateLock.lock(); + boolean exist = this.transferQueueMap.get(transferId) != null; if (exist) { createQueueResult.setTransferQueue(this.transferQueueMap.get(transferId)); String[] elements = MetaInfo.INSTANCE_ID.split(":"); @@ -364,7 +389,7 @@ public CreateQueueResult createNewQueue(String transferId, String sessionId, boo return createQueueResult; } if (MetaInfo.PROPERTY_DEPLOY_MODE.equals(DeployMode.cluster.name()) && !forceCreateLocal) { - /** + /* * 缓存的集群信息中能够找到,直接返回信息 */ if (this.transferQueueApplyInfoMap.get(transferId) != null) { @@ -378,21 +403,20 @@ public CreateQueueResult createNewQueue(String transferId, String sessionId, boo createQueueResult.setRedirectIp(ip); return createQueueResult; } else { - /** + /* * 这种情况存在于本地已删除,而集群信息未同步更新,可能存在延迟,这时重走申请流程 */ } - }; - - Osx.Outbound applyTopicResponse = this.applyFromMaster(transferId,sessionId,MetaInfo.INSTANCE_ID); + } + Osx.Outbound applyTopicResponse = this.applyFromMaster(transferId, sessionId, MetaInfo.INSTANCE_ID); logger.info("apply topic response {}", applyTopicResponse); if (applyTopicResponse != null) { - /** + /* * 从clustermananger 返回的结果中比对instantceId ,如果为本实例,则在本地建Q */ - String applyInstanceId = applyTopicResponse.getMetadataMap().get(Osx.Metadata.InstanceId.name()); + String applyInstanceId = applyTopicResponse.getMetadataMap().get(Osx.Metadata.InstanceId.name()); if (MetaInfo.INSTANCE_ID.equals(applyInstanceId)) { @@ -403,43 +427,42 @@ public CreateQueueResult createNewQueue(String transferId, String sessionId, boo registerTransferQueue(transferId, sessionId); //createQueueResult = applyFromCluster(transferId,sessionId); } else { - if(applyInstanceId!=null) { + if (applyInstanceId != null) { String[] args = applyInstanceId.split(":"); String ip = args[0]; String portString = args[1]; int grpcPort = Integer.parseInt(portString); createQueueResult.setRedirectIp(ip); createQueueResult.setPort(grpcPort); - }else{ + } else { throw new CreateTopicErrorException("apply topic from master error"); } - }; + } } else { throw new RuntimeException(); } } else { - /** + /* * 单机版部署,直接本地建Q */ createQueueResult.setTransferQueue(localCreate(transferId, sessionId)); - String[] args = MetaInfo.INSTANCE_ID.split(":"); - String ip = args[0]; - String portString = args[1]; - createQueueResult.setPort(Integer.parseInt(portString)); - createQueueResult.setRedirectIp(ip); +// String[] args = MetaInfo.INSTANCE_ID.split("_"); +// String ip = args[0]; +// String portString = args[1]; + + createQueueResult.setPort(MetaInfo.PROPERTY_GRPC_PORT); + createQueueResult.setRedirectIp(NetUtils.getLocalHost()); } return createQueueResult; } finally { transferCreateLock.unlock(); + } } private void registerTransferQueue(String transferId, String sessionId) { - StringBuffer sb = new StringBuffer(); - sb.append(ZK_QUEUE_PREFIX).append("/"); - sb.append(transferId); - String path = sb.toString(); + String path = buildZkPath(transferId); TransferQueueApplyInfo transferQueueApplyInfo = new TransferQueueApplyInfo(); transferQueueApplyInfo.setTransferId(transferId); transferQueueApplyInfo.setSessionId(sessionId); @@ -448,25 +471,26 @@ private void registerTransferQueue(String transferId, String sessionId) { try { ServiceContainer.zkClient.create(path, JsonUtil.object2Json(transferQueueApplyInfo), true); } catch (KeeperException.NodeExistsException e) { - e.printStackTrace(); + logger.error("register path {} to zk error", path); } } + public String buildZkPath(String transferId) { + return ZK_QUEUE_PREFIX + "/" + transferId; + } + private CreateQueueResult applyFromCluster(String transferId, String sessionId) { CreateQueueResult createQueueResult = null; if (MetaInfo.PROPERTY_USE_ZOOKEEPER) { createQueueResult = new CreateQueueResult(); - StringBuffer sb = new StringBuffer(); - sb.append(ZK_QUEUE_PREFIX).append("/"); - sb.append(transferId); - String path = sb.toString(); + String path = buildZkPath(transferId); boolean exist = ServiceContainer.zkClient.checkExists(path); if (exist) { String content = ServiceContainer.zkClient.getContent(path); TransferQueueApplyInfo transferQueueApplyInfo = JsonUtil.json2Object(content, TransferQueueApplyInfo.class); } else { - /** + /* * 如何平均 */ TransferQueueApplyInfo transferQueueApplyInfo = new TransferQueueApplyInfo(); @@ -477,10 +501,11 @@ private CreateQueueResult applyFromCluster(String transferId, String sessionId) try { ServiceContainer.zkClient.create(path, JsonUtil.object2Json(transferQueueApplyInfo), true); } catch (KeeperException.NodeExistsException e) { - e.printStackTrace(); + logger.error("register path {} in zk error", path); } String content = ServiceContainer.zkClient.getContent(path); transferQueueApplyInfo = JsonUtil.json2Object(content, TransferQueueApplyInfo.class); + assert transferQueueApplyInfo != null; if (MetaInfo.INSTANCE_ID.equals(transferQueueApplyInfo.getInstanceId())) { createQueueResult.setTransferQueue(localCreate(transferId, sessionId)); } else { @@ -491,33 +516,35 @@ private CreateQueueResult applyFromCluster(String transferId, String sessionId) } } return createQueueResult; - } - public Osx.Outbound applyFromMaster( String topic,String sessionId,String instanceId) { - if (!isMaster()) { + public Osx.Outbound applyFromMaster(String topic, String sessionId, String instanceId) { - RouterInfo routerInfo = this.getMasterAddress(); + if (!isMaster()) { + RouterInfo routerInfo = this.getMasterAddress(); //context.setRouterInfo(routerInfo); - ManagedChannel managedChannel = GrpcConnectionFactory.createManagedChannel(routerInfo,true); + ManagedChannel managedChannel = GrpcConnectionFactory.createManagedChannel(routerInfo, true); PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub stub = PrivateTransferProtocolGrpc.newBlockingStub(managedChannel); try { - Osx.Inbound.Builder builder = Osx.Inbound.newBuilder(); - builder.putMetadata(Osx.Metadata.MessageTopic.name(),topic); - builder.putMetadata(Osx.Metadata.InstanceId.name(),instanceId); - builder.putMetadata(Osx.Header.SessionID.name(),sessionId); + Osx.Inbound.Builder builder = Osx.Inbound.newBuilder(); + builder.putMetadata(Osx.Metadata.MessageTopic.name(), topic); + builder.putMetadata(Osx.Metadata.InstanceId.name(), instanceId); + builder.putMetadata(Osx.Header.SessionID.name(), sessionId); + builder.putMetadata(Osx.Header.TechProviderCode.name(), MetaInfo.PROPERTY_FATE_TECH_PROVIDER); + builder.putMetadata(Osx.Metadata.TargetMethod.name(), TargetMethod.APPLY_TOPIC.name()); return stub.invoke(builder.build()); - }catch(StatusRuntimeException e){ - throw new RemoteRpcException("send to "+routerInfo.toKey()+" error"); + } catch (StatusRuntimeException e) { + logger.error("apply topic {} from master error", topic, e); + throw new RemoteRpcException("send to " + routerInfo.toKey() + " error"); } } else { TransferQueueApplyInfo transferQueueApplyInfo = this.handleClusterApply(topic, instanceId, sessionId); Osx.Outbound.Builder outboundBuilder = Osx.Outbound.newBuilder(); - outboundBuilder.getMetadataMap().put(Osx.Metadata.MessageTopic.name(), topic); - outboundBuilder.getMetadataMap().put(Osx.Metadata.InstanceId.name(), instanceId); - outboundBuilder.getMetadataMap().put(Osx.Metadata.Timestamp.name(), Long.toString(transferQueueApplyInfo.getApplyTimestamp())); + outboundBuilder.putMetadata(Osx.Metadata.MessageTopic.name(), topic); + outboundBuilder.putMetadata(Osx.Metadata.InstanceId.name(), instanceId); + outboundBuilder.putMetadata(Osx.Metadata.Timestamp.name(), Long.toString(transferQueueApplyInfo.getApplyTimestamp())); outboundBuilder.setCode(StatusCode.SUCCESS); outboundBuilder.setMessage(Dict.SUCCESS); return outboundBuilder.build(); @@ -530,37 +557,55 @@ private RouterInfo getMasterAddress() { String[] args = MetaInfo.masterInfo.getInstanceId().split(Dict.COLON); routerInfo.setHost(args[0]); routerInfo.setPort(Integer.parseInt(args[1])); + routerInfo.setProtocol(Protocol.grpc); return routerInfo; } private void unRegisterCluster(String transferId) { - logger.info("unRegister transferId {}", transferId); - if (MetaInfo.isCluster()) { - ServiceContainer.zkClient.delete(ZK_QUEUE_PREFIX + "/" + transferId); + + if (MetaInfo.isCluster() && MetaInfo.isCluster()) { + logger.info("unRegister topic {} from zk", transferId); + ServiceContainer.zkClient.delete(buildZkPath(transferId)); } } + private void setMsgCallBack(TransferQueue transferQueue) { + this.msgCallBackRuleMap.forEach((rule, msgCallbacks) -> { - private TransferQueue localCreate(String transferId, String sessionId) { - logger.info("create local topic {}",transferId); - TransferQueue transferQueue = new TransferQueue(transferId, this, MetaInfo.PROPERTY_TRANSFER_FILE_PATH_PRE + File.separator + MetaInfo.INSTANCE_ID); + if (rule.isMatch(transferQueue)) { + // logger.info("rule {} is mactched",rule); + transferQueue.registerMsgCallback(msgCallbacks); + } else { + // logger.info("rule {} is not matched",rule); + } + }); + } + + ; + + + private TransferQueue localCreate(String topic, String sessionId) { + logger.info("create local topic {}", topic); + TransferQueue transferQueue = new TransferQueue(topic, this, MetaInfo.PROPERTY_TRANSFER_FILE_PATH_PRE + File.separator + MetaInfo.INSTANCE_ID); transferQueue.setSessionId(sessionId); transferQueue.start(); - transferQueue.registeDestoryCallback(() -> { - this.transferQueueMap.remove(transferId); + transferQueue.registerDestoryCallback(() -> { + this.transferQueueMap.remove(topic); if (this.sessionQueueMap.get(sessionId) != null) { - this.sessionQueueMap.get(sessionId).remove(transferId); + this.sessionQueueMap.get(sessionId).remove(topic); } + unRegisterCluster(topic); }); - transferQueueMap.put(transferId, transferQueue); + setMsgCallBack(transferQueue); + transferQueueMap.put(topic, transferQueue); sessionQueueMap.putIfAbsent(sessionId, new HashSet<>()); - sessionQueueMap.get(sessionId).add(transferId); + sessionQueueMap.get(sessionId).add(topic); return transferQueue; } - public TransferQueue getQueue(String transferId) { - return transferQueueMap.get(transferId); + public TransferQueue getQueue(String topic) { + return transferQueueMap.get(topic); } public Map getAllLocalQueue() { @@ -568,17 +613,18 @@ public Map getAllLocalQueue() { } - private void destroy(String transferId) { - Preconditions.checkArgument(StringUtils.isNotEmpty(transferId)); - ReentrantLock transferIdLock = this.transferIdLockMap.get(transferId); + private void destroy(String topic) { + logger.info("start clear topic queue , topic = {}",topic); + Preconditions.checkArgument(StringUtils.isNotEmpty(topic)); + ReentrantLock transferIdLock = this.transferIdLockMap.get(topic); if (transferIdLock != null) { transferIdLock.lock(); } try { - TransferQueue transferQueue = getQueue(transferId); + TransferQueue transferQueue = getQueue(topic); if (transferQueue != null) { destroyInner(transferQueue); - transferIdLockMap.remove(transferId); + transferIdLockMap.remove(topic); } } finally { @@ -592,12 +638,10 @@ private void destroy(String transferId) { public void onError(String transferId, Throwable throwable) { TransferQueue transferQueue = transferQueueMap.get(transferId); if (transferQueue != null) { - /** + /* * 这里需要处理的问题是,当异常发生时,消费者并没有接入,等触发之后才接入 */ - errorCallBackExecutor.execute(() -> { - transferQueue.onError(throwable); - }); + errorCallBackExecutor.execute(() -> transferQueue.onError(throwable)); } this.destroy(transferId); } @@ -617,21 +661,23 @@ public TransferQueueApplyInfo queryGlobleQueue(String transferId) { } public void destroyAll() { - logger.info("prepare to destory {}", transferQueueMap); if (MetaInfo.isCluster()) { try { if (this.isMaster()) { ServiceContainer.zkClient.delete(MASTER_PATH); } ServiceContainer.zkClient.close(); - ; } catch (Exception e) { e.printStackTrace(); } - logger.info("unregister component over"); } this.transferQueueMap.forEach((transferId, transferQueue) -> { transferQueue.destory(); }); } + + + public void addMsgCallBackRule(EventDriverRule rule, List callbacks) { + this.msgCallBackRuleMap.put(rule, callbacks); + } } diff --git a/java/osx/broker/src/main/java/com/osx/broker/queue/TransferQueueMonitorService.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/TransferQueueMonitorService.java similarity index 92% rename from java/osx/broker/src/main/java/com/osx/broker/queue/TransferQueueMonitorService.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/TransferQueueMonitorService.java index bf1cbfbdca..7c707377bf 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/queue/TransferQueueMonitorService.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/TransferQueueMonitorService.java @@ -13,8 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.queue; -import com.osx.core.frame.ServiceThread; +package org.fedai.osx.broker.queue; +import org.fedai.osx.core.frame.ServiceThread; + public class TransferQueueMonitorService extends ServiceThread { TransferQueueManager transferQueueManager; public TransferQueueMonitorService(TransferQueueManager transferQueueManager) { diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/router/DefaultFateRouterServiceImpl.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/router/DefaultFateRouterServiceImpl.java new file mode 100644 index 0000000000..8523d5b12a --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/router/DefaultFateRouterServiceImpl.java @@ -0,0 +1,448 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.fedai.osx.broker.router; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.google.common.base.Preconditions; +import com.google.common.collect.Maps; +import com.google.protobuf.InvalidProtocolBufferException; +import com.webank.ai.eggroll.api.networking.proxy.Proxy; +import com.webank.eggroll.core.transfer.Transfer; +import org.apache.commons.lang3.StringUtils; +import org.fedai.osx.api.constants.Protocol; +import org.fedai.osx.api.context.Context; +import org.fedai.osx.api.router.RouterInfo; +import org.fedai.osx.broker.util.TelnetUtil; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.constant.Dict; +import org.fedai.osx.core.datasource.FileRefreshableDataSource; +import org.fedai.osx.core.exceptions.CycleRouteInfoException; +import org.fedai.osx.core.exceptions.ErrorMessageUtil; +import org.fedai.osx.core.exceptions.ExceptionInfo; +import org.fedai.osx.core.exceptions.InvalidRouteInfoException; +import org.fedai.osx.core.flow.PropertyListener; +import org.fedai.osx.core.frame.Lifecycle; +import org.fedai.osx.core.frame.ServiceThread; +import org.fedai.osx.core.service.InboundPackage; +import org.fedai.osx.core.utils.FileUtils; +import org.fedai.osx.core.utils.JsonUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileNotFoundException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public class DefaultFateRouterServiceImpl implements FateRouterService, Lifecycle { + + private static final String IP = "ip"; + private static final String PORT = "port"; + private static final String URL = "url"; + private static final String USE_SSL = "useSSL"; + private static final String HOSTNAME = "hostname"; + private static final String negotiationType = "negotiationType"; + private static final String certChainFile = "certChainFile"; + private static final String privateKeyFile = "privateKeyFile"; + private static final String caFile = "caFile"; + private static final String DEFAULT = "default"; + private static final String VERSION = "version"; + + //Pattern urlIpPort = Pattern.compile("(\\d+\\.\\d+\\.\\d+\\.\\d+)\\:(\\d+)"); + + Pattern urlIpPortPattern = Pattern.compile("((http|ftp|https)://)((([a-zA-Z0-9._-]+)|([0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}))(([a-zA-Z]{2,6})|(:[0-9]{1,4})?))"); + + Logger logger = LoggerFactory.getLogger(DefaultFateRouterServiceImpl.class); + Map> routerInfoMap = new ConcurrentHashMap>(); + Map>> endPointMap = new ConcurrentHashMap<>(); + FileRefreshableDataSource fileRefreshableDataSource; + + @Override + public RouterInfo route(Proxy.Packet packet) { + Preconditions.checkArgument(packet != null); + RouterInfo routerInfo = null; + Proxy.Metadata metadata = packet.getHeader(); + Transfer.RollSiteHeader rollSiteHeader = null; + String dstPartyId = null; + try { + rollSiteHeader = Transfer.RollSiteHeader.parseFrom(metadata.getExt()); + if (rollSiteHeader != null) { + dstPartyId = rollSiteHeader.getDstPartyId(); + } + } catch (InvalidProtocolBufferException e) { + e.printStackTrace(); + } + if (StringUtils.isEmpty(dstPartyId)) { + dstPartyId = metadata.getDst().getPartyId(); + } + String desRole = metadata.getDst().getRole(); + String srcRole = metadata.getSrc().getRole(); + String srcPartyId = metadata.getSrc().getPartyId(); + routerInfo = this.route(srcPartyId, srcRole, dstPartyId, desRole); + //logger.info("query router info {} to {} {} return {}", srcPartyId, dstPartyId, desRole, routerInfo); + return routerInfo; + } + + private RouterInfo buildRouterInfo(Map endpoint, String srcPartyId, String srcRole, String dstPartyId, String desRole) { + + Preconditions.checkArgument(endpoint != null); + RouterInfo routerInfo = new RouterInfo(); + if (endpoint.get(IP) != null) { + routerInfo.setHost(endpoint.get(IP).toString()); + } + if (endpoint.get(PORT) != null) { + routerInfo.setPort(((Number) endpoint.get(PORT)).intValue()); + } + routerInfo.setDesPartyId(dstPartyId); + routerInfo.setSourcePartyId(srcPartyId); + routerInfo.setVersion(endpoint.get(VERSION) != null ? endpoint.get(VERSION).toString() : null); + routerInfo.setNegotiationType(endpoint.get(negotiationType) != null ? endpoint.get(negotiationType).toString() : ""); + routerInfo.setDesRole(desRole); + Protocol protocol = Protocol.grpc; + if (endpoint.get(Dict.PROTOCOL) != null) { + try { + protocol = Protocol.valueOf(endpoint.get(Dict.PROTOCOL).toString()); + } catch (Exception ignore) { + + } + } + routerInfo.setProtocol(protocol); + routerInfo.setUrl(endpoint.get(Dict.URL) != null ? endpoint.get(Dict.URL).toString() : ""); + routerInfo.setUseSSL(endpoint.get(Dict.USE_SSL) != null && Boolean.parseBoolean(endpoint.get(Dict.USE_SSL).toString())); + routerInfo.setCaFile(endpoint.get(Dict.CA_FILE) != null ? endpoint.get(Dict.CA_FILE).toString() : ""); + routerInfo.setCertChainFile(endpoint.get(Dict.CERT_CHAIN_FILE) != null ? endpoint.get(Dict.CERT_CHAIN_FILE).toString() : ""); + routerInfo.setPrivateKeyFile(endpoint.get(Dict.PRIVATE_KEY_FILE) != null ? endpoint.get(Dict.PRIVATE_KEY_FILE).toString() : ""); + if (routerInfo.getProtocol().equals(Protocol.http)) { + if (StringUtils.isEmpty(routerInfo.getUrl())) { + throw new InvalidRouteInfoException(); + } + } + if (endpoint.get(Dict.IS_CYCLE) != null && (Boolean) endpoint.get(Dict.IS_CYCLE)) { + logger.error("router info {} has a cycle invoke", routerInfo.toKey()); + throw new CycleRouteInfoException("router info has a cycle invoke"); + } + return routerInfo; + } + + public RouterInfo route(String srcPartyId, String srcRole, String dstPartyId, String desRole) { + // logger.info("try to find routerInfo =={}=={}=={}=={}",srcPartyId,srcRole,dstPartyId,desRole); + RouterInfo routerInfo = null; + Map> partyIdMap = this.endPointMap.containsKey(dstPartyId)?this.endPointMap.get(dstPartyId):this.endPointMap.get(DEFAULT); + if (partyIdMap != null) { + if (StringUtils.isNotEmpty(desRole) && partyIdMap.get(desRole) != null) { + List ips = partyIdMap.getOrDefault(desRole, null); + if (ips != null && ips.size() > 0) { + Map endpoint = ips.get((int) (System.currentTimeMillis() % ips.size())); + routerInfo = buildRouterInfo(endpoint, srcPartyId, srcRole, dstPartyId, desRole); + } + } else { + + List ips = partyIdMap.getOrDefault(DEFAULT, null); + if (ips != null && ips.size() > 0) { + Map endpoint = ips.get((int) (System.currentTimeMillis() % ips.size())); + routerInfo = buildRouterInfo(endpoint, srcPartyId, srcRole, dstPartyId, desRole); + } + if (StringUtils.isNotEmpty(desRole)) { + // logger.warn("role {} is not found,return default router info ",desRole); + } + } + } + + return routerInfo; + } + + + Map>> initRouteTable(Map confJson) { + // BasicMeta.Endpoint.Builder endpointBuilder = BasicMeta.Endpoint.newBuilder(); + Map>> newRouteTable = new ConcurrentHashMap<>(); + // loop through coordinator + + confJson.forEach((k, v) -> { + String coordinatorKey = k.toString(); + Map coordinatorValue = (Map) v; + + Map> serviceTable = newRouteTable.get(coordinatorKey); + if (serviceTable == null) { + serviceTable = new ConcurrentHashMap<>(4); + newRouteTable.put(coordinatorKey, serviceTable); + } + // loop through role in coordinator + for (Object roleEntryObject : coordinatorValue.entrySet()) { + Map.Entry roleEntry = (Map.Entry) roleEntryObject; + String roleKey = roleEntry.getKey().toString(); + if (roleKey.equals("createTime") || roleKey.equals("updateTime")) { + continue; + } + List roleValue = (List) roleEntry.getValue(); + + List endpoints = serviceTable.get(roleKey); + if (endpoints == null) { + endpoints = new ArrayList<>(); + serviceTable.put(roleKey, endpoints); + } + // loop through endpoints + for (Object endpointElement : roleValue) { + Map element = Maps.newHashMap(); + Map endpointJson = (Map) endpointElement; + element.putAll(endpointJson); + endpoints.add(element); + } + } + + }); + + return newRouteTable; + } + + @Override + public void init() { + + } + + public void start() { + String currentPath = getRouterTablePath(); + logger.info("load router file {}", currentPath); + File confFile = new File(currentPath); + FileRefreshableDataSource fileRefreshableDataSource = null; + try { + fileRefreshableDataSource = new FileRefreshableDataSource(confFile, (source) -> { + // logger.info("read route_table {}", source); + return source; + }); + fileRefreshableDataSource.getProperty().addListener(new RouterTableListener()); + + } catch (FileNotFoundException e) { + logger.error("router file {} is not found", currentPath); + } + /** + * 检查路由表中是否存在回环,是否能连通 + */ + ServiceThread routerInfoChecker = new ServiceThread() { + + @Override + public void run() { + while (true) { + //Map> partyIdMap = this.endPointMap.get(dstPartyId); + endPointMap.forEach((desPartyId, desPoint) -> { + desPoint.forEach((role, routerElementMap) -> { + routerElementMap.forEach(endPoint -> { + + String ip = null; + int port = 0; + Protocol protocol = Protocol.grpc; + try { + if (endPoint.get(Dict.PROTOCOL) != null) { + try { + protocol = Protocol.valueOf(endPoint.get(Dict.PROTOCOL).toString()); + } catch (Exception e) { + logger.warn("route info {}->{} protocol is invalid , please check route_table.json", desPartyId, role); + } + } + ; + if (endPoint.get(Dict.URL) != null) { + String ipPortString = getIpInfoFromUrl(endPoint.get(Dict.URL).toString()); + if (StringUtils.isNotEmpty(ipPortString)) { + ip = ipPortString.split(Dict.COLON)[0]; + String portString = ipPortString.split(Dict.COLON)[1]; + port = Integer.parseInt(portString); + } + } + if (protocol.equals(Protocol.grpc)) { + if (endPoint.get(IP) != null) { + ip = endPoint.get(IP).toString(); + } + if (endPoint.get(PORT) != null) { + port = ((Number) endPoint.get(PORT)).intValue(); + } + } + //if (!MetaInfo.PROPERTY_SELF_PARTY.contains(desPartyId)) { + + boolean isCycle = checkCycle(ip, port); + if (isCycle) { + logger.warn("route info {}->{}->{}->{} is a cycle , please check route_table.json", desPartyId, role, ip, port); + } + endPoint.put(Dict.IS_CYCLE, isCycle); + //} + checkConnected(desPartyId, role, ip, port); + + } catch (Exception ignore) { + ignore.printStackTrace(); + } + } + ); + }); + } + ); + + this.waitForRunning(60000); + } + } + + @Override + public String getServiceName() { + return "cycle_checker"; + } + }; + routerInfoChecker.start(); + } + + private String getRouterTablePath() { + return MetaInfo.PROPERTY_CONFIG_DIR + "/broker/route_table.json"; + } + + @Override + public void destroy() { + + } + + private void checkConnected(String partyId, String role, String ip, int port) { + + if (MetaInfo.PROPERTY_USE_REMOTE_HEALTH_CHECK) { + if (StringUtils.isNotEmpty(ip)) { + + boolean result = TelnetUtil.tryTelnet(ip, port); + if (!result) { + // logger.warn("route info {}->{}->{}->{} unable to connect , please check route_table.json", partyId, role, ip, port); + + } + } + } + } + + private boolean checkCycle(String ip, int port) { + + boolean cycle = false; + + if(MetaInfo.PROPERTY_OPEN_ROUTE_CYCLE_CHECKER) { + String localIp = MetaInfo.INSTANCE_ID.split(":")[0]; + + if (localIp.equals(ip) || Dict.LOCALHOST.equals(ip) || Dict.LOCALHOST2.equals(ip)) { + if (MetaInfo.PROPERTY_GRPC_PORT == (port)) { + cycle = true; + } + if (MetaInfo.PROPERTY_OPEN_GRPC_TLS_SERVER) { + if (MetaInfo.PROPERTY_GRPC_TLS_PORT == port) { + cycle = true; + } + } + if (MetaInfo.PROPERTY_OPEN_HTTP_SERVER) { + if (MetaInfo.PROPERTY_HTTP_PORT == (port)) { + cycle = true; + } + } + } + } + + return cycle; + } + + + private class RouterTableListener implements PropertyListener { + + @Override + public void configUpdate(String value) { + logger.info("found router_table.json has been changed, update content {}",value); + Map confJson = JsonUtil.json2Object(value, Map.class); + // JsonObject confJson = JsonParser.parseString(value).getAsJsonObject(); + Map content = (Map) confJson.get("route_table"); + endPointMap = initRouteTable(content); + } + + @Override + public void configLoad(String value) { + Map confJson = JsonUtil.json2Object(value, Map.class); + if(confJson!=null){ + + // throw new ConfigErrorException("content of route_table.json is invalid"); + + Map content = (Map) confJson.get("route_table"); + endPointMap = initRouteTable(content); + logger.info("load router config {}", JsonUtil.formatJson(JsonUtil.object2Json(endPointMap))); + + }else{ + logger.error("content of route_table.json is invalid , content is {}",value); + + } + } + } + + + public String getIpInfoFromUrl(String url) { + Matcher m = urlIpPortPattern.matcher(url); + String result = ""; + if (m.find()) { + result = m.group(3); + } + return result; + } + + public boolean saveRouterTable(Context context, InboundPackage data) { + try { + String inboundRouteJson = (String) context.getData("route"); + if (StringUtils.isNotBlank(inboundRouteJson)) { + Map routeMap = JsonUtil.object2Objcet(inboundRouteJson, new TypeReference>() { + }); + Map route_table = (Map) routeMap.get("route_table"); + route_table.forEach((partyId, value) -> { + List routeList = (List) value; + for (RouterInfo routerInfo : routeList) { + routerInfo.setProtocol(StringUtils.isBlank(routerInfo.getProtocol().toString()) ? Protocol.grpc : routerInfo.getProtocol()); + } + }); + inboundRouteJson = JsonUtil.object2Json(routeMap); + } + String routerTablePath = getRouterTablePath(); + File routerTableFile = new File(routerTablePath); + if (!routerTableFile.exists()) { + if (!routerTableFile.getParentFile().exists()) { + if (!routerTableFile.getParentFile().mkdirs()) { + logger.warn("mkdir failed : {}", routerTableFile.getParent()); + return false; + } + } + if (!routerTableFile.createNewFile()) { + logger.warn("create router_table.json failed : {}", routerTableFile.getAbsoluteFile()); + return false; + } + } + return FileUtils.writeStr2ReplaceFileSync(JsonUtil.formatJson(inboundRouteJson), routerTablePath); + } catch (Exception e) { + logger.error("save router table failed ", e); + ExceptionInfo exceptionInfo = ErrorMessageUtil.handleExceptionExceptionInfo(context, e); + context.setReturnCode(exceptionInfo.getCode()); + context.setReturnMsg("save router table failed"); + return false; + } + } + + public static void main(String[] args) { +// System.out.println(MetaInfo.PROPERTY_USER_DIR); +// System.out.println(MetaInfo.PROPERTY_USER_HOME); +// System.out.println(Thread.currentThread().getContextClassLoader().getResource("").getPath()); +// System.out.println(Thread.currentThread().getContextClassLoader().getResource("route_table.json")); +// System.out.println(Thread.currentThread().getContextClassLoader().getResource("flowRule.json")); + DefaultFateRouterServiceImpl defaultFateRouterService = new DefaultFateRouterServiceImpl(); + defaultFateRouterService.getIpInfoFromUrl("http://127.0.0.1:9000/xxxx"); + + + } + + +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/router/FateRouterService.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/router/FateRouterService.java similarity index 78% rename from java/osx/broker/src/main/java/com/osx/broker/router/FateRouterService.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/router/FateRouterService.java index 630c954f44..b192dbeb39 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/router/FateRouterService.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/router/FateRouterService.java @@ -13,16 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.router; +package org.fedai.osx.broker.router; -import com.osx.core.router.RouterInfo; import com.webank.ai.eggroll.api.networking.proxy.Proxy; -public interface FateRouterService { - - RouterInfo route(String srcPartyId, String srcRole, String dstPartyId, String desRole); +import org.fedai.osx.api.router.RouterInfo; +public interface FateRouterService extends RouterService{ RouterInfo route(Proxy.Packet packet); - } diff --git a/java/osx/broker/src/main/java/com/osx/broker/router/RemoteRouterDataSource.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/router/RemoteRouterDataSource.java similarity index 86% rename from java/osx/broker/src/main/java/com/osx/broker/router/RemoteRouterDataSource.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/router/RemoteRouterDataSource.java index a5a3e348ab..d5248e7c9d 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/router/RemoteRouterDataSource.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/router/RemoteRouterDataSource.java @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.router; -import com.osx.core.datasource.AutoRefreshDataSource; -import com.osx.core.datasource.Converter; +package org.fedai.osx.broker.router; +import org.fedai.osx.core.datasource.AutoRefreshDataSource; +import org.fedai.osx.core.datasource.Converter; public class RemoteRouterDataSource extends AutoRefreshDataSource { public RemoteRouterDataSource(Converter configParser) { diff --git a/java/osx/broker/src/main/java/com/osx/broker/router/RouterMetric.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/router/RouterMetric.java similarity index 97% rename from java/osx/broker/src/main/java/com/osx/broker/router/RouterMetric.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/router/RouterMetric.java index 80e559245d..d7985808ea 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/router/RouterMetric.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/router/RouterMetric.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.router; -import com.osx.core.utils.JsonUtil; +package org.fedai.osx.broker.router; +import org.fedai.osx.core.utils.JsonUtil; import java.util.concurrent.atomic.AtomicLong; diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/router/RouterRegister.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/router/RouterRegister.java new file mode 100644 index 0000000000..ec041a4a11 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/router/RouterRegister.java @@ -0,0 +1,70 @@ +package org.fedai.osx.broker.router; + +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.constant.Dict; +import org.fedai.osx.core.frame.Lifecycle; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.BufferedInputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.InputStream; +import java.util.Properties; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +public class RouterRegister implements Lifecycle { + + Logger logger = LoggerFactory.getLogger(RouterRegister.class); + + private final String ROUTER_CONFIG_FILE = "components/router.properties"; + + private ConcurrentMap routerServiceMap = new ConcurrentHashMap<>(); + + public RouterService getRouterService(String key){ + return routerServiceMap.get(key); + } + + @Override + public void init() { + String configDir= MetaInfo.PROPERTY_CONFIG_DIR; + String fileName = configDir+ Dict.SLASH+ROUTER_CONFIG_FILE; + File file = new File(fileName); + Properties config = new Properties(); + try (InputStream inputStream = new BufferedInputStream(new FileInputStream(file))) { + config.load(inputStream); + }catch (Exception e){ + logger.error("can not found {}",fileName); + } + config.forEach((k,v)->{ + if(v!=null){ + try { + Class genClass = Class.forName(v.toString()); + Object rawObject = genClass.getConstructor().newInstance(); + routerServiceMap.put(k.toString(),(RouterService)rawObject); + if(rawObject instanceof Lifecycle){ + ( (Lifecycle)rawObject).init(); + } + } catch (Exception e) { + logger.error("register router error {} : {}",k,v,e); + } + } + }); + } + + @Override + public void start() { + routerServiceMap.forEach((k,v)->{ + if(v instanceof Lifecycle){ + ( (Lifecycle)v).start(); + } + }); + logger.info("router register : {}",routerServiceMap); + } + + @Override + public void destroy() { + + } +} diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/router/RouterService.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/router/RouterService.java new file mode 100644 index 0000000000..3fe6c3674b --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/router/RouterService.java @@ -0,0 +1,8 @@ +package org.fedai.osx.broker.router; + + +import org.fedai.osx.api.router.RouterInfo; + +public interface RouterService { + RouterInfo route(String srcPartyId, String srcRole, String dstPartyId, String desRole); +} diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/security/MockTokenGenerator.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/security/MockTokenGenerator.java new file mode 100644 index 0000000000..2c3acbc799 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/security/MockTokenGenerator.java @@ -0,0 +1,13 @@ +package org.fedai.osx.broker.security; + + +import org.fedai.osx.api.context.Context; + +public class MockTokenGenerator implements TokenGenerator{ + + + @Override + public String createNewToken(Context context) { + return "mock"; + } +} diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/security/TokenGenerator.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/security/TokenGenerator.java new file mode 100644 index 0000000000..81323d088d --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/security/TokenGenerator.java @@ -0,0 +1,10 @@ +package org.fedai.osx.broker.security; + + +import org.fedai.osx.api.context.Context; + +public interface TokenGenerator { + + String createNewToken(Context context); + +} diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/security/TokenGeneratorRegister.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/security/TokenGeneratorRegister.java new file mode 100644 index 0000000000..d2d55b6079 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/security/TokenGeneratorRegister.java @@ -0,0 +1,75 @@ +package org.fedai.osx.broker.security; + +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.frame.Lifecycle; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.BufferedInputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.InputStream; +import java.lang.reflect.InvocationTargetException; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.ConcurrentHashMap; + + +public class TokenGeneratorRegister implements Lifecycle { + + Logger logger = LoggerFactory.getLogger(TokenGeneratorRegister.class); + + final String DEFAULT_KEY = "default"; + + private Map tokenGeneratorMap = new ConcurrentHashMap<>(); + + + @Override + public void init() { + if(MetaInfo.PROPERTY_OPEN_TOKEN_GENERATOR){ + String configFilePath= MetaInfo.PROPERTY_TOKEN_GENERATOR_CONFIG_PATH; + File file = new File(configFilePath); + Properties config = new Properties(); + try (InputStream inputStream = new BufferedInputStream(new FileInputStream(file))) { + config.load(inputStream); + }catch (Exception e){ + + } + config.forEach((k,v)->{ + if(v!=null){ + try { + Class genClass = Class.forName(v.toString()); + Object rawObject = genClass.getConstructor().newInstance(); + if(!(rawObject instanceof TokenGenerator)){ + logger.error("create token generator err , {} ",v); + return ; + } + tokenGeneratorMap.put(k.toString(),(TokenGenerator)rawObject); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } catch (InvocationTargetException e) { + throw new RuntimeException(e); + } catch (InstantiationException e) { + throw new RuntimeException(e); + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } catch (NoSuchMethodException e) { + throw new RuntimeException(e); + } + } + }); + } + + + } + + @Override + public void start() { + logger.info("register token generator : {}",this.tokenGeneratorMap); + } + + @Override + public void destroy() { + + } +} diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/security/TokenValidator.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/security/TokenValidator.java new file mode 100644 index 0000000000..7ab864e2d8 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/security/TokenValidator.java @@ -0,0 +1,8 @@ +package org.fedai.osx.broker.security; + + +import org.fedai.osx.api.context.Context; + +public interface TokenValidator { + public void validate(Context context, String token); +} diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/security/TokenValidatorRegister.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/security/TokenValidatorRegister.java new file mode 100644 index 0000000000..c6294d0912 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/security/TokenValidatorRegister.java @@ -0,0 +1,75 @@ +package org.fedai.osx.broker.security; + +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.constant.Dict; +import org.fedai.osx.core.frame.Lifecycle; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.BufferedInputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.InputStream; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.ConcurrentHashMap; + +public class TokenValidatorRegister implements Lifecycle { + + Logger logger = LoggerFactory.getLogger(TokenValidatorRegister.class); + + final String DEFAULT_KEY = "default"; + final String TOKEY_VALIDATOR_CONFIG_FILE="token_validator.properties"; + + private Map tokenValidatorMap = new ConcurrentHashMap<>(); + + public TokenValidator getTokenValidator(String key,String defaultKey){ + TokenValidator result = tokenValidatorMap.get(key); + if(result ==null){ + result = tokenValidatorMap.get(defaultKey); + }; + return result; + } + @Override + public void init() { + if(MetaInfo.PROPERTY_OPEN_TOKEN_GENERATOR){ + String configDir= MetaInfo.PROPERTY_CONFIG_DIR; + String fileName = configDir+ Dict.SLASH+TOKEY_VALIDATOR_CONFIG_FILE; + File file = new File(fileName); + Properties config = new Properties(); + try (InputStream inputStream = new BufferedInputStream(new FileInputStream(file))) { + config.load(inputStream); + }catch (Exception e){ + logger.error("parse file {} error",fileName); + } + + config.forEach((k,v)->{ + if(v!=null){ + try { + Class genClass = Class.forName(v.toString()); + Object rawObject = genClass.getConstructor().newInstance(); + if(!(rawObject instanceof TokenValidator)){ + logger.error("parse token validator err , {} ",v); + return ; + } + tokenValidatorMap.put(k.toString(),(TokenValidator)rawObject); + } catch (Exception e) { + logger.error("register token validator error {} : {}",k,v); + } + } + }); + } + } + + @Override + public void start() { + logger.info("register token validator : {}",this.tokenValidatorMap); + } + + @Override + public void destroy() { + this.tokenValidatorMap.clear(); + } + + +} diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/server/OsxServer.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/server/OsxServer.java new file mode 100644 index 0000000000..e8e24f6379 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/server/OsxServer.java @@ -0,0 +1,320 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.fedai.osx.broker.server; + +import io.grpc.ServerInterceptors; +import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts; +import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder; +import io.grpc.netty.shaded.io.netty.handler.ssl.ClientAuth; +import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder; +import io.grpc.netty.shaded.io.netty.handler.ssl.SslProvider; +import org.apache.commons.lang3.StringUtils; +import org.eclipse.jetty.server.HttpConnectionFactory; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.server.ServerConnector; +import org.eclipse.jetty.server.SslConnectionFactory; +import org.eclipse.jetty.servlet.ServletContextHandler; +import org.eclipse.jetty.util.ssl.SslContextFactory; +import org.fedai.osx.broker.grpc.ContextPrepareInterceptor; +import org.fedai.osx.broker.grpc.PcpGrpcService; +import org.fedai.osx.broker.grpc.ProxyGrpcService; +import org.fedai.osx.broker.grpc.ServiceExceptionHandler; +import org.fedai.osx.broker.http.DispatchServlet; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.utils.OSXCertUtils; +import org.fedai.osx.core.utils.OsxX509TrustManager; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManager; +import java.io.File; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.security.KeyStore; +import java.security.SecureRandom; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import static org.fedai.osx.core.config.MetaInfo.PROPERTY_OPEN_GRPC_TLS_SERVER; + +/** + * http1.X + grpc + */ +public class OsxServer { + + Logger logger = LoggerFactory.getLogger(OsxServer.class); + io.grpc.Server server; + io.grpc.Server tlsServer; + org.eclipse.jetty.server.Server httpServer; + org.eclipse.jetty.server.Server httpsServer; + ProxyGrpcService proxyGrpcService; + PcpGrpcService pcpGrpcService; + + private synchronized void init() { + try { + proxyGrpcService = new ProxyGrpcService(); + pcpGrpcService = new PcpGrpcService(); + server = buildServer(); + if (MetaInfo.PROPERTY_OPEN_HTTP_SERVER) { + logger.info("prepare to create http server"); + httpServer = buildHttpServer(); + if (httpServer == null) { + System.exit(0); + } + if (MetaInfo.PROPERTY_HTTP_USE_TLS) { + logger.info("prepare to create http server with TLS"); + httpsServer = buildHttpsServer(); + if (httpsServer == null) { + System.exit(0); + } + } + } + tlsServer = buildTlsServer(); + }catch(Exception e){ + logger.error("server init error ",e); + e.printStackTrace(); + } + } + + public Server buildHttpServer() { + Server server = new Server(); + try { + HttpConnectionFactory http11 = new HttpConnectionFactory(); + ServerConnector connector; + connector = new ServerConnector(server, MetaInfo.PROPERTY_HTTP_SERVER_ACCEPTOR_NUM, MetaInfo.PROPERTY_HTTP_SERVER_SELECTOR_NUM, http11); + // logger.info("http server try to start listen port {}", MetaInfo.PROPERTY_HTTP_PORT); + connector.setPort(MetaInfo.PROPERTY_HTTP_PORT); + connector.setHost(MetaInfo.PROPERTY_BIND_HOST); + connector.setAcceptQueueSize(MetaInfo.PROPERTY_HTTP_RECEIVE_QUEUE_SIZE); + connector.setAcceptedReceiveBufferSize(MetaInfo.PROPERTY_HTTP_ACCEPT_RECEIVE_BUFFER_SIZE); + server.addConnector(connector); + server.setHandler(buildServlet()); + return server; + } catch (Exception e) { + logger.error("build http server error", e); + } + return null; + } + + public Server buildHttpsServer() { + Server server = new Server(); + try { + HttpConnectionFactory http11 = new HttpConnectionFactory(); + ServerConnector connector; + SslContextFactory.Server sslServer = new SslContextFactory.Server(); +// //如果PROPERTY_HTTP_SSL_TRUST_STORE_PATH 为空, 则去读取证书套件,然后生成一个TRUST_STORE + if (StringUtils.isNotBlank(MetaInfo.PROPERTY_HTTP_SSL_TRUST_STORE_PATH)) { + sslServer.setTrustStoreType(MetaInfo.PROPERTY_HTTP_SSL_TRUST_STORE_TYPE.toUpperCase()); + sslServer.setKeyStorePath(MetaInfo.PROPERTY_HTTP_SSL_TRUST_STORE_PATH); + sslServer.setTrustStore(OSXCertUtils.getTrustStore(MetaInfo.PROPERTY_HTTP_SSL_TRUST_STORE_PATH, MetaInfo.PROPERTY_HTTP_SSL_TRUST_STORE_TYPE)); + if (StringUtils.isAllBlank(MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD, MetaInfo.PROPERTY_HTTP_SSL_TRUST_STORE_PASSWORD)) { + throw new IllegalArgumentException("http.ssl.key.store.password/http.ssl.trust.store.password is not set,please check config file"); + } + sslServer.setTrustStorePassword(StringUtils.firstNonBlank(MetaInfo.PROPERTY_HTTP_SSL_TRUST_STORE_PASSWORD, MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD)); + sslServer.setKeyStorePassword(StringUtils.firstNonBlank(MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD, MetaInfo.PROPERTY_HTTP_SSL_TRUST_STORE_PASSWORD)); + sslServer.setTrustStoreProvider(MetaInfo.PROPERTY_HTTP_SSL_TRUST_STORE_PROVIDER); + } else { + SSLContext sslContext = SSLContext.getInstance("SSL"); + KeyStore keyStore = OSXCertUtils.getKeyStore(MetaInfo.PROPERTY_SERVER_CA_FILE, MetaInfo.PROPERTY_SERVER_CERT_CHAIN_FILE, MetaInfo.PROPERTY_SERVER_PRIVATE_KEY_FILE); + TrustManager[] tm = {OsxX509TrustManager.getInstance(keyStore)}; + // Load client certificate + KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); + kmf.init(keyStore, MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD.toCharArray()); + sslContext.init(kmf.getKeyManagers(), tm, new SecureRandom()); + sslServer.setSslContext(sslContext); + } + sslServer.setNeedClientAuth(true); + sslServer.setSslSessionTimeout(MetaInfo.PROPERTY_HTTP_SSL_SESSION_TIME_OUT); + SslConnectionFactory tls = new SslConnectionFactory(sslServer, http11.getProtocol()); + connector = new ServerConnector(server, MetaInfo.PROPERTY_HTTP_SERVER_ACCEPTOR_NUM, MetaInfo.PROPERTY_HTTP_SERVER_SELECTOR_NUM, tls, http11); + // logger.info("http server try to start listen port {}", MetaInfo.PROPERTY_HTTP_PORT); + connector.setPort(MetaInfo.PROPERTY_HTTPS_PORT); + connector.setHost(MetaInfo.PROPERTY_BIND_HOST); + connector.setAcceptQueueSize(MetaInfo.PROPERTY_HTTP_RECEIVE_QUEUE_SIZE); + connector.setAcceptedReceiveBufferSize(MetaInfo.PROPERTY_HTTP_ACCEPT_RECEIVE_BUFFER_SIZE); + server.addConnector(connector); + server.setHandler(buildServlet()); +// new Thread(()->{ +// while (true){ +// try { +// logger.info("========================= http连接数 = {}",server.getConnectors().length); +// Thread.sleep(5000); +// } catch (InterruptedException e) { +// e.printStackTrace(); +// } +// } +// }).start(); + return server; + } catch (Exception e) { + logger.error("build https server error = {}", e.getMessage()); + e.printStackTrace(); + } + return null; + } + + ServletContextHandler buildServlet() { + ServletContextHandler context = new ServletContextHandler(); + context.setContextPath(MetaInfo.PROPERTY_HTTP_CONTEXT_PATH); + context.addServlet(DispatchServlet.class, MetaInfo.PROPERTY_HTTP_SERVLET_PATH); + context.setMaxFormContentSize(Integer.MAX_VALUE); + return context; + } + + public boolean start() { + init(); + //grpc + try { + server.start(); + logger.info("listen grpc port {} success", MetaInfo.PROPERTY_GRPC_PORT); + } catch (Exception e) { + if (e instanceof IOException || e.getCause() instanceof java.net.BindException) { + logger.error("port {} already in use, please try to choose another one !!!!", MetaInfo.PROPERTY_GRPC_PORT); + } + e.printStackTrace(); + return false; + } + + //http + try { + if (httpServer != null) { + httpServer.start(); + logger.info("listen http port {} success", MetaInfo.PROPERTY_HTTP_PORT); + } + } catch (Exception e) { + if (e instanceof java.net.BindException || e.getCause() instanceof java.net.BindException) { + logger.error("port {} already in use, please try to choose another one !!!!", MetaInfo.PROPERTY_HTTP_PORT); + } + e.printStackTrace(); + return false; + } + + //tls + try { + if (tlsServer != null) { + logger.info("grpc tls server try to start, listen port {}", MetaInfo.PROPERTY_GRPC_TLS_PORT); + tlsServer.start(); + logger.info("listen grpc tls port {} success", MetaInfo.PROPERTY_GRPC_TLS_PORT); + } + } catch (Exception e) { + if (e instanceof java.net.BindException || e.getCause() instanceof java.net.BindException) { + logger.error("port {} already in use, please try to choose another one !!!!", MetaInfo.PROPERTY_GRPC_TLS_PORT); + } + e.printStackTrace(); + return false; + } + + //https + try { + if (httpsServer != null) { + httpsServer.start(); + logger.info("listen https port {} success", MetaInfo.PROPERTY_HTTPS_PORT); + } + } catch (Exception e) { + if (e instanceof java.net.BindException || e.getCause() instanceof java.net.BindException) { + logger.error("port {} already in use, please try to choose another one !!!!", MetaInfo.PROPERTY_HTTPS_PORT); + } + e.printStackTrace(); + return false; + } + return true; + } + + private io.grpc.Server buildTlsServer() { + String certChainFilePath = MetaInfo.PROPERTY_SERVER_CERT_CHAIN_FILE; + String privateKeyFilePath = MetaInfo.PROPERTY_SERVER_PRIVATE_KEY_FILE; + String trustCertCollectionFilePath = MetaInfo.PROPERTY_SERVER_CA_FILE; + if (PROPERTY_OPEN_GRPC_TLS_SERVER && StringUtils.isNotBlank(certChainFilePath) + && StringUtils.isNotBlank(privateKeyFilePath) && StringUtils.isNotBlank(trustCertCollectionFilePath)) { + try { + SocketAddress address = new InetSocketAddress(MetaInfo.PROPERTY_BIND_HOST, MetaInfo.PROPERTY_GRPC_TLS_PORT); + NettyServerBuilder nettyServerBuilder = NettyServerBuilder.forAddress(address); + SslContextBuilder sslContextBuilder = GrpcSslContexts.forServer(new File(certChainFilePath), new File(privateKeyFilePath)) + .trustManager(new File(trustCertCollectionFilePath)) + .clientAuth(ClientAuth.REQUIRE) + .sessionTimeout(MetaInfo.PROPERTY_GRPC_SSL_SESSION_TIME_OUT) + .sessionCacheSize(MetaInfo.PROPERTY_HTTP_SSL_SESSION_CACHE_SIZE); + logger.info("running in secure mode. server crt path: {}, server key path: {}, ca crt path: {}.", + certChainFilePath, privateKeyFilePath, trustCertCollectionFilePath); + //serverBuilder.executor(executor); + nettyServerBuilder.sslContext(GrpcSslContexts.configure(sslContextBuilder, SslProvider.OPENSSL).build()); + nettyServerBuilder.addService(ServerInterceptors.intercept(proxyGrpcService, new ServiceExceptionHandler(), new ContextPrepareInterceptor())); + nettyServerBuilder.addService(ServerInterceptors.intercept(pcpGrpcService, new ServiceExceptionHandler(), new ContextPrepareInterceptor())); + + + nettyServerBuilder + .executor(Executors.newCachedThreadPool()) + .maxConcurrentCallsPerConnection(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONCURRENT_CALL_PER_CONNECTION) + .maxInboundMessageSize(MetaInfo.PROPERTY_GRPC_SERVER_MAX_INBOUND_MESSAGE_SIZE) + .maxInboundMetadataSize(MetaInfo.PROPERTY_GRPC_SERVER_MAX_INBOUND_METADATA_SIZE) + .flowControlWindow(MetaInfo.PROPERTY_GRPC_SERVER_FLOW_CONTROL_WINDOW); + if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC > 0) + nettyServerBuilder.keepAliveTime(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC, TimeUnit.SECONDS); + if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC > 0) + nettyServerBuilder.keepAliveTimeout(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC, TimeUnit.SECONDS); + if (MetaInfo.PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC > 0) { + nettyServerBuilder.permitKeepAliveTime(MetaInfo.PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC, TimeUnit.SECONDS); + } + if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED) + nettyServerBuilder.permitKeepAliveWithoutCalls(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED); + if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC > 0) + nettyServerBuilder.maxConnectionIdle(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC, TimeUnit.SECONDS); + if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC > 0) + nettyServerBuilder.maxConnectionAge(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC, TimeUnit.SECONDS); + if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC > 0) + nettyServerBuilder.maxConnectionAgeGrace(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC, TimeUnit.SECONDS); + + return nettyServerBuilder.build(); + } catch (SSLException e) { + throw new SecurityException(e); + } + } + return null; + } + + + private io.grpc.Server buildServer() { + SocketAddress address = new InetSocketAddress(MetaInfo.PROPERTY_BIND_HOST, MetaInfo.PROPERTY_GRPC_PORT); + NettyServerBuilder nettyServerBuilder = NettyServerBuilder.forAddress(address); + nettyServerBuilder.addService(ServerInterceptors.intercept(proxyGrpcService, new ServiceExceptionHandler(), new ContextPrepareInterceptor())); + nettyServerBuilder.addService(ServerInterceptors.intercept(pcpGrpcService, new ServiceExceptionHandler(), new ContextPrepareInterceptor())); + nettyServerBuilder + .executor(Executors.newCachedThreadPool()) + .maxConcurrentCallsPerConnection(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONCURRENT_CALL_PER_CONNECTION) + .maxInboundMessageSize(MetaInfo.PROPERTY_GRPC_SERVER_MAX_INBOUND_MESSAGE_SIZE) + .maxInboundMetadataSize(MetaInfo.PROPERTY_GRPC_SERVER_MAX_INBOUND_METADATA_SIZE) + .flowControlWindow(MetaInfo.PROPERTY_GRPC_SERVER_FLOW_CONTROL_WINDOW); + if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC > 0) + nettyServerBuilder.keepAliveTime(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC, TimeUnit.SECONDS); + if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC > 0) + nettyServerBuilder.keepAliveTimeout(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC, TimeUnit.SECONDS); + if (MetaInfo.PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC > 0) { + nettyServerBuilder.permitKeepAliveTime(MetaInfo.PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC, TimeUnit.SECONDS); + } + if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED) + nettyServerBuilder.permitKeepAliveWithoutCalls(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED); + if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC > 0) + nettyServerBuilder.maxConnectionIdle(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC, TimeUnit.SECONDS); + if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC > 0) + nettyServerBuilder.maxConnectionAge(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC, TimeUnit.SECONDS); + if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC > 0) + nettyServerBuilder.maxConnectionAgeGrace(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC, TimeUnit.SECONDS); + return nettyServerBuilder.build(); + } +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/service/PushService.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/service/PushService.java similarity index 56% rename from java/osx/broker/src/main/java/com/osx/broker/service/PushService.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/service/PushService.java index c3a3fb645c..1edff14499 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/service/PushService.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/service/PushService.java @@ -13,38 +13,41 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.service; +package org.fedai.osx.broker.service; -import com.osx.broker.grpc.PushRequestDataWrap; -import com.osx.broker.grpc.QueuePushReqStreamObserver; -import com.osx.core.context.Context; -import com.osx.core.exceptions.ExceptionInfo; -import com.osx.core.exceptions.SysException; -import com.osx.core.service.AbstractServiceAdaptor; -import com.osx.core.service.InboundPackage; import com.webank.ai.eggroll.api.networking.proxy.Proxy; import io.grpc.stub.StreamObserver; +import org.fedai.osx.broker.ServiceContainer; +import org.fedai.osx.broker.grpc.QueuePushReqStreamObserver; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.context.FateContext; +import org.fedai.osx.core.exceptions.ExceptionInfo; +import org.fedai.osx.core.exceptions.SysException; +import org.fedai.osx.core.service.AbstractServiceAdaptor; +import org.fedai.osx.core.service.InboundPackage; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class PushService extends AbstractServiceAdaptor { +public class PushService extends AbstractServiceAdaptor { Logger logger = LoggerFactory.getLogger(PushService.class); + + @Override - protected StreamObserver doService(Context context, InboundPackage data + protected StreamObserver doService(FateContext context, InboundPackage data ) { - PushRequestDataWrap pushRequestDataWrap = data.getBody(); - StreamObserver backRespSO = pushRequestDataWrap.getStreamObserver(); - context.setNeedPrintFlowLog(false); + StreamObserver backRespSO = data.getBody(); + // context.setNeedPrintFlowLog(false); QueuePushReqStreamObserver queuePushReqStreamObserver = new QueuePushReqStreamObserver(context, + ServiceContainer.routerRegister.getRouterService(MetaInfo.PROPERTY_FATE_TECH_PROVIDER), backRespSO, Proxy.Metadata.class); return queuePushReqStreamObserver; } @Override - protected StreamObserver transformExceptionInfo(Context context, ExceptionInfo exceptionInfo) { + protected StreamObserver transformExceptionInfo(FateContext context, ExceptionInfo exceptionInfo) { logger.error("PushService error {}", exceptionInfo); throw new SysException(exceptionInfo.toString()); } diff --git a/java/osx/broker/src/main/java/com/osx/broker/service/RegisterService.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/service/RegisterService.java similarity index 87% rename from java/osx/broker/src/main/java/com/osx/broker/service/RegisterService.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/service/RegisterService.java index 9eae69bed9..8e6a2bd54d 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/service/RegisterService.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/service/RegisterService.java @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.service; +package org.fedai.osx.broker.service; -import com.osx.broker.zk.CuratorZookeeperClient; -import com.osx.broker.zk.ZkConfig; +import org.fedai.osx.broker.zk.CuratorZookeeperClient; +import org.fedai.osx.broker.zk.ZkConfig; public class RegisterService { diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/service/RouteService.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/service/RouteService.java new file mode 100644 index 0000000000..433fb300b0 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/service/RouteService.java @@ -0,0 +1,28 @@ +package org.fedai.osx.broker.service; + +import com.webank.ai.eggroll.api.networking.proxy.Proxy; +import org.fedai.osx.api.context.Context; +import org.fedai.osx.broker.router.DefaultFateRouterServiceImpl; +import org.fedai.osx.core.exceptions.ExceptionInfo; +import org.fedai.osx.core.service.AbstractServiceAdaptor; +import org.fedai.osx.core.service.InboundPackage; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class RouteService extends AbstractServiceAdaptor { + + Logger logger = LoggerFactory.getLogger(RouteService.class); + + @Override + protected Proxy.Packet doService(Context context, InboundPackage data) { + DefaultFateRouterServiceImpl defaultFateRouterService = new DefaultFateRouterServiceImpl(); + defaultFateRouterService.saveRouterTable(context, data); + Proxy.Packet.Builder resultBuilder = Proxy.Packet.newBuilder(); + return resultBuilder.build(); + } + + @Override + protected Proxy.Packet transformExceptionInfo(Context context, ExceptionInfo exceptionInfo) { + return null; + } +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/service/TokenApplyService.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/service/TokenApplyService.java similarity index 87% rename from java/osx/broker/src/main/java/com/osx/broker/service/TokenApplyService.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/service/TokenApplyService.java index be693393df..00745537d3 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/service/TokenApplyService.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/service/TokenApplyService.java @@ -13,25 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.service; +package org.fedai.osx.broker.service; import com.google.protobuf.ByteString; -import com.osx.broker.ServiceContainer; -import com.osx.core.config.MetaInfo; -import com.osx.core.constant.Dict; -import com.osx.core.constant.StreamLimitMode; -import com.osx.core.context.Context; -import com.osx.core.flow.FlowRule; -import com.osx.core.frame.GrpcConnectionFactory; -import com.osx.core.frame.Lifecycle; -import com.osx.core.ptp.TargetMethod; -import com.osx.core.router.RouterInfo; -import com.osx.core.token.TokenRequest; -import com.osx.core.token.TokenResult; -import com.osx.core.token.TokenResultStatus; -import com.osx.core.utils.JsonUtil; import io.grpc.ManagedChannel; +import org.fedai.osx.api.context.Context; +import org.fedai.osx.api.router.RouterInfo; +import org.fedai.osx.broker.ServiceContainer; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.constant.StreamLimitMode; +import org.fedai.osx.core.context.FateContext; +import org.fedai.osx.core.flow.FlowRule; +import org.fedai.osx.core.frame.GrpcConnectionFactory; +import org.fedai.osx.core.frame.Lifecycle; +import org.fedai.osx.core.ptp.TargetMethod; +import org.fedai.osx.core.token.TokenRequest; +import org.fedai.osx.core.token.TokenResult; +import org.fedai.osx.core.token.TokenResultStatus; +import org.fedai.osx.core.utils.JsonUtil; import org.ppc.ptp.Osx; import org.ppc.ptp.PrivateTransferProtocolGrpc; import org.slf4j.Logger; @@ -57,7 +57,7 @@ public TokenApplyService() { public PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub buildBlockingStub(String address) { String[] ipports= address.split(":"); - RouterInfo routerInfo = new RouterInfo(); + RouterInfo routerInfo = new RouterInfo(); routerInfo.setHost(ipports[0]); routerInfo.setPort(Integer.parseInt(ipports[1])); ManagedChannel channel = GrpcConnectionFactory.createManagedChannel(routerInfo,true); @@ -66,11 +66,11 @@ public PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub buildBloc } - public void applyToken(Context context, String resource, int count) { + public void applyToken(FateContext context, String resource, int count) { if (MetaInfo.PROPERTY_STREAM_LIMIT_MODE.equals(StreamLimitMode.LOCAL.name()) || MetaInfo.PROPERTY_STREAM_LIMIT_MODE.equals(StreamLimitMode.CLUSTER.name())) { - TokenResult localTokenResult = tryLocalLimit(resource, count); + TokenResult localTokenResult = tryLocalLimit(context,resource, count); logger.info("request token {} count {} result {}", resource, count, localTokenResult); /** * 集群限流 @@ -83,17 +83,18 @@ public void applyToken(Context context, String resource, int count) { tryClusterLimit(resource, count); } } - ServiceContainer.flowCounterManager.pass(resource, count); + // ServiceContainer.flowCounterManager.pass(resource, count); } } - private TokenResult tryLocalLimit(String resource, int count) { + private TokenResult tryLocalLimit(FateContext context,String resource, int count) { boolean needLoop = false; int tryTime = 0; TokenResult tokenResult; + int totalSleepMs = 0; do { tokenResult = ServiceContainer.defaultTokenService.requestToken(resource, count, true); @@ -111,8 +112,8 @@ private TokenResult tryLocalLimit(String resource, int count) { logger.info("should wait {} ms", sleepMs); try { Thread.sleep(sleepMs); - } catch (InterruptedException e) { - e.printStackTrace(); + } catch (InterruptedException igore) { + } needLoop = false; break; @@ -120,9 +121,11 @@ private TokenResult tryLocalLimit(String resource, int count) { try { sleepMs = tokenResult.getWaitInMs(); if (sleepMs > 0) { + totalSleepMs+=sleepMs; Thread.sleep(sleepMs); + } - logger.info("should block {} ms try time {}", sleepMs, tryTime); + // logger.info("should block {} ms try time {}", sleepMs, tryTime); } catch (InterruptedException e) { logger.error(""); } @@ -135,7 +138,7 @@ private TokenResult tryLocalLimit(String resource, int count) { } } } while (needLoop && tryTime < MetaInfo.PROPERTY_STREAM_LIMIT_MAX_TRY_TIME); - + context.setSleepTime(totalSleepMs); return tokenResult; } @@ -143,7 +146,7 @@ private TokenResult tryLocalLimit(String resource, int count) { private void tryClusterLimit(String resource, int count) { - TokenRequest tokenRequest = new TokenRequest(); + TokenRequest tokenRequest = new TokenRequest(); tokenRequest.setResource(resource); tokenRequest.setAcquireCount(count); diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/service/UnaryCallService.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/service/UnaryCallService.java new file mode 100644 index 0000000000..095c42830f --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/service/UnaryCallService.java @@ -0,0 +1,106 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.fedai.osx.broker.service; + +import com.google.protobuf.InvalidProtocolBufferException; +import com.webank.ai.eggroll.api.networking.proxy.DataTransferServiceGrpc; +import com.webank.ai.eggroll.api.networking.proxy.Proxy; +import io.grpc.ManagedChannel; +import org.fedai.osx.api.constants.Protocol; +import org.fedai.osx.api.router.RouterInfo; +import org.fedai.osx.broker.util.TransferUtil; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.constant.ActionType; +import org.fedai.osx.core.constant.StatusCode; +import org.fedai.osx.core.context.FateContext; +import org.fedai.osx.core.exceptions.ExceptionInfo; +import org.fedai.osx.core.exceptions.NoRouterInfoException; +import org.fedai.osx.core.exceptions.RemoteRpcException; +import org.fedai.osx.core.frame.GrpcConnectionFactory; +import org.fedai.osx.core.ptp.SourceMethod; +import org.fedai.osx.core.ptp.TargetMethod; +import org.fedai.osx.core.service.AbstractServiceAdaptor; +import org.fedai.osx.core.service.InboundPackage; +import org.ppc.ptp.Osx; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * 用于兼容旧版FATE + */ +public class UnaryCallService extends AbstractServiceAdaptor { + + Logger logger = LoggerFactory.getLogger(UnaryCallService.class); + + public UnaryCallService() { + + } + + @Override + protected Proxy.Packet doService(FateContext context, InboundPackage data) { + context.setActionType(ActionType.UNARY_CALL.getAlias()); + Proxy.Packet req = (Proxy.Packet) data.getBody(); + Proxy.Packet resp = unaryCall(context, req); + //logger.info("uncary req {} resp {}", req, resp); + return resp; + } + + + protected Proxy.Packet transformExceptionInfo(FateContext context, ExceptionInfo exceptionInfo) { + + throw new RemoteRpcException(exceptionInfo.toString()) ; + + + } + + /** + * 非流式传输 + * + * @param context + * @param + */ + public Proxy.Packet unaryCall(FateContext context, Proxy.Packet req) { + Proxy.Packet result = null; + RouterInfo routerInfo=context.getRouterInfo(); + if(routerInfo==null){ + String sourcePartyId = context.getSrcPartyId(); + String desPartyId = context.getDesPartyId(); + throw new NoRouterInfoException(sourcePartyId+" to "+desPartyId +" found no router info"); + } + if(routerInfo.getProtocol().equals(Protocol.http)){ + Osx.Inbound inbound = TransferUtil.buildInboundFromPushingPacket(req, MetaInfo.PROPERTY_FATE_TECH_PROVIDER, TargetMethod.UNARY_CALL.name(), SourceMethod.OLDUNARY_CALL.name()).build(); + Osx.Outbound outbound = TransferUtil.redirect(context,inbound,routerInfo,true); + if(outbound!=null) { + if (outbound.getCode().equals(StatusCode.SUCCESS)) { + try { + result = Proxy.Packet.parseFrom(outbound.getPayload().toByteArray()); + } catch (InvalidProtocolBufferException e) { + e.printStackTrace(); + } + } else { + throw new RemoteRpcException(outbound.getMessage()); + } + } + }else { + ManagedChannel managedChannel = GrpcConnectionFactory.createManagedChannel(context.getRouterInfo(), true); + DataTransferServiceGrpc.DataTransferServiceBlockingStub stub = DataTransferServiceGrpc.newBlockingStub(managedChannel); + result = stub.unaryCall(req); + } + return result; + } + + +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/store/IndexQueue.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/store/IndexQueue.java similarity index 93% rename from java/osx/broker/src/main/java/com/osx/broker/store/IndexQueue.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/store/IndexQueue.java index 23bb9e4ec7..51325b5117 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/store/IndexQueue.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/store/IndexQueue.java @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.store; +package org.fedai.osx.broker.store; -import com.osx.broker.message.SelectMappedBufferResult; -import com.osx.broker.queue.MappedFile; -import com.osx.broker.queue.MappedFileQueue; +import org.fedai.osx.broker.message.SelectMappedBufferResult; +import org.fedai.osx.broker.queue.MappedFile; +import org.fedai.osx.broker.queue.MappedFileQueue; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -39,7 +39,7 @@ public class IndexQueue { private long maxPhysicOffset = -1; private volatile long minLogicOffset = 0; private AtomicLong logicOffset = new AtomicLong(0); - + Logger logger = LoggerFactory.getLogger(IndexQueue.class); public IndexQueue( final String transferId, final String storePath, @@ -68,9 +68,7 @@ public boolean load() { public long getLastOffset() { long lastOffset = -1; - int logicFileSize = this.mappedFileSize; - MappedFile mappedFile = this.mappedFileQueue.getLastMappedFile(); if (mappedFile != null) { @@ -143,18 +141,18 @@ public long getMinOffsetInQueue() { public long putMessagePositionInfoWrapper(long offset, int msgSize) { final int maxRetries = 30; - + long resultLogicOffset = -1; for (int i = 0; i < maxRetries; i++) { boolean result = this.putMessagePositionInfo(offset, msgSize, this.logicOffset.get() + 1); if (result) { - return logicOffset.addAndGet(1); - + resultLogicOffset = logicOffset.addAndGet(1); + return resultLogicOffset; } } - return -1; + return resultLogicOffset; } @@ -181,8 +179,8 @@ private boolean putMessagePositionInfo(final long offset, final int size, this.mappedFileQueue.setFlushedWhere(expectLogicOffset); this.mappedFileQueue.setCommittedWhere(expectLogicOffset); this.fillPreBlank(mappedFile, expectLogicOffset); - log.info("fill pre blank space " + mappedFile.getFileName() + " " + expectLogicOffset + " " - + mappedFile.getWrotePosition()); +// log.info("fill pre blank space " + mappedFile.getFileName() + " " + expectLogicOffset + " " +// + mappedFile.getWrotePosition()); } if (cqOffset != 0) { diff --git a/java/osx/broker/src/main/java/com/osx/broker/store/MessageStore.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/store/MessageStore.java similarity index 93% rename from java/osx/broker/src/main/java/com/osx/broker/store/MessageStore.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/store/MessageStore.java index 20b567a77b..5c5830a2ea 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/store/MessageStore.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/store/MessageStore.java @@ -13,23 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.store; - -import com.osx.broker.message.*; -import com.osx.broker.queue.*; -import com.osx.core.config.MetaInfo; -import com.osx.core.constant.TransferStatus; -import com.osx.core.exceptions.MappedFileException; -import com.osx.core.exceptions.TransferQueueInvalidStatusException; -import com.osx.core.frame.ServiceThread; +package org.fedai.osx.broker.store; + +import org.fedai.osx.broker.message.*; +import org.fedai.osx.broker.queue.*; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.constant.TransferStatus; +import org.fedai.osx.core.exceptions.MappedFileException; +import org.fedai.osx.core.exceptions.TransferQueueInvalidStatusException; +import org.fedai.osx.core.frame.ServiceThread; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.locks.ReentrantLock; -import static com.osx.core.config.MetaInfo.MAP_FILE_SIZE; - public class MessageStore { protected final AtomicInteger wrotePosition = new AtomicInteger(0); @@ -43,14 +41,14 @@ public class MessageStore { MappedFileQueue mappedFileQueue; ReentrantLock putMessageLock = new ReentrantLock(); long beginTimeInLock; - AppendMessageHandler appendMessageCallback = new DefaultAppendMessageHandler(MAP_FILE_SIZE); + AppendMessageHandler appendMessageCallback = new DefaultAppendMessageHandler(MetaInfo.MAP_FILE_SIZE); AllocateMappedFileService allocateMappedFileService; CleanMappedFileThread cleanMappedFileThread = new CleanMappedFileThread(); public MessageStore(AllocateMappedFileService allocateMappedFileService, String path) { allocateMappedFileService = this.allocateMappedFileService; - mappedFileQueue = new MappedFileQueue(path, MAP_FILE_SIZE, allocateMappedFileService); + mappedFileQueue = new MappedFileQueue(path, MetaInfo.MAP_FILE_SIZE, allocateMappedFileService); this.createTimestamp = System.currentTimeMillis(); this.lastStatusChangeTimestamp = this.createTimestamp; this.lastWriteTimestamp = this.createTimestamp; @@ -153,7 +151,7 @@ public SelectMappedBufferResult getMessage(final long offset, final int size) { if (this.mappedFileQueue != null) { MappedFile mappedFile = this.mappedFileQueue.findMappedFileByOffset(offset, offset == 0); if (mappedFile != null) { - int pos = (int) (offset % MAP_FILE_SIZE); + int pos = (int) (offset % MetaInfo.MAP_FILE_SIZE); return mappedFile.selectMappedBuffer(pos, size); } return null; diff --git a/java/osx/broker/src/main/java/com/osx/broker/token/DefaultTokenService.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/token/DefaultTokenService.java similarity index 83% rename from java/osx/broker/src/main/java/com/osx/broker/token/DefaultTokenService.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/token/DefaultTokenService.java index 1b54594650..e494fe5c05 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/token/DefaultTokenService.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/token/DefaultTokenService.java @@ -13,17 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.token; -import com.osx.core.config.MetaInfo; -import com.osx.core.constant.Dict; -import com.osx.core.context.Context; -import com.osx.core.exceptions.ExceptionInfo; -import com.osx.core.flow.*; -import com.osx.core.service.AbstractServiceAdaptor; -import com.osx.core.service.InboundPackage; -import com.osx.core.token.TokenResult; -import com.osx.core.token.TokenResultStatus; +package org.fedai.osx.broker.token; + import org.apache.commons.lang3.StringUtils; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.constant.Dict; +import org.fedai.osx.core.flow.*; +import org.fedai.osx.core.token.TokenResult; +import org.fedai.osx.core.token.TokenResultStatus; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -46,10 +43,10 @@ public TokenResult requestToken(String resource, int acquireCount, boolean prior } FlowRule rule = ClusterFlowRuleManager.getFlowRuleByResource(resource); if (rule == null) { - logger.error("resource {} no rule", resource); + //logger.error("resource {} no rule", resource); ClusterMetric clusterMetric = ClusterMetricStatistics.getMetric(resource); if (clusterMetric == null) { - ClusterMetricStatistics.putMetricIfAbsent(resource, new ClusterMetric(MetaInfo.PROPERTY_SAMPLE_COUNT, MetaInfo.PROPERTY_INTERVAL_MS)); + ClusterMetricStatistics.putMetricIfAbsent(resource, new ClusterMetric(MetaInfo.PROPERTY_FLOW_CONTROL_SAMPLE_COUNT, MetaInfo.PROPERTY_FLOW_CONTROL_SAMPLE_INTERVAL)); clusterMetric = ClusterMetricStatistics.getMetric(resource); } clusterMetric.add(ClusterFlowEvent.PASS, acquireCount); diff --git a/java/osx/broker/src/main/java/com/osx/broker/util/ContextUtil.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/ContextUtil.java similarity index 70% rename from java/osx/broker/src/main/java/com/osx/broker/util/ContextUtil.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/ContextUtil.java index fff76b41b5..0b90a14a80 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/util/ContextUtil.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/ContextUtil.java @@ -13,19 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.util; +package org.fedai.osx.broker.util; -import com.osx.broker.grpc.ContextPrepareInterceptor; -import com.osx.core.context.Context; - -import java.util.UUID; +import org.fedai.osx.api.constants.Protocol; +import org.fedai.osx.broker.grpc.ContextPrepareInterceptor; +import org.fedai.osx.core.context.FateContext; public class ContextUtil { - public static Context buildContext() { - Context context = new Context(); + public static FateContext buildFateContext(Protocol protocol) { + FateContext context = new FateContext(); + context.setProtocol(protocol); context.setSourceIp(ContextPrepareInterceptor.sourceIp.get() != null ? ContextPrepareInterceptor.sourceIp.get().toString() : ""); - context.setCaseId(UUID.randomUUID().toString()); + return context; } + + } diff --git a/java/osx/broker/src/main/java/com/osx/broker/util/DateUtils.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/DateUtils.java similarity index 99% rename from java/osx/broker/src/main/java/com/osx/broker/util/DateUtils.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/DateUtils.java index 9a60af674a..9976fbb771 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/util/DateUtils.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/DateUtils.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.util; +package org.fedai.osx.broker.util; import org.apache.commons.lang3.time.DateFormatUtils; diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/DebugUtil.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/DebugUtil.java new file mode 100644 index 0000000000..0340a0c4d0 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/DebugUtil.java @@ -0,0 +1,56 @@ +package org.fedai.osx.broker.util; + +import org.fedai.osx.api.constants.Protocol; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.utils.JsonUtil; +import org.ppc.ptp.Osx; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.servlet.http.HttpServletRequest; +import java.io.BufferedReader; +import java.io.IOException; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.Map; + +public class DebugUtil { + + static Logger logger = LoggerFactory.getLogger(DebugUtil.class); + + public static void printGrpcParams(Osx.Inbound request) { + try { + if (MetaInfo.PROTOCOL_PARAMS_PRINT) { + logger.info("【{}】====> {}", Protocol.grpc.name(), JsonUtil.object2Json(request.getMetadataMap())); + } + }catch (Exception e){ + logger.error("DebugUtil.printGrpcParams error : ",e); + } + } + + public static void printHttpParams(HttpServletRequest request) { + try { + if (MetaInfo.PROTOCOL_PARAMS_PRINT) { + StringBuilder info = new StringBuilder("【" + Protocol.http.name() + "】====> " + "(uri) = " + request.getRequestURI() + "\n(head) = "); + Enumeration headerNames = request.getHeaderNames(); + Map headMap = new HashMap<>(); + if (headerNames.hasMoreElements()) { + String headerName = headerNames.nextElement(); + headMap.put(headerName, request.getHeader(headerName)); + } + info.append("\n").append(JsonUtil.object2Json(headMap)).append("\n(body) = "); + try (BufferedReader reader = request.getReader()) { + String line; + while ((line = reader.readLine()) != null) { + info.append(line); + } + } catch (IOException e) { + e.printStackTrace(); + } + logger.info(info.toString()); + } + } catch (Exception e) { + logger.error("DebugUtil.printGrpcParams error : ",e); + } + } +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/util/LibC.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/LibC.java similarity index 97% rename from java/osx/broker/src/main/java/com/osx/broker/util/LibC.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/LibC.java index 065536e10c..05554877cc 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/util/LibC.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/LibC.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.util; +package org.fedai.osx.broker.util; import com.sun.jna.*; diff --git a/java/osx/broker/src/main/java/com/osx/broker/util/MessageConst.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/MessageConst.java similarity index 99% rename from java/osx/broker/src/main/java/com/osx/broker/util/MessageConst.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/MessageConst.java index 7c398f3ba8..d4251451d3 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/util/MessageConst.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/MessageConst.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.util; +package org.fedai.osx.broker.util; import java.util.HashSet; diff --git a/java/osx/broker/src/main/java/com/osx/broker/util/MessageId.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/MessageId.java similarity index 97% rename from java/osx/broker/src/main/java/com/osx/broker/util/MessageId.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/MessageId.java index fbc3214931..16c7bde929 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/util/MessageId.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/MessageId.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.util; +package org.fedai.osx.broker.util; import java.net.SocketAddress; diff --git a/java/osx/broker/src/main/java/com/osx/broker/util/ResourceUtil.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/ResourceUtil.java similarity index 90% rename from java/osx/broker/src/main/java/com/osx/broker/util/ResourceUtil.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/ResourceUtil.java index a55489f0ac..8c1054485d 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/util/ResourceUtil.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/ResourceUtil.java @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.util; +package org.fedai.osx.broker.util; -import com.osx.broker.constants.Direction; -import com.osx.core.router.RouterInfo; +import org.fedai.osx.api.router.RouterInfo; +import org.fedai.osx.broker.constants.Direction; public class ResourceUtil { diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/TelnetUtil.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/TelnetUtil.java new file mode 100644 index 0000000000..0718209eef --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/TelnetUtil.java @@ -0,0 +1,21 @@ +package org.fedai.osx.broker.util; + +import org.apache.commons.net.telnet.TelnetClient; + +public class TelnetUtil { + + public static boolean tryTelnet(String host ,int port){ + TelnetClient telnetClient = new TelnetClient("vt200"); + telnetClient.setDefaultTimeout(5000); + boolean isConnected = false; + try { + telnetClient.connect(host, port); + isConnected = true; + telnetClient.disconnect(); + } catch (Exception e) { + //e.printStackTrace(); + } + return isConnected; + } + +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/util/TimeUtils.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/TimeUtils.java similarity index 97% rename from java/osx/broker/src/main/java/com/osx/broker/util/TimeUtils.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/TimeUtils.java index 8d36cbc689..c6d340eed7 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/util/TimeUtils.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/TimeUtils.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.util; +package org.fedai.osx.broker.util; import org.apache.commons.lang3.StringUtils; diff --git a/java/osx/broker/src/main/java/com/osx/broker/util/TransferExceptionUtil.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/TransferExceptionUtil.java similarity index 98% rename from java/osx/broker/src/main/java/com/osx/broker/util/TransferExceptionUtil.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/TransferExceptionUtil.java index bac1857f7c..3f0e53d3f9 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/util/TransferExceptionUtil.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/TransferExceptionUtil.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.util; +package org.fedai.osx.broker.util; public class TransferExceptionUtil { diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/TransferUtil.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/TransferUtil.java new file mode 100644 index 0000000000..4767a882ee --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/TransferUtil.java @@ -0,0 +1,564 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.fedai.osx.broker.util; + + +import com.google.common.collect.Maps; +import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; +import com.webank.ai.eggroll.api.networking.proxy.DataTransferServiceGrpc; +import com.webank.ai.eggroll.api.networking.proxy.Proxy; +import com.webank.eggroll.core.transfer.Transfer; +import io.grpc.ManagedChannel; +import io.grpc.StatusRuntimeException; +import org.apache.commons.lang3.StringUtils; +import org.fedai.osx.api.constants.Protocol; +import org.fedai.osx.api.context.Context; +import org.fedai.osx.api.router.RouterInfo; +import org.fedai.osx.broker.constants.MessageFlag; +import org.fedai.osx.broker.http.HttpClientPool; +import org.fedai.osx.broker.http.HttpsClientPool; +import org.fedai.osx.broker.queue.TransferQueue; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.config.TransferMeta; +import org.fedai.osx.core.constant.Dict; +import org.fedai.osx.core.constant.PtpHttpHeader; +import org.fedai.osx.core.constant.Role; +import org.fedai.osx.core.constant.StatusCode; +import org.fedai.osx.core.context.FateContext; +import org.fedai.osx.core.exceptions.*; +import org.fedai.osx.core.frame.GrpcConnectionFactory; +import org.fedai.osx.core.ptp.SourceMethod; +import org.fedai.osx.core.utils.AssertUtil; +import org.fedai.osx.core.utils.JsonUtil; +import org.ppc.ptp.Osx; +import org.ppc.ptp.PrivateTransferProtocolGrpc; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.management.MBeanServer; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.io.OutputStream; +import java.lang.management.ManagementFactory; +import java.util.Map; + +public class TransferUtil { + + + static Logger logger = LoggerFactory.getLogger(TransferUtil.class); + + + /** + * 2.0之前版本 + * + * @param version + * @return + */ + public static boolean isOldVersionFate(String version) { + + try { + if (StringUtils.isEmpty(version)) + version = MetaInfo.PROPERTY_DEFAULT_CLIENT_VERSION; + String firstVersion = version.substring(0, 1); + if (Integer.parseInt(firstVersion) >= 2) { + return false; + } else { + return true; + } + } catch (NumberFormatException e) { + throw new ConfigErrorException("remote version config error : " + version); + } + + } + + + public static String buildResource(Osx.Inbound inbound) { + String sourceNodeId = inbound.getMetadataMap().get(Osx.Header.SourceNodeID.name()); + String targetNodeId = inbound.getMetadataMap().get(Osx.Header.TargetNodeID.name()); + String sourceInstId = inbound.getMetadataMap().get(Osx.Header.SourceInstID.name()); + if (sourceInstId == null) { + sourceInstId = ""; + } + String targetInstId = inbound.getMetadataMap().get(Osx.Header.TargetInstID.name()); + if (targetInstId == null) { + targetInstId = ""; + } + StringBuffer sb = new StringBuffer(); + sb.append(sourceInstId).append(sourceNodeId).append("_").append(targetInstId).append(targetNodeId); + return sb.toString(); + } + + public static Proxy.Metadata buildProxyMetadataFromOutbound(Osx.Outbound outbound) { + try { + return Proxy.Metadata.parseFrom(outbound.getPayload()); + } catch (InvalidProtocolBufferException e) { + + } + return null; + } + + public static Osx.Outbound buildOutboundFromProxyMetadata(Proxy.Metadata metadata) { + return Osx.Outbound.newBuilder().setPayload(metadata.toByteString()).build(); + + } + + public static Proxy.Packet parsePacketFromInbound(Osx.Inbound inbound) { + try { + return Proxy.Packet.parseFrom(inbound.getPayload()); + } catch (InvalidProtocolBufferException e) { + return null; + } + } + + public static Osx.Inbound.Builder buildInbound(String provider, + String srcPartyId, + String desPartyId, + String targetMethod, + String topic, + MessageFlag messageFlag, + String sessionId, + byte[] payLoad) { + + Osx.Inbound.Builder inboundBuilder = Osx.Inbound.newBuilder(); + inboundBuilder.putMetadata(Osx.Header.Version.name(), MetaInfo.CURRENT_VERSION); + inboundBuilder.putMetadata(Osx.Header.TechProviderCode.name(), provider); +// inboundBuilder.putMetadata(Osx.Header.Token.name(), ""); + inboundBuilder.putMetadata(Osx.Header.SourceNodeID.name(), srcPartyId); + inboundBuilder.putMetadata(Osx.Header.TargetNodeID.name(), desPartyId); +// inboundBuilder.putMetadata(Osx.Header.SourceInstID.name(), ""); +// inboundBuilder.putMetadata(Osx.Header.TargetInstID.name(), ""); + if (StringUtils.isNotEmpty(sessionId)) { + inboundBuilder.putMetadata(Osx.Header.SessionID.name(), sessionId); + } + inboundBuilder.putMetadata(Osx.Metadata.TargetMethod.name(), targetMethod); +// inboundBuilder.putMetadata(Osx.Metadata.TargetComponentName.name(), ""); +// inboundBuilder.putMetadata(Osx.Metadata.SourceComponentName.name(), ""); + if (StringUtils.isNotEmpty(topic)) { + inboundBuilder.putMetadata(Osx.Metadata.MessageTopic.name(), topic); + } + if (messageFlag != null) { + inboundBuilder.putMetadata(Osx.Metadata.MessageFlag.name(), messageFlag.name()); + } + if (payLoad != null) { + inboundBuilder.setPayload(ByteString.copyFrom(payLoad)); + } + return inboundBuilder; + + } + + + public static TransferMeta parseTransferMetaFromProxyPacket(Proxy.Packet packet) { + TransferMeta transferMeta = new TransferMeta(); + Proxy.Metadata metadata = packet.getHeader(); + Transfer.RollSiteHeader rollSiteHeader = null; + String dstPartyId = null; + String srcPartyId = null; + String desRole = null; + String srcRole = null; + try { + rollSiteHeader = Transfer.RollSiteHeader.parseFrom(metadata.getExt()); + } catch (InvalidProtocolBufferException e) { + throw new ParameterException("invalid rollSiteHeader"); + } + String sessionId = ""; + if (rollSiteHeader != null) { + dstPartyId = rollSiteHeader.getDstPartyId(); + srcPartyId = rollSiteHeader.getSrcPartyId(); + desRole = rollSiteHeader.getDstRole(); + srcRole = rollSiteHeader.getSrcRole(); + } + if (StringUtils.isEmpty(dstPartyId)) { + dstPartyId = metadata.getDst().getPartyId(); + } + if (StringUtils.isEmpty(desRole)) { + desRole = metadata.getDst().getRole(); + } + if (StringUtils.isEmpty(srcRole)) { + srcRole = metadata.getSrc().getRole(); + } + if (StringUtils.isEmpty(srcPartyId)) { + srcPartyId = metadata.getSrc().getPartyId(); + } + + if (rollSiteHeader != null) { + sessionId = String.join("_", rollSiteHeader.getRollSiteSessionId(), desRole, dstPartyId); + } + if (metadata.getDst() != null) { + transferMeta.setTopic(metadata.getDst().getName()); + } + + transferMeta.setDesPartyId(dstPartyId); + transferMeta.setSrcPartyId(srcPartyId); + transferMeta.setDesRole(desRole); + transferMeta.setSrcRole(srcRole); + transferMeta.setSessionId(sessionId); + return transferMeta; + } + + public static void assableContextFromInbound(Context context, Osx.Inbound request) { + //initContext + Map metaDataMap = request.getMetadataMap(); + String version = metaDataMap.get(Osx.Header.Version.name()); + String jobId = metaDataMap.get(Osx.Metadata.JobId.name()); + String techProviderCode = metaDataMap.get(Osx.Header.TechProviderCode.name()); + String traceId = metaDataMap.get(Osx.Header.TraceID.name()); + String token = metaDataMap.get(Osx.Header.Token.name()); + String sourceNodeId = metaDataMap.get(Osx.Header.SourceNodeID.name()); + String targetNodeId = metaDataMap.get(Osx.Header.TargetNodeID.name()); + String sourceInstId = metaDataMap.get(Osx.Header.SourceInstID.name()); + String targetInstId = metaDataMap.get(Osx.Header.TargetInstID.name()); + String sessionId = metaDataMap.get(Osx.Header.SessionID.name()); + String targetMethod = metaDataMap.get(Osx.Metadata.TargetMethod.name()); + String targetComponentName = metaDataMap.get(Osx.Metadata.TargetComponentName.name()); + String sourceComponentName = metaDataMap.get(Osx.Metadata.SourceComponentName.name()); + String sourcePartyId = StringUtils.isEmpty(sourceInstId) ? sourceNodeId : sourceInstId + "." + sourceNodeId; + String targetPartyId = StringUtils.isEmpty(targetInstId) ? targetNodeId : targetInstId + "." + targetNodeId; + String topic = metaDataMap.get(Osx.Metadata.MessageTopic.name()); + String offsetString = metaDataMap.get(Osx.Metadata.MessageOffSet.name()); + String messageCode = metaDataMap.get(Osx.Metadata.MessageCode.name()); + Long offset = StringUtils.isNotEmpty(offsetString) ? Long.parseLong(offsetString) : null; + context.setTraceId(traceId); + context.setToken(token); + context.setDesPartyId(targetPartyId); + context.setSrcPartyId(sourcePartyId); + context.setTopic(topic); + context.setJobId(jobId); + + + if (context instanceof FateContext) { + ((FateContext) context).setRequestMsgIndex(offset); + ((FateContext) context).setMessageCode(messageCode); + } + context.setSessionId(sessionId); + context.setDesComponent(targetComponentName); + context.setSrcComponent(sourceComponentName); + context.setTechProviderCode(techProviderCode); + if (MetaInfo.PROPERTY_SELF_PARTY.contains(context.getDesPartyId())) { + context.setSelfPartyId(context.getDesPartyId()); + } else { + context.setSelfPartyId(MetaInfo.PROPERTY_SELF_PARTY.toArray()[0].toString()); + } + } + + public static void assableContextFromProxyPacket(Context context, Proxy.Packet packet) { + TransferMeta transferMeta = parseTransferMetaFromProxyPacket(packet); + context.setSrcPartyId(transferMeta.getSrcPartyId()); + context.setDesPartyId(transferMeta.getDesPartyId()); + context.setSrcComponent(transferMeta.getSrcRole()); + context.setDesComponent(transferMeta.getDesRole()); + context.setSessionId(transferMeta.getSessionId()); + context.setTopic(transferMeta.getTopic()); + context.setTechProviderCode(MetaInfo.PROPERTY_FATE_TECH_PROVIDER); + if (MetaInfo.PROPERTY_SELF_PARTY.contains(context.getDesPartyId())) { + context.setSelfPartyId(context.getDesPartyId()); + } else { + context.setSelfPartyId(MetaInfo.PROPERTY_SELF_PARTY.toArray()[0].toString()); + } + + } + + + public static Osx.Inbound.Builder buildInboundFromPushingPacket(Proxy.Packet packet, String provider, String targetMethod, String sourceMethod) { + Osx.Inbound.Builder inboundBuilder = Osx.Inbound.newBuilder(); + TransferMeta transferMeta = parseTransferMetaFromProxyPacket(packet); + inboundBuilder.setPayload(packet.toByteString()); + inboundBuilder.putMetadata(Osx.Header.Version.name(), MetaInfo.CURRENT_VERSION); + inboundBuilder.putMetadata(Osx.Header.TechProviderCode.name(), provider); + inboundBuilder.putMetadata(Osx.Header.Token.name(), ""); + inboundBuilder.putMetadata(Osx.Header.SourceNodeID.name(), transferMeta.getSrcPartyId()); + inboundBuilder.putMetadata(Osx.Header.TargetNodeID.name(), transferMeta.getDesPartyId()); + inboundBuilder.putMetadata(Osx.Header.SourceInstID.name(), ""); + inboundBuilder.putMetadata(Osx.Header.TargetInstID.name(), ""); + inboundBuilder.putMetadata(Osx.Metadata.SourceMethod.name(), sourceMethod); + inboundBuilder.putMetadata(Osx.Header.SessionID.name(), transferMeta.getSessionId()); + inboundBuilder.putMetadata(Osx.Metadata.TargetMethod.name(), targetMethod); + inboundBuilder.putMetadata(Osx.Metadata.TargetComponentName.name(), transferMeta.getDesRole()); + inboundBuilder.putMetadata(Osx.Metadata.SourceComponentName.name(), ""); + return inboundBuilder; + + } + + ; + + + static public Osx.Inbound.Builder buildPbFromHttpRequest(Context context, HttpServletRequest request) { + + Osx.Inbound.Builder inboundBuilder = Osx.Inbound.newBuilder(); + String version = request.getHeader(PtpHttpHeader.Version); + String techProviderCode = request.getHeader(PtpHttpHeader.TechProviderCode); + String traceID = request.getHeader(PtpHttpHeader.TraceID); + String token = request.getHeader(PtpHttpHeader.Token); + String sourceNodeID = request.getHeader(PtpHttpHeader.SourceNodeID); + String targetNodeID = request.getHeader(PtpHttpHeader.TargetNodeID); + String sourceInstID = request.getHeader(PtpHttpHeader.SourceInstID); + String targetInstID = request.getHeader(PtpHttpHeader.TargetInstID); + String sessionID = request.getHeader(PtpHttpHeader.SessionID); + String messageTopic = request.getHeader(PtpHttpHeader.MessageTopic); + String messageCode = request.getHeader(Osx.Metadata.MessageCode.name()); + String retryCount = request.getHeader(Osx.Metadata.RetryCount.name()); + String sourceComponentName = request.getHeader(PtpHttpHeader.SourceComponentName); + String targetComponentName = request.getHeader(PtpHttpHeader.TargetComponentName); + String targetMethod = request.getHeader(PtpHttpHeader.TargetMethod); + String sourceMethod = request.getHeader(PtpHttpHeader.SourceMethod); + String messageOffSet = request.getHeader(PtpHttpHeader.MessageOffSet); + String instanceId = request.getHeader(PtpHttpHeader.InstanceId); + String timestamp = request.getHeader(PtpHttpHeader.Timestamp); + String messageFlag = request.getHeader(PtpHttpHeader.MessageFlag); + String jobId = request.getHeader(PtpHttpHeader.JobId); + context.setSrcPartyId(sourceNodeID); + context.setDesPartyId(targetNodeID); + context.setSessionId(sessionID); + context.setTopic(messageTopic); + context.setActionType(targetMethod); + context.setProtocol(Protocol.http); + inboundBuilder.putMetadata(Osx.Header.Version.name(), version != null ? version : ""); + inboundBuilder.putMetadata(Osx.Header.TechProviderCode.name(), techProviderCode != null ? techProviderCode : ""); + inboundBuilder.putMetadata(Osx.Header.Token.name(), token != null ? token : ""); + inboundBuilder.putMetadata(Osx.Header.SourceNodeID.name(), sourceNodeID != null ? sourceNodeID : ""); + inboundBuilder.putMetadata(Osx.Header.TargetNodeID.name(), targetNodeID != null ? targetNodeID : ""); + inboundBuilder.putMetadata(Osx.Header.SourceInstID.name(), sourceInstID != null ? sourceInstID : ""); + inboundBuilder.putMetadata(Osx.Header.TargetInstID.name(), targetInstID != null ? targetInstID : ""); + inboundBuilder.putMetadata(Osx.Header.SessionID.name(), sessionID != null ? sessionID : ""); + inboundBuilder.putMetadata(Osx.Metadata.TargetMethod.name(), targetMethod != null ? targetMethod : ""); + inboundBuilder.putMetadata(Osx.Metadata.TargetComponentName.name(), targetComponentName != null ? targetComponentName : ""); + inboundBuilder.putMetadata(Osx.Metadata.SourceComponentName.name(), sourceComponentName != null ? sourceComponentName : ""); + inboundBuilder.putMetadata(Osx.Metadata.MessageTopic.name(), messageTopic != null ? messageTopic : ""); + inboundBuilder.putMetadata(Osx.Metadata.MessageOffSet.name(), messageOffSet != null ? messageOffSet : ""); + inboundBuilder.putMetadata(Osx.Metadata.InstanceId.name(), instanceId != null ? instanceId : ""); + inboundBuilder.putMetadata(Osx.Metadata.Timestamp.name(), timestamp != null ? timestamp : ""); + inboundBuilder.putMetadata(Osx.Metadata.SourceMethod.name(), sourceMethod != null ? sourceMethod : ""); + inboundBuilder.putMetadata(Osx.Metadata.MessageFlag.name(), messageFlag != null ? messageFlag : ""); + inboundBuilder.putMetadata(Osx.Metadata.JobId.name(), jobId != null ? jobId : ""); + inboundBuilder.putMetadata(Osx.Metadata.MessageCode.name(), messageCode != null ? messageCode : ""); + inboundBuilder.putMetadata(Osx.Metadata.RetryCount.name(), retryCount != null ? retryCount : ""); + return inboundBuilder; + } + + + static public Map parseHttpHeader(Osx.Inbound produceRequest) { + Map metaDataMap = produceRequest.getMetadataMap(); + String version = metaDataMap.get(Osx.Header.Version.name()); + String techProviderCode = metaDataMap.get(Osx.Header.TechProviderCode.name()); + String traceId = metaDataMap.get(Osx.Header.TraceID.name()); + String token = metaDataMap.get(Osx.Header.Token.name()); + String sourceNodeId = metaDataMap.get(Osx.Header.SourceNodeID.name()); + String targetNodeId = metaDataMap.get(Osx.Header.TargetNodeID.name()); + String sourceInstId = metaDataMap.get(Osx.Header.SourceInstID.name()); + String targetInstId = metaDataMap.get(Osx.Header.TargetInstID.name()); + String sessionId = metaDataMap.get(Osx.Header.SessionID.name()); + String targetMethod = metaDataMap.get(Osx.Metadata.TargetMethod.name()); + String sourceMethod = metaDataMap.get(Osx.Metadata.SourceMethod.name()); + String targetComponentName = metaDataMap.get(Osx.Metadata.TargetComponentName.name()); + String sourceComponentName = metaDataMap.get(Osx.Metadata.SourceComponentName.name()); + String sourcePartyId = StringUtils.isEmpty(sourceInstId) ? sourceNodeId : sourceInstId + "." + sourceNodeId; + String targetPartyId = StringUtils.isEmpty(targetInstId) ? targetNodeId : targetInstId + "." + targetNodeId; + String topic = metaDataMap.get(Osx.Metadata.MessageTopic.name()); + String offsetString = metaDataMap.get(Osx.Metadata.MessageOffSet.name()); + String InstanceId = metaDataMap.get(Osx.Metadata.InstanceId.name()); + String timestamp = metaDataMap.get(Osx.Metadata.Timestamp.name()); + String messageCode = metaDataMap.get(Osx.Metadata.MessageCode.name()); + String messageFlag = metaDataMap.get(Osx.Metadata.MessageFlag.name()); + String jobId = metaDataMap.get(Osx.Metadata.JobId.name()); + + Map header = Maps.newHashMap(); + header.put(PtpHttpHeader.Version, version != null ? version : ""); + header.put(PtpHttpHeader.TechProviderCode, techProviderCode != null ? techProviderCode : ""); + header.put(PtpHttpHeader.TraceID, traceId != null ? traceId : ""); + header.put(PtpHttpHeader.Token, token != null ? token : ""); + header.put(PtpHttpHeader.SourceNodeID, sourceNodeId != null ? sourceNodeId : ""); + header.put(PtpHttpHeader.TargetNodeID, targetNodeId != null ? targetNodeId : ""); + header.put(PtpHttpHeader.SourceInstID, sourceInstId != null ? sourceInstId : ""); + header.put(PtpHttpHeader.TargetInstID, targetInstId != null ? targetInstId : ""); + header.put(PtpHttpHeader.SessionID, sessionId != null ? sessionId : ""); + header.put(PtpHttpHeader.MessageTopic, topic != null ? topic : ""); + header.put(PtpHttpHeader.MessageCode, messageCode); + header.put(PtpHttpHeader.SourceComponentName, sourceComponentName != null ? sourceComponentName : ""); + header.put(PtpHttpHeader.TargetComponentName, targetComponentName != null ? targetComponentName : ""); + header.put(PtpHttpHeader.TargetMethod, targetMethod != null ? targetMethod : ""); + header.put(PtpHttpHeader.SourceMethod, sourceMethod != null ? sourceMethod : ""); + header.put(PtpHttpHeader.MessageOffSet, offsetString != null ? offsetString : ""); + header.put(PtpHttpHeader.InstanceId, InstanceId != null ? InstanceId : ""); + header.put(PtpHttpHeader.Timestamp, timestamp != null ? timestamp : ""); + header.put(PtpHttpHeader.MessageFlag, messageFlag != null ? messageFlag : ""); + header.put(PtpHttpHeader.JobId, jobId != null ? jobId : ""); + + return header; + } + + static public Osx.Outbound redirect(FateContext context, Osx.Inbound + produceRequest, RouterInfo routerInfo, boolean usePooled) { + AssertUtil.notNull(routerInfo, context.getDesPartyId() != null ? "des partyId " + context.getDesPartyId() + " router info is null" : " error router info"); + Osx.Outbound result = null; + context.setDataSize(produceRequest.getSerializedSize()); + if (routerInfo.isCycle()) { + throw new CycleRouteInfoException("cycle router info"); + } + if (routerInfo.getProtocol() == null || routerInfo.getProtocol().equals(Protocol.grpc)) { + //来自旧版fateflow的请求,需要用旧版的stub + if (context.isDestination() && Role.fateflow.name().equals(routerInfo.getDesRole()) + && SourceMethod.OLDUNARY_CALL.name().equals(produceRequest.getMetadataMap().get(Osx.Metadata.SourceMethod.name()))) { + ManagedChannel managedChannel = GrpcConnectionFactory.createManagedChannel(routerInfo, usePooled); + DataTransferServiceGrpc.DataTransferServiceBlockingStub stub = DataTransferServiceGrpc.newBlockingStub(managedChannel); + Proxy.Packet request; + try { + request = Proxy.Packet.parseFrom(produceRequest.getPayload().toByteArray()); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + Proxy.Packet response = stub.unaryCall(request); + result = Osx.Outbound.newBuilder().setPayload(response.toByteString()).setCode(StatusCode.SUCCESS).build(); + } else { + + PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub stub = null; + if (context.getData(Dict.BLOCKING_STUB) == null) { + ManagedChannel managedChannel = GrpcConnectionFactory.createManagedChannel(routerInfo, usePooled); + stub = PrivateTransferProtocolGrpc.newBlockingStub(managedChannel); + } else { + stub = (PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub) context.getData(Dict.BLOCKING_STUB); + } + try { + result = stub.invoke(produceRequest); + } catch (StatusRuntimeException e) { + logger.error("redirect error", e); + throw new RemoteRpcException(StatusCode.NET_ERROR, "send to " + routerInfo.toKey() + " error : " + e.getMessage()); + } + } + // ServiceContainer.tokenApplyService.applyToken(context,routerInfo.getResource(),produceRequest.getSerializedSize()); + } else { + String url = routerInfo.getUrl(); + Map header = parseHttpHeader(produceRequest); + long startTime = System.currentTimeMillis(); + try { + if (routerInfo.getProtocol().equals(Protocol.http)) { + + if (routerInfo.isUseSSL()) { + result = HttpsClientPool.sendPtpPost(url, produceRequest.getPayload().toByteArray(), header, routerInfo.getCaFile(), routerInfo.getCertChainFile(), routerInfo.getPrivateKeyFile()); + } else { + + result = HttpClientPool.sendPtpPost(url, produceRequest.getPayload().toByteArray(), header); + } + } + } catch (Exception e) { + + logger.error("sendPtpPost failed : url = {}, startTime = {} , cost = {} ,header = {} , body = {} \n" + , url, startTime, System.currentTimeMillis() - startTime, JsonUtil.object2Json(header), JsonUtil.object2Json(produceRequest.getPayload()), e); + ExceptionInfo exceptionInfo = ErrorMessageUtil.handleExceptionExceptionInfo(context, e); + result = Osx.Outbound.newBuilder().setCode(exceptionInfo.getCode()).setMessage(exceptionInfo.getMessage()).build(); + } + } + return result; + } + + + public static Osx.Outbound.Builder buildResponseInner(String code, String msgReturn, byte[] content) { + + Osx.Outbound.Builder builder = Osx.Outbound.newBuilder(); + builder.setCode(code); + builder.setMessage(msgReturn); + if (content != null) { + builder.setPayload(ByteString.copyFrom(content)); + } + return builder; + } + + + public static Osx.Outbound buildResponse(String code, String msgReturn, TransferQueue.TransferQueueConsumeResult messageWraper) { + + byte[] content = null; + if (messageWraper != null) { + Osx.Message message = null; + try { + message = Osx.Message.parseFrom(messageWraper.getMessage().getBody()); + } catch (InvalidProtocolBufferException e) { + logger.error("parse message error", e); + } + content = message.toByteArray(); + } + Osx.Outbound.Builder builder = buildResponseInner(code, msgReturn, content); + if (messageWraper != null) { + builder.putMetadata(Osx.Metadata.MessageOffSet.name(), Long.toString(messageWraper.getRequestIndex())); + } + return builder.build(); + } + + public static void checkResponse(Osx.Outbound outbound) { + if (outbound != null) { + String code = outbound.getCode(); + String message = outbound.getMessage(); + if (!StatusCode.SUCCESS.equals(code)) { +// logger.error("================== xxxxxx {}", outbound); + throw new RemoteRpcException("remote code : " + code + " remote msg: " + message); + } + } else { + throw new RemoteRpcException("has no response"); + } + } + + public static void writeHttpRespose(HttpServletResponse response, String code, + String msg, + byte[] content) { + try { + response.setHeader(PtpHttpHeader.ReturnCode, code); + response.setHeader(PtpHttpHeader.MessageCode, msg); + OutputStream outputStream = response.getOutputStream(); + if (content != null) { + outputStream.write(content); + } + outputStream.flush(); + } catch (IOException e) { + logger.error("write http response error", e); + } + } + + + public static void main(String[] args) { + + MBeanServer platformMBeanServer = ManagementFactory.getPlatformMBeanServer(); + + if (platformMBeanServer instanceof com.sun.management.OperatingSystemMXBean) { + com.sun.management.OperatingSystemMXBean osBean = (com.sun.management.OperatingSystemMXBean) platformMBeanServer; + + // 获取连接数 + int connectionCount = osBean.getAvailableProcessors(); + System.out.println("HTTP 连接数: " + connectionCount); + } else { + System.out.println("当前平台不支持获取 HTTP 连接数"); + } + +// TransferUtil a = new TransferUtil(); +// a.testHttps(); + } + + public void testHttps() { + try { + new Thread(() -> { + Osx.Outbound outbound = null; + try { + Thread.sleep(3000); + outbound = HttpsClientPool.sendPtpPost("https://127.0.0.1:8088/osx/inbound", new byte[10], null, "D:\\22\\ca.crt", "D:\\22\\174_2.crt", "D:\\22\\174_2.key"); + } catch (Exception e) { + e.printStackTrace(); + } + System.out.println("outbound = " + outbound); + + }).start(); + } catch (Exception e) { + e.printStackTrace(); + } + } +} diff --git a/java/osx/broker/src/main/java/com/osx/broker/util/UtilAll.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/UtilAll.java similarity index 99% rename from java/osx/broker/src/main/java/com/osx/broker/util/UtilAll.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/UtilAll.java index ac00f84f50..bda9cbe519 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/util/UtilAll.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/util/UtilAll.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.util; +package org.fedai.osx.broker.util; import org.apache.commons.lang3.StringUtils; import org.apache.commons.validator.routines.InetAddressValidator; @@ -421,9 +421,6 @@ public static boolean isInternalIP(byte[] ip) { throw new RuntimeException("illegal ipv4 bytes"); } - //10.0.0.0~10.255.255.255 - //172.16.0.0~172.31.255.255 - //192.168.0.0~192.168.255.255 if (ip[0] == (byte) 10) { return true; diff --git a/java/osx/broker/src/main/java/com/osx/broker/zk/AbstractZookeeperClient.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/zk/AbstractZookeeperClient.java similarity index 99% rename from java/osx/broker/src/main/java/com/osx/broker/zk/AbstractZookeeperClient.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/zk/AbstractZookeeperClient.java index 782dde2f22..464ef249f0 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/zk/AbstractZookeeperClient.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/zk/AbstractZookeeperClient.java @@ -29,7 +29,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.zk; +package org.fedai.osx.broker.zk; import org.apache.zookeeper.KeeperException; diff --git a/java/osx/broker/src/main/java/com/osx/broker/zk/ChildListener.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/zk/ChildListener.java similarity index 95% rename from java/osx/broker/src/main/java/com/osx/broker/zk/ChildListener.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/zk/ChildListener.java index 733c5f4e19..4dfd9f09de 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/zk/ChildListener.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/zk/ChildListener.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.zk; +package org.fedai.osx.broker.zk; import java.util.List; diff --git a/java/osx/broker/src/main/java/com/osx/broker/zk/CuratorZookeeperClient.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/zk/CuratorZookeeperClient.java similarity index 98% rename from java/osx/broker/src/main/java/com/osx/broker/zk/CuratorZookeeperClient.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/zk/CuratorZookeeperClient.java index a9368bbe02..3728df17aa 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/zk/CuratorZookeeperClient.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/zk/CuratorZookeeperClient.java @@ -14,10 +14,9 @@ * limitations under the License. */ -package com.osx.broker.zk; +package org.fedai.osx.broker.zk; -import com.osx.core.config.MetaInfo; import org.apache.commons.lang3.StringUtils; import org.apache.curator.framework.CuratorFramework; import org.apache.curator.framework.CuratorFrameworkFactory; @@ -36,6 +35,7 @@ import org.apache.zookeeper.data.ACL; import org.apache.zookeeper.data.Id; import org.apache.zookeeper.server.auth.DigestAuthenticationProvider; +import org.fedai.osx.core.config.MetaInfo; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -199,14 +199,9 @@ public void createEphemeral(String path, String data) throws NodeExistsException public void delete(String path) { try { if (aclEnable) { -// Stat stat = client.checkExists().forPath(path); -// client.delete().withVersion(stat.getAversion()).forPath(path); this.clearAcl(path); } - logger.info("xxxxxxxxxxxxx"); - client.delete().forPath(path); - logger.info("pppppppppppppppp"); } catch (NoNodeException e) { e.printStackTrace(); } catch (Exception e) { diff --git a/java/osx/broker/src/main/java/com/osx/broker/zk/DataListener.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/zk/DataListener.java similarity index 95% rename from java/osx/broker/src/main/java/com/osx/broker/zk/DataListener.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/zk/DataListener.java index fc91455532..b817ef4b49 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/zk/DataListener.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/zk/DataListener.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.zk; +package org.fedai.osx.broker.zk; public interface DataListener { diff --git a/java/osx/broker/src/main/java/com/osx/broker/zk/EventType.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/zk/EventType.java similarity index 98% rename from java/osx/broker/src/main/java/com/osx/broker/zk/EventType.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/zk/EventType.java index 4d414db2e0..69c83e35cc 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/zk/EventType.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/zk/EventType.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.broker.zk; +package org.fedai.osx.broker.zk; import org.apache.zookeeper.Watcher; diff --git a/java/osx/broker/src/main/java/com/osx/broker/zk/StateListener.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/zk/StateListener.java similarity index 95% rename from java/osx/broker/src/main/java/com/osx/broker/zk/StateListener.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/zk/StateListener.java index 7806084e19..06b9d03674 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/zk/StateListener.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/zk/StateListener.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.broker.zk; +package org.fedai.osx.broker.zk; public interface StateListener { diff --git a/java/osx/broker/src/main/java/com/osx/broker/zk/ZkConfig.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/zk/ZkConfig.java similarity index 97% rename from java/osx/broker/src/main/java/com/osx/broker/zk/ZkConfig.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/zk/ZkConfig.java index 0f317aec00..04e48e5b29 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/zk/ZkConfig.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/zk/ZkConfig.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.broker.zk; +package org.fedai.osx.broker.zk; import com.google.common.collect.Lists; diff --git a/java/osx/broker/src/main/java/com/osx/broker/zk/ZookeeperClient.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/zk/ZookeeperClient.java similarity index 98% rename from java/osx/broker/src/main/java/com/osx/broker/zk/ZookeeperClient.java rename to java/osx/osx-broker/src/main/java/org/fedai/osx/broker/zk/ZookeeperClient.java index dd6c1f2053..bec9a6c4c4 100644 --- a/java/osx/broker/src/main/java/com/osx/broker/zk/ZookeeperClient.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/zk/ZookeeperClient.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.broker.zk; +package org.fedai.osx.broker.zk; import org.apache.zookeeper.KeeperException; diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/tech/provider/FateTechProvider.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/tech/provider/FateTechProvider.java new file mode 100644 index 0000000000..1cbf604b5f --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/tech/provider/FateTechProvider.java @@ -0,0 +1,236 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.fedai.osx.tech.provider; + + +import com.google.common.collect.Sets; +import com.google.protobuf.ByteString; +import io.grpc.stub.StreamObserver; +import org.apache.commons.io.IOUtils; +import org.apache.commons.lang3.StringUtils; +import org.fedai.osx.api.constants.Protocol; +import org.fedai.osx.api.context.Context; +import org.fedai.osx.broker.interceptor.PcpHandleInterceptor; +import org.fedai.osx.broker.interceptor.RouterInterceptor; +import org.fedai.osx.broker.interceptor.TokenValidatorInterceptor; +import org.fedai.osx.broker.ptp.*; +import org.fedai.osx.broker.util.ContextUtil; +import org.fedai.osx.broker.util.TransferUtil; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.constant.Dict; +import org.fedai.osx.core.constant.PtpHttpHeader; +import org.fedai.osx.core.exceptions.ErrorMessageUtil; +import org.fedai.osx.core.exceptions.ExceptionInfo; +import org.fedai.osx.core.exceptions.ParameterException; +import org.fedai.osx.core.provider.TechProvider; +import org.fedai.osx.core.ptp.TargetMethod; +import org.fedai.osx.core.service.InboundPackage; +import org.fedai.osx.core.service.OutboundPackage; +import org.fedai.osx.core.service.ServiceAdaptor; +import org.fedai.osx.core.utils.FlowLogUtil; +import org.ppc.ptp.Osx; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.io.OutputStream; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +/** + * FATE 相关实现 + */ + +public class FateTechProvider implements TechProvider { + + Logger logger = LoggerFactory.getLogger(FateTechProvider.class); + ConcurrentMap serviceAdaptorConcurrentMap = new ConcurrentHashMap<>(); + PcpHandleInterceptor requestHandleInterceptor; + TokenValidatorInterceptor tokenValidatorInterceptor; + RouterInterceptor routerInterceptor; + private Set httpAllowedMethod = Sets.newHashSet(TargetMethod.PRODUCE_MSG.name(), TargetMethod.UNARY_CALL.name()); + + public FateTechProvider() { + requestHandleInterceptor = new PcpHandleInterceptor(); + tokenValidatorInterceptor = new TokenValidatorInterceptor(); + routerInterceptor = new RouterInterceptor(); + registerServiceAdaptor(); + } + + private void checkHttpAllowedMethod(String targetMethod) { + if (!httpAllowedMethod.contains(targetMethod)) { + throw new ParameterException("target method :" + targetMethod + "is not allowed"); + } + } + + @Override + public void processHttpInvoke(HttpServletRequest request, HttpServletResponse response) { + Context context = ContextUtil.buildFateContext(Protocol.http); + context.putData(Dict.HTTP_SERVLET_RESPONSE, response); + Osx.Inbound.Builder inboundBuilder; + ServiceAdaptor serviceAdaptor = null; + try { + inboundBuilder = TransferUtil.buildPbFromHttpRequest(context, request); + String targetMethod = inboundBuilder.getMetadataMap().get(Osx.Metadata.TargetMethod.name()); + if (StringUtils.isEmpty(targetMethod)) { + throw new ParameterException("target method is null"); + } + checkHttpAllowedMethod(targetMethod); + serviceAdaptor = this.getServiceAdaptor(targetMethod); + + byte[] buffer = new byte[MetaInfo.PROPERTY_HTTP_REQUEST_BODY_MAX_SIZE]; + int length = IOUtils.read(request.getInputStream(), buffer); + byte[] data = new byte[length]; + System.arraycopy(buffer, 0, data, 0, length); + inboundBuilder.setPayload(ByteString.copyFrom(data)); + } catch (Exception e) { + logger.error("processHttpInvoke error :" , e); + ExceptionInfo exceptionInfo = ErrorMessageUtil.handleExceptionExceptionInfo(context, e); + this.writeHttpRespose(response, exceptionInfo.getCode(), exceptionInfo.getMessage(), null); + context.setReturnCode(exceptionInfo.getCode()); + context.setReturnMsg(exceptionInfo.getMessage()); + FlowLogUtil.printFlowLog(context); + return; + } + InboundPackage inboundPackage = new InboundPackage(); + inboundPackage.setBody(inboundBuilder.build()); + OutboundPackage outboundPackage = serviceAdaptor.service(context, inboundPackage); + Osx.Outbound outbound = outboundPackage.getData(); + response.setContentType(Dict.CONTENT_TYPE_JSON_UTF8); + TransferUtil.writeHttpRespose(response, outbound.getCode(), outbound.getMessage(), outbound.getPayload().toByteArray()); + } + + private void writeHttpRespose(HttpServletResponse response, String code, + String msg, + byte[] content) { + try { + response.setHeader(PtpHttpHeader.ReturnCode, code); + response.setHeader(PtpHttpHeader.MessageCode, msg); + OutputStream outputStream = response.getOutputStream(); + if (content != null) { + outputStream.write(content); + } + outputStream.flush(); + } catch (IOException e) { + e.printStackTrace(); + } + } + + + @Override + public void processGrpcInvoke(Osx.Inbound request, StreamObserver responseObserver) { + Context context = ContextUtil.buildFateContext(Protocol.grpc); + context.putData(Dict.RESPONSE_STREAM_OBSERVER, responseObserver); + Osx.Outbound result = null; + try { + Map metaDataMap = request.getMetadataMap(); + String targetMethod = metaDataMap.get(Osx.Metadata.TargetMethod.name()); + ServiceAdaptor serviceAdaptor = this.getServiceAdaptor(targetMethod); + if (serviceAdaptor == null) { + throw new ParameterException("invalid target method " + targetMethod); + } + InboundPackage inboundPackage = new InboundPackage(); + inboundPackage.setBody(request); + OutboundPackage outboundPackage = serviceAdaptor.service(context, inboundPackage); + if (outboundPackage.getData() != null) { + result = outboundPackage.getData(); + } + } catch (Exception e) { + ExceptionInfo exceptionInfo = ErrorMessageUtil.handleExceptionExceptionInfo(context, e); + //this.writeHttpRespose(response, exceptionInfo.getCode(),exceptionInfo.getMessage(),null); + context.setReturnCode(exceptionInfo.getCode()); + context.setReturnMsg(exceptionInfo.getMessage()); + FlowLogUtil.printFlowLog(context); + result = Osx.Outbound.newBuilder().setCode(exceptionInfo.getCode()).setMessage(exceptionInfo.getMessage()).build(); + } + if (result != null) { + responseObserver.onNext(result); + responseObserver.onCompleted(); + } + + } + + + @Override + public StreamObserver processGrpcTransport(Osx.Inbound fristPackage, StreamObserver responseObserver) { + Map metaDataMap = fristPackage.getMetadataMap(); + String targetMethod = metaDataMap.get(Osx.Metadata.TargetMethod.name()); + ServiceAdaptor serviceAdaptor = this.getServiceAdaptor(targetMethod); + if (serviceAdaptor == null) { + throw new ParameterException("invalid target method " + targetMethod); + } + Context context = ContextUtil.buildFateContext(Protocol.grpc); + InboundPackage inboundPackage = new InboundPackage(); + inboundPackage.setBody(responseObserver); + OutboundPackage> outboundPackage = serviceAdaptor.service(context, inboundPackage); + if (outboundPackage != null && outboundPackage.getData() != null) { + return (StreamObserver) outboundPackage.getData(); + } else { + return null; + } + } + + @Override + public void processGrpcPeek(Osx.PeekInbound inbound, StreamObserver responseObserver) { + + } + + @Override + public void processGrpcPush(Osx.PushInbound inbound, StreamObserver responseObserver) { + + } + + @Override + public void processGrpcPop(Osx.PopInbound inbound, StreamObserver responseObserver) { + + } + + @Override + public void processGrpcRelease(Osx.ReleaseInbound inbound, StreamObserver responseObserver) { + + } + + + public ServiceAdaptor getServiceAdaptor(String name) { + return this.serviceAdaptorConcurrentMap.get(name); + } + + private void registerServiceAdaptor() { + this.serviceAdaptorConcurrentMap.put(TargetMethod.UNARY_CALL.name(), new PtpUnaryCallService() + .addPreProcessor(requestHandleInterceptor) + .addPreProcessor(tokenValidatorInterceptor) + .addPreProcessor(routerInterceptor)); + this.serviceAdaptorConcurrentMap.put(TargetMethod.PRODUCE_MSG.name(), new PtpProduceService() + .addPreProcessor(requestHandleInterceptor) + .addPreProcessor(routerInterceptor)); + this.serviceAdaptorConcurrentMap.put(TargetMethod.ACK_MSG.name(), new PtpAckService() + .addPreProcessor(requestHandleInterceptor)); + this.serviceAdaptorConcurrentMap.put(TargetMethod.CONSUME_MSG.name(), new PtpConsumeService() + .addPreProcessor(requestHandleInterceptor)); + this.serviceAdaptorConcurrentMap.put(TargetMethod.QUERY_TOPIC.name(), new PtpQueryTransferQueueService() + .addPreProcessor(requestHandleInterceptor)); + this.serviceAdaptorConcurrentMap.put(TargetMethod.CANCEL_TOPIC.name(), new PtpCancelTransferService() + .addPreProcessor(requestHandleInterceptor)); + this.serviceAdaptorConcurrentMap.put(TargetMethod.PUSH.name(), new PtpPushService()); + this.serviceAdaptorConcurrentMap.put(TargetMethod.APPLY_TOKEN.name(), new PtpClusterTokenApplyService()); + this.serviceAdaptorConcurrentMap.put(TargetMethod.APPLY_TOPIC.name(), new PtpClusterTopicApplyService()); + // this.serviceAdaptorConcurrentMap.put(TargetMethod.TEST_STREAM.name(), new PtpStreamTestService()); + } +} diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/tech/provider/TechProviderRegister.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/tech/provider/TechProviderRegister.java new file mode 100644 index 0000000000..808f757861 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/tech/provider/TechProviderRegister.java @@ -0,0 +1,72 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.fedai.osx.tech.provider; + +import org.apache.commons.lang3.StringUtils; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.constant.Dict; +import org.fedai.osx.core.exceptions.ParameterException; +import org.fedai.osx.core.frame.Lifecycle; +import org.fedai.osx.core.provider.TechProvider; +import org.fedai.osx.core.utils.ClassUtils; +import org.fedai.osx.core.utils.PropertiesUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Properties; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +/** + * 厂商选择 + */ +public class TechProviderRegister implements Lifecycle { + + Logger logger = LoggerFactory.getLogger(TechProviderRegister.class); + ConcurrentMap registerMap = new ConcurrentHashMap<>(); + final String configFileName = "components/provider.properties"; + final + public TechProvider select(String techProviderCode ) { + if(StringUtils.isEmpty(techProviderCode)){ + throw new ParameterException("techProviderCode is null"); + } + return this.registerMap.get(techProviderCode); + } + public void init() { + Properties properties = PropertiesUtil.getProperties(MetaInfo.PROPERTY_CONFIG_DIR+Dict.SLASH+Dict.SLASH+configFileName); + properties.forEach((k,v)->{ + try { + this.registerMap.put(k.toString(), (TechProvider) ClassUtils.newInstance(v.toString())); + }catch(Exception e){ + logger.error("provider {} class {} init error",k,v); + } + }); + logger.info("tech provider register : {}",this.registerMap); + } + + @Override + public void start() { + init(); + } + @Override + public void destroy() { + this.registerMap.clear(); + } + +} + + + diff --git a/java/osx/osx-broker/src/main/resources/broker/broker.properties b/java/osx/osx-broker/src/main/resources/broker/broker.properties new file mode 100644 index 0000000000..70d9445f31 --- /dev/null +++ b/java/osx/osx-broker/src/main/resources/broker/broker.properties @@ -0,0 +1,61 @@ +grpc.port= 9370 +# Http switch for the server. +# If set to True, the server will open the http port. +# http port configuration can be set through http.port +open.http.server=false +# port of http +http.port=8087 +https.port=8088 +# whether the http server uses TLS +#ttp.use.tls = false +# whether the grpc server uses TLS? +# If true, a grpc port will be specially opened to listen for TLS requests +# grpc tls port configuration can be set through grpc.tls.port +open.grpc.tls.server=false +grpc.tls.port=9883 +# the partyId of self ,multiple partyIds can be set. +# eg: 9999,10000,10001 +self.party=9999 +# deployment mode, including cluster/standalone, +# respectively representing cluster mode and standalone mode , +# and standalone is used by default +deploy.mode=standalone +# the zookeeper address needs to be configured when the deployment mode is cluster +zk.url=127.0.0.1:2181 +stream.limit.mode=LOCAL + +# the IP of the cluster manager component of eggroll +eggroll.cluster.manager.ip = localhost +# the port of the cluster manager component of eggroll +eggroll.cluster.manager.port = 4670 +# maximum number of message retries +produce.msg.max.try.time =3 + +http.client.method.config = {"UNARY_CALL":{"reqTimeout":0,"connectionTimeout":0,"socketTimeout":0}} + +http.use.tls=false + +http.ssl.trust.store.type=PKCS12 + +http.ssl.key.store.alias=22 + +http.ssl.key.store.password=123456 + + +mapped.file.size=134217728 + +#http.ssl.trust.store.path=D:\\44\\127.0.0.1.pfx + +server.ca.file= +server.cert.chain.file= +server.private.key.file= + + + + + + + + + + diff --git a/java/osx/broker/src/main/resources/flowRule.json b/java/osx/osx-broker/src/main/resources/broker/flowRule.json similarity index 85% rename from java/osx/broker/src/main/resources/flowRule.json rename to java/osx/osx-broker/src/main/resources/broker/flowRule.json index 50382af6bb..d18f9c43f9 100644 --- a/java/osx/broker/src/main/resources/flowRule.json +++ b/java/osx/osx-broker/src/main/resources/broker/flowRule.json @@ -1,6 +1,6 @@ [ { - "resource": "-10000", + "resource": "9999_9999", "grade": 1, "count": 2000, "strategy": 0, diff --git a/java/osx/broker/src/main/resources/route_table.json b/java/osx/osx-broker/src/main/resources/broker/route_table.json old mode 100755 new mode 100644 similarity index 55% rename from java/osx/broker/src/main/resources/route_table.json rename to java/osx/osx-broker/src/main/resources/broker/route_table.json index b34487d85d..abe60b8c56 --- a/java/osx/broker/src/main/resources/route_table.json +++ b/java/osx/osx-broker/src/main/resources/broker/route_table.json @@ -3,29 +3,24 @@ { "9999": { - "default":[ - { - "port": 9371, - "ip": "localhost" - } - ], "fateflow":[ { "port": 9360, - "ip": "localhost" + "ip": "127.0.0.1" } ] }, "10000":{ "default":[{ - "port": 9889, - "ip": "localhost" + "protocol":"http", + "url": "http://127.0.0.1:8087/osx/inbound", + "ip": "127.0.0.1", + "port": 9370 }] - } }, "permission": { "default_allow": true } -} +} \ No newline at end of file diff --git a/java/osx/osx-broker/src/main/resources/components/provider.properties b/java/osx/osx-broker/src/main/resources/components/provider.properties new file mode 100644 index 0000000000..2dd3cacb8e --- /dev/null +++ b/java/osx/osx-broker/src/main/resources/components/provider.properties @@ -0,0 +1,2 @@ +FATE=org.fedai.osx.tech.provider.FateTechProvider + diff --git a/java/osx/osx-broker/src/main/resources/components/router.properties b/java/osx/osx-broker/src/main/resources/components/router.properties new file mode 100644 index 0000000000..aacb53668b --- /dev/null +++ b/java/osx/osx-broker/src/main/resources/components/router.properties @@ -0,0 +1 @@ +FATE=org.fedai.osx.broker.router.DefaultFateRouterServiceImpl \ No newline at end of file diff --git a/java/osx/osx-broker/src/main/resources/components/translator.properties b/java/osx/osx-broker/src/main/resources/components/translator.properties new file mode 100644 index 0000000000..cbd3bc5e17 --- /dev/null +++ b/java/osx/osx-broker/src/main/resources/components/translator.properties @@ -0,0 +1,2 @@ +9999-10000= org.fedai.osx.broker.demo.DemoTranslator +10000-9999= org.fedai.osx.broker.demo.DemoTranslator \ No newline at end of file diff --git a/java/osx/broker/src/main/resources/log4j2.xml b/java/osx/osx-broker/src/main/resources/log4j2.xml similarity index 100% rename from java/osx/broker/src/main/resources/log4j2.xml rename to java/osx/osx-broker/src/main/resources/log4j2.xml diff --git a/java/osx/broker/src/test/java/com/osx/broker/cluster/ClusterClientEndpointTest.java b/java/osx/osx-broker/src/test/java/org/fedai/osx/broker/cluster/ClusterClientEndpointTest.java similarity index 100% rename from java/osx/broker/src/test/java/com/osx/broker/cluster/ClusterClientEndpointTest.java rename to java/osx/osx-broker/src/test/java/org/fedai/osx/broker/cluster/ClusterClientEndpointTest.java diff --git a/java/osx/osx-broker/src/test/java/org/fedai/osx/broker/mock/MockHttpServer.java b/java/osx/osx-broker/src/test/java/org/fedai/osx/broker/mock/MockHttpServer.java new file mode 100644 index 0000000000..847ab28043 --- /dev/null +++ b/java/osx/osx-broker/src/test/java/org/fedai/osx/broker/mock/MockHttpServer.java @@ -0,0 +1,26 @@ +package org.fedai.osx.broker.mock; + +import org.fedai.osx.broker.ServiceContainer; +import org.fedai.osx.core.config.MetaInfo; + +import java.util.HashSet; + + +public class MockHttpServer { + + + + + + public static void main(String[] args){ + HashSet selfPartyIds = new HashSet(); + selfPartyIds.add("10001"); + MetaInfo.PROPERTY_SELF_PARTY= selfPartyIds; + MetaInfo.PROPERTY_GRPC_PORT=9372; + MetaInfo.PROPERTY_HTTP_PORT=8222; + MetaInfo.PROPERTY_OPEN_HTTP_SERVER = Boolean.TRUE; + ServiceContainer.init(); + } + + +} diff --git a/java/osx/broker/src/test/java/com/osx/broker/mock/MockServer.java b/java/osx/osx-broker/src/test/java/org/fedai/osx/broker/mock/MockServer.java similarity index 69% rename from java/osx/broker/src/test/java/com/osx/broker/mock/MockServer.java rename to java/osx/osx-broker/src/test/java/org/fedai/osx/broker/mock/MockServer.java index 424aa14ed6..fd19385d08 100644 --- a/java/osx/broker/src/test/java/com/osx/broker/mock/MockServer.java +++ b/java/osx/osx-broker/src/test/java/org/fedai/osx/broker/mock/MockServer.java @@ -1,31 +1,22 @@ -package com.osx.broker.mock; +package org.fedai.osx.broker.mock; //import com.firework.cluster.rpc .FireworkQueueServiceGrpc; //import com.firework.cluster.rpc.FireworkTransfer; import com.google.protobuf.ByteString; -import com.osx.broker.grpc.ContextPrepareInterceptor; -import com.osx.broker.grpc.ServiceExceptionHandler; -import com.osx.core.config.MetaInfo; -import com.osx.core.constant.Dict; -import com.osx.core.constant.StatusCode; import com.webank.ai.eggroll.api.networking.proxy.DataTransferServiceGrpc; import com.webank.ai.eggroll.api.networking.proxy.Proxy; import io.grpc.Server; -import io.grpc.ServerBuilder; -import io.grpc.ServerInterceptors; import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder; import io.grpc.stub.StreamObserver; +import org.fedai.osx.core.constant.Dict; +import org.fedai.osx.core.constant.StatusCode; import org.ppc.ptp.Osx; - import org.ppc.ptp.PrivateTransferProtocolGrpc; import java.io.IOException; import java.net.InetSocketAddress; -import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; -import static com.osx.broker.ServiceContainer.proxyGrpcService; public class MockServer { @@ -66,34 +57,34 @@ public static void main(String[] args) { } - private static Server buildServer() { - NettyServerBuilder nettyServerBuilder = (NettyServerBuilder) ServerBuilder.forPort(9375); - nettyServerBuilder.addService(ServerInterceptors.intercept(proxyGrpcService, new ServiceExceptionHandler(), new ContextPrepareInterceptor())); - // nettyServerBuilder.addService(ServerInterceptors.intercept(queueGrpcservice, new ServiceExceptionHandler(),new ContextPrepareInterceptor())); - // nettyServerBuilder.addService(ServerInterceptors.intercept(commonService, new ServiceExceptionHandler(),new ContextPrepareInterceptor())); - - nettyServerBuilder - .executor(Executors.newCachedThreadPool()) - .maxConcurrentCallsPerConnection(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONCURRENT_CALL_PER_CONNECTION) - .maxInboundMessageSize(MetaInfo.PROPERTY_GRPC_SERVER_MAX_INBOUND_MESSAGE_SIZE) - .maxInboundMetadataSize(MetaInfo.PROPERTY_GRPC_SERVER_MAX_INBOUND_METADATA_SIZE) - .flowControlWindow(MetaInfo.PROPERTY_GRPC_SERVER_FLOW_CONTROL_WINDOW); - if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC > 0) - nettyServerBuilder.keepAliveTime(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC, TimeUnit.SECONDS); - if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC > 0) - nettyServerBuilder.keepAliveTimeout(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC, TimeUnit.SECONDS); - if (MetaInfo.PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC > 0) - nettyServerBuilder.permitKeepAliveTime(MetaInfo.PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC, TimeUnit.SECONDS); - if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED) - nettyServerBuilder.permitKeepAliveWithoutCalls(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED); - if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC > 0) - nettyServerBuilder.maxConnectionIdle(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC, TimeUnit.SECONDS); - if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC > 0) - nettyServerBuilder.maxConnectionAge(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC, TimeUnit.SECONDS); - if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC > 0) - nettyServerBuilder.maxConnectionAgeGrace(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC, TimeUnit.SECONDS); - return nettyServerBuilder.build(); - } +// private static Server buildServer() { +// NettyServerBuilder nettyServerBuilder = (NettyServerBuilder) ServerBuilder.forPort(9375); +// nettyServerBuilder.addService(ServerInterceptors.intercept(proxyGrpcService, new ServiceExceptionHandler(), new ContextPrepareInterceptor())); +// // nettyServerBuilder.addService(ServerInterceptors.intercept(queueGrpcservice, new ServiceExceptionHandler(),new ContextPrepareInterceptor())); +// // nettyServerBuilder.addService(ServerInterceptors.intercept(commonService, new ServiceExceptionHandler(),new ContextPrepareInterceptor())); +// +// nettyServerBuilder +// .executor(Executors.newCachedThreadPool()) +// .maxConcurrentCallsPerConnection(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONCURRENT_CALL_PER_CONNECTION) +// .maxInboundMessageSize(MetaInfo.PROPERTY_GRPC_SERVER_MAX_INBOUND_MESSAGE_SIZE) +// .maxInboundMetadataSize(MetaInfo.PROPERTY_GRPC_SERVER_MAX_INBOUND_METADATA_SIZE) +// .flowControlWindow(MetaInfo.PROPERTY_GRPC_SERVER_FLOW_CONTROL_WINDOW); +// if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC > 0) +// nettyServerBuilder.keepAliveTime(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC, TimeUnit.SECONDS); +// if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC > 0) +// nettyServerBuilder.keepAliveTimeout(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC, TimeUnit.SECONDS); +// if (MetaInfo.PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC > 0) +// nettyServerBuilder.permitKeepAliveTime(MetaInfo.PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC, TimeUnit.SECONDS); +// if (MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED) +// nettyServerBuilder.permitKeepAliveWithoutCalls(MetaInfo.PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED); +// if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC > 0) +// nettyServerBuilder.maxConnectionIdle(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC, TimeUnit.SECONDS); +// if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC > 0) +// nettyServerBuilder.maxConnectionAge(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC, TimeUnit.SECONDS); +// if (MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC > 0) +// nettyServerBuilder.maxConnectionAgeGrace(MetaInfo.PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC, TimeUnit.SECONDS); +// return nettyServerBuilder.build(); +// } private static class PtpService extends PrivateTransferProtocolGrpc.PrivateTransferProtocolImplBase { diff --git a/java/osx/osx-broker/src/test/java/org/fedai/osx/broker/test/grpc/EggrollTest.java b/java/osx/osx-broker/src/test/java/org/fedai/osx/broker/test/grpc/EggrollTest.java new file mode 100644 index 0000000000..18135b4b57 --- /dev/null +++ b/java/osx/osx-broker/src/test/java/org/fedai/osx/broker/test/grpc/EggrollTest.java @@ -0,0 +1,4 @@ +package org.fedai.osx.broker.test.grpc; + +public class EggrollTest { +} diff --git a/java/osx/osx-broker/src/test/java/org/fedai/osx/broker/test/grpc/Grpc_UC.java b/java/osx/osx-broker/src/test/java/org/fedai/osx/broker/test/grpc/Grpc_UC.java new file mode 100644 index 0000000000..1dad43e9af --- /dev/null +++ b/java/osx/osx-broker/src/test/java/org/fedai/osx/broker/test/grpc/Grpc_UC.java @@ -0,0 +1,55 @@ +package org.fedai.osx.broker.test.grpc; + +import io.grpc.ManagedChannel; +import io.grpc.StatusRuntimeException; +import org.fedai.osx.api.router.RouterInfo; +import org.fedai.osx.core.constant.Dict; +import org.fedai.osx.core.constant.StatusCode; +import org.fedai.osx.core.context.FateContext; +import org.fedai.osx.core.exceptions.RemoteRpcException; +import org.fedai.osx.core.frame.GrpcConnectionFactory; +import org.fedai.osx.core.utils.JsonUtil; +import org.junit.Test; +import org.ppc.ptp.Osx; +import org.ppc.ptp.PrivateTransferProtocolGrpc; + +public class Grpc_UC { + + String contextStr = "{\"actionType\":\"unary-call-new\",\"protocol\":\"grpc\",\"techProviderCode\":\"FATE\",\"needCheckRouterInfo\":true,\"costTime\":0,\"resourceName\":\"I_unary-call-new\",\"timeStamp\":1685499290484,\"downstreamCost\":0,\"downstreamBegin\":0,\"destination\":false,\"sourceIp\":\"127.0.0.1\",\"desPartyId\":\"20008\",\"srcPartyId\":\"\",\"returnCode\":\"0\",\"desComponent\":\"fateflow\",\"routerInfo\":{\"protocol\":\"grpc\",\"sourcePartyId\":\"\",\"desPartyId\":\"20008\",\"desRole\":\"fateflow\",\"url\":\"\",\"host\":\"127.0.0.1\",\"port\":9360,\"useSSL\":false,\"negotiationType\":\"\",\"certChainFile\":\"\",\"privateKeyFile\":\"\",\"caFile\":\"\",\"resource\":\"-20008\",\"cycle\":false},\"selfPartyId\":\"10008\"}"; + String routerJson = "{\n" + + " \"protocol\": \"grpc\",\n" + + " \"sourcePartyId\": \"\",\n" + + " \"desPartyId\": \"10008\",\n" + + " \"desRole\": \"fateflow\",\n" + + " \"url\": \"http://127.0.0.1:8087/osx/inbound\",\n" + + " \"host\": \"127.0.0.1\",\n" + + " \"port\": 9883,\n" + + " \"useSSL\": true,\n" + + " \"negotiationType\": \"TLS\",\n" + + " \"certChainFile\": \"D:/33/127.0.0.1.crt\",\n" + + " \"privateKeyFile\": \"D:/33/127.0.0.1.key\",\n" + + " \"caFile\": \"D:/33/testRoot.crt\",\n" + + " \"resource\": \"-10008\",\n" + + " \"cycle\": false\n" + + "}"; + + @Test + public void run(){ + FateContext context = JsonUtil.json2Object(contextStr,FateContext.class); + RouterInfo routerInfo = JsonUtil.json2Object(routerJson,RouterInfo.class); + PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub stub = null; + if (context.getData(Dict.BLOCKING_STUB) == null) { + ManagedChannel managedChannel = GrpcConnectionFactory.createManagedChannel(routerInfo, true); + stub = PrivateTransferProtocolGrpc.newBlockingStub(managedChannel); + } else { + stub = (PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub) context.getData(Dict.BLOCKING_STUB); + } + try { + // logger.info("===========send data {}",produceRequest); + Osx.Outbound invoke = stub.invoke(null); + } catch (StatusRuntimeException e) { + e.printStackTrace(); + throw new RemoteRpcException(StatusCode.NET_ERROR, "send to " + routerInfo.toKey() + " error : " + e.getMessage()); + } + } +} diff --git a/java/osx/osx-broker/src/test/java/org/fedai/osx/broker/test/grpc/NewFateTest.java b/java/osx/osx-broker/src/test/java/org/fedai/osx/broker/test/grpc/NewFateTest.java new file mode 100644 index 0000000000..1ee40fe133 --- /dev/null +++ b/java/osx/osx-broker/src/test/java/org/fedai/osx/broker/test/grpc/NewFateTest.java @@ -0,0 +1,144 @@ +package org.fedai.osx.broker.test.grpc; + +import com.google.protobuf.ByteString; +import io.grpc.ManagedChannel; +import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder; +import io.grpc.stub.StreamObserver; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.ptp.TargetMethod; +import org.junit.Before; +import org.junit.Test; +import org.ppc.ptp.Osx; +import org.ppc.ptp.PrivateTransferProtocolGrpc; + +import java.nio.charset.StandardCharsets; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +public class NewFateTest { + + String ip = "localhost"; + //int port = 8250;//nginx + int port = 9370;//nginx + String desPartyId = "10000"; + String desRole = ""; + String srcPartyId = "9999"; + String srcRole = ""; + String transferId = "testTransferId"; + String sessionId = "testSessionId"; + PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub blockingStub; + PrivateTransferProtocolGrpc.PrivateTransferProtocolStub stub; + @Before + public void init() { + ManagedChannel managedChannel = createManagedChannel(ip, port); + // stub = PrivateTransferProtocolGrpc.newBlockingStub(); + // ManagedChannel managedChannel2 = createManagedChannel(ip, port); + blockingStub = PrivateTransferProtocolGrpc.newBlockingStub(managedChannel); + stub = PrivateTransferProtocolGrpc.newStub(managedChannel); + + + } + + public static ManagedChannel createManagedChannel(String ip, int port) { + try { + NettyChannelBuilder channelBuilder = NettyChannelBuilder + .forAddress(ip, port) + .keepAliveTime(60, TimeUnit.SECONDS) + .keepAliveTimeout(60, TimeUnit.SECONDS) + .keepAliveWithoutCalls(true) + .idleTimeout(60, TimeUnit.SECONDS) + .perRpcBufferLimit(128 << 20) + .flowControlWindow(32 << 20) + .maxInboundMessageSize(32 << 20) + .enableRetry() + .retryBufferSize(16 << 20) + .maxRetryAttempts(20); + channelBuilder.usePlaintext(); + + return channelBuilder.build(); + } catch (Exception e) { + e.printStackTrace(); + // logger.error("create channel error : " ,e); + //e.printStackTrace(); + } + return null; + } + + @Test + public void testUnaryCall(byte[] data){ + Osx.Inbound.Builder inboundBuilder = Osx.Inbound.newBuilder(); + inboundBuilder.putMetadata(Osx.Header.Version.name(), "123"); + inboundBuilder.putMetadata(Osx.Header.TechProviderCode.name(), "FATE");// + inboundBuilder.putMetadata(Osx.Header.Token.name(), "testToken"); + inboundBuilder.putMetadata(Osx.Header.SourceNodeID.name(), "9999"); + inboundBuilder.putMetadata(Osx.Header.TargetNodeID.name(), "10000"); + inboundBuilder.putMetadata(Osx.Header.SourceInstID.name(), "9999"); + inboundBuilder.putMetadata(Osx.Header.TargetInstID.name(), "10000"); + inboundBuilder.putMetadata(Osx.Header.SessionID.name(), "testSessionID"); + inboundBuilder.putMetadata(Osx.Metadata.TargetMethod.name(), "UNARY_CALL"); + inboundBuilder.putMetadata(Osx.Metadata.TargetComponentName.name(), "fateflow"); + inboundBuilder.putMetadata(Osx.Metadata.SourceComponentName.name(), ""); + inboundBuilder.putMetadata(Osx.Header.TraceID.name(), "28938999993"); + inboundBuilder.setPayload(ByteString.copyFrom(data)); + Osx.Outbound outbound = blockingStub.invoke(inboundBuilder.build()); + System.err.println("response : "+outbound); + } + + + @Test + public void testStream(){ + + io.grpc.stub.StreamObserver reqSb = stub.transport(new StreamObserver() { + @Override + public void onNext(Osx.Outbound outbound) { + System.err.println(outbound); + } + @Override + public void onError(Throwable throwable) { + throwable.printStackTrace(); + } + @Override + public void onCompleted() { + System.err.println("completed"); + } + }); + for(int i=0;i<3;i++){ + Osx.Inbound.Builder inboundBuilder = Osx.Inbound.newBuilder(); + inboundBuilder.putMetadata(Osx.Header.Version.name(), "123"); + inboundBuilder.putMetadata(Osx.Header.TechProviderCode.name(), MetaInfo.PROPERTY_FATE_TECH_PROVIDER); + inboundBuilder.putMetadata(Osx.Header.Token.name(), "testToken"); + inboundBuilder.putMetadata(Osx.Header.SourceNodeID.name(), "9999"); + inboundBuilder.putMetadata(Osx.Header.TargetNodeID.name(), "10000"); + inboundBuilder.putMetadata(Osx.Header.SourceInstID.name(), ""); + inboundBuilder.putMetadata(Osx.Header.TargetInstID.name(), ""); + inboundBuilder.putMetadata(Osx.Header.SessionID.name(), "testSessionID"); + inboundBuilder.putMetadata(Osx.Metadata.TargetMethod.name(), TargetMethod.TEST_STREAM.name()); + inboundBuilder.putMetadata(Osx.Metadata.TargetComponentName.name(), ""); + inboundBuilder.putMetadata(Osx.Metadata.SourceComponentName.name(), ""); + // inboundBuilder.putMetadata(Osx.Metadata.MessageTopic.name(), transferId); + + inboundBuilder.setPayload(ByteString.copyFrom(("test "+i).getBytes(StandardCharsets.UTF_8))); + reqSb.onNext(inboundBuilder.build()); + } + + System.err.println("=========================="); + + } + + + public static void main(String[] args) { + System.err.println("==============="); + NewFateTest newFateTest = new NewFateTest(); + newFateTest.init(); + newFateTest.testStream(); + + CountDownLatch countDownLatch = new CountDownLatch(1); + try { + countDownLatch.await(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + + } + +} diff --git a/java/osx/broker/src/test/java/com/osx/broker/test/grpc/OldFateTest.java b/java/osx/osx-broker/src/test/java/org/fedai/osx/broker/test/grpc/OldFateTest.java similarity index 88% rename from java/osx/broker/src/test/java/com/osx/broker/test/grpc/OldFateTest.java rename to java/osx/osx-broker/src/test/java/org/fedai/osx/broker/test/grpc/OldFateTest.java index 61a86498fd..bfb6881b89 100644 --- a/java/osx/broker/src/test/java/com/osx/broker/test/grpc/OldFateTest.java +++ b/java/osx/osx-broker/src/test/java/org/fedai/osx/broker/test/grpc/OldFateTest.java @@ -1,4 +1,4 @@ -package com.osx.broker.test.grpc; +package org.fedai.osx.broker.test.grpc; import com.google.protobuf.ByteString; import com.webank.ai.eggroll.api.networking.proxy.DataTransferServiceGrpc; @@ -19,8 +19,8 @@ public class OldFateTest { static int port = 9370;//9371 - // static String ip = "localhost"; - static String ip = "10.42.0.85"; + static String ip = "localhost"; + static Logger logger = LoggerFactory.getLogger(OldFateTest.class); @@ -129,7 +129,7 @@ public void onNext(Proxy.Metadata value) { @Override public void onError(Throwable t) { - logger.error("on Error", t); + logger.error("on Error {}", t.getMessage()); t.printStackTrace(); } @@ -148,12 +148,25 @@ public void onCompleted() { // } // for (int t = 0; t < 1; t++) { + String srcPartyId = "9999"; + String desPartyId = "10000"; // new Thread(() -> { StreamObserver requestOb = stub.push(responseOb); for (int i = 0; i < 3; i++) { + +// Proxy.Metadata metadata = packet.getHeader(); +// ByteString encodedRollSiteHeader = metadata.getExt(); + Transfer.RollSiteHeader.Builder rollSiteHeader = Transfer.RollSiteHeader.newBuilder(); + rollSiteHeader.setDstRole("desRole"); + rollSiteHeader.setDstPartyId(desPartyId); + rollSiteHeader.setSrcPartyId(srcPartyId); + rollSiteHeader.setSrcRole("srcRole"); Proxy.Packet.Builder packetBuilder = Proxy.Packet.newBuilder(); - packetBuilder.setHeader(Proxy.Metadata.newBuilder().setSrc(Proxy.Topic.newBuilder().setPartyId("9999")).setDst(Proxy.Topic.newBuilder().setPartyId("10000").setName("kaidengTestTopic").build()).build()); + packetBuilder.setHeader(Proxy.Metadata.newBuilder().setSrc(Proxy.Topic.newBuilder().setPartyId("10000")) + .setDst(Proxy.Topic.newBuilder().setPartyId("9999").setName("kaidengTestTopic").build()) + .setExt(rollSiteHeader.build().toByteString()) + .build()); // Transfer.RollSiteHeader.Builder headerBuilder = Transfer.RollSiteHeader.newBuilder(); // headerBuilder.setDstPartyId("10000"); // packetBuilder.setHeader(Proxy.Metadata.newBuilder().setExt(headerBuilder.build().toByteString())); @@ -180,7 +193,8 @@ public void onCompleted() { public static void main(String[] args) { System.err.println("==============="); - testUnaryCall(); + testPush(); + //testUnaryCall(); CountDownLatch countDownLatch = new CountDownLatch(1); try { countDownLatch.await(); diff --git a/java/osx/broker/src/test/java/com/osx/broker/test/grpc/QueueTest.java b/java/osx/osx-broker/src/test/java/org/fedai/osx/broker/test/grpc/QueueTest.java similarity index 65% rename from java/osx/broker/src/test/java/com/osx/broker/test/grpc/QueueTest.java rename to java/osx/osx-broker/src/test/java/org/fedai/osx/broker/test/grpc/QueueTest.java index cb87acd659..81c011eb74 100644 --- a/java/osx/broker/src/test/java/com/osx/broker/test/grpc/QueueTest.java +++ b/java/osx/osx-broker/src/test/java/org/fedai/osx/broker/test/grpc/QueueTest.java @@ -1,17 +1,16 @@ -package com.osx.broker.test.grpc; +package org.fedai.osx.broker.test.grpc; //import com.firework.cluster.rpc.FireworkQueueServiceGrpc; //import com.firework.cluster.rpc.FireworkTransfer; import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; -import com.osx.core.config.MetaInfo; -import com.osx.core.constant.Dict; -import com.osx.core.frame.GrpcConnectionFactory; -import com.osx.core.ptp.TargetMethod; -import com.osx.core.router.RouterInfo; -import com.osx.core.utils.JsonUtil; import io.grpc.ManagedChannel; import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder; +import org.fedai.osx.api.router.RouterInfo; +import org.fedai.osx.broker.util.TransferUtil; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.context.FateContext; +import org.fedai.osx.core.ptp.TargetMethod; import org.junit.Before; import org.junit.FixMethodOrder; import org.junit.Test; @@ -21,21 +20,35 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.nio.charset.StandardCharsets; import java.util.concurrent.TimeUnit; @FixMethodOrder(MethodSorters.NAME_ASCENDING) public class QueueTest { Logger logger = LoggerFactory.getLogger(QueueTest.class); - String ip = "localhost"; + static String ip = "localhost"; //int port = 8250;//nginx - int port = 9370;//nginx - String desPartyId = "10000"; - String desRole = ""; - String srcPartyId = "9999"; - String srcRole = ""; - String transferId = "testTransferId"; - String sessionId = "testSessionId"; + static int port = 9370;//nginx + static String desPartyId = "9999"; + static String desRole = ""; + static String srcPartyId = "10000"; + static String srcRole = ""; + static String transferId = "testTransferId"; + static String sessionId = "testSessionId"; + static FateContext fateContext= new FateContext(); + static RouterInfo routerInfo= new RouterInfo(); + static { + routerInfo.setHost(ip); + routerInfo.setPort(port); + } + + + + + //4359615 + + + + PrivateTransferProtocolGrpc.PrivateTransferProtocolBlockingStub blockingStub; // FireworkQueueServiceGrpc.FireworkQueueServiceBlockingStub blockingStub; @@ -46,7 +59,7 @@ public static ManagedChannel createManagedChannel(String ip, int port) { .keepAliveTime(12, TimeUnit.MINUTES) .keepAliveTimeout(12, TimeUnit.MINUTES) .keepAliveWithoutCalls(true) - //.idleTimeout(60, TimeUnit.SECONDS) + .idleTimeout(60, TimeUnit.SECONDS) .perRpcBufferLimit(128 << 20) .flowControlWindow(32 << 20) .maxInboundMessageSize(32 << 20) @@ -64,10 +77,10 @@ public static ManagedChannel createManagedChannel(String ip, int port) { @Before public void init() { - ManagedChannel managedChannel = createManagedChannel(ip, port); - // stub = PrivateTransferProtocolGrpc.newBlockingStub(); - ManagedChannel managedChannel2 = createManagedChannel(ip, port); - blockingStub = PrivateTransferProtocolGrpc.newBlockingStub(managedChannel2); +// ManagedChannel managedChannel = createManagedChannel(ip, port); +// // stub = PrivateTransferProtocolGrpc.newBlockingStub(); +// ManagedChannel managedChannel2 = createManagedChannel(ip, port); +// blockingStub = PrivateTransferProtocolGrpc.newBlockingStub(managedChannel2); } @@ -89,8 +102,8 @@ public void test02Query() { inboundBuilder.putMetadata(Osx.Metadata.SourceComponentName.name(), ""); inboundBuilder.putMetadata(Osx.Metadata.MessageTopic.name(), transferId); - - Osx.Outbound outbound = blockingStub.invoke(inboundBuilder.build()); + Osx.Outbound outbound =TransferUtil.redirect(fateContext,inboundBuilder.build(),routerInfo,false); + // Osx.Outbound outbound = blockingStub.invoke(inboundBuilder.build()); Osx.TopicInfo topicInfo = null; try { topicInfo = Osx.TopicInfo.parseFrom(outbound.getPayload()); @@ -107,36 +120,65 @@ public void test02Query() { } + public void testUnaryConsume(){ + + } + + + private byte[] createBigArray(int size){ + byte[] result = new byte[size]; + for(int i=0;i body = new HashMap<>(); + body.put("uri", "/v2/partner/job/resource/apply"); + body.put("json_body", "{role=host, party_id=10008, job_id=202305251708508595320}"); + body.put("headers", "{}"); + body.put("method", "POST"); + body.put("MessageCode", "111"); + body.put("RetryCount", "111"); + return JsonUtil.object2Json(body); + } + + public Map buildHead() { + Map head = new HashMap<>(); +// CONSUME_MSG -> org.fedai.osx.broker.ptp.PtpConsumeService +// APPLY_TOPIC -> org.fedai.osx.broker.ptp.PtpClusterTopicApplyService +// APPLY_TOKEN -> org.fedai.osx.broker.ptp.PtpClusterTokenApplyService +// QUERY_TOPIC -> org.fedai.osx.broker.ptp.PtpQueryTransferQueueService +// PRODUCE_MSG -> org.fedai.osx.broker.ptp.PtpProduceService +// ACK_MSG -> org.fedai.osx.broker.ptp.PtpAckService +// UNARY_CALL -> org.fedai.osx.broker.ptp.PtpUnaryCallService +// CANCEL_TOPIC -> org.fedai.osx.broker.ptp.PtpCancelTransferService +// PUSH -> org.fedai.osx.broker.ptp.PtpPushService + head.put("x-ptp-target-method", "PRODUCE_MSG"); + head.put("x-ptp-job-id", "202305251708508595320"); + head.put("x-ptp-tech-provider-code", "FATE"); + head.put("x-ptp-message-offset", ""); + head.put("x-ptp-source-inst-id", ""); + head.put("x-ptp-timestamp", ""); + head.put("x-ptp-target-component-name", "fateflow"); + head.put("x-ptp-message-topic", ""); + head.put("x-ptp-trace-id", ""); + head.put("x-ptp-source-node-id", ""); + head.put("x-ptp-source-method", ""); + head.put("x-ptp-token", ""); + head.put("x-ptp-message-flag", ""); + head.put("x-ptp-version", ""); + head.put("x-ptp-source-component-name", ""); + head.put("x-ptp-session-id", ""); + head.put("x-ptp-instance-id", ""); + head.put("x-ptp-target-node-id", "10008"); + head.put("x-ptp-target-inst-id", ""); + + head.put(PtpHttpHeader.SessionID, "111"); + head.put(PtpHttpHeader.MessageTopic, "111"); + head.put(Osx.Metadata.MessageCode.name(), "111"); + head.put(Osx.Metadata.RetryCount.name(), "111"); + return head; + } +} diff --git a/java/osx/osx-broker/src/test/java/org/fedai/osx/broker/test/utils/JsonToMapCode.java b/java/osx/osx-broker/src/test/java/org/fedai/osx/broker/test/utils/JsonToMapCode.java new file mode 100644 index 0000000000..ae42796140 --- /dev/null +++ b/java/osx/osx-broker/src/test/java/org/fedai/osx/broker/test/utils/JsonToMapCode.java @@ -0,0 +1,27 @@ +package org.fedai.osx.broker.test.utils; + +import com.fasterxml.jackson.core.type.TypeReference; +import org.fedai.osx.core.utils.JsonUtil; +import org.junit.Test; + +import java.util.Map; + +/** + * @date 2023/5/29 + * @remark + */ +public class JsonToMapCode { + + String json = "{\"uri\": \"/v2/partner/job/resource/apply\", \"json_body\": {\"role\": \"host\", \"party_id\": \"10008\", \"job_id\": \"202305251708508595320\"}, \"headers\": {}, \"method\": \"POST\"}"; + + @Test + public void run(){ + Map head = JsonUtil.json2Object(json, new TypeReference>() { + }); + StringBuffer sb = new StringBuffer(); + head.forEach((k,v)->{ + sb.append("head.put(\"").append(k).append("\",\"").append(v).append("\");").append("\n"); + }); + System.out.println("sb = " + sb); + } +} diff --git a/java/osx/core/pom.xml b/java/osx/osx-core/pom.xml similarity index 93% rename from java/osx/core/pom.xml rename to java/osx/osx-core/pom.xml index 75cfc948d1..9506fbfe57 100644 --- a/java/osx/core/pom.xml +++ b/java/osx/osx-core/pom.xml @@ -9,7 +9,7 @@ 4.0.0 - core + osx-core 8 @@ -17,6 +17,11 @@ + + osx + osx-api + ${osx.version} + org.slf4j slf4j-api @@ -82,6 +87,7 @@ commons-io commons-io + diff --git a/java/osx/osx-core/src/main/java/org/fedai/osx/core/config/Config.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/config/Config.java new file mode 100644 index 0000000000..0f23c0869d --- /dev/null +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/config/Config.java @@ -0,0 +1,20 @@ +package org.fedai.osx.core.config; + +import java.lang.annotation.Inherited; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import static java.lang.annotation.ElementType.FIELD; + +@Target({FIELD}) +@Retention(RetentionPolicy.RUNTIME) +@Inherited +public @interface Config { + + String pattern() default ""; +// String defaultValue() default ""; + String confKey(); + + +} diff --git a/java/osx/core/src/main/java/com/osx/core/config/GrpcChannelInfo.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/config/GrpcChannelInfo.java similarity index 96% rename from java/osx/core/src/main/java/com/osx/core/config/GrpcChannelInfo.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/config/GrpcChannelInfo.java index a11396b7a0..85d3aea193 100644 --- a/java/osx/core/src/main/java/com/osx/core/config/GrpcChannelInfo.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/config/GrpcChannelInfo.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.config; +package org.fedai.osx.core.config; import lombok.Data; diff --git a/java/osx/core/src/main/java/com/osx/core/config/MasterInfo.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/config/MasterInfo.java similarity index 95% rename from java/osx/core/src/main/java/com/osx/core/config/MasterInfo.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/config/MasterInfo.java index c8ecf98ecd..2938d99c09 100644 --- a/java/osx/core/src/main/java/com/osx/core/config/MasterInfo.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/config/MasterInfo.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.config; +package org.fedai.osx.core.config; import lombok.Data; diff --git a/java/osx/osx-core/src/main/java/org/fedai/osx/core/config/MetaInfo.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/config/MetaInfo.java new file mode 100644 index 0000000000..b297d3fcb8 --- /dev/null +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/config/MetaInfo.java @@ -0,0 +1,342 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.fedai.osx.core.config; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import org.apache.commons.lang3.StringUtils; +import org.fedai.osx.core.constant.DeployMode; +import org.fedai.osx.core.constant.Dict; +import org.fedai.osx.core.constant.StreamLimitMode; +import org.fedai.osx.core.exceptions.ConfigErrorException; +import org.fedai.osx.core.utils.JsonUtil; +import org.fedai.osx.core.utils.NetUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.lang.reflect.Field; +import java.util.*; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public class MetaInfo { + + static Logger logger = LoggerFactory.getLogger(MetaInfo.class); + + @Config(confKey = "user.home") + public static String PROPERTY_USER_HOME = System.getProperty("user.home"); + @Config(confKey = "user.dir") + public static String PROPERTY_USER_DIR = System.getProperty("user.dir"); + + public static String CURRENT_VERSION = "100"; + @Config(confKey = "fate.tech.provider") + public static String PROPERTY_FATE_TECH_PROVIDER = "FATE"; + @Config(confKey = "default.client.version") + public static String PROPERTY_DEFAULT_CLIENT_VERSION = "2.X.X"; + public static volatile MasterInfo masterInfo; + @Config(confKey = "grpc.server.max.concurrent.call.per.connection", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_GRPC_SERVER_MAX_CONCURRENT_CALL_PER_CONNECTION = 1000; + @Config(confKey = "grpc.server.max.inbound.metadata.size", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_GRPC_SERVER_MAX_INBOUND_METADATA_SIZE = 128 << 20; + @Config(confKey = "grpc.server.max.inbound.message.size", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_GRPC_SERVER_MAX_INBOUND_MESSAGE_SIZE = (2 << 30) - 1; + @Config(confKey = "grpc.server.flow.control.window", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_GRPC_SERVER_FLOW_CONTROL_WINDOW = 128 << 20; + @Config(confKey = "grpc.server.keepalive.time.sec", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC = 7200; + @Config(confKey = "grpc.server.keepalive.timeout.sec", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC = 3600; + @Config(confKey = "grpc.server.permit.keepalive.time.sec", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC = 10; + @Config(confKey = "grpc.server.keepalive.without.calls.enabled", pattern = Dict.BOOLEAN_PATTERN) + public static Boolean PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED = true; + @Config(confKey = "grpc.server.max.connection.idle.sec", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC = 86400; + @Config(confKey = "grpc.server.max.connection.age.sec", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC = 86400; + @Config(confKey = "grpc.server.max.connection.age.grace.sec", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC = 86400; + @Config(confKey = "grpc.oncompleted.wait.timeout", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_GRPC_ONCOMPLETED_WAIT_TIMEOUT = 600; + @Config(confKey = "grpc.client.max.inbound.message.size", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_GRPC_CLIENT_MAX_INBOUND_MESSAGE_SIZE = (2 << 30) - 1; + @Config(confKey = "grpc.client.flow.control.window", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_GRPC_CLIENT_FLOW_CONTROL_WINDOW = 128 << 20; + @Config(confKey = "grpc.client.keepalive.time", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_GRPC_CLIENT_KEEPALIVE_TIME_SEC = 7200; + @Config(confKey = "grpc.client.keepalive.timeout", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_GRPC_CLIENT_KEEPALIVE_TIMEOUT_SEC = 3600; + @Config(confKey = "grpc.client.keepalive.without.calls.enabled", pattern = Dict.BOOLEAN_PATTERN) + public static Boolean PROPERTY_GRPC_CLIENT_KEEPALIVE_WITHOUT_CALLS_ENABLED = true; + @Config(confKey = "grpc.client.max.connection.idle", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_GRPC_CLIENT_MAX_CONNECTION_IDLE_SEC = 86400; + @Config(confKey = "grpc.client.per.rpc.buffer.limit", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_GRPC_CLIENT_PER_RPC_BUFFER_LIMIT = (2 << 30) - 1; + @Config(confKey = "grpc.client.retry.buffer.size", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_GRPC_CLIENT_RETRY_BUFFER_SIZE = 86400; + @Config(confKey = "transfer.cached.msgid.size", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_TRANSFER_CACHED_MSGID_SIZE = 10; + @Config(confKey = "grpc.ssl.session.timeout", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_GRPC_SSL_SESSION_TIME_OUT = 3600 << 4; + @Config(confKey = "grpc.ssl.session.cache.size", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_HTTP_SSL_SESSION_CACHE_SIZE = 65536; + + @Config(confKey = "mapped.file.expire.time", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_MAPPED_FILE_EXPIRE_TIME = 3600 * 1000 * 36; + @Config(confKey = "mapped.file.size", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer MAP_FILE_SIZE = 1 << 28; + @Config(confKey = "mapped.file.dir") + public static String PROPERTY_TRANSFER_FILE_PATH_PRE = "mapped/.fate/transfer_file"; + + @Config(confKey = "index.mapped.file.size", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_INDEX_MAP_FILE_SIZE = 1 << 21; + @Config(confKey = "server.cert.chain.file") + public static String PROPERTY_SERVER_CERT_CHAIN_FILE; + @Config(confKey = "server.private.key.file") + public static String PROPERTY_SERVER_PRIVATE_KEY_FILE; + @Config(confKey = "server.ca.file") + public static String PROPERTY_SERVER_CA_FILE; + @Config(confKey = "custom.local.host") + public static String PROPERTY_CUSTOMER_LOCAL_HOST; + @Config(confKey = "bind.host") + public static String PROPERTY_BIND_HOST = "0.0.0.0"; + @Config(confKey = "open.grpc.tls.server", pattern = Dict.BOOLEAN_PATTERN) + public static Boolean PROPERTY_OPEN_GRPC_TLS_SERVER = false; + @Config(confKey = "grpc.port", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_GRPC_PORT = 9370; + @Config(confKey = "grpc.tls.port", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_GRPC_TLS_PORT; + @Config(confKey = "use.remote.health.check", pattern = Dict.BOOLEAN_PATTERN) + public static Boolean PROPERTY_USE_REMOTE_HEALTH_CHECK = true; + @Config(confKey = "http.port", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_HTTP_PORT; + @Config(confKey = "https.port", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_HTTPS_PORT; + @Config(confKey = "open.http.server", pattern = Dict.BOOLEAN_PATTERN) + public static Boolean PROPERTY_OPEN_HTTP_SERVER = false; + @Config(confKey = "http.use.tls", pattern = Dict.BOOLEAN_PATTERN) + public static Boolean PROPERTY_HTTP_USE_TLS = false; + @Config(confKey = "http.server.acceptor.num", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_HTTP_SERVER_ACCEPTOR_NUM = 10; + @Config(confKey = "http.server.selector.num", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_HTTP_SERVER_SELECTOR_NUM = 1; + @Config(confKey = "http.ssl.trust.store.type") + public static String PROPERTY_HTTP_SSL_TRUST_STORE_TYPE = "PKCS12"; + @Config(confKey = "http.ssl.trust.store.provider") + public static String PROPERTY_HTTP_SSL_TRUST_STORE_PROVIDER = "SUN"; + @Config(confKey = "http.ssl.key.store.alias") + public static String PROPERTY_HTTP_SSL_KEY_STORE_ALIAS = ""; + @Config(confKey = "http.ssl.key.store.password") + public static String PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD = ""; + @Config(confKey = "http.ssl.trust.store.password") + public static String PROPERTY_HTTP_SSL_TRUST_STORE_PASSWORD = ""; + @Config(confKey = "http.ssl.trust.store.path") + public static String PROPERTY_HTTP_SSL_TRUST_STORE_PATH = ""; + @Config(confKey = "http.ssl.hostname.verify") + public static Boolean PROPERTY_HTTP_SSL_HOSTNAME_VERIFY = false; + + @Config(confKey = "http.request.body.max.size") + public static Integer PROPERTY_HTTP_REQUEST_BODY_MAX_SIZE = 32 * 1024 * 1024; + @Config(confKey = "http.context.path") + public static String PROPERTY_HTTP_CONTEXT_PATH = "/osx"; + @Config(confKey = "http.servlet.path") + public static String PROPERTY_HTTP_SERVLET_PATH = "/inbound"; + @Config(confKey = "http.receive.queue.size", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_HTTP_RECEIVE_QUEUE_SIZE = 36; + @Config(confKey = "http.accept.receive.buffer.size", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_HTTP_ACCEPT_RECEIVE_BUFFER_SIZE = 4096; + @Config(confKey = "zk.url") + public static String PROPERTY_ZK_URL; + @Config(confKey = "stream.limit.max.try.time", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_STREAM_LIMIT_MAX_TRY_TIME = 20; + @Config(confKey = "produce.msg.max.try.time", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_PRODUCE_MSG_MAX_TRY_TIME = 3; + @Config(confKey = "produce.msg.max.try.interval", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_PRODUCE_MSG_RETRY_INTERVAL = 100; + + @Config(confKey = "produce.msg.cache.max.size", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PRODUCE_MSG_CACHE_MAX_SIZE = 1000; + @Config(confKey = "produce.msg.cache.timeout", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PRODUCE_MSG_CACHE_TIMEOUT; + + + @Config(confKey = "flow.control.sample.count", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_FLOW_CONTROL_SAMPLE_COUNT = 10; + @Config(confKey = "flow.control.sample.interval", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_FLOW_CONTROL_SAMPLE_INTERVAL = 1000; + @Config(confKey = "stream.limit.mode") + public static String PROPERTY_STREAM_LIMIT_MODE = StreamLimitMode.NOLIMIT.name(); + @Config(confKey = "deploy.mode") + public static String PROPERTY_DEPLOY_MODE = DeployMode.standalone.name(); + @Config(confKey = "self.party") + public static Set PROPERTY_SELF_PARTY = Sets.newHashSet();// + @Config(confKey = "flow.rule") + public static String PROPERTY_FLOW_RULE_TABLE = "broker/flowRule.json"; + @Config(confKey = "use.zookeeper", pattern = Dict.BOOLEAN_PATTERN) + public static Boolean PROPERTY_USE_ZOOKEEPER = true; + @Config(confKey = "open.route.cycle.checker", pattern = Dict.BOOLEAN_PATTERN) + public static Boolean PROPERTY_OPEN_ROUTE_CYCLE_CHECKER = false; + + @Config(confKey = "zookeeper.acl.enable", pattern = Dict.BOOLEAN_PATTERN) + public static Boolean PROPERTY_ACL_ENABLE = false; + @Config(confKey = "zookeeper.acl.username") + public static String PROPERTY_ACL_USERNAME; + @Config(confKey = "zookeeper.acl.password") + public static String PROPERTY_ACL_PASSWORD; + @Config(confKey = "queue.max.free.time", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_QUEUE_MAX_FREE_TIME = 60000000; + @Config(confKey = "queue.check.interval", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_TRANSFER_QUEUE_CHECK_INTERVAL = 60 * 1000 * 10; + public static String INSTANCE_ID = NetUtils.getLocalHost() + ":" + MetaInfo.PROPERTY_GRPC_PORT; + + + + + @Config(confKey = "eggroll.cluster.manager.ip") + public static String PROPERTY_EGGROLL_CLUSTER_MANANGER_IP; + @Config(confKey = "eggroll.cluster.manager.port", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_EGGROLL_CLUSTER_MANANGER_PORT; + + + /** + * 从连接池中申请连接的超时时间 + */ + @Config(confKey = "http.client.method.config") + public static Map> PROPERTY_HTTP_CLIENT_METHOD_CONFIG_MAP =new HashMap<>(); + + @Config(confKey = "http.client.con.req.timeout", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_HTTP_CLIENT_CONFIG_CONN_REQ_TIME_OUT = 500; + /** + * 建立连接的超时时间 + */ + @Config(confKey = "http.client.connection.timeout", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_HTTP_CLIENT_CONFIG_CONN_TIME_OUT = 10000; + + @Config(confKey = "http.client.max.idle.time", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_HTTP_CLIENT_MAX_IDLE_TIME = 5; + /** + * 等待数据 + */ + @Config(confKey = "http.client.socket.timeout", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_HTTP_CLIENT_CONFIG_SOCK_TIME_OUT = 300000; + @Config(confKey = "http.ssl.session.timeout", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_HTTP_SSL_SESSION_TIME_OUT = 3600 << 4; + @Config(confKey = "http.client.pool.max.total", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_HTTP_CLIENT_INIT_POOL_MAX_TOTAL = 500; + @Config(confKey = "http.client.pool.max.per.router", pattern = Dict.POSITIVE_INTEGER_PATTERN) + public static Integer PROPERTY_HTTP_CLIENT_INIT_POOL_DEF_MAX_PER_ROUTE = 200; + @Config(confKey = "open.token.validator", pattern = Dict.BOOLEAN_PATTERN) + public static Boolean PROPERTY_OPEN_TOKEN_VALIDATOR = false; + @Config(confKey = "open.token.generator", pattern = Dict.BOOLEAN_PATTERN) + public static Boolean PROPERTY_OPEN_TOKEN_GENERATOR = false; + + public static String PROPERTY_TOKEN_GENERATOR_CONFIG_PATH; + public static String PROPERTY_CONFIG_DIR; + + @Config(confKey = "protocol.params.print", pattern = Dict.BOOLEAN_PATTERN) + public static Boolean PROTOCOL_PARAMS_PRINT = false; + + + public static boolean isCluster() { + return PROPERTY_DEPLOY_MODE.equals(DeployMode.cluster.name()); + } + + + public static boolean checkPattern(String pattern, String value) { + Pattern p = Pattern.compile(pattern); + Matcher m = p.matcher(value); + if (m.find()) { + return true; + } else { + return false; + } + } + + public static void init(Properties environment) { + Field[] fields = MetaInfo.class.getFields(); + Arrays.stream(fields).forEach(field -> { + try { + Config config = field.getDeclaredAnnotation(Config.class); + if (config != null) { + Class clazz = field.getType(); + String confKey = config.confKey(); + Object value = environment.get(confKey); + System.err.println("key:"+confKey+ " value :"+value); + if (value != null) { + String pattern = config.pattern(); + if (StringUtils.isNotEmpty(pattern) && !checkPattern(pattern, value.toString())) { + logger.error("conf {} has wrong value {},please check config file", confKey, value); + throw new ConfigErrorException("conf " + confKey + " has wrong value : " + value); + } + if (clazz == Integer.class) { + field.set(null, Integer.parseInt(value.toString())); + } else if (clazz == Long.class) { + field.set(null, Long.parseLong(value.toString())); + } else if (clazz == String.class) { + field.set(null, value.toString()); + + } else if (clazz == Boolean.class) { + field.set(null, Boolean.valueOf(value.toString())); + } else if (clazz.isAssignableFrom(Set.class)) { + Set set = new HashSet(); + set.addAll(Lists.newArrayList(value.toString().split(","))); + field.set(null, set); + } else if (clazz.isAssignableFrom(Map.class)) { + + Map> conConfig = JsonUtil.object2Objcet(value, new TypeReference>>() { + }); + field.set(null,conConfig); + } + } + if (StringUtils.isNotEmpty(confKey)) { + logger.info("{}={} ", confKey, field.get(null)); + } + } + } catch (Exception e) { + // e.printStackTrace(); + logger.error("parse config error",e); + //throw new ConfigErrorException("parse config error: "+e.getMessage()); + } + }); + } + + + public static Map toMap() { + Map result = Maps.newHashMap(); + Field[] fields = MetaInfo.class.getFields(); + + for (Field field : fields) { + try { + if (field.get(MetaInfo.class) != null) { + String key = Dict.class.getField(field.getName()) != null ? String.valueOf(Dict.class.getField(field.getName()).get(Dict.class)) : field.getName(); + result.put(key, field.get(MetaInfo.class)); + } + } catch (IllegalAccessException | NoSuchFieldException e) { + + } + } + return result; + } + + public static void main(String args){ + + System.err.println( (2 << 30) - 1); + } + +} diff --git a/java/osx/osx-core/src/main/java/org/fedai/osx/core/config/TransferMeta.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/config/TransferMeta.java new file mode 100644 index 0000000000..e4afa6cd24 --- /dev/null +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/config/TransferMeta.java @@ -0,0 +1,17 @@ +package org.fedai.osx.core.config; + +import lombok.Data; + +@Data +public class TransferMeta { + + String srcPartyId; + String desPartyId; + String srcRole; + String desRole; + String sessionId; + String topic; + + + +} diff --git a/java/osx/core/src/main/java/com/osx/core/constant/ActionType.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/ActionType.java similarity index 95% rename from java/osx/core/src/main/java/com/osx/core/constant/ActionType.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/ActionType.java index 07310638a4..d5845dc60e 100644 --- a/java/osx/core/src/main/java/com/osx/core/constant/ActionType.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/ActionType.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.constant; +package org.fedai.osx.core.constant; public enum ActionType { @@ -28,6 +28,7 @@ public enum ActionType { INNER_REDIRECT("inner-redirect"), LONG_PULLING_ANSWER("long-pulling-answer"), MSG_DOWNLOAD("msg-download"), + MSG_REDIRECT("msg-redirect"), REDIRECT_ACK("redirect-ack"), UNARY_CALL("unary-call"), UNARY_CALL_NEW("unary-call-new"), diff --git a/java/osx/core/src/main/java/com/osx/core/constant/DeployMode.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/DeployMode.java similarity index 94% rename from java/osx/core/src/main/java/com/osx/core/constant/DeployMode.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/DeployMode.java index a2e45afe6e..11ff024ad8 100644 --- a/java/osx/core/src/main/java/com/osx/core/constant/DeployMode.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/DeployMode.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.constant; +package org.fedai.osx.core.constant; public enum DeployMode { cluster, diff --git a/java/osx/core/src/main/java/com/osx/core/constant/Dict.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/Dict.java similarity index 55% rename from java/osx/core/src/main/java/com/osx/core/constant/Dict.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/Dict.java index ae163bc606..e918956d3e 100644 --- a/java/osx/core/src/main/java/com/osx/core/constant/Dict.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/Dict.java @@ -14,12 +14,10 @@ * limitations under the License. */ -package com.osx.core.constant; - -import com.osx.core.config.MetaInfo; +package org.fedai.osx.core.constant; public class Dict { - public static final String PROPERTY_FATE_TECH_PROVIDER="fate.tech.provider"; + public static final String ORIGIN_REQUEST = "origin_request"; public static final String CASEID = "caseid"; public static final String SEQNO = "seqno"; @@ -29,11 +27,9 @@ public class Dict { public static final String HTTP_PORT ="http.port"; public static final String INSTANCE_ID = "instanceId"; - public static final String HIT_CACHE = "hitCache"; - - + public static final String POSITIVE_INTEGER_PATTERN = "^[1-9]\\d*$"; + public static final String BOOLEAN_PATTERN="^(true)|(false)$"; public static final String REQUEST_SEQNO = "REQUEST_SEQNO"; - public static final String VERSION = "version"; public static final String GRPC_TYPE = "grpcType"; public static final String ROUTER_INFO = "routerInfo"; @@ -48,67 +44,18 @@ public class Dict { public static final String DOWN_STREAM_BEGIN = "downstreamBegin"; public static final String ROUTE_BASIS = "routeBasis"; public static final String SOURCE_IP = "sourceIp"; - public static final String PROPERTY_SERVING_CORE_POOL_SIZE = "serving.core.pool.size"; - public static final String SERVING_MAX_POOL_ZIE = "serving.max.pool.size"; - public static final String PROPERTY_SERVING_POOL_ALIVE_TIME = "serving.pool.alive.time"; - public static final String PROPERTY_SERVING_POOL_QUEUE_SIZE = "serving.pool.queue.size"; + //HttpServletResponse + public static final String HTTP_SERVLET_RESPONSE = "httpServletResponse"; + - public static final String CACHE_TYPE_REDIS = "redis"; - public static final String DEFAULT_FATE_ROOT = "FATE-SERVICES"; +// public static final String PROPERTY_BIND_HOST_KEY = "bind.host"; + /** * configuration property key */ - public static final String PROPERTY_SELF_PARTY = "self.party"; - - public static final String PROPERTY_CACHE_TYPE = "cache.type"; - - public static final String PROPERTY_REDIS_EXPIRE = "redis.expire"; - public static final String PROPERTY_REDIS_CLUSTER_NODES = "redis.cluster.nodes"; - public static final String PROPERTY_LOCAL_CACHE_MAXSIZE = "local.cache.maxsize"; - public static final String PROPERTY_LOCAL_CACHE_EXPIRE = "local.cache.expire"; - public static final String PROPERTY_LOCAL_CACHE_INTERVAL = "local.cache.interval"; - - public static final String PROPERTY_GRPC_TIMEOUT = "grpc.timeout"; - public static final String PROPERTY_EXTERNAL_INFERENCE_RESULT_CACHE_DB_INDEX = "external.inferenceResultCacheDBIndex"; - public static final String PROPERTY_EXTERNAL_INFERENCE_RESULT_CACHE_TTL = "external.inferenceResultCacheTTL"; - public static final String PROPERTY_EXTERNAL_REMOTE_MODEL_INFERENCE_RESULT_CACHE_DB_INDEX = "external.remoteModelInferenceResultCacheDBIndex"; - public static final String PROPERTY_EXTERNAL_PROCESS_CACHE_DB_INDEX = "external.processCacheDBIndex"; - public static final String PROPERTY_EXTERNAL_REMOTE_MODEL_INFERENCE_RESULT_CACHE_TTL = "external.remoteModelInferenceResultCacheTTL"; - public static final String PROPERTY_CAN_CACHE_RET_CODE = "canCacheRetcode"; - public static final String PROPERTY_SERVICE_ROLE_NAME = "serviceRoleName"; - public static final String PROPERTY_SERVICE_ROLE_NAME_DEFAULT_VALUE = "serving"; - public static final String PROPERTY_ONLINE_DATA_ACCESS_ADAPTER = "OnlineDataAccessAdapter"; - public static final String PROPERTY_ONLINE_DATA_BATCH_ACCESS_ADAPTER = "OnlineDataBatchAccessAdapter"; - public static final String PROPERTY_MODEL_CACHE_ACCESS_TTL = "modelCacheAccessTTL"; - public static final String PROPERTY_MODEL_CACHE_MAX_SIZE = "modelCacheMaxSize"; - public static final String PROPERTY_INFERENCE_WORKER_THREAD_NUM = "inferenceWorkerThreadNum"; - public static final String PROPERTY_PROXY_ADDRESS = "proxy"; - public static final String ONLINE_ENVIRONMENT = "online"; - public static final String PROPERTY_ROLL_ADDRESS = "roll"; - public static final String PROPERTY_FLOW_ADDRESS = "flow"; - public static final String PROPERTY_SERVING_ADDRESS = "serving"; - public static final String PROPERTY_USE_ZOOKEEPER = "useZookeeper"; - public static final String PROPERTY_PORT = "port"; - public static final String PROPERTY_GRPC_PORT = "grpc.port"; - public static final String PROPERTY_GRPC_TLS_PORT = "grpc.tls.port"; - public static final String PROPERTY_USER_DIR = "user.dir"; - public static final String PROPERTY_USER_HOME = "user.home"; - public static final String PROPERTY_FILE_SEPARATOR = "file.separator"; - public static final String PROPERTY_ZK_URL = "zk.url"; - public static final String PROPERTY_USE_ZK_ROUTER = "useZkRouter"; - public static final String PROPERTY_USE_REGISTER = "useRegister"; - public static final String PROPERTY_MODEL_TRANSFER_URL = "model.transfer.url"; - public static final String PROPERTY_MODEL_SYNC = "model.synchronize"; - public static final String PROPERTY_TRANSFER_FILE_PATH = "transfer.file.path"; - - public static final String PROPERTY_FEATURE_BATCH_ADAPTOR = "feature.batch.adaptor"; - public static final String PROPERTY_ACL_ENABLE = "acl.enable"; - public static final String PROPERTY_ACL_USERNAME = "acl.username"; - public static final String PROPERTY_ACL_PASSWORD = "acl.password"; - public static final String PROXY_ROUTER_TABLE = "proxy.router.table"; - public static final String PROPERTY_BATCH_INFERENCE_MAX = "batch.inference.max"; + public static final String PROPERTY_SELF_PARTY_KEY = "self.party"; public static final String PROPERTY_PRINT_INPUT_DATA = "print.input.data"; public static final String PROPERTY_PRINT_OUTPUT_DATA = "print.output.data"; public static final String PROPERTY_NEGOTIATIONTYPE = "server.negotiationType"; @@ -120,68 +67,10 @@ public class Dict { public static final String CURRENT_VERSION = "currentVersion"; public static final String PROPERTY_COORDINATOR = "coordinator"; -// public static final String PROPERTY_SERVER_PORT = "server.port"; - - - public static final String PROPERTY_INFERENCE_SERVICE_NAME = "inference.service.name"; - public static final String PROPERTY_ROUTE_TYPE = "routeType"; - public static final String PROPERTY_ROUTE_TABLE = "route.table"; - public static final String PROPERTY_FLOW_RULE_TABLE = "flow.rule"; - public static final String PROPERTY_AUTH_FILE = "auth.file"; - public static final String PROPERTY_AUTH_OPEN = "auth.open"; - public static final String PROPERTY_PROXY_GRPC_INTRA_PORT = "proxy.grpc.intra.port"; - public static final String PROPERTY_PROXY_GRPC_INTER_PORT = "proxy.grpc.inter.port"; - public static final String PROPERTY_PROXY_GRPC_INFERENCE_TIMEOUT = "proxy.grpc.inference.timeout"; - public static final String PROPERTY_PROXY_GRPC_INFERENCE_ASYNC_TIMEOUT = "proxy.grpc.inference.async.timeout"; - public static final String PROPERTY_PROXY_GRPC_UNARYCALL_TIMEOUT = "proxy.grpc.unaryCall.timeout"; - public static final String PROPERTY_PROXY_GRPC_THREADPOOL_CORESIZE = "proxy.grpc.threadpool.coresize"; - public static final String PROPERTY_PROXY_GRPC_THREADPOOL_MAXSIZE = "proxy.grpc.threadpool.maxsize"; - public static final String PROPERTY_PROXY_GRPC_THREADPOOL_QUEUESIZE = "proxy.grpc.threadpool.queuesize"; - public static final String PROPERTY_PROXY_ASYNC_TIMEOUT = "proxy.async.timeout"; - public static final String PROPERTY_PROXY_ASYNC_CORESIZE = "proxy.async.coresize"; - public static final String PROPERTY_PROXY_ASYNC_MAXSIZE = "proxy.async.maxsize"; - public static final String PROPERTY_PROXY_GRPC_BATCH_INFERENCE_TIMEOUT = "proxy.grpc.batch.inference.timeout"; - public static final String PROPERTY_MODEL_CACHE_PATH = "model.cache.path"; - public static final String PROPERTY_LR_USE_PARALLEL = "lr.use.parallel"; - public static final String PROPERTY_ALLOW_HEALTH_CHECK = "health.check.allow"; - public static final String PROPERTY_TRANSFER_FILE_CACHE_SIZE = "transfer.file.cache.size"; - public static final String PROPERTY_MAX_TRANSFER_CACHE_SIZE = "max.transfer.cache.size"; - public static final String PROPERTY_USE_DIRECT_CACHE = "use.direct.cache"; - public static final String PROPERTY_GRPC_ONCOMPLETED_WAIT_TIMEOUT = "grpc.oncompleted.wait.timeout"; -// public static final String PROPERTY_USE_QUEUE_MODEL = "use.queue.model"; - public static final String PROPERTY_STREAM_LIMIT_MODE = "stream.limit.mode"; - public static final String PROPERTY_STREAM_LIMIT_MAX_TRY_TIME = "stream.limit.max.try.time"; - public static final String PROPERTY_GRPC_SERVER_MAX_CONCURRENT_CALL_PER_CONNECTION = "grpc.server.max.concurrent.call.per.connection"; - public static final String PROPERTY_GRPC_SERVER_MAX_INBOUND_MESSAGE_SIZE = "grpc.server.max.inbound.message.size"; - public static final String PROPERTY_GRPC_SERVER_MAX_INBOUND_METADATA_SIZE = "grpc.server.max.inbound.metadata.size"; - public static final String PROPERTY_GRPC_SERVER_FLOW_CONTROL_WINDOW = "grpc.server.flow.control.window"; - public static final String PROPERTY_GRPC_SERVER_KEEPALIVE_TIME_SEC = "grpc.server.keepalive.time.sec"; - public static final String PROPERTY_GRPC_SERVER_KEEPALIVE_TIMEOUT_SEC = "grpc.server.keepalive.timeout.sec"; - public static final String PROPERTY_GRPC_SERVER_PERMIT_KEEPALIVE_TIME_SEC = "grpc.server.permit.keepalive.time.sec"; - public static final String PROPERTY_GRPC_SERVER_KEEPALIVE_WITHOUT_CALLS_ENABLED = "grpc.server.keepalive.without.calls.enabled"; - public static final String PROPERTY_GRPC_SERVER_MAX_CONNECTION_IDLE_SEC = "grpc.server.max.connection.idle.sec"; - public static final String PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_SEC = "grpc.server.max.connection.age.sec"; - public static final String PROPERTY_GRPC_SERVER_MAX_CONNECTION_AGE_GRACE_SEC = "grpc.server.max.connection.age.grace.sec"; - public static final String PROPERTY_INTERVAL_MS = "interval.ms"; - public static final String PROPERTY_SAMPLE_COUNT = "sample.count"; - public static final String PRPPERTY_QUEUE_MAX_FREE_TIME = "queue.max.free.time"; - - public static String PROPERTY_OPEN_HTTP_SERVER = "open.http.server"; - public static String PROPERTY_OPEN_GRPC_TLS_SERVER = "open.grpc.tls.server"; - public static String PROPERTY_DEFAULT_CLIENT_VERSION="default.client.version"; - - - public static final String HTTP_CLIENT_CONFIG_CONN_REQ_TIME_OUT = "httpclinet.config.connection.req.timeout"; - public static final String HTTP_CLIENT_CONFIG_CONN_TIME_OUT = "httpclient.config.connection.timeout"; - public static final String HTTP_CLIENT_CONFIG_SOCK_TIME_OUT = "httpclient.config.sockect.timeout"; - public static final String HTTP_CLIENT_INIT_POOL_MAX_TOTAL = "httpclient.init.pool.maxtotal"; - public static final String HTTP_CLIENT_INIT_POOL_DEF_MAX_PER_ROUTE = "httpclient.init.pool.def.max.pre.route"; - public static final String HTTP_CLIENT_INIT_POOL_SOCK_TIME_OUT = "httpclient.init.pool.sockect.timeout"; - public static final String HTTP_CLIENT_INIT_POOL_CONN_TIME_OUT = "httpclient.init.pool.connection.timeout"; - public static final String HTTP_CLIENT_INIT_POOL_CONN_REQ_TIME_OUT = "httpclient.init.pool.connection.req.timeout"; - public static final String HTTP_CLIENT_TRAN_CONN_REQ_TIME_OUT = "httpclient.tran.connection.req.timeout"; - public static final String HTTP_CLIENT_TRAN_CONN_TIME_OUT = "httpclient.tran.connection.timeout"; - public static final String HTTP_CLIENT_TRAN_SOCK_TIME_OUT = "httpclient.tran.sockect.timeout"; + + + + public static final String ACTION_TYPE_ASYNC_EXECUTE = "ASYNC_EXECUTE"; @@ -191,6 +80,8 @@ public class Dict { public static final String DATA = "data"; public static final String STATUS = "status"; public static final String SUCCESS = "success"; + public static final String DUP_MSG = "dup_msg"; + public static final String PROCESSED_MSG = "Processed messages"; public static final String PROB = "prob"; public static final String ACCESS = "access"; @@ -233,6 +124,8 @@ public class Dict { public static final String CASE_ID = "caseid"; public static final String CODE = "code"; public static final String MESSAGE = "message"; + public static final String MESSAGE_FLAG = "message_flag"; + public static final String MESSAGE_CODE = "message_code"; public static final String MODEL_ID = "modelId"; public static final String MODEL_VERSION = "modelVersion"; public static final String TIMESTAMP = "timestamp"; @@ -248,7 +141,10 @@ public class Dict { public static final String SELF_ENVIRONMENT = "online"; public static final String HEAD = "head"; public static final String BODY = "body"; - + public static final String SESSION_ID = "sessionId"; + public static final String METHOD_CONFIG_REQ_TIMEOUT = "reqTimeout"; + public static final String METHOD_CONFIG_CONNECTION_TIMEOUT = "connectionTimeout"; + public static final String METHOD_CONFIG_SOCKET_TIMEOUT = "socketTimeout"; public static final String SBT_TREE_NODE_ID_ARRAY = "sbtTreeNodeIdArray"; @@ -274,8 +170,6 @@ public class Dict { public static final String ERROR_LIST = "errorList"; public static final String HEALTH_INFO = "healthInfo"; public static final String PROPERTY_ADMIN_HEALTH_CHECK_TIME = "health.check.time"; - - public static final String ROLLSITE_ROUTE_TABLE_KEY = "rollsite.route.table.key"; public static final String ROLLSITE_ROUTE_TABLE_WHITE_LIST = "rollsite.route.table.whitList"; public static final String ROLLSITE_ROUTE_TABLE_PARTY_ID = "rollsite.route.table.party.id"; @@ -299,15 +193,16 @@ public class Dict { public static final String TOPIC = "topic"; - public static final String PROPERTY_DEPLOY_MODE = "deploy.model"; - public static final String PROPERTY_CLUSTER_MANAGER_ADDRESS = "cluster.manager.address"; + public static final String PROPERTY_DEPLOY_MODE_KEY = "deploy.model"; +// public static final String PROPERTY_CLUSTER_MANAGER_ADDRESS = "cluster.manager.address"; - public static final String PROPERTY_EGGROLL_CLUSTER_MANANGER_IP = "eggroll.cluster.manager.ip"; - public static final String PROPERTY_EGGROLL_CLUSTER_MANANGER_PORT = "eggroll.cluster.manager.port"; + public static final String PROPERTY_EGGROLL_CLUSTER_MANANGER_IP_KEY = "eggroll.cluster.manager.ip"; + public static final String PROPERTY_EGGROLL_CLUSTER_MANANGER_PORT_KEY = "eggroll.cluster.manager.port"; public final static String UNKNOWN = "UNKNOWN"; public final static String PROTOBUF = "PROTOBUF"; public final static String SLASH = "/"; + public final static String COMPONENTS_DIR = "components"; public final static String GRPC_PARSE_FROM = "parseFrom"; public final static String AT = "@"; public final static String AND = "&"; @@ -361,6 +256,7 @@ public class Dict { public final static String QUEUE = "queue"; public final static String TOTAL = "total"; public final static String LOCALHOST = "localhost"; + public final static String LOCALHOST2 = "127.0.0.1"; public final static String STORE_TYPE = "storeType"; public final static String STORE_TYPE_SNAKECASE = "store_type"; public final static String NAMESPACE = "namespace"; @@ -371,8 +267,21 @@ public class Dict { public final static String PARTITIONER = "partitioner"; public final static String SERDES = "serdes"; public final static String TRANSFER_BROKER_NAME = "transfer_broker_name"; - public static String PROPERTY_DLEDGER_PEER = "dledger.peer"; - public static String PROPERTY_DLEDGER_SELF = "dledger.self"; + public final static String TRANSFER_QUEUE = "transfer_queue"; + public final static String IS_CYCLE="cycle"; +// public final static String EGGROLL_SEND_TOPIC_PREFIX="EGGROLL_SEND_"; +// public final static String EGGROLL_BACK_TOPIC_PREFIX="EGGROLL_BACK_"; + public final static String STREAM_SEND_TOPIC_PREFIX = "STREAM_SEND_"; + public final static String STREAM_BACK_TOPIC_PREFIX = "STREAM_BACK_"; + public final static String BLOCKING_STUB = "BLOCKING_STUB"; + public final static String PROTOCOL = "protocol"; + public final static String URL="url"; + + public final static String USE_SSL="useSSL"; + public final static String CA_FILE="caFile"; + public final static String CERT_CHAIN_FILE="certChainFile"; + public final static String PRIVATE_KEY_FILE="privateKeyFile"; + } diff --git a/java/osx/core/src/main/java/com/osx/core/constant/EncryptMethod.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/EncryptMethod.java similarity index 95% rename from java/osx/core/src/main/java/com/osx/core/constant/EncryptMethod.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/EncryptMethod.java index 8300e998a4..b4123f119e 100644 --- a/java/osx/core/src/main/java/com/osx/core/constant/EncryptMethod.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/EncryptMethod.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.core.constant; +package org.fedai.osx.core.constant; public enum EncryptMethod { /** diff --git a/java/osx/core/src/main/java/com/osx/core/constant/NegotiationType.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/NegotiationType.java similarity index 94% rename from java/osx/core/src/main/java/com/osx/core/constant/NegotiationType.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/NegotiationType.java index eba9c8846e..b3ae87325f 100644 --- a/java/osx/core/src/main/java/com/osx/core/constant/NegotiationType.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/NegotiationType.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.constant; +package org.fedai.osx.core.constant; public enum NegotiationType { TLS,PLAINTEXT diff --git a/java/osx/core/src/main/java/com/osx/core/constant/PtpHttpHeader.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/PtpHttpHeader.java similarity index 93% rename from java/osx/core/src/main/java/com/osx/core/constant/PtpHttpHeader.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/PtpHttpHeader.java index 756298010b..33f5ba20a4 100644 --- a/java/osx/core/src/main/java/com/osx/core/constant/PtpHttpHeader.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/PtpHttpHeader.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.constant; +package org.fedai.osx.core.constant; public class PtpHttpHeader { @@ -51,10 +51,15 @@ public class PtpHttpHeader { static public final String TargetComponentName = "x-ptp-target-component-name"; static public final String TargetMethod = "x-ptp-target-method"; + static public final String SourceMethod = "x-ptp-source-method"; + static public final String MessageOffSet = "x-ptp-message-offset"; static public final String InstanceId = "x-ptp-instance-id"; static public final String Timestamp = "x-ptp-timestamp"; + + static public final String MessageFlag = "x-ptp-message-flag"; static public final String ReturnCode = "x-ptp-code"; static public final String ReturnMessage = "x-ptp-message"; + static public final String JobId = "x-ptp-job-id"; } diff --git a/java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/Role.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/Role.java new file mode 100644 index 0000000000..57ceb57152 --- /dev/null +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/Role.java @@ -0,0 +1,5 @@ +package org.fedai.osx.core.constant; + +public enum Role { + fateflow +} diff --git a/java/osx/core/src/main/java/com/osx/core/constant/StatusCode.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/StatusCode.java similarity index 86% rename from java/osx/core/src/main/java/com/osx/core/constant/StatusCode.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/StatusCode.java index a72244d60b..b0a211494f 100644 --- a/java/osx/core/src/main/java/com/osx/core/constant/StatusCode.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/StatusCode.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.core.constant; +package org.fedai.osx.core.constant; public class StatusCode { public static final String SUCCESS = "0"; @@ -40,6 +40,12 @@ public class StatusCode { public static final String INVALID_REDIRECT_INFO = "142"; public static final String INVALID_INDEXFILE_DETAIL = "143"; public static final String CREATE_TOPIC_ERROR = "144"; + public static final String CYCLE_ROUTE_ERROR = "145"; + public static final String CONSUME_MSG_TIMEOUT= "146"; + + public static final String SESSION_INIT_ERROR= "147"; + public static final String TRANSFER_QUEUE_REDIRECT = "148"; + } diff --git a/java/osx/core/src/main/java/com/osx/core/constant/StreamLimitMode.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/StreamLimitMode.java similarity index 95% rename from java/osx/core/src/main/java/com/osx/core/constant/StreamLimitMode.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/StreamLimitMode.java index 60eae0b587..8b7286592d 100644 --- a/java/osx/core/src/main/java/com/osx/core/constant/StreamLimitMode.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/StreamLimitMode.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.constant; +package org.fedai.osx.core.constant; public enum StreamLimitMode { //不使用限流 diff --git a/java/osx/core/src/main/java/com/osx/core/constant/TransferStatus.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/TransferStatus.java similarity index 94% rename from java/osx/core/src/main/java/com/osx/core/constant/TransferStatus.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/TransferStatus.java index e71ea49bd7..55067c7fa9 100644 --- a/java/osx/core/src/main/java/com/osx/core/constant/TransferStatus.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/TransferStatus.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.constant; +package org.fedai.osx.core.constant; public enum TransferStatus { INIT, TRANSFERING, ERROR, FINISH, DESTROY diff --git a/java/osx/core/src/main/java/com/osx/core/context/Context.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/context/FateContext.java similarity index 54% rename from java/osx/core/src/main/java/com/osx/core/context/Context.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/context/FateContext.java index 434761ece3..35b9bcda11 100644 --- a/java/osx/core/src/main/java/com/osx/core/context/Context.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/context/FateContext.java @@ -13,38 +13,83 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.context; +package org.fedai.osx.core.context; + import com.google.common.collect.Maps; import com.google.common.util.concurrent.ListenableFuture; -import com.osx.core.constant.Dict; -import com.osx.core.router.RouterInfo; -import com.osx.core.utils.FlowLogPrinter; -import com.osx.core.utils.FlowLogUtil; import org.apache.commons.lang3.StringUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import org.fedai.osx.api.constants.Protocol; +import org.fedai.osx.api.context.Context; +import org.fedai.osx.api.router.RouterInfo; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.constant.Dict; + import java.util.Map; -public class Context { - static final String LOGGER_NAME = "flow"; - private static final Logger logger = LoggerFactory.getLogger(LOGGER_NAME); +public class FateContext implements Context{ protected long timestamp = System.currentTimeMillis(); protected boolean needAssembleException = false; protected String actionType; protected String sessionId; + protected Protocol protocol; + protected String traceId; + protected String token; + protected String sourceInstId; + protected String desInstId; + protected String techProviderCode; protected boolean needPrintFlowLog = true; + protected boolean needCheckRouterInfo = true; protected Long dataSize; + + public Integer getSleepTime() { + return sleepTime; + } + + public void setSleepTime(Integer sleepTime) { + this.sleepTime = sleepTime; + } + + protected Integer sleepTime; + + public Integer getRetryTime() { + return retryTime; + } + + public void setRetryTime(Integer retryTime) { + this.retryTime = retryTime; + } + + protected Integer retryTime =1; protected Map dataMap = Maps.newHashMap(); long costTime; String resourceName; - Throwable t; - FlowLogPrinter flowLogPrinter = FlowLogUtil::printFlowLog; + String messageFlag; + String messageCode; - public Context(){ + public String getJobId() { + return jobId; } - public Context(long timestamp, Map dataMap){ + + @Override + public void setJobId(String jobId) { + this.jobId = jobId; + } + + String jobId; + + Throwable t; + public FateContext(){ + } + public FateContext(long timestamp, Map dataMap){ timestamp = timestamp; this.dataMap = dataMap; } + + public boolean isDestination(){ + if(StringUtils.isNotEmpty(this.getDesPartyId())) + return MetaInfo.PROPERTY_SELF_PARTY.contains(this.getDesPartyId()); + else + return false; + } public Long getDataSize() { return dataSize; } @@ -56,53 +101,84 @@ public String getTopic() { return dataMap.get(Dict.TOPIC).toString(); return null; } + public String getTechProviderCode() { + return techProviderCode; + } + public void setTechProviderCode(String techProviderCode) { + this.techProviderCode = techProviderCode; + } + public Protocol getProtocol() { + return protocol; + } + public void setProtocol(Protocol protocol) { + this.protocol = protocol; + } + public String getMessageFlag() { + return messageFlag; + } + public String getMessageCode() { + return messageCode; + } + public void setMessageCode(String messageCode) { + this.messageCode = messageCode; + } + public void setMessageFlag(String messageFlag) { + this.messageFlag = messageFlag; + } public void setTopic(String topic) { this.dataMap.put(Dict.TOPIC, topic); } - public String getInstanceId() { return (String) dataMap.get(Dict.INSTANCE_ID); } - public void setInstanceId(String instanceId) { this.dataMap.put(Dict.INSTANCE_ID, instanceId); } - public Throwable getException() { return t; } - public void setException(Throwable t) { this.t = t; } - public String getSessionId() { return this.sessionId; } - public void setSessionId(String sessionId) { this.sessionId = sessionId; } - public String getActionType() { return actionType; } - public void setActionType(String actionType) { this.actionType = actionType; } - public Object getData(Object key) { return dataMap.get(key); } - public Object getDataOrDefault(Object key, Object defaultValue) { return dataMap.getOrDefault(key, defaultValue); } - public void putData(Object key, Object data) { dataMap.put(key, data); } + public String getTraceId() { + return traceId; + } + public void setTraceId(String traceId) { + this.traceId = traceId; + } + public String getToken() { + return token; + } + public void setToken(String token) { + this.token = token; + } + public boolean isNeedCheckRouterInfo() { + return needCheckRouterInfo; + } + public void setNeedCheckRouterInfo(boolean needCheckRouterInfo) { + this.needCheckRouterInfo = needCheckRouterInfo; + } public String getCaseId() { if (dataMap.get(Dict.CASEID) != null) { @@ -114,109 +190,89 @@ public String getCaseId() { public void setCaseId(String caseId) { dataMap.put(Dict.CASEID, caseId); } - public long getTimeStamp() { return timestamp; } - public Context subContext() { + public FateContext subContext() { Map newDataMap = Maps.newHashMap(dataMap); - return new Context(this.timestamp, newDataMap); + return new FateContext(this.timestamp, newDataMap); } - public boolean needPrintFlowLog() { return needPrintFlowLog; } - public void setNeedPrintFlowLog(boolean needPrintFlowLog) { this.needPrintFlowLog = needPrintFlowLog; } - public Long getRequestMsgIndex() { return (Long) this.dataMap.get(Dict.REQUEST_INDEX); } - public void setRequestMsgIndex(Long index) { this.dataMap.put(Dict.REQUEST_INDEX, index); } - public Long getCurrentMsgIndex() { return (Long) this.dataMap.get(Dict.CURRENT_INDEX); } - public void setCurrentMsgIndex(Long index) { this.dataMap.put(Dict.CURRENT_INDEX, index); } - public long getCostTime() { return costTime; } - public String getSrcPartyId() { return (String) dataMap.get(Dict.SOURCE_PARTY_ID); } - public void setSrcPartyId(String guestAppId) { dataMap.put(Dict.SOURCE_PARTY_ID, guestAppId); } - public String getDesPartyId() { return (String) dataMap.get(Dict.DES_PARTY_ID); } - public void setDesPartyId(String hostAppid) { dataMap.put(Dict.DES_PARTY_ID, hostAppid); } - public void setSrcComponent(String srcComponent){ dataMap.put(Dict.SOURCE_COMPONENT,srcComponent); } - public String getSrcComponent(){ return (String)dataMap.get(Dict.SOURCE_COMPONENT); } - public void setDesComponent(String desComponent){ dataMap.put(Dict.DES_COMPONENT,desComponent); } - public String getDesComponent(){ return (String)dataMap.get(Dict.DES_COMPONENT); } - public RouterInfo getRouterInfo() { return (RouterInfo) dataMap.get(Dict.ROUTER_INFO); } - public void setRouterInfo(RouterInfo routerInfo) { dataMap.put(Dict.ROUTER_INFO, routerInfo); } - public Object getResultData() { return dataMap.get(Dict.RESULT_DATA); } - public void setResultData(Object resultData) { dataMap.put(Dict.RESULT_DATA, resultData); } - public String getReturnCode() { return (String) dataMap.get(Dict.RETURN_CODE); } - public void setReturnCode(String returnCode) { dataMap.put(Dict.RETURN_CODE, returnCode); } - - public String getReturnMsg() { return (String) dataMap.get(Dict.RET_MSG); } - public void setReturnMsg(String returnMsg) { dataMap.put(Dict.RET_MSG, returnMsg); } - + public String getSelfPartyId(){ + return (String) dataMap.get(Dict.PROPERTY_SELF_PARTY_KEY); + } + public void setSelfPartyId(String partyId){ + dataMap.put(Dict.PROPERTY_SELF_PARTY_KEY,partyId); + } public long getDownstreamCost() { if (dataMap.get(Dict.DOWN_STREAM_COST) != null) { @@ -229,43 +285,33 @@ public long getDownstreamCost() { public void setDownstreamCost(long downstreamCost) { dataMap.put(Dict.DOWN_STREAM_COST, downstreamCost); } - public long getDownstreamBegin() { return dataMap.get(Dict.DOWN_STREAM_BEGIN) != null ? (long) dataMap.get(Dict.DOWN_STREAM_BEGIN) : 0; } - public void setDownstreamBegin(long downstreamBegin) { dataMap.put(Dict.DOWN_STREAM_BEGIN, downstreamBegin); } - public String getSourceIp() { return (String) dataMap.get(Dict.SOURCE_IP); } - public void setSourceIp(String sourceIp) { dataMap.put(Dict.SOURCE_IP, sourceIp); } - public String getServiceName() { return (String) dataMap.get(Dict.SERVICE_NAME); } - public void setServiceName(String serviceName) { dataMap.put(Dict.SERVICE_NAME, serviceName); } - public String getCallName() { return (String) dataMap.get(Dict.CALL_NAME); } - public void setCallName(String callName) { dataMap.put(Dict.CALL_NAME, callName); } - public void setRemoteFuture(ListenableFuture future) { this.dataMap.put(Dict.FUTURE, future); } - public String getResourceName() { if (StringUtils.isNotEmpty(resourceName)) { return resourceName; @@ -274,22 +320,75 @@ public String getResourceName() { } return resourceName; } - public boolean needAssembleException() { return needAssembleException; } + public String toString(){ + StringBuffer stringBuffer = new StringBuffer(); + if (this.getProtocol() != null) { + stringBuffer.append(this.getProtocol()).append(SPLIT); + } + if (this.getActionType() != null) { + stringBuffer.append(this.getActionType()).append(SPLIT); + } +// if(context.getSessionId()!=null){ +// stringBuffer.append("session:").append(context.getSessionId()).append(SPLIT); +// } + if (this.getTopic() != null) { + stringBuffer.append("topic:").append(this.getTopic()).append(SPLIT); + } - public FlowLogPrinter getFlowLogPrinter() { - return flowLogPrinter; - } - - public Context setFlowLogPrinter(FlowLogPrinter flowLogPrinter) { - this.flowLogPrinter = flowLogPrinter; - return this; - } - public void printFlowLog() { - if (needPrintFlowLog) { - flowLogPrinter.print(this); + if (this.getMessageFlag() != null) { + stringBuffer.append(this.getMessageFlag()).append(SPLIT); + } + if (this.getRequestMsgIndex() != null) { + stringBuffer.append("req-offset:").append(this.getRequestMsgIndex()).append(SPLIT); + } + if (this.getData(Dict.CURRENT_INDEX) != null) { + stringBuffer.append("offset-in-queue:").append(this.getData(Dict.CURRENT_INDEX)).append(SPLIT); } + if(StringUtils.isNotEmpty(this.messageCode)){ + stringBuffer.append("msg-code:").append(this.getMessageCode()).append(SPLIT); + } + if(this.jobId!=null){ + stringBuffer.append("job-id:").append(this.getJobId()).append(SPLIT); + } + if (this.getSrcPartyId() != null) { + stringBuffer.append("src:").append(this.getSrcPartyId()).append(SPLIT); + } + if (this.getDesPartyId() != null) { + stringBuffer.append("des:").append(this.getDesPartyId()).append(SPLIT); + } + if (this.getReturnCode() != null) { + stringBuffer.append("code:").append(this.getReturnCode()).append(SPLIT); + } + stringBuffer.append("cost:").append(System.currentTimeMillis() - this.getTimeStamp()).append(SPLIT); + if (this.getRouterInfo() != null) { + Protocol protocol = this.getRouterInfo().getProtocol(); + if (protocol != null) { + if (protocol.equals(Protocol.grpc)) { + stringBuffer.append(this.getRouterInfo().getHost() + ":" + this.getRouterInfo().getPort()).append(SPLIT); + } else if (protocol.equals(Protocol.http)) { + stringBuffer.append(this.getRouterInfo().getUrl()).append(SPLIT); + } + } + } + if (this.getDataSize() != null) { + stringBuffer.append("size:").append(this.getDataSize()).append(SPLIT); + } + if(this.getSleepTime()!=null&&this.getSleepTime()>0){ + stringBuffer.append("sleep:").append(this.getSleepTime()).append(SPLIT); + } + if(this.retryTime>1){ + stringBuffer.append("retry:").append(this.retryTime).append(SPLIT); + } + if (this.getReturnMsg() != null) { + stringBuffer.append("msg:").append(this.getReturnMsg()); + } + + + return stringBuffer.toString(); } + static final String SPLIT= "|"; + } diff --git a/java/osx/core/src/main/java/com/osx/core/datasource/AbstractDataSource.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/datasource/AbstractDataSource.java similarity index 91% rename from java/osx/core/src/main/java/com/osx/core/datasource/AbstractDataSource.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/datasource/AbstractDataSource.java index 74b2edaf50..beb3773423 100644 --- a/java/osx/core/src/main/java/com/osx/core/datasource/AbstractDataSource.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/datasource/AbstractDataSource.java @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.datasource; +package org.fedai.osx.core.datasource; -import com.osx.core.flow.DynamicProperty; -import com.osx.core.flow.Property; +import org.fedai.osx.core.flow.DynamicProperty; +import org.fedai.osx.core.flow.Property; public abstract class AbstractDataSource implements ReadableDataSource { diff --git a/java/osx/core/src/main/java/com/osx/core/datasource/AutoRefreshDataSource.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/datasource/AutoRefreshDataSource.java similarity index 98% rename from java/osx/core/src/main/java/com/osx/core/datasource/AutoRefreshDataSource.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/datasource/AutoRefreshDataSource.java index 8d6a14b8f5..67e1735da5 100644 --- a/java/osx/core/src/main/java/com/osx/core/datasource/AutoRefreshDataSource.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/datasource/AutoRefreshDataSource.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.datasource; +package org.fedai.osx.core.datasource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/java/osx/core/src/main/java/com/osx/core/datasource/Converter.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/datasource/Converter.java similarity index 94% rename from java/osx/core/src/main/java/com/osx/core/datasource/Converter.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/datasource/Converter.java index a82d2a785a..3c75a5a211 100644 --- a/java/osx/core/src/main/java/com/osx/core/datasource/Converter.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/datasource/Converter.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.datasource; +package org.fedai.osx.core.datasource; public interface Converter { diff --git a/java/osx/core/src/main/java/com/osx/core/datasource/FileRefreshableDataSource.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/datasource/FileRefreshableDataSource.java similarity index 99% rename from java/osx/core/src/main/java/com/osx/core/datasource/FileRefreshableDataSource.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/datasource/FileRefreshableDataSource.java index 70b8f8fa41..c89abcb3dd 100644 --- a/java/osx/core/src/main/java/com/osx/core/datasource/FileRefreshableDataSource.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/datasource/FileRefreshableDataSource.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.datasource; +package org.fedai.osx.core.datasource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/java/osx/core/src/main/java/com/osx/core/datasource/NamedThreadFactory.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/datasource/NamedThreadFactory.java similarity index 97% rename from java/osx/core/src/main/java/com/osx/core/datasource/NamedThreadFactory.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/datasource/NamedThreadFactory.java index 9b9693709f..4179117d23 100644 --- a/java/osx/core/src/main/java/com/osx/core/datasource/NamedThreadFactory.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/datasource/NamedThreadFactory.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.datasource; +package org.fedai.osx.core.datasource; import java.util.concurrent.ThreadFactory; import java.util.concurrent.atomic.AtomicInteger; diff --git a/java/osx/core/src/main/java/com/osx/core/datasource/ReadableDataSource.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/datasource/ReadableDataSource.java similarity index 94% rename from java/osx/core/src/main/java/com/osx/core/datasource/ReadableDataSource.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/datasource/ReadableDataSource.java index 10f661f6cb..5845831b2b 100644 --- a/java/osx/core/src/main/java/com/osx/core/datasource/ReadableDataSource.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/datasource/ReadableDataSource.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.datasource; -import com.osx.core.flow.Property; +package org.fedai.osx.core.datasource; +import org.fedai.osx.core.flow.Property; public interface ReadableDataSource { diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/AckIndexException.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/AckIndexException.java similarity index 91% rename from java/osx/core/src/main/java/com/osx/core/exceptions/AckIndexException.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/AckIndexException.java index 672f3f8da4..5b71dd92d5 100644 --- a/java/osx/core/src/main/java/com/osx/core/exceptions/AckIndexException.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/AckIndexException.java @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.exceptions; +package org.fedai.osx.core.exceptions; -import com.osx.core.constant.StatusCode; +import org.fedai.osx.core.constant.StatusCode; public class AckIndexException extends BaseException { public AckIndexException() { diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/BaseException.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/BaseException.java similarity index 95% rename from java/osx/core/src/main/java/com/osx/core/exceptions/BaseException.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/BaseException.java index 3ffe287d02..710c4ea2a3 100644 --- a/java/osx/core/src/main/java/com/osx/core/exceptions/BaseException.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/BaseException.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.core.exceptions; +package org.fedai.osx.core.exceptions; public class BaseException extends RuntimeException { protected String retcode; diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/ConfigErrorException.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/ConfigErrorException.java similarity index 64% rename from java/osx/core/src/main/java/com/osx/core/exceptions/ConfigErrorException.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/ConfigErrorException.java index db17d78a51..82160106d6 100644 --- a/java/osx/core/src/main/java/com/osx/core/exceptions/ConfigErrorException.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/ConfigErrorException.java @@ -1,6 +1,6 @@ -package com.osx.core.exceptions; +package org.fedai.osx.core.exceptions; -import com.osx.core.constant.StatusCode; +import org.fedai.osx.core.constant.StatusCode; public class ConfigErrorException extends BaseException{ diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/ConsumeNoMessageException.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/ConsumeNoMessageException.java similarity index 91% rename from java/osx/core/src/main/java/com/osx/core/exceptions/ConsumeNoMessageException.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/ConsumeNoMessageException.java index 293a41a11d..e236fc0425 100644 --- a/java/osx/core/src/main/java/com/osx/core/exceptions/ConsumeNoMessageException.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/ConsumeNoMessageException.java @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.exceptions; +package org.fedai.osx.core.exceptions; -import com.osx.core.constant.StatusCode; +import org.fedai.osx.core.constant.StatusCode; public class ConsumeNoMessageException extends BaseException { diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/ConsumerNotExistException.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/ConsumerNotExistException.java similarity index 91% rename from java/osx/core/src/main/java/com/osx/core/exceptions/ConsumerNotExistException.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/ConsumerNotExistException.java index 90cf882d93..480208423b 100644 --- a/java/osx/core/src/main/java/com/osx/core/exceptions/ConsumerNotExistException.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/ConsumerNotExistException.java @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.exceptions; +package org.fedai.osx.core.exceptions; -import com.osx.core.constant.StatusCode; +import org.fedai.osx.core.constant.StatusCode; public class ConsumerNotExistException extends BaseException { diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/CreateTopicErrorException.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/CreateTopicErrorException.java similarity index 90% rename from java/osx/core/src/main/java/com/osx/core/exceptions/CreateTopicErrorException.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/CreateTopicErrorException.java index 9ea27cc57b..803d4046b9 100644 --- a/java/osx/core/src/main/java/com/osx/core/exceptions/CreateTopicErrorException.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/CreateTopicErrorException.java @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.exceptions; +package org.fedai.osx.core.exceptions; -import com.osx.core.constant.StatusCode; +import org.fedai.osx.core.constant.StatusCode; public class CreateTopicErrorException extends BaseException{ public CreateTopicErrorException(String msg){ diff --git a/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/CycleRouteInfoException.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/CycleRouteInfoException.java new file mode 100644 index 0000000000..a8a4484ac4 --- /dev/null +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/CycleRouteInfoException.java @@ -0,0 +1,9 @@ +package org.fedai.osx.core.exceptions; + +import org.fedai.osx.core.constant.StatusCode; + +public class CycleRouteInfoException extends BaseException{ + public CycleRouteInfoException(String msg){ + super(StatusCode.CYCLE_ROUTE_ERROR,msg); + } +} diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/ErrorMessageUtil.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/ErrorMessageUtil.java similarity index 62% rename from java/osx/core/src/main/java/com/osx/core/exceptions/ErrorMessageUtil.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/ErrorMessageUtil.java index dcaed77a25..1b03ccb2e7 100644 --- a/java/osx/core/src/main/java/com/osx/core/exceptions/ErrorMessageUtil.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/ErrorMessageUtil.java @@ -14,13 +14,15 @@ * limitations under the License. */ -package com.osx.core.exceptions; +package org.fedai.osx.core.exceptions; -import com.osx.core.constant.Dict; -import com.osx.core.constant.StatusCode; -import com.osx.core.context.Context; import io.grpc.Status; import io.grpc.StatusRuntimeException; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.exception.ExceptionUtils; +import org.fedai.osx.api.context.Context; +import org.fedai.osx.core.constant.Dict; +import org.fedai.osx.core.constant.StatusCode; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -35,6 +37,8 @@ public class ErrorMessageUtil { static Logger logger = LoggerFactory.getLogger(ErrorMessageUtil.class); + static final String MESSAGE_PREFIX = "PARTY_"; + public static String buildRemoteRpcErrorMsg(int code, String msg) { return new StringBuilder().append("host return code ").append(code) .append(" host msg :").append(msg).toString(); @@ -64,21 +68,46 @@ public static StatusRuntimeException throwableToException(Context context, Throw return status.asRuntimeException(); } + public static StatusRuntimeException toGrpcRuntimeException(Throwable throwable) { + StatusRuntimeException result = null; + + if (throwable instanceof StatusRuntimeException) { + result = (StatusRuntimeException) throwable; + } else { + result = Status.INTERNAL + .withCause(throwable) + // .withDescription(throwable.getMessage()) + .withDescription(throwable.getMessage()+ ": " + ExceptionUtils.getStackTrace(throwable)) + .asRuntimeException(); + } + + return result; + } public static ExceptionInfo handleExceptionExceptionInfo(Context context, Throwable e) { ExceptionInfo exceptionInfo = new ExceptionInfo(); + String selfPartyId = context.getSelfPartyId(); + String oriMessage = e.getMessage(); + String message = ""; + if(StringUtils.isNotEmpty(selfPartyId)){ + message = MESSAGE_PREFIX+selfPartyId+":"+oriMessage; + }else{ + message = oriMessage; + } if (e instanceof BaseException) { BaseException baseException = (BaseException) e; exceptionInfo.setCode(baseException.getRetcode()); - exceptionInfo.setMessage(baseException.getMessage()); } else { + logger.error("SYSTEM_ERROR ==> " ,e); exceptionInfo.setCode(StatusCode.SYSTEM_ERROR); - exceptionInfo.setMessage(e.getMessage()); } + exceptionInfo.setMessage(message); exceptionInfo.setThrowable(e); - if (context.needAssembleException()) { - exceptionInfo.setThrowable(throwableToException(context, e)); - } +// if (context.needAssembleException()) { +// exceptionInfo.setThrowable(throwableToException(context, e)); +// } + + return exceptionInfo; } @@ -95,35 +124,6 @@ public static Map handleExceptionToMap(Throwable e) { } public static Map handleException(Map result, Throwable e) { -// if (e instanceof IllegalArgumentException) { -// result.put(Dict.CODE, StatusCode.PARAM_ERROR); -// result.put(Dict.MESSAGE, "PARAM_ERROR"); -// } else if (e instanceof NoRouterInfoException) { -// result.put(Dict.CODE, StatusCode.GUEST_ROUTER_ERROR); -// result.put(Dict.MESSAGE, "ROUTER_ERROR"); -// } else if (e instanceof SysException) { -// result.put(Dict.CODE, StatusCode.SYSTEM_ERROR); -// result.put(Dict.MESSAGE, "SYSTEM_ERROR"); -// } else if (e instanceof OverLoadException) { -// result.put(Dict.CODE, StatusCode.OVER_LOAD_ERROR); -// result.put(Dict.MESSAGE, "OVER_LOAD"); -// } else if (e instanceof InvalidRoleInfoException) { -// result.put(Dict.CODE, StatusCode.INVALID_ROLE_ERROR); -// result.put(Dict.MESSAGE, "ROLE_ERROR"); -// } else if (e instanceof ShowDownRejectException) { -// result.put(Dict.CODE, StatusCode.SHUTDOWN_ERROR); -// result.put(Dict.MESSAGE, "SHUTDOWN_ERROR"); -// -// } else if (e instanceof NoResultException) { -// logger.error("NET_ERROR ", e); -// result.put(Dict.CODE, StatusCode.NET_ERROR); -// result.put(Dict.MESSAGE, "NET_ERROR"); -// } else { -// logger.error("SYSTEM_ERROR ", e); -// result.put(Dict.CODE, StatusCode.SYSTEM_ERROR); -// result.put(Dict.MESSAGE, "SYSTEM_ERROR"); -// } - return result; } } diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/ExceptionInfo.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/ExceptionInfo.java similarity index 92% rename from java/osx/core/src/main/java/com/osx/core/exceptions/ExceptionInfo.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/ExceptionInfo.java index 0d3187f6e2..f2a6586740 100644 --- a/java/osx/core/src/main/java/com/osx/core/exceptions/ExceptionInfo.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/ExceptionInfo.java @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.exceptions; +package org.fedai.osx.core.exceptions; import com.google.common.collect.Maps; -import com.osx.core.constant.Dict; -import com.osx.core.utils.JsonUtil; +import org.fedai.osx.core.constant.Dict; +import org.fedai.osx.core.utils.JsonUtil; import java.util.Map; diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/InvalidRedirectInfoException.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/InvalidRedirectInfoException.java similarity index 90% rename from java/osx/core/src/main/java/com/osx/core/exceptions/InvalidRedirectInfoException.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/InvalidRedirectInfoException.java index f0f4e8edfb..e9a083fbf5 100644 --- a/java/osx/core/src/main/java/com/osx/core/exceptions/InvalidRedirectInfoException.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/InvalidRedirectInfoException.java @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.exceptions; +package org.fedai.osx.core.exceptions; -import com.osx.core.constant.StatusCode; +import org.fedai.osx.core.constant.StatusCode; public class InvalidRedirectInfoException extends BaseException { public InvalidRedirectInfoException() { diff --git a/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/InvalidRouteInfoException.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/InvalidRouteInfoException.java new file mode 100644 index 0000000000..d4c8060704 --- /dev/null +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/InvalidRouteInfoException.java @@ -0,0 +1,6 @@ +package org.fedai.osx.core.exceptions; + +public class InvalidRouteInfoException extends BaseException{ + + +} diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/MappedFileException.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/MappedFileException.java similarity index 94% rename from java/osx/core/src/main/java/com/osx/core/exceptions/MappedFileException.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/MappedFileException.java index 4b25c25ff6..db96b66ac7 100644 --- a/java/osx/core/src/main/java/com/osx/core/exceptions/MappedFileException.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/MappedFileException.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.exceptions; +package org.fedai.osx.core.exceptions; public class MappedFileException extends BaseException { } diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/MessageParseException.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/MessageParseException.java similarity index 90% rename from java/osx/core/src/main/java/com/osx/core/exceptions/MessageParseException.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/MessageParseException.java index b8dc2a17cf..196337d33f 100644 --- a/java/osx/core/src/main/java/com/osx/core/exceptions/MessageParseException.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/MessageParseException.java @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.exceptions; +package org.fedai.osx.core.exceptions; -import com.osx.core.constant.StatusCode; +import org.fedai.osx.core.constant.StatusCode; public class MessageParseException extends BaseException { public MessageParseException(String msg) { diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/NoRouterInfoException.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/NoRouterInfoException.java similarity index 90% rename from java/osx/core/src/main/java/com/osx/core/exceptions/NoRouterInfoException.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/NoRouterInfoException.java index 633b5fd9d0..9f956dbc54 100644 --- a/java/osx/core/src/main/java/com/osx/core/exceptions/NoRouterInfoException.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/NoRouterInfoException.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.exceptions; -import com.osx.core.constant.StatusCode; +package org.fedai.osx.core.exceptions; +import org.fedai.osx.core.constant.StatusCode; public class NoRouterInfoException extends BaseException { diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/ParameterException.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/ParameterException.java similarity index 91% rename from java/osx/core/src/main/java/com/osx/core/exceptions/ParameterException.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/ParameterException.java index 5434cf52fb..3ef03b93e8 100644 --- a/java/osx/core/src/main/java/com/osx/core/exceptions/ParameterException.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/ParameterException.java @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.exceptions; -import com.osx.core.constant.StatusCode; +package org.fedai.osx.core.exceptions; +import org.fedai.osx.core.constant.StatusCode; public class ParameterException extends BaseException { public ParameterException(String retCode, String message) { diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/ProduceMsgExcption.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/ProduceMsgExcption.java similarity index 94% rename from java/osx/core/src/main/java/com/osx/core/exceptions/ProduceMsgExcption.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/ProduceMsgExcption.java index 1c9825fe17..5e321f3da1 100644 --- a/java/osx/core/src/main/java/com/osx/core/exceptions/ProduceMsgExcption.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/ProduceMsgExcption.java @@ -13,6 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.exceptions; +package org.fedai.osx.core.exceptions; public class ProduceMsgExcption extends BaseException { } diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/PutMessageException.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/PutMessageException.java similarity index 91% rename from java/osx/core/src/main/java/com/osx/core/exceptions/PutMessageException.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/PutMessageException.java index 3031a0e8fd..f7c8c78e19 100644 --- a/java/osx/core/src/main/java/com/osx/core/exceptions/PutMessageException.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/PutMessageException.java @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.exceptions; +package org.fedai.osx.core.exceptions; -import com.osx.core.constant.StatusCode; +import org.fedai.osx.core.constant.StatusCode; public class PutMessageException extends BaseException { diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/RemoteRpcException.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/RemoteRpcException.java similarity index 91% rename from java/osx/core/src/main/java/com/osx/core/exceptions/RemoteRpcException.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/RemoteRpcException.java index a15c8a2c12..77342aa3a5 100644 --- a/java/osx/core/src/main/java/com/osx/core/exceptions/RemoteRpcException.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/RemoteRpcException.java @@ -14,9 +14,9 @@ * limitations under the License. */ -package com.osx.core.exceptions; +package org.fedai.osx.core.exceptions; -import com.osx.core.constant.StatusCode; +import org.fedai.osx.core.constant.StatusCode; public class RemoteRpcException extends BaseException { diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/RouterInfoOperateException.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/RouterInfoOperateException.java similarity index 91% rename from java/osx/core/src/main/java/com/osx/core/exceptions/RouterInfoOperateException.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/RouterInfoOperateException.java index 37d6f51acc..df3ec2c6a5 100644 --- a/java/osx/core/src/main/java/com/osx/core/exceptions/RouterInfoOperateException.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/RouterInfoOperateException.java @@ -14,9 +14,9 @@ * limitations under the License. */ -package com.osx.core.exceptions; +package org.fedai.osx.core.exceptions; -import com.osx.core.constant.StatusCode; +import org.fedai.osx.core.constant.StatusCode; public class RouterInfoOperateException extends BaseException { diff --git a/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/SessionInitException.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/SessionInitException.java new file mode 100644 index 0000000000..0a0d7088a4 --- /dev/null +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/SessionInitException.java @@ -0,0 +1,12 @@ +package org.fedai.osx.core.exceptions; + +import org.fedai.osx.core.constant.StatusCode; + +public class SessionInitException extends BaseException{ + public SessionInitException(String retCode, String message) { + super(retCode, message); + } + public SessionInitException(String message) { + super(StatusCode.SESSION_INIT_ERROR, message); + } +} diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/SysException.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/SysException.java similarity index 91% rename from java/osx/core/src/main/java/com/osx/core/exceptions/SysException.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/SysException.java index a1b738fed8..8fab9f8038 100644 --- a/java/osx/core/src/main/java/com/osx/core/exceptions/SysException.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/SysException.java @@ -14,10 +14,10 @@ * limitations under the License. */ -package com.osx.core.exceptions; +package org.fedai.osx.core.exceptions; -import com.osx.core.constant.StatusCode; +import org.fedai.osx.core.constant.StatusCode; public class SysException extends BaseException { diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/TransferQueueAlreadyExistException.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/TransferQueueAlreadyExistException.java similarity index 94% rename from java/osx/core/src/main/java/com/osx/core/exceptions/TransferQueueAlreadyExistException.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/TransferQueueAlreadyExistException.java index 69429713ef..19856f9b67 100644 --- a/java/osx/core/src/main/java/com/osx/core/exceptions/TransferQueueAlreadyExistException.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/TransferQueueAlreadyExistException.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.exceptions; +package org.fedai.osx.core.exceptions; public class TransferQueueAlreadyExistException extends BaseException { } diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/TransferQueueInvalidStatusException.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/TransferQueueInvalidStatusException.java similarity index 90% rename from java/osx/core/src/main/java/com/osx/core/exceptions/TransferQueueInvalidStatusException.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/TransferQueueInvalidStatusException.java index 7ba7d5a614..3e6899be0c 100644 --- a/java/osx/core/src/main/java/com/osx/core/exceptions/TransferQueueInvalidStatusException.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/TransferQueueInvalidStatusException.java @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.exceptions; +package org.fedai.osx.core.exceptions; -import com.osx.core.constant.StatusCode; +import org.fedai.osx.core.constant.StatusCode; public class TransferQueueInvalidStatusException extends BaseException { diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/TransferQueueNotExistException.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/TransferQueueNotExistException.java similarity index 80% rename from java/osx/core/src/main/java/com/osx/core/exceptions/TransferQueueNotExistException.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/TransferQueueNotExistException.java index 2956292fbb..890e6bdd20 100644 --- a/java/osx/core/src/main/java/com/osx/core/exceptions/TransferQueueNotExistException.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/TransferQueueNotExistException.java @@ -13,16 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.exceptions; +package org.fedai.osx.core.exceptions; -import com.osx.core.constant.StatusCode; +import org.fedai.osx.core.constant.StatusCode; public class TransferQueueNotExistException extends BaseException { public TransferQueueNotExistException() { super(StatusCode.TRANSFER_QUEUE_NOT_FIND, "TRANSFER_QUEUE_NOT_FIND"); } - public TransferQueueNotExistException(String code, String msg) { - super(code, msg); + public TransferQueueNotExistException( String msg) { + super(StatusCode.TRANSFER_QUEUE_NOT_FIND, msg); } } diff --git a/java/osx/core/src/main/java/com/osx/core/exceptions/UnSupportMethodException.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/UnSupportMethodException.java similarity index 95% rename from java/osx/core/src/main/java/com/osx/core/exceptions/UnSupportMethodException.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/UnSupportMethodException.java index 7c4e4d598a..51bb921ac2 100644 --- a/java/osx/core/src/main/java/com/osx/core/exceptions/UnSupportMethodException.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/exceptions/UnSupportMethodException.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.core.exceptions; +package org.fedai.osx.core.exceptions; public class UnSupportMethodException extends RuntimeException { diff --git a/java/osx/core/src/main/java/com/osx/core/flow/AbstractRule.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/AbstractRule.java similarity index 98% rename from java/osx/core/src/main/java/com/osx/core/flow/AbstractRule.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/AbstractRule.java index 0f953fc323..67ff636131 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/AbstractRule.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/AbstractRule.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; public abstract class AbstractRule implements Rule { diff --git a/java/osx/core/src/main/java/com/osx/core/flow/BucketLeapArray.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/BucketLeapArray.java similarity index 97% rename from java/osx/core/src/main/java/com/osx/core/flow/BucketLeapArray.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/BucketLeapArray.java index 52f12e4a35..b16a7ecc24 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/BucketLeapArray.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/BucketLeapArray.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; diff --git a/java/osx/core/src/main/java/com/osx/core/flow/ClusterFlowChecker.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterFlowChecker.java similarity index 97% rename from java/osx/core/src/main/java/com/osx/core/flow/ClusterFlowChecker.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterFlowChecker.java index f7373a88b6..095c9fe1c8 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/ClusterFlowChecker.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterFlowChecker.java @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; -import com.osx.core.token.TokenResult; -import com.osx.core.token.TokenResultStatus; +import org.fedai.osx.core.token.TokenResult; +import org.fedai.osx.core.token.TokenResultStatus; final public class ClusterFlowChecker { diff --git a/java/osx/core/src/main/java/com/osx/core/flow/ClusterFlowConfig.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterFlowConfig.java similarity index 99% rename from java/osx/core/src/main/java/com/osx/core/flow/ClusterFlowConfig.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterFlowConfig.java index aafc2952fa..62446ac9f8 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/ClusterFlowConfig.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterFlowConfig.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; import java.util.Objects; diff --git a/java/osx/core/src/main/java/com/osx/core/flow/ClusterFlowEvent.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterFlowEvent.java similarity index 94% rename from java/osx/core/src/main/java/com/osx/core/flow/ClusterFlowEvent.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterFlowEvent.java index 63d5a12106..f1e1184e01 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/ClusterFlowEvent.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterFlowEvent.java @@ -1,4 +1,4 @@ -package com.osx.core.flow; +package org.fedai.osx.core.flow; public enum ClusterFlowEvent { diff --git a/java/osx/core/src/main/java/com/osx/core/flow/ClusterFlowRuleManager.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterFlowRuleManager.java similarity index 89% rename from java/osx/core/src/main/java/com/osx/core/flow/ClusterFlowRuleManager.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterFlowRuleManager.java index 9e202f33fc..2c36ee4aa6 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/ClusterFlowRuleManager.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterFlowRuleManager.java @@ -13,25 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; + import com.fasterxml.jackson.core.type.TypeReference; -import com.osx.core.config.MetaInfo; -import com.osx.core.datasource.FileRefreshableDataSource; -import com.osx.core.utils.AssertUtil; -import com.osx.core.utils.JsonUtil; import org.apache.commons.lang3.StringUtils; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.datasource.FileRefreshableDataSource; +import org.fedai.osx.core.utils.AssertUtil; +import org.fedai.osx.core.utils.JsonUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; import java.io.FileNotFoundException; -import java.net.URL; import java.util.*; import java.util.concurrent.ConcurrentHashMap; -import static com.osx.core.config.MetaInfo.PROPERTY_INTERVAL_MS; -import static com.osx.core.config.MetaInfo.PROPERTY_SAMPLE_COUNT; - public final class ClusterFlowRuleManager { public static final Function>> DEFAULT_PROPERTY_SUPPLIER = @@ -70,34 +67,23 @@ private static void initDefaultProperty() { PropertyListener> listener = new FlowRulePropertyListener(defaultNamespace); registerPropertyInternal(defaultNamespace, defaultProperty, listener); String currentPath = null; - if (MetaInfo.PROPERTY_FLOW_RULE_TABLE != null) { - currentPath = MetaInfo.PROPERTY_FLOW_RULE_TABLE; - } else { - URL url = Thread.currentThread().getContextClassLoader().getResource("flowRule.json"); - - if (url != null) { - currentPath = url.getPath(); - } else { - logger.error("file flowRule.json not found"); - } - } + //先考虑开发本地情况、不是本地再按服务器方式获取 + currentPath = MetaInfo.PROPERTY_CONFIG_DIR + "/" + MetaInfo.PROPERTY_FLOW_RULE_TABLE; logger.info("load flow rule {}", currentPath); - if (currentPath != null) { - File confFile = new File(currentPath); - FileRefreshableDataSource fileRefreshableDataSource = null; - try { - fileRefreshableDataSource = new FileRefreshableDataSource(confFile, (source) -> { - - List content = JsonUtil.json2List((String) source, new TypeReference>() { - }); - logger.info("load flow rule content {}", content); - return content; + File confFile = new File(currentPath); + FileRefreshableDataSource fileRefreshableDataSource; + try { + fileRefreshableDataSource = new FileRefreshableDataSource(confFile, (source) -> { + + List content = JsonUtil.json2List((String) source, new TypeReference>() { }); - fileRefreshableDataSource.getProperty().addListener(listener); - } catch (FileNotFoundException e) { - e.printStackTrace(); - logger.error("flow rule file not exist"); - } + logger.info("load flow rule content {}", content); + return content; + }); + fileRefreshableDataSource.getProperty().addListener(listener); + } catch (FileNotFoundException e) { + e.printStackTrace(); + logger.error("flow rule file not exist"); } } @@ -177,6 +163,7 @@ private static void registerPropertyInternal(/*@NonNull*/ String namespace, /*@V resetNamespaceFlowIdMapFor(namespace); } } + public static void removeProperty(String namespace) { AssertUtil.notEmpty(namespace, "namespace cannot be empty"); synchronized (UPDATE_LOCK) { @@ -338,7 +325,6 @@ private static void applyClusterFlowRule(List list, /*@Valid*/ String Set flowIdSet = new HashSet<>(); for (FlowRule rule : list) { - System.err.println("===================" + rule); if (!rule.isClusterMode()) { continue; } @@ -368,7 +354,7 @@ private static void applyClusterFlowRule(List list, /*@Valid*/ String // Prepare cluster metric from valid flow ID. ClusterMetricStatistics.putMetricIfAbsent(rule.getResource(), - new ClusterMetric(PROPERTY_SAMPLE_COUNT, PROPERTY_INTERVAL_MS)); + new ClusterMetric(MetaInfo.PROPERTY_FLOW_CONTROL_SAMPLE_COUNT, MetaInfo.PROPERTY_FLOW_CONTROL_SAMPLE_INTERVAL)); } // Cleanup unused cluster metrics. diff --git a/java/osx/core/src/main/java/com/osx/core/flow/ClusterMetric.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterMetric.java similarity index 97% rename from java/osx/core/src/main/java/com/osx/core/flow/ClusterMetric.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterMetric.java index 6b26220562..c8f8b033e2 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/ClusterMetric.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterMetric.java @@ -13,8 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.flow; -import com.osx.core.utils.AssertUtil; +package org.fedai.osx.core.flow; +import org.fedai.osx.core.utils.AssertUtil; + import java.util.List; import java.util.concurrent.atomic.AtomicLong; diff --git a/java/osx/core/src/main/java/com/osx/core/flow/ClusterMetricBucket.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterMetricBucket.java similarity index 97% rename from java/osx/core/src/main/java/com/osx/core/flow/ClusterMetricBucket.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterMetricBucket.java index e658960237..78d61cadbe 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/ClusterMetricBucket.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterMetricBucket.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; import java.util.concurrent.atomic.LongAdder; public class ClusterMetricBucket { diff --git a/java/osx/core/src/main/java/com/osx/core/flow/ClusterMetricLeapArray.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterMetricLeapArray.java similarity index 98% rename from java/osx/core/src/main/java/com/osx/core/flow/ClusterMetricLeapArray.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterMetricLeapArray.java index e4527893a7..6ab8da0305 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/ClusterMetricLeapArray.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterMetricLeapArray.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; import java.util.concurrent.atomic.LongAdder; public class ClusterMetricLeapArray extends LeapArray { diff --git a/java/osx/core/src/main/java/com/osx/core/flow/ClusterMetricStatistics.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterMetricStatistics.java similarity index 96% rename from java/osx/core/src/main/java/com/osx/core/flow/ClusterMetricStatistics.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterMetricStatistics.java index 44e2c78e53..2308649dad 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/ClusterMetricStatistics.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterMetricStatistics.java @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; -import com.osx.core.utils.AssertUtil; +import org.fedai.osx.core.utils.AssertUtil; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; diff --git a/java/osx/core/src/main/java/com/osx/core/flow/ClusterRuleConstant.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterRuleConstant.java similarity index 96% rename from java/osx/core/src/main/java/com/osx/core/flow/ClusterRuleConstant.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterRuleConstant.java index adb1a5dc6c..fc5a89c301 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/ClusterRuleConstant.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterRuleConstant.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; public final class ClusterRuleConstant { diff --git a/java/osx/core/src/main/java/com/osx/core/flow/ClusterRuleUtil.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterRuleUtil.java similarity index 95% rename from java/osx/core/src/main/java/com/osx/core/flow/ClusterRuleUtil.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterRuleUtil.java index 03836f4e73..c9a98b9fdc 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/ClusterRuleUtil.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/ClusterRuleUtil.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; public final class ClusterRuleUtil { diff --git a/java/osx/core/src/main/java/com/osx/core/flow/CurrentConcurrencyManager.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/CurrentConcurrencyManager.java similarity index 95% rename from java/osx/core/src/main/java/com/osx/core/flow/CurrentConcurrencyManager.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/CurrentConcurrencyManager.java index 8615529b9a..c0a0b7197d 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/CurrentConcurrencyManager.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/CurrentConcurrencyManager.java @@ -1,7 +1,7 @@ -package com.osx.core.flow; +package org.fedai.osx.core.flow; -import com.osx.core.datasource.NamedThreadFactory; +import org.fedai.osx.core.datasource.NamedThreadFactory; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; diff --git a/java/osx/core/src/main/java/com/osx/core/flow/DebugSupport.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/DebugSupport.java similarity index 95% rename from java/osx/core/src/main/java/com/osx/core/flow/DebugSupport.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/DebugSupport.java index 434ec01fac..2ea77c9393 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/DebugSupport.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/DebugSupport.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; public interface DebugSupport { diff --git a/java/osx/core/src/main/java/com/osx/core/flow/DynamicProperty.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/DynamicProperty.java similarity index 98% rename from java/osx/core/src/main/java/com/osx/core/flow/DynamicProperty.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/DynamicProperty.java index 4a5afd1f52..14fd0d1169 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/DynamicProperty.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/DynamicProperty.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/java/osx/core/src/main/java/com/osx/core/flow/FileMetricReport.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/FileMetricReport.java similarity index 92% rename from java/osx/core/src/main/java/com/osx/core/flow/FileMetricReport.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/FileMetricReport.java index 37da21a12c..af98cbd44f 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/FileMetricReport.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/FileMetricReport.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -40,10 +40,10 @@ public void rmAllFile() throws Exception { @Override public void report(List data) { try { - // logger.info("report {}",data); + // logger.info("report {}",data); metricWriter.write(TimeUtil.currentTimeMillis(), data); } catch (Exception e) { - e.printStackTrace(); + // e.printStackTrace(); } } } diff --git a/java/osx/core/src/main/java/com/osx/core/flow/FlowCounter.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/FlowCounter.java similarity index 98% rename from java/osx/core/src/main/java/com/osx/core/flow/FlowCounter.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/FlowCounter.java index 18adf4ef0e..122f1061d0 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/FlowCounter.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/FlowCounter.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; import java.util.List; import java.util.concurrent.atomic.LongAdder; diff --git a/java/osx/core/src/main/java/com/osx/core/flow/FlowCounterManager.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/FlowCounterManager.java similarity index 76% rename from java/osx/core/src/main/java/com/osx/core/flow/FlowCounterManager.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/FlowCounterManager.java index 5a5d7b1b52..ea363d8765 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/FlowCounterManager.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/FlowCounterManager.java @@ -14,13 +14,13 @@ * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; import com.fasterxml.jackson.core.type.TypeReference; import com.google.common.collect.Lists; import com.google.common.collect.Maps; -import com.osx.core.utils.GetSystemInfo; -import com.osx.core.utils.JsonUtil; +import org.fedai.osx.core.utils.GetSystemInfo; +import org.fedai.osx.core.utils.JsonUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -64,41 +64,9 @@ public FlowCounterManager(String appName, Boolean countModelRequest) { } metricSearcher = new MetricSearcher(MetricWriter.METRIC_BASE_DIR, baseFileName); metricReport = new FileMetricReport(appName); -// if (countModelRequest) { -// modelMetricReport = new FileMetricReport("model"); -// // modelMetricSearcher = new MetricSearcher(MetricWriter.METRIC_BASE_DIR, modelFileName); -// } } -// public static void main(String[] args) throws IOException { -// MetaInfo.PROPERTY_ROOT_PATH = new File("").getCanonicalPath(); -// -// FlowCounterManager flowCounterManager = new FlowCounterManager("test"); -// flowCounterManager.setMetricReport(new FileMetricReport("Test")); -// flowCounterManager.setMetricSearcher(new MetricSearcher(MetricWriter.METRIC_BASE_DIR, "Test" + "-metrics.log.pid" + GetSystemInfo.getPid())); -// flowCounterManager.startReport(); -// flowCounterManager.file = new File(MetaInfo.PROPERTY_ROOT_PATH + File.separator + ".fate" + File.separator + "flowRules.json"); -// flowCounterManager.initialize(); -// -// int i = 0; -// while (true) { -// flowCounterManager.setAllowQps("source-" + i, i); -// i++; -// try { -// Thread.sleep(5000); -// } catch (InterruptedException e) { -// e.printStackTrace(); -// } -// } -// /*while (true) { -// flowCounterManager.pass("M_test"); -// try { -// Thread.sleep(100); -// } catch (InterruptedException e) { -// e.printStackTrace(); -// } -// }*/ -// } + public MetricSearcher getMetricSearcher() { return metricSearcher; @@ -247,38 +215,38 @@ public void rmAllFiles() { } } - /** - * init rules - */ - private void initialize() throws IOException { - file = new File(DEFAULT_CONFIG_FILE); - logger.info("try to load flow counter rules, {}", file.getAbsolutePath()); - - if (file.exists()) { - String result = ""; - try ( - FileInputStream fileInputStream = new FileInputStream(file); - InputStreamReader inputStreamReader = new InputStreamReader(fileInputStream, "UTF-8"); - BufferedReader reader = new BufferedReader(inputStreamReader) - ) { - String tempString; - while ((tempString = reader.readLine()) != null) { - result += tempString; - } - - List list = JsonUtil.json2Object(result, new TypeReference>() { - }); - if (list != null) { - list.forEach(map -> { - sourceQpsAllowMap.put((String) map.get("source"), Double.valueOf(String.valueOf(map.get("allow_qps")))); - }); - } - } catch (IOException e) { - logger.error("load flow counter rules failed, use default setting, cause by: {}", e.getMessage()); - } - logger.info("load flow counter rules success"); - } - } +// /** +// * init rules +// */ +// private void initialize() throws IOException { +// file = new File(DEFAULT_CONFIG_FILE); +// logger.info("try to load flow counter rules, {}", file.getAbsolutePath()); +// +// if (file.exists()) { +// String result = ""; +// try ( +// FileInputStream fileInputStream = new FileInputStream(file); +// InputStreamReader inputStreamReader = new InputStreamReader(fileInputStream, "UTF-8"); +// BufferedReader reader = new BufferedReader(inputStreamReader) +// ) { +// String tempString; +// while ((tempString = reader.readLine()) != null) { +// result += tempString; +// } +// +// List list = JsonUtil.json2Object(result, new TypeReference>() { +// }); +// if (list != null) { +// list.forEach(map -> { +// sourceQpsAllowMap.put((String) map.get("source"), Double.valueOf(String.valueOf(map.get("allow_qps")))); +// }); +// } +// } catch (IOException e) { +// logger.error("load flow counter rules failed, use default setting, cause by: {}", e.getMessage()); +// } +// logger.info("load flow counter rules success"); +// } +// } private void store(File file, byte[] data) { try { @@ -332,8 +300,4 @@ public void destroy() { } } -// @Override -// public void onApplicationEvent(ApplicationReadyEvent applicationReadyEvent) { -// startReport(); -// } } diff --git a/java/osx/core/src/main/java/com/osx/core/flow/FlowRule.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/FlowRule.java similarity index 98% rename from java/osx/core/src/main/java/com/osx/core/flow/FlowRule.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/FlowRule.java index 9bc1b30c14..0baf2a6d3a 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/FlowRule.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/FlowRule.java @@ -13,12 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; import com.google.common.collect.Lists; -import com.osx.core.utils.JsonUtil; +import org.fedai.osx.core.utils.JsonUtil; import java.util.List; diff --git a/java/osx/core/src/main/java/com/osx/core/flow/Function.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/Function.java similarity index 95% rename from java/osx/core/src/main/java/com/osx/core/flow/Function.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/Function.java index aebf9d7cd6..f8b9d4fd4e 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/Function.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/Function.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; public interface Function { diff --git a/java/osx/core/src/main/java/com/osx/core/flow/GlobalRequestLimiter.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/GlobalRequestLimiter.java similarity index 96% rename from java/osx/core/src/main/java/com/osx/core/flow/GlobalRequestLimiter.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/GlobalRequestLimiter.java index 4c35323359..6080a5810e 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/GlobalRequestLimiter.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/GlobalRequestLimiter.java @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; -import com.osx.core.utils.AssertUtil; +import org.fedai.osx.core.utils.AssertUtil; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; diff --git a/java/osx/core/src/main/java/com/osx/core/flow/LeapArray.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/LeapArray.java similarity index 99% rename from java/osx/core/src/main/java/com/osx/core/flow/LeapArray.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/LeapArray.java index bbc5e946b9..4ec6413b0e 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/LeapArray.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/LeapArray.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; import java.util.ArrayList; import java.util.List; diff --git a/java/osx/core/src/main/java/com/osx/core/flow/LimitQueue.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/LimitQueue.java similarity index 98% rename from java/osx/core/src/main/java/com/osx/core/flow/LimitQueue.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/LimitQueue.java index 2cb4082941..945ce3515d 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/LimitQueue.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/LimitQueue.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; import java.util.Collection; import java.util.Iterator; diff --git a/java/osx/core/src/main/java/com/osx/core/flow/Metric.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/Metric.java similarity index 99% rename from java/osx/core/src/main/java/com/osx/core/flow/Metric.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/Metric.java index 582e662ab6..faa1f87f54 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/Metric.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/Metric.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; diff --git a/java/osx/core/src/main/java/com/osx/core/flow/MetricBucket.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/MetricBucket.java similarity index 98% rename from java/osx/core/src/main/java/com/osx/core/flow/MetricBucket.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/MetricBucket.java index 18b250d8d0..fcdf69ea2b 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/MetricBucket.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/MetricBucket.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; import java.util.concurrent.atomic.LongAdder; diff --git a/java/osx/core/src/main/java/com/osx/core/flow/MetricEvent.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/MetricEvent.java similarity index 95% rename from java/osx/core/src/main/java/com/osx/core/flow/MetricEvent.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/MetricEvent.java index 537abd12db..63dde9805d 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/MetricEvent.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/MetricEvent.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; public enum MetricEvent { diff --git a/java/osx/core/src/main/java/com/osx/core/flow/MetricNode.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/MetricNode.java similarity index 99% rename from java/osx/core/src/main/java/com/osx/core/flow/MetricNode.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/MetricNode.java index 14c67fdfca..384e43fc2d 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/MetricNode.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/MetricNode.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; import java.text.DateFormat; import java.text.SimpleDateFormat; diff --git a/java/osx/core/src/main/java/com/osx/core/flow/MetricReport.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/MetricReport.java similarity index 95% rename from java/osx/core/src/main/java/com/osx/core/flow/MetricReport.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/MetricReport.java index fdaf4bf541..65112d57f8 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/MetricReport.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/MetricReport.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; import java.util.List; diff --git a/java/osx/core/src/main/java/com/osx/core/flow/MetricSearcher.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/MetricSearcher.java similarity index 99% rename from java/osx/core/src/main/java/com/osx/core/flow/MetricSearcher.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/MetricSearcher.java index fec65aeef1..229721eff9 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/MetricSearcher.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/MetricSearcher.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; import java.io.DataInputStream; import java.io.EOFException; diff --git a/java/osx/core/src/main/java/com/osx/core/flow/MetricWriter.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/MetricWriter.java similarity index 99% rename from java/osx/core/src/main/java/com/osx/core/flow/MetricWriter.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/MetricWriter.java index 721e2ee844..2edc80bd15 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/MetricWriter.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/MetricWriter.java @@ -14,11 +14,11 @@ * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; -import com.osx.core.config.MetaInfo; -import com.osx.core.utils.GetSystemInfo; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.utils.GetSystemInfo; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/java/osx/core/src/main/java/com/osx/core/flow/MetricsReader.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/MetricsReader.java similarity index 99% rename from java/osx/core/src/main/java/com/osx/core/flow/MetricsReader.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/MetricsReader.java index fbe9d97f80..d0c24b4e7e 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/MetricsReader.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/MetricsReader.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; import java.io.BufferedReader; import java.io.FileInputStream; diff --git a/java/osx/core/src/main/java/com/osx/core/flow/NamespaceFlowProperty.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/NamespaceFlowProperty.java similarity index 97% rename from java/osx/core/src/main/java/com/osx/core/flow/NamespaceFlowProperty.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/NamespaceFlowProperty.java index 7a8eb5409b..3d82ba5873 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/NamespaceFlowProperty.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/NamespaceFlowProperty.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; import java.util.List; diff --git a/java/osx/core/src/main/java/com/osx/core/flow/OccupySupport.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/OccupySupport.java similarity index 98% rename from java/osx/core/src/main/java/com/osx/core/flow/OccupySupport.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/OccupySupport.java index 991f700582..5c7e17fe23 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/OccupySupport.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/OccupySupport.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; public interface OccupySupport { diff --git a/java/osx/core/src/main/java/com/osx/core/flow/Property.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/Property.java similarity index 98% rename from java/osx/core/src/main/java/com/osx/core/flow/Property.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/Property.java index c3c2421f6a..19220c793b 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/Property.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/Property.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; public interface Property { diff --git a/java/osx/core/src/main/java/com/osx/core/flow/PropertyListener.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/PropertyListener.java similarity index 95% rename from java/osx/core/src/main/java/com/osx/core/flow/PropertyListener.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/PropertyListener.java index 57afcb9e48..2a114b5bd3 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/PropertyListener.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/PropertyListener.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; public interface PropertyListener { diff --git a/java/osx/core/src/main/java/com/osx/core/flow/RequestLimiter.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/RequestLimiter.java similarity index 96% rename from java/osx/core/src/main/java/com/osx/core/flow/RequestLimiter.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/RequestLimiter.java index e7785b01f6..aa3c1a0eae 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/RequestLimiter.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/RequestLimiter.java @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; -import com.osx.core.utils.AssertUtil; +import org.fedai.osx.core.utils.AssertUtil; import java.util.List; import java.util.concurrent.atomic.LongAdder; diff --git a/java/osx/core/src/main/java/com/osx/core/flow/Rule.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/Rule.java similarity index 95% rename from java/osx/core/src/main/java/com/osx/core/flow/Rule.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/Rule.java index 7b8b2c70f7..286b406502 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/Rule.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/Rule.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; public interface Rule { diff --git a/java/osx/core/src/main/java/com/osx/core/flow/RuleConstant.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/RuleConstant.java similarity index 98% rename from java/osx/core/src/main/java/com/osx/core/flow/RuleConstant.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/RuleConstant.java index 712e6b4226..b1089043b1 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/RuleConstant.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/RuleConstant.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; public final class RuleConstant { diff --git a/java/osx/core/src/main/java/com/osx/core/flow/TimeUtil.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/TimeUtil.java similarity index 97% rename from java/osx/core/src/main/java/com/osx/core/flow/TimeUtil.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/TimeUtil.java index ead5e4b6c5..d1d22ceee6 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/TimeUtil.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/TimeUtil.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; import java.util.concurrent.TimeUnit; diff --git a/java/osx/core/src/main/java/com/osx/core/flow/TokenService.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/TokenService.java similarity index 91% rename from java/osx/core/src/main/java/com/osx/core/flow/TokenService.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/TokenService.java index 72d0da0b6d..8983fac2e0 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/TokenService.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/TokenService.java @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; -import com.osx.core.token.TokenResult; +import org.fedai.osx.core.token.TokenResult; public interface TokenService { diff --git a/java/osx/core/src/main/java/com/osx/core/flow/UnaryLeapArray.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/UnaryLeapArray.java similarity index 97% rename from java/osx/core/src/main/java/com/osx/core/flow/UnaryLeapArray.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/UnaryLeapArray.java index 7832e74c89..a10f3a4d4e 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/UnaryLeapArray.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/UnaryLeapArray.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; import java.util.concurrent.atomic.LongAdder; diff --git a/java/osx/core/src/main/java/com/osx/core/flow/WindowWrap.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/WindowWrap.java similarity index 98% rename from java/osx/core/src/main/java/com/osx/core/flow/WindowWrap.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/WindowWrap.java index e713d8e104..a55b4899e5 100644 --- a/java/osx/core/src/main/java/com/osx/core/flow/WindowWrap.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/flow/WindowWrap.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.core.flow; +package org.fedai.osx.core.flow; public class WindowWrap { diff --git a/java/osx/core/src/main/java/com/osx/core/frame/CountDownLatch.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/frame/CountDownLatch.java similarity index 98% rename from java/osx/core/src/main/java/com/osx/core/frame/CountDownLatch.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/frame/CountDownLatch.java index c35d2ecf96..6871e65939 100644 --- a/java/osx/core/src/main/java/com/osx/core/frame/CountDownLatch.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/frame/CountDownLatch.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.frame; +package org.fedai.osx.core.frame; import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.AbstractQueuedSynchronizer; diff --git a/java/osx/core/src/main/java/com/osx/core/frame/GrpcConnectionFactory.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/frame/GrpcConnectionFactory.java similarity index 67% rename from java/osx/core/src/main/java/com/osx/core/frame/GrpcConnectionFactory.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/frame/GrpcConnectionFactory.java index c1b12a9a39..9bea044ef1 100644 --- a/java/osx/core/src/main/java/com/osx/core/frame/GrpcConnectionFactory.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/frame/GrpcConnectionFactory.java @@ -14,18 +14,19 @@ * limitations under the License. */ -package com.osx.core.frame; +package org.fedai.osx.core.frame; -import com.osx.core.config.GrpcChannelInfo; -import com.osx.core.exceptions.NoRouterInfoException; -import com.osx.core.exceptions.SysException; -import com.osx.core.router.RouterInfo; import io.grpc.ManagedChannel; import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts; import io.grpc.netty.shaded.io.grpc.netty.NegotiationType; import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder; import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder; import org.apache.commons.lang3.StringUtils; +import org.fedai.osx.api.router.RouterInfo; +import org.fedai.osx.core.config.GrpcChannelInfo; +import org.fedai.osx.core.config.MetaInfo; +import org.fedai.osx.core.exceptions.NoRouterInfoException; +import org.fedai.osx.core.exceptions.SysException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -33,8 +34,6 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; -import static com.osx.core.config.MetaInfo.*; - public class GrpcConnectionFactory { private static final Logger logger = LoggerFactory.getLogger(GrpcConnectionFactory.class); @@ -51,10 +50,14 @@ public static synchronized ManagedChannel createManagedChannel(RouterInfo router } if(usePooled) { if (managedChannelPool.get(routerInfo.toKey()) != null) { + ManagedChannel targetChannel = managedChannelPool.get(routerInfo.toKey()); + // logger.info("channel is shutdown : {} isTerminated {}",targetChannel.isShutdown() ,targetChannel.isTerminated() ,targetChannel.getState(true)); return managedChannelPool.get(routerInfo.toKey()); } else { ManagedChannel managedChannel = createManagedChannel(routerInfo, buildDefaultGrpcChannelInfo()); - managedChannelPool.put(routerInfo.toKey(), managedChannel); + if(managedChannel!=null) { + managedChannelPool.put(routerInfo.toKey(), managedChannel); + } return managedChannel; } }else{ @@ -66,14 +69,14 @@ public static synchronized ManagedChannel createManagedChannel(RouterInfo router private static GrpcChannelInfo buildDefaultGrpcChannelInfo(){ GrpcChannelInfo grpcChannelInfo = new GrpcChannelInfo(); - grpcChannelInfo.setKeepAliveTime(PROPERTY_GRPC_CLIENT_KEEPALIVE_TIME_SEC); - grpcChannelInfo.setKeepAliveTimeout(PROPERTY_GRPC_CLIENT_KEEPALIVE_TIMEOUT_SEC); - grpcChannelInfo.setKeepAliveWithoutCalls(PROPERTY_GRPC_CLIENT_KEEPALIVE_WITHOUT_CALLS_ENABLED); - grpcChannelInfo.setFlowControlWindow(PROPERTY_GRPC_CLIENT_FLOW_CONTROL_WINDOW); - grpcChannelInfo.setMaxInboundMessageSize(PROPERTY_GRPC_CLIENT_MAX_INBOUND_MESSAGE_SIZE); - grpcChannelInfo.setRetryBufferSize(PROPERTY_GRPC_CLIENT_RETRY_BUFFER_SIZE); - grpcChannelInfo.setIdelTimeOut(PROPERTY_GRPC_CLIENT_MAX_CONNECTION_IDLE_SEC); - grpcChannelInfo.setPerRpcBufferLimit(PROPERTY_GRPC_CLIENT_PER_RPC_BUFFER_LIMIT); + grpcChannelInfo.setKeepAliveTime(MetaInfo.PROPERTY_GRPC_CLIENT_KEEPALIVE_TIME_SEC); + grpcChannelInfo.setKeepAliveTimeout(MetaInfo.PROPERTY_GRPC_CLIENT_KEEPALIVE_TIMEOUT_SEC); + grpcChannelInfo.setKeepAliveWithoutCalls(MetaInfo.PROPERTY_GRPC_CLIENT_KEEPALIVE_WITHOUT_CALLS_ENABLED); + grpcChannelInfo.setFlowControlWindow(MetaInfo.PROPERTY_GRPC_CLIENT_FLOW_CONTROL_WINDOW); + grpcChannelInfo.setMaxInboundMessageSize(MetaInfo.PROPERTY_GRPC_CLIENT_MAX_INBOUND_MESSAGE_SIZE); + grpcChannelInfo.setRetryBufferSize(MetaInfo.PROPERTY_GRPC_CLIENT_RETRY_BUFFER_SIZE); + grpcChannelInfo.setIdelTimeOut(MetaInfo.PROPERTY_GRPC_CLIENT_MAX_CONNECTION_IDLE_SEC); + grpcChannelInfo.setPerRpcBufferLimit(MetaInfo.PROPERTY_GRPC_CLIENT_PER_RPC_BUFFER_LIMIT); return grpcChannelInfo; } @@ -96,25 +99,20 @@ public static synchronized ManagedChannel createManagedChannel(RouterInfo router .enableRetry() .retryBufferSize(channelInfo.getRetryBufferSize()) .maxRetryAttempts(channelInfo.getMaxRetryAttemps()); - - if (routerInfo != null && NegotiationType.TLS.name().equals(routerInfo.getNegotiationType()) - && StringUtils.isNotBlank(routerInfo.getCertChainFile()) - && StringUtils.isNotBlank(routerInfo.getPrivateKeyFile()) - && StringUtils.isNotBlank(routerInfo.getTrustCertCollectionFile())) { + if (routerInfo.isUseSSL() && NegotiationType.TLS.name().equals(routerInfo.getNegotiationType()) && StringUtils.isNotBlank(routerInfo.getCertChainFile()) && StringUtils.isNotBlank(routerInfo.getPrivateKeyFile()) && StringUtils.isNotBlank(routerInfo.getCaFile())) { SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient() .keyManager(new File(routerInfo.getCertChainFile()), new File(routerInfo.getPrivateKeyFile())) - .trustManager(new File(routerInfo.getTrustCertCollectionFile())) + .trustManager(new File(routerInfo.getCaFile())) .sessionTimeout(3600 << 4) .sessionCacheSize(65536); - channelBuilder.sslContext(sslContextBuilder.build()).useTransportSecurity(); - + channelBuilder.sslContext(sslContextBuilder.build()).useTransportSecurity().overrideAuthority(routerInfo.getHost()); } else { channelBuilder.usePlaintext(); } return channelBuilder.build(); } catch (Exception e) { - logger.error("create channel error : ", e); + logger.error("create channel to {} error : ",routerInfo, e); //e.printStackTrace(); } return null; diff --git a/java/osx/core/src/main/java/com/osx/core/frame/Lifecycle.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/frame/Lifecycle.java similarity index 95% rename from java/osx/core/src/main/java/com/osx/core/frame/Lifecycle.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/frame/Lifecycle.java index addef63075..facff5a9c4 100644 --- a/java/osx/core/src/main/java/com/osx/core/frame/Lifecycle.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/frame/Lifecycle.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.frame; +package org.fedai.osx.core.frame; public interface Lifecycle { public void init(); diff --git a/java/osx/core/src/main/java/com/osx/core/frame/ServiceDataWrapper.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/frame/ServiceDataWrapper.java similarity index 97% rename from java/osx/core/src/main/java/com/osx/core/frame/ServiceDataWrapper.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/frame/ServiceDataWrapper.java index 817f191d5a..96ca0bec36 100644 --- a/java/osx/core/src/main/java/com/osx/core/frame/ServiceDataWrapper.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/frame/ServiceDataWrapper.java @@ -14,10 +14,10 @@ * limitations under the License. */ -package com.osx.core.frame; +package org.fedai.osx.core.frame; -import com.osx.core.utils.JsonUtil; +import org.fedai.osx.core.utils.JsonUtil; public class ServiceDataWrapper { diff --git a/java/osx/core/src/main/java/com/osx/core/frame/ServiceThread.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/frame/ServiceThread.java similarity index 99% rename from java/osx/core/src/main/java/com/osx/core/frame/ServiceThread.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/frame/ServiceThread.java index aeb03319d0..e4c716c4fe 100644 --- a/java/osx/core/src/main/java/com/osx/core/frame/ServiceThread.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/frame/ServiceThread.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.frame; +package org.fedai.osx.core.frame; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/java/osx/core/src/main/java/com/osx/core/jvm/JVMGCUtils.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/jvm/JVMGCUtils.java similarity index 99% rename from java/osx/core/src/main/java/com/osx/core/jvm/JVMGCUtils.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/jvm/JVMGCUtils.java index 44ed817fd6..1bb5dc1238 100644 --- a/java/osx/core/src/main/java/com/osx/core/jvm/JVMGCUtils.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/jvm/JVMGCUtils.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.core.jvm; +package org.fedai.osx.core.jvm; import java.lang.management.GarbageCollectorMXBean; import java.lang.management.ManagementFactory; diff --git a/java/osx/core/src/main/java/com/osx/core/jvm/JVMMemoryUtils.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/jvm/JVMMemoryUtils.java similarity index 99% rename from java/osx/core/src/main/java/com/osx/core/jvm/JVMMemoryUtils.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/jvm/JVMMemoryUtils.java index b8b93d58da..23d400556f 100644 --- a/java/osx/core/src/main/java/com/osx/core/jvm/JVMMemoryUtils.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/jvm/JVMMemoryUtils.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.core.jvm; +package org.fedai.osx.core.jvm; import java.lang.management.ManagementFactory; import java.lang.management.MemoryMXBean; diff --git a/java/osx/core/src/main/java/com/osx/core/jvm/JVMThreadUtils.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/jvm/JVMThreadUtils.java similarity index 77% rename from java/osx/core/src/main/java/com/osx/core/jvm/JVMThreadUtils.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/jvm/JVMThreadUtils.java index ad8762dc6a..8989a265e0 100644 --- a/java/osx/core/src/main/java/com/osx/core/jvm/JVMThreadUtils.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/jvm/JVMThreadUtils.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.core.jvm; +package org.fedai.osx.core.jvm; import java.lang.management.ManagementFactory; import java.lang.management.ThreadMXBean; @@ -95,19 +95,6 @@ static public int getDeadLockedThreadCount() { } } - public static void main(String[] args) { - for (; ; ) { -// System.out.println("======================================================================="); -// System.out.println("getDaemonThreadCount: " + JVMThreadUtils.getDaemonThreadCount()); -// System.out.println("getNonHeapMemoryUsage: " + JVMThreadUtils.getThreadCount()); -// System.out.println("getPeakThreadCountAndReset: " + JVMThreadUtils.getAndResetPeakThreadCount()); -// System.out.println("getDeadLockedThreadCount: " + JVMThreadUtils.getDeadLockedThreadCount()); - try { - Thread.sleep(5000); - } catch (InterruptedException e) { - e.printStackTrace(); - } - } - } + } diff --git a/java/osx/core/src/main/java/com/osx/core/jvm/JvmInfo.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/jvm/JvmInfo.java similarity index 98% rename from java/osx/core/src/main/java/com/osx/core/jvm/JvmInfo.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/jvm/JvmInfo.java index 844e22b0d5..4ebc22083b 100644 --- a/java/osx/core/src/main/java/com/osx/core/jvm/JvmInfo.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/jvm/JvmInfo.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.core.jvm; +package org.fedai.osx.core.jvm; public class JvmInfo { diff --git a/java/osx/core/src/main/java/com/osx/core/jvm/JvmInfoCounter.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/jvm/JvmInfoCounter.java similarity index 86% rename from java/osx/core/src/main/java/com/osx/core/jvm/JvmInfoCounter.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/jvm/JvmInfoCounter.java index ebe7cef686..93db8c1960 100644 --- a/java/osx/core/src/main/java/com/osx/core/jvm/JvmInfoCounter.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/jvm/JvmInfoCounter.java @@ -14,11 +14,11 @@ * limitations under the License. */ -package com.osx.core.jvm; +package org.fedai.osx.core.jvm; import com.google.common.collect.Lists; -import com.osx.core.flow.LeapArray; -import com.osx.core.flow.TimeUtil; +import org.fedai.osx.core.flow.LeapArray; +import org.fedai.osx.core.flow.TimeUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -67,16 +67,6 @@ public void run() { } } - public static void main(String[] args) { - JvmInfoCounter.start(); - while (true) { - System.err.println(JvmInfoCounter.getMemInfos()); - try { - Thread.sleep(1000); - } catch (InterruptedException e) { - e.printStackTrace(); - } - } - } + } diff --git a/java/osx/core/src/main/java/com/osx/core/jvm/JvmInfoLeapArray.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/jvm/JvmInfoLeapArray.java similarity index 90% rename from java/osx/core/src/main/java/com/osx/core/jvm/JvmInfoLeapArray.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/jvm/JvmInfoLeapArray.java index f21c8da024..e2c08f365b 100644 --- a/java/osx/core/src/main/java/com/osx/core/jvm/JvmInfoLeapArray.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/jvm/JvmInfoLeapArray.java @@ -14,10 +14,10 @@ * limitations under the License. */ -package com.osx.core.jvm; +package org.fedai.osx.core.jvm; -import com.osx.core.flow.LeapArray; -import com.osx.core.flow.WindowWrap; +import org.fedai.osx.core.flow.LeapArray; +import org.fedai.osx.core.flow.WindowWrap; public class JvmInfoLeapArray extends LeapArray { diff --git a/java/osx/osx-core/src/main/java/org/fedai/osx/core/provider/TechProvider.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/provider/TechProvider.java new file mode 100644 index 0000000000..fc45e9dcc9 --- /dev/null +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/provider/TechProvider.java @@ -0,0 +1,52 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.fedai.osx.core.provider; + +import io.grpc.stub.StreamObserver; +import org.ppc.ptp.Osx; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +public interface TechProvider { + //用于处理http1.X请求 + void processHttpInvoke(HttpServletRequest httpServletRequest,HttpServletResponse httpServletResponse); + //用于处理grpc非流式请求 + void processGrpcInvoke(Osx.Inbound request, + io.grpc.stub.StreamObserver responseObserver); + +// rpc peek (PeekInbound) returns (TransportOutbound); +// rpc pop (PopInbound) returns (TransportOutbound); +// rpc push (PushInbound) returns (TransportOutbound); +// rpc release (ReleaseInbound) returns (TransportOutbound); + + //用于处理grpc流式请求 + public StreamObserver processGrpcTransport(Osx.Inbound inbound, StreamObserver responseObserver); + +// + void processGrpcPeek(Osx.PeekInbound inbound, io.grpc.stub.StreamObserver responseObserver); + + void processGrpcPush(Osx.PushInbound inbound, io.grpc.stub.StreamObserver responseObserver); + + void processGrpcPop(Osx.PopInbound inbound, io.grpc.stub.StreamObserver responseObserver); + + void processGrpcRelease(Osx.ReleaseInbound inbound, io.grpc.stub.StreamObserver responseObserver); + + + + + +} diff --git a/java/osx/osx-core/src/main/java/org/fedai/osx/core/ptp/SourceMethod.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/ptp/SourceMethod.java new file mode 100644 index 0000000000..1c7c4c79f5 --- /dev/null +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/ptp/SourceMethod.java @@ -0,0 +1,6 @@ +package org.fedai.osx.core.ptp; + +public enum SourceMethod { + UNARY_CALL, OLDUNARY_CALL,PUSH + +} diff --git a/java/osx/core/src/main/java/com/osx/core/ptp/TargetMethod.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/ptp/TargetMethod.java similarity index 92% rename from java/osx/core/src/main/java/com/osx/core/ptp/TargetMethod.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/ptp/TargetMethod.java index 56680b18af..da8889bad9 100644 --- a/java/osx/core/src/main/java/com/osx/core/ptp/TargetMethod.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/ptp/TargetMethod.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.ptp; +package org.fedai.osx.core.ptp; public enum TargetMethod { @@ -25,7 +25,8 @@ public enum TargetMethod { CANCEL_TOPIC, PUSH, APPLY_TOKEN, - APPLY_TOPIC + APPLY_TOPIC, + TEST_STREAM diff --git a/java/osx/core/src/main/java/com/osx/core/queue/ClusterTransferQueueInfo.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/queue/ClusterTransferQueueInfo.java similarity index 95% rename from java/osx/core/src/main/java/com/osx/core/queue/ClusterTransferQueueInfo.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/queue/ClusterTransferQueueInfo.java index 9d36b4a6c8..9257caeff5 100644 --- a/java/osx/core/src/main/java/com/osx/core/queue/ClusterTransferQueueInfo.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/queue/ClusterTransferQueueInfo.java @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.queue; +package org.fedai.osx.core.queue; -import com.osx.core.utils.JsonUtil; +import org.fedai.osx.core.utils.JsonUtil; public class ClusterTransferQueueInfo { diff --git a/java/osx/core/src/main/java/com/osx/core/queue/TranferQueueInfo.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/queue/TranferQueueInfo.java similarity index 96% rename from java/osx/core/src/main/java/com/osx/core/queue/TranferQueueInfo.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/queue/TranferQueueInfo.java index 6dd5420676..c067f50fab 100644 --- a/java/osx/core/src/main/java/com/osx/core/queue/TranferQueueInfo.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/queue/TranferQueueInfo.java @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.queue; +package org.fedai.osx.core.queue; -import com.osx.core.constant.TransferStatus; +import org.fedai.osx.core.constant.TransferStatus; public class TranferQueueInfo { diff --git a/java/osx/core/src/main/java/com/osx/core/service/AbstractServiceAdaptor.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/service/AbstractServiceAdaptor.java similarity index 66% rename from java/osx/core/src/main/java/com/osx/core/service/AbstractServiceAdaptor.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/service/AbstractServiceAdaptor.java index 4a81e442e7..1634855cdf 100644 --- a/java/osx/core/src/main/java/com/osx/core/service/AbstractServiceAdaptor.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/service/AbstractServiceAdaptor.java @@ -14,17 +14,19 @@ * limitations under the License. */ -package com.osx.core.service; +package org.fedai.osx.core.service; import com.google.common.collect.Lists; import com.google.common.collect.Maps; -import com.osx.core.constant.Dict; -import com.osx.core.constant.StatusCode; -import com.osx.core.context.Context; -import com.osx.core.exceptions.ErrorMessageUtil; -import com.osx.core.exceptions.ExceptionInfo; -import com.osx.core.utils.JsonUtil; import io.grpc.stub.AbstractStub; +import org.apache.commons.lang3.exception.ExceptionUtils; +import org.fedai.osx.api.context.Context; +import org.fedai.osx.core.constant.Dict; +import org.fedai.osx.core.constant.StatusCode; +import org.fedai.osx.core.context.FateContext; +import org.fedai.osx.core.exceptions.ErrorMessageUtil; +import org.fedai.osx.core.exceptions.ExceptionInfo; +import org.fedai.osx.core.utils.FlowLogUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -38,17 +40,17 @@ * @Author **/ -public abstract class AbstractServiceAdaptor implements ServiceAdaptor { +public abstract class AbstractServiceAdaptor implements ServiceAdaptor { static public AtomicInteger requestInHandle = new AtomicInteger(0); public static boolean isOpen = true; - protected Logger flowLogger = LoggerFactory.getLogger("flow"); +// protected Logger flowLogger = LoggerFactory.getLogger("flow"); protected String serviceName; Logger logger = LoggerFactory.getLogger(this.getClass().getName()); - ServiceAdaptor serviceAdaptor; - InterceptorChain preChain = new DefaultInterceptorChain(); - InterceptorChain postChain = new DefaultInterceptorChain(); + ServiceAdaptor serviceAdaptor; + InterceptorChain preChain = new DefaultInterceptorChain<>(); + InterceptorChain postChain = new DefaultInterceptorChain<>(); private Map methodMap = Maps.newHashMap(); private AbstractStub serviceStub; @@ -69,7 +71,7 @@ public void setMethodMap(Map methodMap) { this.methodMap = methodMap; } - public AbstractServiceAdaptor addPreProcessor(Interceptor interceptor) { + public AbstractServiceAdaptor addPreProcessor(Interceptor interceptor) { preChain.addInterceptor(interceptor); return this; } @@ -78,7 +80,7 @@ public void addPostProcessor(Interceptor interceptor) { postChain.addInterceptor(interceptor); } - public ServiceAdaptor getServiceAdaptor() { + public ServiceAdaptor getServiceAdaptor() { return serviceAdaptor; } @@ -102,7 +104,7 @@ public void setServiceName(String serviceName) { this.serviceName = serviceName; } - protected abstract resp doService(Context context, InboundPackage data); + protected abstract resp doService(ctx context, InboundPackage data); /** * @param context @@ -111,7 +113,7 @@ public void setServiceName(String serviceName) { * @throws Exception */ @Override - public OutboundPackage service(Context context, InboundPackage data) throws RuntimeException { + public OutboundPackage service(ctx context, InboundPackage data) throws RuntimeException { OutboundPackage outboundPackage = new OutboundPackage(); // context.preProcess(); @@ -129,19 +131,14 @@ public OutboundPackage service(Context context, InboundPackage data) resp result = null; context.setServiceName(this.serviceName); try { - preChain.doPreProcess(context, data); + preChain.doProcess(context, data, outboundPackage); result = doService(context, data); - if (logger.isDebugEnabled()) { - logger.debug("do service, router info: {}, service name: {}, result: {}", JsonUtil.object2Json(context.getRouterInfo()), serviceName, result); - } } catch (Throwable e) { exceptions.add(e); e.printStackTrace(); - logger.error("do service fail, cause by: {}", e.getMessage()); + logger.error("do service fail, {} ", ExceptionUtils.getStackTrace(e)); } outboundPackage.setData(result); - //postChain.doPostProcess(context, data, outboundPackage); - } catch (Throwable e) { exceptions.add(e); logger.error("service error", e); @@ -152,7 +149,16 @@ public OutboundPackage service(Context context, InboundPackage data) outboundPackage = this.serviceFail(context, data, exceptions); } } finally { - printFlowLog(context); + if(context instanceof FateContext ) + { + FateContext fateContext =(FateContext )context; + if(fateContext.needPrintFlowLog()){ + FlowLogUtil.printFlowLog(context); + } + }else { + + FlowLogUtil.printFlowLog(context); + } } // int returnCode = context.getReturnCode(); @@ -162,24 +168,20 @@ public OutboundPackage service(Context context, InboundPackage data) // context.postProcess(data, outboundPackage); } + try { + postChain.doProcess(context, data, outboundPackage); + } catch (Exception e) { + logger.error("service PostDoProcess error", e); + } return outboundPackage; } - protected void printFlowLog(Context context) { - - context.printFlowLog(); - -// flowLogger.info("{}|{}|{}|{}|" + -// "{}|{}|{}|{}|" + -// "{}|{}", -// context.getSourceIp(), context.getSrcPartyId(), -// context.getDesPartyId(), context.getReturnCode(), context.getCostTime(), -// context.getDownstreamCost(), serviceName, context.getRouterInfo() != null ? context.getRouterInfo() : "", -// MetaInfo.PROPERTY_PRINT_INPUT_DATA?context.getData(Dict.INPUT_DATA):"", -// MetaInfo.PROPERTY_PRINT_OUTPUT_DATA?context.getData(Dict.OUTPUT_DATA):""); - } +// protected void printFlowLog(ctx context) { +//// context.printFlowLog(); +// FlowLogUtil.printFlowLog(context); +// } - protected OutboundPackage serviceFailInner(Context context, InboundPackage data, Throwable e) { + protected OutboundPackage serviceFailInner(ctx context, InboundPackage data, Throwable e) { OutboundPackage outboundPackage = new OutboundPackage(); ExceptionInfo exceptionInfo = ErrorMessageUtil.handleExceptionExceptionInfo(context, e); context.setReturnCode(exceptionInfo.getCode()); @@ -191,13 +193,13 @@ protected OutboundPackage serviceFailInner(Context context, InboundPackage } @Override - public OutboundPackage serviceFail(Context context, InboundPackage data, List errors) throws RuntimeException { + public OutboundPackage serviceFail(ctx context, InboundPackage data, List errors) throws RuntimeException { Throwable e = errors.get(0); logger.error("service fail ", e); return serviceFailInner(context, data, e); } - protected abstract resp transformExceptionInfo(Context context, ExceptionInfo exceptionInfo); + protected abstract resp transformExceptionInfo(ctx context, ExceptionInfo exceptionInfo); } \ No newline at end of file diff --git a/java/osx/core/src/main/java/com/osx/core/service/DefaultInterceptorChain.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/service/DefaultInterceptorChain.java similarity index 63% rename from java/osx/core/src/main/java/com/osx/core/service/DefaultInterceptorChain.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/service/DefaultInterceptorChain.java index adfd18a0fd..71a4b527f8 100644 --- a/java/osx/core/src/main/java/com/osx/core/service/DefaultInterceptorChain.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/service/DefaultInterceptorChain.java @@ -14,10 +14,10 @@ * limitations under the License. */ -package com.osx.core.service; +package org.fedai.osx.core.service; import com.google.common.collect.Lists; -import com.osx.core.context.Context; +import org.fedai.osx.api.context.Context; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -27,14 +27,14 @@ * @Description TODO * @Author **/ -public class DefaultInterceptorChain implements InterceptorChain { +public class DefaultInterceptorChain implements InterceptorChain { Logger logger = LoggerFactory.getLogger(DefaultInterceptorChain.class); - List> chain = Lists.newArrayList(); + List> chain = Lists.newArrayList(); @Override - public void addInterceptor(Interceptor interceptor) { + public void addInterceptor(Interceptor interceptor) { chain.add(interceptor); } @@ -46,12 +46,11 @@ public void addInterceptor(Interceptor interceptor) { * @throws Exception */ @Override - public void doPreProcess(Context context, InboundPackage inboundPackage) throws Exception { - for (Interceptor interceptor : chain) { - logger.info("====== {}",interceptor); - interceptor.doPreProcess(context, inboundPackage); - + public void doProcess(ctx context, InboundPackage inboundPackage,OutboundPackage outboundPackage) throws Exception { + for (Interceptor interceptor : chain) { + if (interceptor != null) { + interceptor.doProcess(context, inboundPackage,outboundPackage); + } } } - } diff --git a/java/osx/core/src/main/java/com/osx/core/service/InboundPackage.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/service/InboundPackage.java similarity index 85% rename from java/osx/core/src/main/java/com/osx/core/service/InboundPackage.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/service/InboundPackage.java index 652b47bd13..7712dae934 100644 --- a/java/osx/core/src/main/java/com/osx/core/service/InboundPackage.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/service/InboundPackage.java @@ -14,9 +14,9 @@ * limitations under the License. */ -package com.osx.core.service; +package org.fedai.osx.core.service; + -import com.osx.core.router.RouterInfo; import io.grpc.ManagedChannel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -28,7 +28,7 @@ public class InboundPackage { static Logger logger = LoggerFactory.getLogger(InboundPackage.class); ManagedChannel managedChannel; - RouterInfo routerInfo; + String source; Map head; @@ -42,14 +42,6 @@ public void setManagedChannel(ManagedChannel managedChannel) { this.managedChannel = managedChannel; } - public RouterInfo getRouterInfo() { - return routerInfo; - } - - public void setRouterInfo(RouterInfo routerInfo) { - this.routerInfo = routerInfo; - } - public String getSource() { return source; } diff --git a/java/osx/core/src/main/java/com/osx/core/service/InterceptorChain.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/service/Interceptor.java similarity index 68% rename from java/osx/core/src/main/java/com/osx/core/service/InterceptorChain.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/service/Interceptor.java index 627058b569..d3754a812c 100644 --- a/java/osx/core/src/main/java/com/osx/core/service/InterceptorChain.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/service/Interceptor.java @@ -14,11 +14,15 @@ * limitations under the License. */ -package com.osx.core.service; +package org.fedai.osx.core.service; -public interface InterceptorChain extends Interceptor { +import org.fedai.osx.api.context.Context; - public void addInterceptor(Interceptor interceptor); +public interface Interceptor { + + default public void doProcess(ctx context, InboundPackage inboundPackage, OutboundPackage outboundPackage) throws Exception { + + } } diff --git a/java/osx/core/src/main/java/com/osx/core/service/Interceptor.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/service/InterceptorChain.java similarity index 71% rename from java/osx/core/src/main/java/com/osx/core/service/Interceptor.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/service/InterceptorChain.java index 611c811ea4..b87c8ef66d 100644 --- a/java/osx/core/src/main/java/com/osx/core/service/Interceptor.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/service/InterceptorChain.java @@ -14,14 +14,13 @@ * limitations under the License. */ -package com.osx.core.service; +package org.fedai.osx.core.service; -import com.osx.core.context.Context; +import org.fedai.osx.api.context.Context; -public interface Interceptor { +public interface InterceptorChain extends Interceptor { - default public void doPreProcess(Context context, InboundPackage inboundPackage) throws Exception { + public void addInterceptor(Interceptor interceptor); - } } diff --git a/java/osx/core/src/main/java/com/osx/core/service/OutboundPackage.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/service/OutboundPackage.java similarity index 96% rename from java/osx/core/src/main/java/com/osx/core/service/OutboundPackage.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/service/OutboundPackage.java index 2487ef2b77..517beb5823 100644 --- a/java/osx/core/src/main/java/com/osx/core/service/OutboundPackage.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/service/OutboundPackage.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.core.service; +package org.fedai.osx.core.service; /** * @Description TODO diff --git a/java/osx/core/src/main/java/com/osx/core/service/ServiceAdaptor.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/service/ServiceAdaptor.java similarity index 66% rename from java/osx/core/src/main/java/com/osx/core/service/ServiceAdaptor.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/service/ServiceAdaptor.java index d526f4c468..dab45d8ecc 100644 --- a/java/osx/core/src/main/java/com/osx/core/service/ServiceAdaptor.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/service/ServiceAdaptor.java @@ -14,17 +14,15 @@ * limitations under the License. */ -package com.osx.core.service; - - -import com.osx.core.context.Context; +package org.fedai.osx.core.service; +import org.fedai.osx.api.context.Context; import java.util.List; -public interface ServiceAdaptor { +public interface ServiceAdaptor { - public OutboundPackage service(Context context, InboundPackage inboundPackage); + public OutboundPackage service(ctx context, InboundPackage inboundPackage); - public OutboundPackage serviceFail(Context context, InboundPackage data, List e); + public OutboundPackage serviceFail(ctx context, InboundPackage data, List e); } diff --git a/java/osx/core/src/main/java/com/osx/core/timer/HashedWheelTimer.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/timer/HashedWheelTimer.java similarity index 97% rename from java/osx/core/src/main/java/com/osx/core/timer/HashedWheelTimer.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/timer/HashedWheelTimer.java index 023b274b63..0b7c16fe4f 100644 --- a/java/osx/core/src/main/java/com/osx/core/timer/HashedWheelTimer.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/timer/HashedWheelTimer.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.core.timer; +package org.fedai.osx.core.timer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -26,7 +26,7 @@ import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.concurrent.atomic.AtomicLong; -public class HashedWheelTimer implements Timer { +public class HashedWheelTimer implements org.fedai.osx.core.timer.Timer { public static final String NAME = "hased"; static final String OS_NAME = "os.name"; @@ -206,7 +206,7 @@ public Set stop() { throw new IllegalStateException( HashedWheelTimer.class.getSimpleName() + ".stop() cannot be called from " + - TimerTask.class.getSimpleName()); + org.fedai.osx.core.timer.TimerTask.class.getSimpleName()); } if (!WORKER_STATE_UPDATER.compareAndSet(this, WORKER_STATE_STARTED, WORKER_STATE_SHUTDOWN)) { @@ -244,7 +244,7 @@ public boolean isStop() { } @Override - public Timeout newTimeout(TimerTask task, long delay, TimeUnit unit) { + public Timeout newTimeout(org.fedai.osx.core.timer.TimerTask task, long delay, TimeUnit unit) { if (task == null) { throw new NullPointerException("task"); } @@ -293,7 +293,7 @@ private static final class HashedWheelTimeout implements Timeout { AtomicIntegerFieldUpdater.newUpdater(HashedWheelTimeout.class, "state"); private final HashedWheelTimer timer; - private final TimerTask task; + private final org.fedai.osx.core.timer.TimerTask task; private final long deadline; /** * RemainingRounds will be calculated and set by Worker.transferTimeoutsToBuckets() before the @@ -313,7 +313,7 @@ private static final class HashedWheelTimeout implements Timeout { @SuppressWarnings({"unused", "FieldMayBeFinal", "RedundantFieldInitialization"}) private volatile int state = ST_INIT; - HashedWheelTimeout(HashedWheelTimer timer, TimerTask task, long deadline) { + HashedWheelTimeout(HashedWheelTimer timer, org.fedai.osx.core.timer.TimerTask task, long deadline) { this.timer = timer; this.task = task; this.deadline = deadline; @@ -325,7 +325,7 @@ public Timer timer() { } @Override - public TimerTask task() { + public org.fedai.osx.core.timer.TimerTask task() { return task; } diff --git a/java/osx/core/src/main/java/com/osx/core/timer/Timeout.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/timer/Timeout.java similarity index 95% rename from java/osx/core/src/main/java/com/osx/core/timer/Timeout.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/timer/Timeout.java index a2dc952f5a..c32c134c6d 100644 --- a/java/osx/core/src/main/java/com/osx/core/timer/Timeout.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/timer/Timeout.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.core.timer; +package org.fedai.osx.core.timer; public interface Timeout { diff --git a/java/osx/core/src/main/java/com/osx/core/timer/Timer.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/timer/Timer.java similarity index 96% rename from java/osx/core/src/main/java/com/osx/core/timer/Timer.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/timer/Timer.java index d4ccfbd1a1..266d70124f 100644 --- a/java/osx/core/src/main/java/com/osx/core/timer/Timer.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/timer/Timer.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.core.timer; +package org.fedai.osx.core.timer; import java.util.Set; import java.util.concurrent.TimeUnit; diff --git a/java/osx/core/src/main/java/com/osx/core/timer/TimerTask.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/timer/TimerTask.java similarity index 95% rename from java/osx/core/src/main/java/com/osx/core/timer/TimerTask.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/timer/TimerTask.java index ff964cc714..22f45e6be2 100644 --- a/java/osx/core/src/main/java/com/osx/core/timer/TimerTask.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/timer/TimerTask.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.core.timer; +package org.fedai.osx.core.timer; public interface TimerTask { diff --git a/java/osx/core/src/main/java/com/osx/core/token/TokenRequest.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/token/TokenRequest.java similarity index 95% rename from java/osx/core/src/main/java/com/osx/core/token/TokenRequest.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/token/TokenRequest.java index 46b5a467b2..1fd9866192 100644 --- a/java/osx/core/src/main/java/com/osx/core/token/TokenRequest.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/token/TokenRequest.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.token; +package org.fedai.osx.core.token; import lombok.Data; diff --git a/java/osx/core/src/main/java/com/osx/core/token/TokenResult.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/token/TokenResult.java similarity index 98% rename from java/osx/core/src/main/java/com/osx/core/token/TokenResult.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/token/TokenResult.java index 40eddc01e7..cb77c5b7a2 100644 --- a/java/osx/core/src/main/java/com/osx/core/token/TokenResult.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/token/TokenResult.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.token; +package org.fedai.osx.core.token; import java.util.Map; diff --git a/java/osx/core/src/main/java/com/osx/core/token/TokenResultStatus.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/token/TokenResultStatus.java similarity index 94% rename from java/osx/core/src/main/java/com/osx/core/token/TokenResultStatus.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/token/TokenResultStatus.java index 5c8dd49037..91dcdad8ad 100644 --- a/java/osx/core/src/main/java/com/osx/core/token/TokenResultStatus.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/token/TokenResultStatus.java @@ -1,4 +1,4 @@ -package com.osx.core.token; +package org.fedai.osx.core.token; public final class TokenResultStatus { diff --git a/java/osx/core/src/main/java/com/osx/core/utils/AssertUtil.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/AssertUtil.java similarity index 98% rename from java/osx/core/src/main/java/com/osx/core/utils/AssertUtil.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/AssertUtil.java index 37d90f65c3..a61aa15c92 100644 --- a/java/osx/core/src/main/java/com/osx/core/utils/AssertUtil.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/AssertUtil.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.utils; +package org.fedai.osx.core.utils; import org.apache.commons.lang3.StringUtils; diff --git a/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/ClassUtils.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/ClassUtils.java new file mode 100644 index 0000000000..9db653e8f5 --- /dev/null +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/ClassUtils.java @@ -0,0 +1,392 @@ + +package org.fedai.osx.core.utils; + + + +import org.apache.commons.lang3.StringUtils; + +import java.io.PrintWriter; +import java.io.StringWriter; +import java.lang.reflect.Array; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; + + +public class ClassUtils { + + public static final String CLASS_EXTENSION = ".class"; + + public static final String JAVA_EXTENSION = ".java"; + private static final int JIT_LIMIT = 5 * 1024; + + private ClassUtils() { + } + + public static Object newInstance(String name) { + try { + return forName(name).getDeclaredConstructor().newInstance(); + } catch (InstantiationException | IllegalAccessException | InvocationTargetException | NoSuchMethodException e) { + throw new IllegalStateException(e.getMessage(), e); + } + } + + public static Class forName(String[] packages, String className) { + try { + return classForName(className); + } catch (ClassNotFoundException e) { + if (packages != null && packages.length > 0) { + for (String pkg : packages) { + try { + return classForName(pkg + "." + className); + } catch (ClassNotFoundException ignore) { + } + } + } + throw new IllegalStateException(e.getMessage(), e); + } + } + + public static Class forName(String className) { + try { + return classForName(className); + } catch (ClassNotFoundException e) { + throw new IllegalStateException(e.getMessage(), e); + } + } + + public static Class classForName(String className) throws ClassNotFoundException { + switch (className) { + case "boolean": + return boolean.class; + case "byte": + return byte.class; + case "char": + return char.class; + case "short": + return short.class; + case "int": + return int.class; + case "long": + return long.class; + case "float": + return float.class; + case "double": + return double.class; + case "boolean[]": + return boolean[].class; + case "byte[]": + return byte[].class; + case "char[]": + return char[].class; + case "short[]": + return short[].class; + case "int[]": + return int[].class; + case "long[]": + return long[].class; + case "float[]": + return float[].class; + case "double[]": + return double[].class; + default: + } + try { + return arrayForName(className); + } catch (ClassNotFoundException e) { + // try to load from java.lang package + if (className.indexOf('.') == -1) { + try { + return arrayForName("java.lang." + className); + } catch (ClassNotFoundException ignore) { + // ignore, let the original exception be thrown + } + } + throw e; + } + } + + private static Class arrayForName(String className) throws ClassNotFoundException { + return Class.forName(className.endsWith("[]") + ? "[L" + className.substring(0, className.length() - 2) + ";" + : className, true, Thread.currentThread().getContextClassLoader()); + } + + public static Class getBoxedClass(Class type) { + if (type == boolean.class) { + return Boolean.class; + } else if (type == char.class) { + return Character.class; + } else if (type == byte.class) { + return Byte.class; + } else if (type == short.class) { + return Short.class; + } else if (type == int.class) { + return Integer.class; + } else if (type == long.class) { + return Long.class; + } else if (type == float.class) { + return Float.class; + } else if (type == double.class) { + return Double.class; + } else { + return type; + } + } + + public static Boolean boxed(boolean v) { + return Boolean.valueOf(v); + } + + public static Character boxed(char v) { + return Character.valueOf(v); + } + + public static Byte boxed(byte v) { + return Byte.valueOf(v); + } + + public static Short boxed(short v) { + return Short.valueOf(v); + } + + public static Integer boxed(int v) { + return Integer.valueOf(v); + } + + public static Long boxed(long v) { + return Long.valueOf(v); + } + + public static Float boxed(float v) { + return Float.valueOf(v); + } + + public static Double boxed(double v) { + return Double.valueOf(v); + } + + public static Object boxed(Object v) { + return v; + } + + public static boolean unboxed(Boolean v) { + return v == null ? false : v.booleanValue(); + } + + public static char unboxed(Character v) { + return v == null ? '\0' : v.charValue(); + } + + public static byte unboxed(Byte v) { + return v == null ? 0 : v.byteValue(); + } + + public static short unboxed(Short v) { + return v == null ? 0 : v.shortValue(); + } + + public static int unboxed(Integer v) { + return v == null ? 0 : v.intValue(); + } + + public static long unboxed(Long v) { + return v == null ? 0 : v.longValue(); + } + + public static float unboxed(Float v) { + return v == null ? 0 : v.floatValue(); + } + + public static double unboxed(Double v) { + return v == null ? 0 : v.doubleValue(); + } + + public static Object unboxed(Object v) { + return v; + } + + public static boolean isNotEmpty(Object object) { + return getSize(object) > 0; + } + + public static int getSize(Object object) { + if (object == null) { + return 0; + } + if (object instanceof Collection) { + return ((Collection) object).size(); + } else if (object instanceof Map) { + return ((Map) object).size(); + } else if (object.getClass().isArray()) { + return Array.getLength(object); + } else { + return -1; + } + } + + public static URI toURI(String name) { + try { + return new URI(name); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public static boolean isBeforeJava5(String javaVersion) { + return (StringUtils.isEmpty(javaVersion) || "1.0".equals(javaVersion) + || "1.1".equals(javaVersion) || "1.2".equals(javaVersion) + || "1.3".equals(javaVersion) || "1.4".equals(javaVersion)); + } + + public static boolean isBeforeJava6(String javaVersion) { + return isBeforeJava5(javaVersion) || "1.5".equals(javaVersion); + } + + public static String toString(Throwable e) { + StringWriter w = new StringWriter(); + PrintWriter p = new PrintWriter(w); + p.print(e.getClass().getName() + ": "); + if (e.getMessage() != null) { + p.print(e.getMessage() + "\n"); + } + p.println(); + try { + e.printStackTrace(p); + return w.toString(); + } finally { + p.close(); + } + } + + public static void checkBytecode(String name, byte[] bytecode) { + if (bytecode.length > JIT_LIMIT) { + System.err.println("The template bytecode too long, may be affect the JIT compiler. template class: " + name); + } + } + + public static String getSizeMethod(Class cls) { + try { + return cls.getMethod("size", new Class[0]).getName() + "()"; + } catch (NoSuchMethodException e) { + try { + return cls.getMethod("length", new Class[0]).getName() + "()"; + } catch (NoSuchMethodException e2) { + try { + return cls.getMethod("getSize", new Class[0]).getName() + "()"; + } catch (NoSuchMethodException e3) { + try { + return cls.getMethod("getLength", new Class[0]).getName() + "()"; + } catch (NoSuchMethodException e4) { + return null; + } + } + } + } + } + + public static String getMethodName(Method method, Class[] parameterClasses, String rightCode) { + StringBuilder buf = new StringBuilder(rightCode); + if (method.getParameterTypes().length > parameterClasses.length) { + Class[] types = method.getParameterTypes(); + for (int i = parameterClasses.length; i < types.length; i++) { + if (buf.length() > 0) { + buf.append(','); + } + Class type = types[i]; + String def; + if (type == boolean.class) { + def = "false"; + } else if (type == char.class) { + def = "\'\\0\'"; + } else if (type == byte.class + || type == short.class + || type == int.class + || type == long.class + || type == float.class + || type == double.class) { + def = "0"; + } else { + def = "null"; + } + buf.append(def); + } + } + return method.getName() + "(" + buf + ")"; + } + + public static Method searchMethod(Class currentClass, String name, Class[] parameterTypes) throws NoSuchMethodException { + if (currentClass == null) { + throw new NoSuchMethodException("class == null"); + } + try { + return currentClass.getMethod(name, parameterTypes); + } catch (NoSuchMethodException e) { + for (Method method : currentClass.getMethods()) { + if (method.getName().equals(name) + && parameterTypes.length == method.getParameterTypes().length + && Modifier.isPublic(method.getModifiers())) { + if (parameterTypes.length > 0) { + Class[] types = method.getParameterTypes(); + boolean match = true; + for (int i = 0; i < parameterTypes.length; i++) { + if (!types[i].isAssignableFrom(parameterTypes[i])) { + match = false; + break; + } + } + if (!match) { + continue; + } + } + return method; + } + } + throw e; + } + } + + public static String getInitCode(Class type) { + if (byte.class.equals(type) + || short.class.equals(type) + || int.class.equals(type) + || long.class.equals(type) + || float.class.equals(type) + || double.class.equals(type)) { + return "0"; + } else if (char.class.equals(type)) { + return "'\\0'"; + } else if (boolean.class.equals(type)) { + return "false"; + } else { + return "null"; + } + } + + public static Map toMap(Map.Entry[] entries) { + Map map = new HashMap(); + if (entries != null && entries.length > 0) { + for (Map.Entry entry : entries) { + map.put(entry.getKey(), entry.getValue()); + } + } + return map; + } + + /** + * get simple class name from qualified class name + */ + public static String getSimpleClassName(String qualifiedName) { + if (null == qualifiedName) { + return null; + } + int i = qualifiedName.lastIndexOf('.'); + return i < 0 ? qualifiedName : qualifiedName.substring(i + 1); + } + +} diff --git a/java/osx/core/src/main/java/com/osx/core/utils/EncryptUtils.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/EncryptUtils.java similarity index 94% rename from java/osx/core/src/main/java/com/osx/core/utils/EncryptUtils.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/EncryptUtils.java index 7c676ad2fc..c586af2333 100644 --- a/java/osx/core/src/main/java/com/osx/core/utils/EncryptUtils.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/EncryptUtils.java @@ -14,9 +14,9 @@ * limitations under the License. */ -package com.osx.core.utils; +package org.fedai.osx.core.utils; -import com.osx.core.constant.EncryptMethod; +import org.fedai.osx.core.constant.EncryptMethod; import javax.crypto.Mac; import javax.crypto.SecretKey; @@ -38,8 +38,8 @@ public static String encrypt(String originString, EncryptMethod encryptMethod) { result += Integer.toHexString((0x000000FF & s[i]) | 0xFFFFFF00).substring(6); } return result; - } catch (Exception e) { - e.printStackTrace(); + } catch (Exception igore) { + } return ""; diff --git a/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/FileUtils.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/FileUtils.java new file mode 100644 index 0000000000..a847171ee0 --- /dev/null +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/FileUtils.java @@ -0,0 +1,137 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.fedai.osx.core.utils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.*; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.channels.FileLock; + +public class FileUtils { + private static final Logger logger = LoggerFactory.getLogger(FileUtils.class); + + public static boolean writeFile(String context, File target) { + BufferedWriter out = null; + try { + if (!target.exists()) { + target.createNewFile(); + } + out = new BufferedWriter(new FileWriter(target)); + out.write(context); + } catch (IOException e) { + logger.error(e.getMessage()); + return false; + } finally { + try { + if (out != null) { + out.flush(); + out.close(); + } + } catch (IOException ex) { + logger.error("write file error", ex); + } + } + return true; + } + + /** + * Write string to file, + * synchronize operation, exclusive lock + */ + public static boolean writeStr2ReplaceFileSync(String str, String pathFile) throws Exception { + File file = new File(pathFile); + try { + if (!file.exists()) { + file.createNewFile(); + } + } catch (IOException e) { + logger.error("Failed to create the file. Check whether the path is valid and the read/write permission is correct"); + throw new IOException("Failed to create the file. Check whether the path is valid and the read/write permission is correct"); + } + FileOutputStream fileOutputStream = null; + FileChannel fileChannel = null; + FileLock fileLock; + try { + + /* + * write file + */ + fileOutputStream = new FileOutputStream(file); + fileChannel = fileOutputStream.getChannel(); + + try { + fileLock = fileChannel.tryLock();// exclusive lock + } catch (Exception e) { + throw new IOException("another thread is writing ,refresh and try again"); + } + if (fileLock != null) { + fileChannel.write(ByteBuffer.wrap(str.getBytes())); + if (fileLock.isValid()) { + fileLock.release(); // release-write-lock + } + if (file.length() != str.getBytes().length) { + throw new IOException("write successfully but the content was lost, reedit and try again"); + } + } + + } catch (IOException e) { + logger.error(e.getMessage()); + throw new IOException(e.getMessage()); + } finally { + close(fileChannel); + close(fileOutputStream); + } + return true; + } + + public static void close(Closeable closeable) { + if (closeable != null) { + try { + closeable.close(); + } catch (IOException e) { + e.printStackTrace(); + } + } + } + + public static boolean createNewFile(String filePath) { + return createNewFile(new File(filePath)); + } + + public static boolean createNewFile(File file) { + try { + if (!file.exists()) { + if (!file.getParentFile().exists()) { + if (!file.getParentFile().mkdirs()) { + return false; + } + } + if (!file.createNewFile()) { + return false; + } + } + } catch (IOException e) { + logger.error("create file failed , path = {}", file.getAbsoluteFile()); + return false; + } + return true; + } + +} diff --git a/java/osx/core/src/main/java/com/osx/core/utils/FlowLogPrinter.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/FlowLogPrinter.java similarity index 81% rename from java/osx/core/src/main/java/com/osx/core/utils/FlowLogPrinter.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/FlowLogPrinter.java index 897700de63..3bfe6f039f 100644 --- a/java/osx/core/src/main/java/com/osx/core/utils/FlowLogPrinter.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/FlowLogPrinter.java @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.utils; - -import com.osx.core.context.Context; - -public interface FlowLogPrinter { - - public void print(Context context); -} +//package org.fedai.osx.core.utils; +// +// +// +//public interface FlowLogPrinter { +// +// public void print(Context context); +//} diff --git a/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/FlowLogUtil.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/FlowLogUtil.java new file mode 100644 index 0000000000..20081e3434 --- /dev/null +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/FlowLogUtil.java @@ -0,0 +1,18 @@ +package org.fedai.osx.core.utils; + +import org.fedai.osx.api.context.Context; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class FlowLogUtil { + static Logger logger = LoggerFactory.getLogger("flow"); + + public static void printFlowLog(Context context) { + try { + logger.info(context.toString()); + }catch (Throwable ignore){ + } + + } + +} diff --git a/java/osx/core/src/main/java/com/osx/core/utils/GetSystemInfo.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/GetSystemInfo.java similarity index 99% rename from java/osx/core/src/main/java/com/osx/core/utils/GetSystemInfo.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/GetSystemInfo.java index 4192462b6f..a4ef5c6029 100644 --- a/java/osx/core/src/main/java/com/osx/core/utils/GetSystemInfo.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/GetSystemInfo.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.core.utils; +package org.fedai.osx.core.utils; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; diff --git a/java/osx/core/src/main/java/com/osx/core/utils/JsonUtil.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/JsonUtil.java similarity index 94% rename from java/osx/core/src/main/java/com/osx/core/utils/JsonUtil.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/JsonUtil.java index ef2d441cdc..990eeadeb0 100644 --- a/java/osx/core/src/main/java/com/osx/core/utils/JsonUtil.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/JsonUtil.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.core.utils; +package org.fedai.osx.core.utils; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.core.type.TypeReference; @@ -58,8 +58,8 @@ public static T json2Object(String json, Class c) { T t = null; try { t = mapper.readValue(json, c); - } catch (IOException e) { - e.printStackTrace(); + } catch (IOException igore) { + } return t; } @@ -82,8 +82,7 @@ public static T json2List(String json, TypeReference typeReference) { T result = null; try { result = mapper.readValue(json, typeReference); - } catch (IOException e) { - e.printStackTrace(); + } catch (IOException igore) { } return result; } @@ -123,7 +122,7 @@ public static T object2Objcet(Object source, TypeReference tr) { } public static String formatJson(String jsonStr) { - return formatJson(jsonStr, " "); + return formatJson(jsonStr, "\t"); } /*** @@ -205,7 +204,7 @@ public static String pbToJson(MessageOrBuilder message) { public static void main(String[] args) { - String s = JsonUtil.formatJson("{\"route_table\":{\"default\":{\"default\":[{\"ip\":\"127.0.0.1\",\"port\":9999,\"useSSL\":false}]},\"10000\":{\"default\":[{\"ip\":\"127.0.0.1\",\"port\":8889}],\"serving\":[{\"ip\":\"127.0.0.1\",\"port\":8080}]},\"123\":[{\"host\":\"10.35.27.23\",\"port\":8888,\"useSSL\":false,\"negotiationType\":\"\",\"certChainFile\":\"\",\"privateKeyFile\":\"\",\"caFile\":\"\"}]},\"permission\":{\"default_allow\":true}}"); + String s = JsonUtil.formatJson("{\"route_table\":{\"default\":{\"default\":[{\"ip\":\"127.0.0.1\",\"port\":9999,\"useSSL\":false}]},\"10000\":{\"default\":[{\"ip\":\"127.0.0.1\",\"port\":8889}],\"serving\":[{\"ip\":\"127.0.0.1\",\"port\":8080}]},\"123\":[{\"host\":\"127.0.0.1\",\"port\":8888,\"useSSL\":false,\"negotiationType\":\"\",\"certChainFile\":\"\",\"privateKeyFile\":\"\",\"caFile\":\"\"}]},\"permission\":{\"default_allow\":true}}"); System.out.println(s); } diff --git a/java/osx/core/src/main/java/com/osx/core/utils/NetUtils.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/NetUtils.java similarity index 99% rename from java/osx/core/src/main/java/com/osx/core/utils/NetUtils.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/NetUtils.java index 8834c29d36..64af87c47a 100644 --- a/java/osx/core/src/main/java/com/osx/core/utils/NetUtils.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/NetUtils.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.core.utils; +package org.fedai.osx.core.utils; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; diff --git a/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/OSXCertUtils.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/OSXCertUtils.java new file mode 100644 index 0000000000..dbbffee885 --- /dev/null +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/OSXCertUtils.java @@ -0,0 +1,144 @@ +package org.fedai.osx.core.utils; + + +import org.fedai.osx.core.config.MetaInfo; +import sun.misc.BASE64Decoder; +import sun.security.x509.X509CertImpl; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManager; +import java.io.*; +import java.security.*; +import java.security.cert.Certificate; +import java.security.cert.CertificateFactory; +import java.security.spec.PKCS8EncodedKeySpec; +import java.util.concurrent.atomic.AtomicInteger; + +/*** + * certificates type conversion + */ +public class OSXCertUtils { + private static final int I0 = 0; + private static final int I1 = 1; + private static final String type = "PKCS12"; + private static final AtomicInteger keyStoreCount = new AtomicInteger(1); + + /*** + * x509 certificate packaged into p12 certificate + * @param chain cert chain, issue cert +> superior cert +> ... + * @param privateKey issued cert private key + * @param filePath path to save p12 the cert + * @param alias alias + * @throws Exception NoCert, NoSuchAlgorithm , NoKeyStore, io + */ + public static void x509ToPkCS12(Certificate[] chain, Key privateKey, String filePath, String alias) throws Exception { + try (OutputStream os = new FileOutputStream(filePath)) { + KeyStore keyStore = KeyStore.getInstance(type); + keyStore.load(null); + keyStore.setKeyEntry(alias, privateKey, MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD.toCharArray(), chain); + keyStore.store(os, MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD.toCharArray()); + } + } + + + /*** + * get x509 cert and private key by p12 cert + * @param filePath p12 cert file + * @param cs p1:storePassword p2:certPassword + * @return x509 certificate and private key + * @throws Exception NoCert, NoSuchAlgorithm , NoKeyStore, io + */ + public static X509AndKey getX509AndKeyByPkCS12(String filePath, String... cs) throws Exception { + try (InputStream is = new FileInputStream(filePath)) { + KeyStore keyStore = KeyStore.getInstance(type); + keyStore.load(is, toCharArray(I0, cs)); + String alias = keyStore.aliases().nextElement(); + return new X509AndKey(((X509CertImpl) keyStore.getCertificate(alias)), + ((PrivateKey) keyStore.getKey(alias, toCharArray(I1, cs)))); + } + } + + public static SSLContext getSSLContext(String caPath, String clientCertPath, String clientKeyPath) throws Exception { + KeyStore keyStore = getKeyStore(caPath, clientCertPath, clientKeyPath); + // Initialize the ssl context object + SSLContext sslContext = SSLContext.getInstance("SSL"); + TrustManager[] tm = {OsxX509TrustManager.getInstance(keyStore)}; + // Load client certificate + KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); + kmf.init(keyStore, MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD.toCharArray()); + sslContext.init(kmf.getKeyManagers(), tm, new SecureRandom()); + return sslContext; + } + + public static KeyStore getKeyStore(String caPath, String clientCertPath, String clientKeyPath) throws Exception { + KeyStore keyStore = KeyStore.getInstance("PKCS12"); + keyStore.load(null); + keyStore.setKeyEntry(MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_ALIAS, importPrivateKey(clientKeyPath), MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD.toCharArray(), + new Certificate[]{importCert(clientCertPath), importCert(caPath)}); + return keyStore; + } + + public static KeyStore getTrustStore(String keyStorePath, String trustStoreType) throws Exception { + KeyStore keyStore = KeyStore.getInstance(trustStoreType.toUpperCase()); + keyStore.load(new FileInputStream(new File(keyStorePath)), MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD.toCharArray()); + return keyStore; + } + + public static String createKeyStore(String caPath, String clientCertPath, String clientKeyPath) throws Exception { + PrivateKey privateKey = importPrivateKey(clientKeyPath); +// Certificate[] certificates = {importCert(clientCertPath), importCert(caPath)}; + Certificate[] certificates = {importCert(clientCertPath), importCert(caPath)}; + String pfxPath = OSXCertUtils.getTempStorePath(); + File pfxFile = new File(pfxPath); + FileUtils.createNewFile(pfxFile); + OSXCertUtils.x509ToPkCS12(certificates, privateKey, pfxPath, MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_ALIAS); + return pfxPath; + } + + public static Certificate importCert(String certFile) throws Exception { + try (FileInputStream certStream = new FileInputStream(certFile)) { + CertificateFactory cf = CertificateFactory.getInstance("X.509"); + return cf.generateCertificate(certStream); + } + } + + // Import private key + public static PrivateKey importPrivateKey(String privateKeyFile) throws Exception { + try (FileInputStream keyStream = new FileInputStream(privateKeyFile)) { + String space = ""; + byte[] bytes = new byte[keyStream.available()]; + int length = keyStream.read(bytes); + String keyString = new String(bytes, 0, length); + if (keyString.startsWith("-----BEGIN PRIVATE KEY-----\n")) { + keyString = keyString.replace("-----BEGIN PRIVATE KEY-----\n", space).replace("-----END PRIVATE KEY-----", space); + } + PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(new BASE64Decoder().decodeBuffer(keyString)); + return KeyFactory.getInstance("RSA").generatePrivate(keySpec); + } + } + + + //determine whether the string is null and get the default string character array + private static char[] toCharArray(int index, String... str) { + return str.length <= index || str[index] == null ? MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD.toCharArray() : str[index].toCharArray(); + } + + public static String getTempStorePath(){ + return ""; + } + + /*** + * this class pack X509Certificate and privateKey + */ + public static class X509AndKey { + private final X509CertImpl x509Cert; + private final PrivateKey privateKey; + + public X509AndKey(X509CertImpl x509Certificate, PrivateKey privateKey) { + this.x509Cert = x509Certificate; + this.privateKey = privateKey; + } + + } +} diff --git a/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/OsxX509TrustManager.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/OsxX509TrustManager.java new file mode 100644 index 0000000000..1faf330ce1 --- /dev/null +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/OsxX509TrustManager.java @@ -0,0 +1,247 @@ +package org.fedai.osx.core.utils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.*; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.NoSuchProviderException; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; + +public class OsxX509TrustManager implements X509TrustManager { + private static final Logger logger = LoggerFactory.getLogger(OsxX509TrustManager.class); + public static final String tabs = "%2F", equalSign = "%3D"; + + private final X509TrustManager x509TrustManager; + + public OsxX509TrustManager(X509TrustManager x509TrustManager) { + this.x509TrustManager = x509TrustManager; + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) { + try { + if (this.x509TrustManager == null) return; + this.x509TrustManager.checkClientTrusted(chain, authType); + } catch (CertificateException exc) { + logger.error(exc.getMessage()); + } + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) { + // sunJSSEX509TrustManager.checkServerTrusted(chain, authType); +// if (checkServer) { +// for (X509Certificate x509Certificate : chain) { +// // Use ca certificate verify +// verify(caX509Certificate, x509Certificate); +// +// // Send ocsp request verify +// ocspVerify(x509Certificate); +// } +// } + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + if (this.x509TrustManager == null) return null; + return this.x509TrustManager.getAcceptedIssuers(); + } + + public static OsxX509TrustManager getInstance() { + return new OsxX509TrustManager(null); + } + + public static OsxX509TrustManager getInstance(KeyStore keyStore) throws NoSuchProviderException, NoSuchAlgorithmException, KeyStoreException { + X509TrustManager x509TrustManager = null; + TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509", "SunJSSE"); + tmf.init(keyStore); + TrustManager[] tms = tmf.getTrustManagers(); + for (TrustManager tm : tms) { + if (tm instanceof X509TrustManager) { + x509TrustManager = (X509TrustManager) tm; + break; + } + } + return new OsxX509TrustManager(x509TrustManager); + } + + // Verify that the certificate if expired, and is issued for the root certificate +// public static void verify(X509Certificate superiorCert, X509Certificate issueCert) throws CertificateException { +// try { +// issueCert.checkValidity(); +// issueCert.verify(superiorCert.getPublicKey()); +// } catch (Exception e) { +// throw new CertificateException(e); +// } +// } + + // Obtain ocsp service address from the certificate and verify the validity of the certificate +// public static void ocspVerify(X509Certificate x509Certificate) throws CertificateException { +// X509CertImpl x509Cert = (X509CertImpl) x509Certificate; +// AuthorityInfoAccessExtension accessExtension = x509Cert.getAuthorityInfoAccessExtension(); +// List accessDescriptions = accessExtension.getAccessDescriptions(); +// for (AccessDescription accessDescription : accessDescriptions) { +// String anObject = accessDescription.getAccessMethod().toString(); +// if ("ocsp".equals(anObject) || "1.3.6.1.5.5.7.48.1".equals(anObject)) { +// GeneralName accessLocation = accessDescription.getAccessLocation(); +// URI ocspUrl = ((URIName) accessLocation.getName()).getURI(); +// goSendOCSP(ocspUrl.toString(), x509Cert); +// } +// } +// } + + // Send ocsp request +// public static void goSendOCSP(String ocspUrl, X509CertImpl x509Certificate) throws CertificateException { +// try { +// URL url = new URL(ocspUrl + "/" + getOcspRequestData(x509Certificate)); +// HttpURLConnection urlConnection = (HttpURLConnection) url.openConnection(); +// urlConnection.setConnectTimeout(5000); +// urlConnection.setReadTimeout(15000); +// urlConnection.setRequestMethod("GET"); +// urlConnection.setDoOutput(true); +// urlConnection.setDoInput(true); +// urlConnection.setRequestProperty("Content-type", "application/ocsp-request"); +// +// try (InputStream br = urlConnection.getInputStream(); +// ByteArrayOutputStream aos = new ByteArrayOutputStream() +// ) { +// int len; +// byte[] bytes = new byte[br.available()]; +// while ((len = br.read(bytes)) != -1) { +// aos.write(bytes, 0, len); +// } +// OCSPResponse ocspResponse = new OCSPResponse(aos.toByteArray()); +// OCSPResponse.ResponseStatus responseStatus = ocspResponse.getResponseStatus(); +// +// if (!responseStatus.equals(OCSPResponse.ResponseStatus.SUCCESSFUL)) { +// throw new CertificateException("ocsp request failed, request state: " + responseStatus); +// } +// +// Set certIds = ocspResponse.getCertIds(); +// for (CertId certId : certIds) { +// // Date nextUpdate = singleResponse.getNextUpdate(); +// // CRLReason revocationReason = singleResponse.getRevocationReason(); +// // Date thisUpdate = singleResponse.getThisUpdate(); +// OCSPResponse.SingleResponse singleResponse = ocspResponse.getSingleResponse(certId); +// OCSP.RevocationStatus.CertStatus certStatus = singleResponse.getCertStatus(); +// System.out.println("server certificate serial number: " + certId.getSerialNumber().toString(16) + ", status: " + certStatus); +// +// if (!OCSP.RevocationStatus.CertStatus.GOOD.equals(certStatus)) { +// // throw new CertificateException("服务器验证失败, 证书状态: " + certStatus); +// } +// } +// +// +// } catch (Exception e) { +// throw new CertificateException(e); +// } +// } catch (IOException e) { +// throw new CertificateException(e); +// } +// } + + // get ocsp request bytes +// private static byte[] getOcspRequestBytesData(X509CertImpl x509Certificate) throws IOException { +// return new OCSPRequest(new CertId(x509Certificate, x509Certificate.getSerialNumberObject())).encodeBytes(); +// } + + // get Base64 encode ocsp request url string parameter +// private static String getOcspRequestData(X509CertImpl certificate) throws IOException { +// CertId certId = new CertId(certificate, certificate.getSerialNumberObject()); +// OCSPRequest request = new OCSPRequest(certId); +// String encodeBuffer = new BASE64Encoder().encodeBuffer(request.encodeBytes()); +// return encodeBuffer.replace("\r\n", "").replace("/", tabs).replace("=", equalSign); +// } + + // OCSPRequest +// private static class OCSPRequest { +// private static final Debug debug = Debug.getInstance("certpath"); +// private static final boolean dump; +// private final List certIds; +// private final List extensions; +// private byte[] nonce; +// +// public OCSPRequest(CertId certId) { +// this(Collections.singletonList(certId)); +// } +// +// public OCSPRequest(List certIdList) { +// this.certIds = certIdList; +// this.extensions = Collections.emptyList(); +// } +// +// public OCSPRequest(List certIdList, List extensionList) { +// this.certIds = certIdList; +// this.extensions = extensionList; +// } +// +// public byte[] encodeBytes() throws IOException { +// DerOutputStream fillDOS = new DerOutputStream(); +// DerOutputStream certIdDOS = new DerOutputStream(); +// +// for (CertId certId : this.certIds) { +// DerOutputStream encodeDos = new DerOutputStream(); +// certId.encode(encodeDos); +// certIdDOS.write((byte) 48, encodeDos); +// } +// +// fillDOS.write((byte) 48, certIdDOS); +// DerOutputStream extensionDos; +// DerOutputStream endDos; +// if (!this.extensions.isEmpty()) { +// extensionDos = new DerOutputStream(); +// +// for (java.security.cert.Extension extension : this.extensions) { +// extension.encode(extensionDos); +// if (extension.getId().equals(PKIXExtensions.OCSPNonce_Id.toString())) { +// this.nonce = extension.getValue(); +// } +// } +// +// endDos = new DerOutputStream(); +// endDos.write((byte) 48, extensionDos); +// fillDOS.write(DerValue.createTag((byte) -128, true, (byte) 2), endDos); +// } +// +// extensionDos = new DerOutputStream(); +// extensionDos.write((byte) 48, fillDOS); +// endDos = new DerOutputStream(); +// endDos.write((byte) 48, extensionDos); +// byte[] bytes = endDos.toByteArray(); +// if (dump) { +// HexDumpEncoder dumpEncoder = new HexDumpEncoder(); +// debug.println("OCSPRequest bytes...\n\n" + dumpEncoder.encode(bytes) + "\n"); +// } +// +// return bytes; +// } +// +// public List getCertIds() { +// return this.certIds; +// } +// +// public byte[] getNonce() { +// return this.nonce; +// } +// +// static { +// dump = debug != null && Debug.isOn("ocsp"); +// } +// } + + public static class HostnameVerifier2 implements HostnameVerifier { + + @Override + public boolean verify(String s, SSLSession sslSession) { + return true; + } + + public static HostnameVerifier2 getInstance() { + return new HostnameVerifier2(); + } + } +} diff --git a/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/PropertiesUtil.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/PropertiesUtil.java new file mode 100644 index 0000000000..eb879e2006 --- /dev/null +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/PropertiesUtil.java @@ -0,0 +1,131 @@ +package org.fedai.osx.core.utils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.*; +import java.util.Properties; + + + +public final class PropertiesUtil +{ + public static Logger logger = LoggerFactory.getLogger(PropertiesUtil.class); + + public static Properties getProperties(String path) + { + Properties prop = new Properties(); + + loadProp(prop, path); + + return prop; + } + + private static void loadProp(Properties p, String conf) + { + InputStream is = getInputStream(conf); + + if(null != is) + { + try + { + p.load(is); + } + catch (IOException e) + { + logger.info("file not found!"); + } + finally + { + if(is != null) + { + try + { + is.close(); + } + catch (IOException e) + { + logger.info("stream close fail!"); + } + } + } + } + } + + //获取输入流 + private static InputStream getInputStream(String conf) + { + File file = new File(conf); + InputStream is = null; + try { + is = new BufferedInputStream(new FileInputStream(file)); + } catch (FileNotFoundException e) { + e.printStackTrace(); + } + return is; + } + + //获取输出流 + private static OutputStream getOutPutStream(String conf) + { + ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); + + OutputStream out = null; + + if(null != classLoader) + { + String filePath = classLoader.getResource(conf).getFile(); + try + { + out = new FileOutputStream(filePath); + } + catch (FileNotFoundException e) + { + logger.info("file not found!!!"); + } + } + return out; + } + + //根据key读取value + public static String getValue(Properties p, String key) + { + String value = p.getProperty(key); + + return value == null?"":value; + } + + //设置key=value + public static void setValue(String conf, String key, String value) + { + Properties p = getProperties(conf); + + OutputStream out = getOutPutStream(conf); + + p.setProperty(key, value); + + try + { + p.store(out, "set:"+key+"="+value); + } + catch (IOException e) + { + logger.info("set properties fail!!!"); + } + finally + { + if(out != null) + { + try + { + out.close(); + } + catch (IOException e) + { + logger.info("stream close fail!"); + } + } + } + } + +} \ No newline at end of file diff --git a/java/osx/core/src/main/java/com/osx/core/utils/RouterUtil.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/RouterUtil.java similarity index 96% rename from java/osx/core/src/main/java/com/osx/core/utils/RouterUtil.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/RouterUtil.java index a3c3040298..cd518d2885 100644 --- a/java/osx/core/src/main/java/com/osx/core/utils/RouterUtil.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/RouterUtil.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.utils; +package org.fedai.osx.core.utils; public class RouterUtil { diff --git a/java/osx/core/src/main/java/com/osx/core/utils/ServerUtil.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/ServerUtil.java similarity index 98% rename from java/osx/core/src/main/java/com/osx/core/utils/ServerUtil.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/ServerUtil.java index fe3ad9dd6e..866ff94fb9 100644 --- a/java/osx/core/src/main/java/com/osx/core/utils/ServerUtil.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/ServerUtil.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.utils; +package org.fedai.osx.core.utils; import org.apache.commons.cli.*; diff --git a/java/osx/core/src/main/java/com/osx/core/utils/ThreadPoolUtil.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/ThreadPoolUtil.java similarity index 97% rename from java/osx/core/src/main/java/com/osx/core/utils/ThreadPoolUtil.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/ThreadPoolUtil.java index f20cc25f58..08f622843f 100644 --- a/java/osx/core/src/main/java/com/osx/core/utils/ThreadPoolUtil.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/ThreadPoolUtil.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.osx.core.utils; +package org.fedai.osx.core.utils; import java.util.concurrent.SynchronousQueue; import java.util.concurrent.ThreadPoolExecutor; diff --git a/java/osx/core/src/main/java/com/osx/core/utils/ToStringUtils.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/ToStringUtils.java similarity index 94% rename from java/osx/core/src/main/java/com/osx/core/utils/ToStringUtils.java rename to java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/ToStringUtils.java index cb7a98ef68..ca1863e342 100644 --- a/java/osx/core/src/main/java/com/osx/core/utils/ToStringUtils.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/ToStringUtils.java @@ -13,12 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.osx.core.utils; +package org.fedai.osx.core.utils; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.MessageOrBuilder; import com.google.protobuf.util.JsonFormat; -import com.osx.core.constant.Dict; +import org.fedai.osx.core.constant.Dict; public class ToStringUtils { diff --git a/java/osx/pom.xml b/java/osx/pom.xml index 6ca280af42..eabf58e4ed 100644 --- a/java/osx/pom.xml +++ b/java/osx/pom.xml @@ -9,28 +9,21 @@ pom ${osx.version} - core - broker - - - - - - - - + osx-core + osx-broker + osx-api - 1.0.0-alpha + 1.0.0-beta 1.8 UTF-8 UTF-8 - 31.1-jre - 1.51.1 + 32.1.2-jre + 1.55.1 1.18.24 3.21.12 - 5.26 + 0.6.1 1.6.1 1.7.36 @@ -57,15 +50,19 @@ 1.10.0 4.13.2 5.12.1 + 3.8.0 + - - - - - + + + com.lmax + disruptor + ${disruptor.version} + + org.slf4j slf4j-api @@ -121,6 +118,12 @@ ${jetty.version} + + + + + + commons-io commons-io @@ -209,6 +212,12 @@ ${flatbuffers-java.version} + + commons-net + commons-net + ${commons-net.version} + + com.google.flatbuffers flatbuffers-java-grpc @@ -224,19 +233,16 @@ grpc-core ${grpc.version} - io.grpc grpc-netty-shaded ${grpc.version} - io.grpc grpc-protobuf ${grpc.version} - io.grpc grpc-stub @@ -247,54 +253,41 @@ protobuf-java-util ${protobuf.version} - - - com.fasterxml.jackson.core jackson-annotations ${jackson.version} - com.fasterxml.jackson.core jackson-databind ${jackson.version} - io.grpc grpc-core ${grpc.version} - io.grpc grpc-netty-shaded ${grpc.version} - io.grpc grpc-protobuf ${grpc.version} - - io.grpc grpc-stub ${grpc.version} - - com.googlecode.protobuf-java-format protobuf-java-format ${protobuf-java-format.version} - - org.apache.httpcomponents httpclient @@ -384,9 +377,6 @@ - - - org.apache.maven.plugins diff --git a/java/osx/proto/osx.proto b/java/osx/proto/osx.proto index f75453ada1..b8aca166e1 100644 --- a/java/osx/proto/osx.proto +++ b/java/osx/proto/osx.proto @@ -51,9 +51,16 @@ enum Metadata { SourceComponentName = 2; // 源组件名称 TargetComponentName = 3; // 目标组件名称 TargetMethod = 4; // 目标方法 - MessageOffSet = 5; // 消息序列号 - InstanceId = 6; // 实例ID - Timestamp = 7; // 时间戳 + SourceMethod = 5; // 协议标志 + MessageOffSet = 6; // 消息序列号 + InstanceId = 7; // 实例ID + Timestamp = 8; // 时间戳 + MessageFlag = 9; // 消息标志 + MessageTopicBack = 10; // 接受应答消息队列 + RetryCount = 11; // 重试次数 + Timeout = 12; // 超时时间 + JobId = 13; //jobId + } // 通信传输层输入报文编码 @@ -70,6 +77,33 @@ message Outbound { string message = 4; // 状态说明 } +message PeekInbound { + string topic = 1; // optional 会话主题,相同信道具有唯一性,用于同一信道的传输隔离 +} + +message PopInbound { + string topic = 1; // optional 会话主题,相同信道具有唯一性,用于同一信道的传输隔离 + int32 timeout = 2; // optional 阻塞超时时间,默认120s +} + +message PushInbound{ + bytes payload = 1; // 二进制报文 + string topic = 2; // optional 会话主题,相同信道具有唯一性,用于同一信道的传输隔离 + map metadata = 3; // optional 保留参数,用于扩展性 +} + +message ReleaseInbound { + string topic = 1; // optional 会话主题,相同信道具有唯一性,用于同一信道的传输隔离 + int32 timeout = 2; // optional 阻塞超时时间, +} + +message TransportOutbound { + bytes payload = 1; // 二进制报文 + string code = 2; // 状态码 + string message = 3; // 状态说明 +} + + // 互联互通如果使用异步传输协议作为标准参考,Header会复用metadata传输互联互通协议报头,且metadata中会传输异步场景下的消息相关属性 // 互联互通如果使用其他协议作为参考标准,Header会复用metadata传输互联互通协议报头 // 互联互通如果使用GRPC作为参考标准,Header会复用HTTP2的报头传输互联互通协议报头 @@ -77,6 +111,12 @@ message Outbound { service PrivateTransferProtocol { rpc transport (stream Inbound) returns (stream Outbound); rpc invoke (Inbound) returns (Outbound); + rpc test(stream Inbound) returns (stream Outbound); + + rpc peek (PeekInbound) returns (TransportOutbound); + rpc pop (PopInbound) returns (TransportOutbound); + rpc push (PushInbound) returns (TransportOutbound); + rpc release (ReleaseInbound) returns (TransportOutbound); } @@ -85,3 +125,7 @@ service PrivateTransferProtocol { + + + + diff --git a/python/fate/_info.py b/python/fate/_info.py index bc58543500..22cd41b09e 100644 --- a/python/fate/_info.py +++ b/python/fate/_info.py @@ -12,5 +12,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.0.0-alpha" +__version__ = "2.0.0-beta" __provider__ = "fate" diff --git a/python/fate/arch/__init__.py b/python/fate/arch/__init__.py index d87314c42f..714de7533b 100644 --- a/python/fate/arch/__init__.py +++ b/python/fate/arch/__init__.py @@ -14,7 +14,7 @@ # limitations under the License. # -from .context import Context -from .unify import Backend, device +from .context import CipherKit, Context +from .unify import URI, Backend, device -__all__ = ["Backend", "device", "Context"] +__all__ = ["Backend", "device", "Context", "URI", "CipherKit"] diff --git a/python/fate/arch/_standalone.py b/python/fate/arch/_standalone.py index a5e021af77..4c9fc90168 100644 --- a/python/fate/arch/_standalone.py +++ b/python/fate/arch/_standalone.py @@ -14,13 +14,15 @@ # limitations under the License. # -import asyncio import hashlib import itertools import logging +import logging.config import os import pickle as c_pickle import shutil +import signal +import threading import time import uuid from collections.abc import Iterable @@ -30,12 +32,12 @@ from heapq import heapify, heappop, heapreplace from operator import is_not from pathlib import Path -from typing import List, Optional, Tuple +from typing import List, Tuple import cloudpickle as f_pickle import lmdb import numpy as np -from fate.interface import PartyMeta +from fate.arch.abc import PartyMeta from .federation import FederationDataType @@ -49,6 +51,7 @@ if (STANDALONE_DATA_PATH := os.getenv("STANDALONE_DATA_PATH")) is not None: _data_dir = Path(STANDALONE_DATA_PATH) + LOGGER.debug(f"env STANDALONE_DATA_PATH is set to {STANDALONE_DATA_PATH}, using {_data_dir} as data dir") else: _data_dir = Path( os.path.abspath( @@ -57,6 +60,74 @@ ) ) ) + LOGGER.debug(f"env STANDALONE_DATA_PATH is not set, using {_data_dir} as data dir") + + +def _watch_thread_react_to_parent_die(ppid, logger_config): + """ + this function is call when a process is created, and it will watch parent process and initialize loggers + Args: + ppid: parent process id + """ + + # watch parent process, if parent process is dead, then kill self + # the trick is to use os.kill(ppid, 0) to check if parent process is alive periodically + # and if parent process is dead, then kill self + # + # Note: this trick is modified from the answer by aaron: https://stackoverflow.com/a/71369760/14697733 + pid = os.getpid() + + def f(): + while True: + try: + os.kill(ppid, 0) + except OSError: + os.kill(pid, signal.SIGTERM) + time.sleep(1) + + thread = threading.Thread(target=f, daemon=True) + thread.start() + + # initialize loggers + if logger_config is not None: + logging.config.dictConfig(logger_config) + # else: + # level = os.getenv("DEBUG_MODE_LOG_LEVEL", "DEBUG") + # try: + # import rich.logging + # + # logging_class = "rich.logging.RichHandler" + # logging_formatters = {} + # handlers = { + # "console": { + # "class": logging_class, + # "level": level, + # "filters": [], + # } + # } + # except ImportError: + # logging_class = "logging.StreamHandler" + # logging_formatters = { + # "console": { + # "format": "[%(levelname)s][%(asctime)-8s][%(process)s][%(module)s.%(funcName)s][line:%(lineno)d]: %(message)s" + # } + # } + # handlers = { + # "console": { + # "class": logging_class, + # "level": level, + # "formatter": "console", + # } + # } + # logging.config.dictConfig(dict( + # version=1, + # formatters=logging_formatters, + # handlers=handlers, + # filters={}, + # loggers={}, + # root=dict(handlers=["console"], level="DEBUG"), + # disable_existing_loggers=False, + # )) # noinspection PyPep8Naming @@ -106,11 +177,7 @@ def destroy(self): db = env.open_db() with env.begin(write=True) as txn: txn.drop(db) - - table_key = f"{self._namespace}.{self._name}" - _get_meta_table().delete(table_key) - path = _get_storage_dir(self._namespace, self._name) - shutil.rmtree(path, ignore_errors=True) + _TableMetaManager.destroy_table(self._namespace, self._name) def take(self, n, **kwargs): if n <= 0: @@ -250,7 +317,9 @@ def union(self, other: "Table", func=lambda v1, v2: v1): # noinspection PyProtectedMember def _map_reduce(self, mapper, reducer): - results = self._session._submit_map_reduce(mapper, reducer, self._partitions, self._name, self._namespace) + results = self._session._submit_map_reduce_in_partition( + mapper, reducer, self._partitions, self._name, self._namespace + ) result = results[0] # noinspection PyProtectedMember return _create_table( @@ -358,9 +427,16 @@ def delete(self, k): # noinspection PyMethodMayBeStatic class Session(object): - def __init__(self, session_id, max_workers=None): + def __init__(self, session_id, max_workers=None, logger_config=None): self.session_id = session_id - self._pool = Executor(max_workers=max_workers) + self._pool = Executor( + max_workers=max_workers, + initializer=_watch_thread_react_to_parent_die, + initargs=( + os.getpid(), + logger_config, + ), + ) def __getstate__(self): # session won't be pickled @@ -393,12 +469,11 @@ def parallelize(self, data: Iterable, partition: int, include_key: bool = False, return table def cleanup(self, name, namespace): - data_path = _get_data_dir() - if not data_path.is_dir(): - LOGGER.error(f"illegal data dir: {data_path}") + if not _data_dir.is_dir(): + LOGGER.error(f"illegal data dir: {_data_dir}") return - namespace_dir = data_path.joinpath(namespace) + namespace_dir = _data_dir.joinpath(namespace) if not namespace_dir.is_dir(): return @@ -426,7 +501,9 @@ def _submit_unary(self, func, _do_func, partitions, name, namespace): ) futures = [] for p in range(partitions): - futures.append(self._pool.submit(_do_func, _UnaryProcess(task_info, _Operand(namespace, name, p)))) + futures.append( + self._pool.submit(_do_func, _UnaryProcess(task_info, _Operand(namespace, name, p, partitions))) + ) results = [r.result() for r in futures] return results @@ -442,7 +519,7 @@ def _submit_map_reduce_in_partition(self, mapper, reducer, partitions, name, nam futures.append( self._pool.submit( _do_map_reduce_in_partitions, - _MapReduceProcess(task_info, _Operand(namespace, name, p)), + _MapReduceProcess(task_info, _Operand(namespace, name, p, partitions)), ) ) results = [r.result() for r in futures] @@ -456,10 +533,13 @@ def _submit_binary(self, func, do_func, partitions, name, namespace, other_name, ) futures = [] for p in range(partitions): - left = _Operand(namespace, name, p) - right = _Operand(other_namespace, other_name, p) + left = _Operand(namespace, name, p, partitions) + right = _Operand(other_namespace, other_name, p, partitions) futures.append(self._pool.submit(do_func, _BinaryProcess(task_info, left, right))) - results = [r.result() for r in futures] + results = [] + for f in futures: + r = f.result() + results.append(r) return results @@ -476,109 +556,26 @@ def _get_splits(obj, max_message_size): class Federation(object): - def _federation_object_key(self, name, tag, s_party, d_party): + def _federation_object_key(self, name: str, tag: str, s_party: Tuple[str, str], d_party: Tuple[str, str]): return f"{self._session_id}-{name}-{tag}-{s_party[0]}-{s_party[1]}-{d_party[0]}-{d_party[1]}" - def __init__(self, session: Session, session_id, party: Tuple[str, str]): + def __init__(self, session: Session, session_id: str, party: Tuple[str, str]): self._session_id = session_id self._party = party self._session = session self._max_message_size = DEFAULT_MESSAGE_MAX_SIZE self._other_status_tables = {} self._other_object_tables = {} - self._even_loop = None self._federation_status_table_cache = None self._federation_object_table_cache = None + self._meta = _FederationMetaManager(session_id, party) + def destroy(self): self._session.cleanup(namespace=self._session_id, name="*") - @property - def _federation_status_table(self): - if self._federation_status_table_cache is None: - self._federation_status_table_cache = _create_table( - session=self._session, - name=self._get_status_table_name(self._party), - namespace=self._session_id, - partitions=1, - need_cleanup=True, - error_if_exist=False, - ) - return self._federation_status_table_cache - - @property - def _federation_object_table(self): - if self._federation_object_table_cache is None: - self._federation_object_table_cache = _create_table( - session=self._session, - name=self._get_object_table_name(self._party), - namespace=self._session_id, - partitions=1, - need_cleanup=True, - error_if_exist=False, - ) - return self._federation_object_table_cache - - @property - def _loop(self): - if self._even_loop is None: - self._even_loop = asyncio.get_event_loop() - return self._even_loop - - @staticmethod - def _get_status_table_name(party): - return f"__federation_status__.{party[0]}_{party[1]}" - - @staticmethod - def _get_object_table_name(party): - return f"__federation_object__.{party[0]}_{party[1]}" - - def _get_other_status_table(self, party): - if party in self._other_status_tables: - return self._other_status_tables[party] - table = _create_table( - self._session, - name=self._get_status_table_name(party), - namespace=self._session_id, - partitions=1, - need_cleanup=False, - error_if_exist=False, - ) - self._other_status_tables[party] = table - return table - - def _get_other_object_table(self, party): - if party in self._other_object_tables: - return self._other_object_tables[party] - table = _create_table( - self._session, - name=self._get_object_table_name(party), - namespace=self._session_id, - partitions=1, - need_cleanup=False, - error_if_exist=False, - ) - self._other_object_tables[party] = table - return table - - # noinspection PyProtectedMember - def _put_status(self, party, _tagged_key, value): - self._get_other_status_table(party).put(_tagged_key, value) - - # noinspection PyProtectedMember - def _put_object(self, party, _tagged_key, value): - self._get_other_object_table(party).put(_tagged_key, value) - - # noinspection PyProtectedMember - def _get_object(self, _tagged_key): - return self._federation_object_table.get(_tagged_key) - - # noinspection PyProtectedMember - def _get_status(self, _tagged_key): - return self._federation_status_table.get(_tagged_key) - # noinspection PyUnusedLocal - def remote(self, v, name: str, tag: str, parties: List[Tuple[str, str]]): + def remote(self, v, name: str, tag: str, parties: List[PartyMeta]): log_str = f"federation.standalone.remote.{name}.{tag}" if v is None: @@ -622,21 +619,20 @@ def remote(self, v, name: str, tag: str, parties: List[Tuple[str, str]]): f"Table(namespace={v.namespace}, name={saved_name}, partitions={v.partitions})" ) _v = v.save_as(name=saved_name, namespace=v.namespace, need_cleanup=False) - self._put_status(party, _tagged_key, (_v.name, _v.namespace, dtype)) + self._meta.set_status(party, _tagged_key, (_v.name, _v.namespace, dtype)) else: - self._put_object(party, _tagged_key, v) - self._put_status(party, _tagged_key, _tagged_key) + self._meta.set_object(party, _tagged_key, v) + self._meta.set_status(party, _tagged_key, _tagged_key) # noinspection PyProtectedMember def get(self, name: str, tag: str, parties: List[PartyMeta]) -> List: log_str = f"federation.standalone.get.{name}.{tag}" LOGGER.debug(f"[{log_str}]") - tasks = [] + results = [] for party in parties: _tagged_key = self._federation_object_key(name, tag, party, self._party) - tasks.append(_check_status_and_get_value(self._get_status, _tagged_key)) - results = self._loop.run_until_complete(asyncio.gather(*tasks)) + results.append(self._meta.wait_status_set(_tagged_key)) rtn = [] for r in results: @@ -657,61 +653,16 @@ def get(self, name: str, tag: str, parties: List[PartyMeta]) -> List: else: rtn.append(table) else: - obj = self._get_object(r) + obj = self._meta.get_object(r) if obj is None: raise EnvironmentError(f"federation get None from {parties} with name {name}, tag {tag}") rtn.append(obj) - self._federation_object_table.delete(k=r) + self._meta.ack_object(r) LOGGER.debug(f"[{log_str}] got object with type: {type(obj)}") - self._federation_status_table.delete(r) + self._meta.ack_status(r) return rtn -_meta_table: Optional[Table] = None - -_SESSION = Session(uuid.uuid1().hex) - - -def _get_meta_table(): - global _meta_table - if _meta_table is None: - _meta_table = Table( - _SESSION, - namespace="__META__", - name="fragments", - partitions=10, - need_cleanup=False, - ) - return _meta_table - - -# noinspection PyProtectedMember -def _get_from_meta_table(key): - return _get_meta_table().get(key) - - -# noinspection PyProtectedMember -def _put_to_meta_table(key, value): - _get_meta_table().put(key, value) - - -def _get_data_dir(): - return _data_dir - - -def _get_storage_dir(*args): - return _data_dir.joinpath(*args) - - -async def _check_status_and_get_value(get_func, key): - value = get_func(key) - while value is None: - await asyncio.sleep(0.1) - value = get_func(key) - LOGGER.debug("[GET] Got {} type {}".format(key, "Table" if isinstance(value, tuple) else "Object")) - return value - - def _create_table( session: "Session", name: str, @@ -720,16 +671,15 @@ def _create_table( need_cleanup=True, error_if_exist=False, ): - if isinstance(namespace, int): - raise ValueError(f"{namespace} {name}") - _table_key = ".".join([namespace, name]) - if _get_from_meta_table(_table_key) is not None: + assert isinstance(name, str) + assert isinstance(namespace, str) + assert isinstance(partitions, int) + if (exist_partitions := _TableMetaManager.get_table_meta(namespace, name)) is None: + _TableMetaManager.add_table_meta(namespace, name, partitions) + else: if error_if_exist: raise RuntimeError(f"table already exist: name={name}, namespace={namespace}") - else: - partitions = _get_from_meta_table(_table_key) - else: - _put_to_meta_table(_table_key, partitions) + partitions = exist_partitions return Table( session=session, @@ -740,14 +690,8 @@ def _create_table( ) -def _exist(name: str, namespace: str): - _table_key = ".".join([namespace, name]) - return _get_from_meta_table(_table_key) is not None - - -def _load_table(session, name, namespace, need_cleanup=False): - _table_key = ".".join([namespace, name]) - partitions = _get_from_meta_table(_table_key) +def _load_table(session, name: str, namespace: str, need_cleanup=False): + partitions = _TableMetaManager.get_table_meta(namespace, name) if partitions is None: raise RuntimeError(f"table not exist: name={name}, namespace={namespace}") return Table( @@ -793,10 +737,11 @@ def get_reducer(self): class _Operand: - def __init__(self, namespace, name, partition): + def __init__(self, namespace, name, partition, num_partitions: int): self.namespace = namespace self.name = name self.partition = partition + self.num_partitions = num_partitions def as_env(self, write=False): return _get_env(self.namespace, self.name, str(self.partition), write=write) @@ -808,7 +753,7 @@ def __init__(self, task_info: _TaskInfo, operand: _Operand): self.operand = operand def output_operand(self): - return _Operand(self.info.task_id, self.info.function_id, self.operand.partition) + return _Operand(self.info.task_id, self.info.function_id, self.operand.partition, self.operand.num_partitions) def get_func(self): return self.info.get_func() @@ -820,7 +765,7 @@ def __init__(self, task_info: _MapReduceTaskInfo, operand: _Operand): self.operand = operand def output_operand(self): - return _Operand(self.info.task_id, self.info.function_id, self.operand.partition) + return _Operand(self.info.task_id, self.info.function_id, self.operand.partition, self.operand.num_partitions) def get_mapper(self): return self.info.get_mapper() @@ -836,14 +781,14 @@ def __init__(self, task_info: _TaskInfo, left: _Operand, right: _Operand): self.right = right def output_operand(self): - return _Operand(self.info.task_id, self.info.function_id, self.left.partition) + return _Operand(self.info.task_id, self.info.function_id, self.left.partition, self.left.num_partitions) def get_func(self): return self.info.get_func() def _get_env(*args, write=False): - _path = _get_storage_dir(*args) + _path = _data_dir.joinpath(*args) return _open_env(_path, write=write) @@ -890,9 +835,8 @@ def _do_map(p: _UnaryProcess): rtn = p.output_operand() with ExitStack() as s: source_env = s.enter_context(p.operand.as_env()) - partitions = _get_from_meta_table(f"{p.operand.namespace}.{p.operand.name}") txn_map = {} - for partition in range(partitions): + for partition in range(p.operand.num_partitions): env = s.enter_context(_get_env(rtn.namespace, rtn.name, str(partition), write=True)) txn_map[partition] = s.enter_context(env.begin(write=True)) source_txn = s.enter_context(source_env.begin()) @@ -901,7 +845,7 @@ def _do_map(p: _UnaryProcess): k, v = deserialize(k_bytes), deserialize(v_bytes) k1, v1 = p.get_func()(k, v) k1_bytes, v1_bytes = serialize(k1), serialize(v1) - partition = _hash_key_to_partition(k1_bytes, partitions) + partition = _hash_key_to_partition(k1_bytes, p.operand.num_partitions) txn_map[partition].put(k1_bytes, v1_bytes) return rtn @@ -925,7 +869,7 @@ def _do_apply_partitions(p: _UnaryProcess): if cursor.last(): k_bytes = cursor.key() dst_txn.put(k_bytes, serialize(v)) - return rtn + return rtn def _do_map_partitions(p: _UnaryProcess): @@ -946,7 +890,7 @@ def _do_map_partitions(p: _UnaryProcess): else: k_bytes = cursor.key() dst_txn.put(k_bytes, serialize(v)) - return rtn + return rtn def _do_map_partitions_with_index(p: _UnaryProcess): @@ -967,16 +911,15 @@ def _do_map_partitions_with_index(p: _UnaryProcess): else: k_bytes = cursor.key() dst_txn.put(k_bytes, serialize(v)) - return rtn + return rtn def _do_map_reduce_in_partitions(p: _MapReduceProcess): rtn = p.output_operand() with ExitStack() as s: source_env = s.enter_context(p.operand.as_env()) - partitions = _get_from_meta_table(f"{p.operand.namespace}.{p.operand.name}") txn_map = {} - for partition in range(partitions): + for partition in range(p.operand.num_partitions): env = s.enter_context(_get_env(rtn.namespace, rtn.name, str(partition), write=True)) txn_map[partition] = s.enter_context(env.begin(write=True)) source_txn = s.enter_context(source_env.begin()) @@ -988,7 +931,7 @@ def _do_map_reduce_in_partitions(p: _MapReduceProcess): for k, v in mapped: k_bytes = serialize(k) - partition = _hash_key_to_partition(k_bytes, partitions) + partition = _hash_key_to_partition(k_bytes, p.operand.num_partitions) # todo: not atomic, fix me pre_v = txn_map[partition].get(k_bytes, None) if pre_v is None: @@ -1040,7 +983,7 @@ def _do_reduce(p: _UnaryProcess): source_env = s.enter_context(p.operand.as_env()) source_txn = s.enter_context(source_env.begin()) cursor = s.enter_context(source_txn.cursor()) - for k_bytes, v_bytes in cursor: + for _, v_bytes in cursor: v = deserialize(v_bytes) if value is None: value = v @@ -1146,7 +1089,12 @@ def _do_join(p: _BinaryProcess): continue v1 = deserialize(v1_bytes) v2 = deserialize(v2_bytes) - v3 = p.get_func()(v1, v2) + try: + v3 = p.get_func()(v1, v2) + except Exception as e: + raise RuntimeError( + f"Error when joining:\n" f"left:\n" f"{v1}\n" f"right:\n" f"{v2}\n" f"error: {e}" + ) from e dst_txn.put(k_bytes, serialize(v3)) return rtn @@ -1189,3 +1137,106 @@ def _kv_to_bytes(k, v): def _k_to_bytes(k): return serialize(k) + + +class _FederationMetaManager: + STATUS_TABLE_NAME_PREFIX = "__federation_status__" + OBJECT_TABLE_NAME_PREFIX = "__federation_object__" + + def __init__(self, session_id, party: Tuple[str, str]) -> None: + self.session_id = session_id + self.party = party + self._env = {} + + def wait_status_set(self, key): + value = self.get_status(key) + while value is None: + time.sleep(0.1) + value = self.get_status(key) + LOGGER.debug("[GET] Got {} type {}".format(key, "Table" if isinstance(value, tuple) else "Object")) + return value + + def get_status(self, key): + return self._get(self._get_status_table_name(self.party), key) + + def set_status(self, party: Tuple[str, str], key: str, value): + return self._set(self._get_status_table_name(party), key, value) + + def ack_status(self, key): + return self._ack(self._get_status_table_name(self.party), key) + + def get_object(self, key): + return self._get(self._get_object_table_name(self.party), key) + + def set_object(self, party: Tuple[str, str], key, value): + return self._set(self._get_object_table_name(party), key, value) + + def ack_object(self, key): + return self._ack(self._get_object_table_name(self.party), key) + + def _get_status_table_name(self, party: Tuple[str, str]): + return f"{self.STATUS_TABLE_NAME_PREFIX}.{party[0]}_{party[1]}" + + def _get_object_table_name(self, party: Tuple[str, str]): + return f"{self.OBJECT_TABLE_NAME_PREFIX}.{party[0]}_{party[1]}" + + def _get_env(self, name): + if name not in self._env: + self._env[name] = _get_env(self.session_id, name, str(0), write=True) + return self._env[name] + + def _get(self, name, key): + env = self._get_env(name) + with env.begin(write=False) as txn: + old_value_bytes = txn.get(serialize(key)) + if old_value_bytes is not None: + old_value_bytes = deserialize(old_value_bytes) + return old_value_bytes + + def _set(self, name, key, value): + env = self._get_env(name) + with env.begin(write=True) as txn: + return txn.put(serialize(key), serialize(value)) + + def _ack(self, name, key): + env = self._get_env(name) + with env.begin(write=True) as txn: + txn.delete(serialize(key)) + + +class _TableMetaManager: + namespace = "__META__" + name = "fragments" + num_partitions = 10 + _env = {} + + @classmethod + def _get_meta_env(cls, namespace: str, name: str): + k_bytes = _k_to_bytes(f"{namespace}.{name}") + p = _hash_key_to_partition(k_bytes, cls.num_partitions) + if p not in cls._env: + cls._env[p] = _get_env(cls.namespace, cls.name, str(p), write=True) + return k_bytes, cls._env[p] + + @classmethod + def add_table_meta(cls, namespace: str, name: str, num_partitions: int): + k_bytes, env = cls._get_meta_env(namespace, name) + with env.begin(write=True) as txn: + return txn.put(k_bytes, serialize(num_partitions)) + + @classmethod + def get_table_meta(cls, namespace: str, name: str): + k_bytes, env = cls._get_meta_env(namespace, name) + with env.begin(write=False) as txn: + old_value_bytes = txn.get(k_bytes) + if old_value_bytes is not None: + old_value_bytes = deserialize(old_value_bytes) + return old_value_bytes + + @classmethod + def destroy_table(cls, namespace: str, name: str): + k_bytes, env = cls._get_meta_env(namespace, name) + with env.begin(write=True) as txn: + txn.delete(k_bytes) + path = _data_dir.joinpath(namespace, name) + shutil.rmtree(path, ignore_errors=True) diff --git a/python/fate/arch/abc/__init__.py b/python/fate/arch/abc/__init__.py new file mode 100644 index 0000000000..7f42bf287d --- /dev/null +++ b/python/fate/arch/abc/__init__.py @@ -0,0 +1,3 @@ +from ._federation import FederationEngine, GarbageCollector +from ._party import PartyMeta +from ._table import CSessionABC, CTableABC diff --git a/python/fate/interface/_federation.py b/python/fate/arch/abc/_federation.py similarity index 83% rename from python/fate/interface/_federation.py rename to python/fate/arch/abc/_federation.py index 90fc40dd1b..21a1437c64 100644 --- a/python/fate/interface/_federation.py +++ b/python/fate/arch/abc/_federation.py @@ -14,8 +14,15 @@ # limitations under the License. from typing import List, Optional, Protocol -from ._gc import GarbageCollector -from ._party import Parties, Party, PartyMeta +from ._party import PartyMeta + + +class GarbageCollector(Protocol): + def register_clean_action(self, name: str, tag: str, obj, method: str, kwargs): + ... + + def clean(self, name: str, tag: str): + ... class FederationEngine(Protocol): @@ -39,10 +46,3 @@ def push( def destroy(self): ... - - -class FederationWrapper(Protocol): - guest: Party - hosts: Parties - arbiter: Party - parties: Parties diff --git a/python/fate/interface/_computing.py b/python/fate/arch/abc/_party.py similarity index 85% rename from python/fate/interface/_computing.py rename to python/fate/arch/abc/_party.py index fac13e33d3..02ab615046 100644 --- a/python/fate/interface/_computing.py +++ b/python/fate/arch/abc/_party.py @@ -12,9 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Protocol +from typing import Literal, Tuple - -class ComputingEngine(Protocol): - def destroy(self): - ... +PartyMeta = Tuple[Literal["guest", "host", "arbiter", "local"], str] diff --git a/python/fate/arch/computing/_computing.py b/python/fate/arch/abc/_table.py similarity index 94% rename from python/fate/arch/computing/_computing.py rename to python/fate/arch/abc/_table.py index 9ccce683f9..dee3d993cd 100644 --- a/python/fate/arch/computing/_computing.py +++ b/python/fate/arch/abc/_table.py @@ -21,15 +21,18 @@ import abc import typing from abc import ABCMeta -from collections.abc import Iterable - -from ._address import Address +from typing import Iterable __all__ = ["CTableABC", "CSessionABC"] +K = typing.TypeVar("K") +V = typing.TypeVar("V") +K_OUT = typing.TypeVar("K_OUT") +V_OUT = typing.TypeVar("V_OUT") + # noinspection PyPep8Naming -class CTableABC(metaclass=ABCMeta): +class CTableABC(typing.Generic[K, V], metaclass=ABCMeta): """ a table of pair-like data supports distributed processing """ @@ -49,7 +52,7 @@ def engine(self): @property @abc.abstractmethod - def partitions(self): + def partitions(self) -> int: """ get the partitions of table @@ -65,7 +68,7 @@ def copy(self): ... @abc.abstractmethod - def save(self, address: Address, partitions: int, schema: dict, **kwargs): + def save(self, uri, schema: dict, options: dict = None): """ save table @@ -73,10 +76,10 @@ def save(self, address: Address, partitions: int, schema: dict, **kwargs): ---------- address: AddressABC address to save table to - partitions: int - number of partitions to save as schema: dict table schema + options: dict + options for saving """ ... @@ -97,7 +100,7 @@ def collect(self, **kwargs) -> typing.Generator: ... @abc.abstractmethod - def take(self, n=1, **kwargs): + def take(self, n=1, **kwargs) -> typing.List[V]: """ take ``n`` data from table @@ -118,7 +121,7 @@ def take(self, n=1, **kwargs): ... @abc.abstractmethod - def first(self, **kwargs): + def first(self, **kwargs) -> V: """ take one data from table @@ -147,7 +150,7 @@ def count(self) -> int: ... @abc.abstractmethod - def map(self, func) -> "CTableABC": + def map(self, func: typing.Callable[[K, V], typing.Tuple[K_OUT, V_OUT]]) -> "CTableABC[K_OUT, V_OUT]": """ apply `func` to each data @@ -172,7 +175,7 @@ def map(self, func) -> "CTableABC": ... @abc.abstractmethod - def mapValues(self, func): + def mapValues(self, func: typing.Callable[[V], V_OUT]) -> "CTableABC[K, V_OUT]": """ apply `func` to each value of data @@ -328,7 +331,7 @@ def flatMap(self, func): ... @abc.abstractmethod - def reduce(self, func): + def reduce(self, func: typing.Callable[[V, V], V]) -> V: """ reduces all value in pair of table by a binary function `func` @@ -445,7 +448,7 @@ def filter(self, func): ... @abc.abstractmethod - def join(self, other, func): + def join(self, other, func: typing.Callable[[typing.Any, typing.Any], typing.Any]) -> "CTableABC": """ returns intersection of this table and the other table. @@ -549,7 +552,7 @@ class CSessionABC(metaclass=ABCMeta): """ @abc.abstractmethod - def load(self, address: Address, partitions, schema: dict, **kwargs) -> CTableABC: + def load(self, uri, schema: dict, options: dict = None) -> CTableABC: """ load a table from given address @@ -557,15 +560,16 @@ def load(self, address: Address, partitions, schema: dict, **kwargs) -> CTableAB ---------- address: AddressABC address to load table from - partitions: int - number of partitions of loaded table schema: dict schema associate with this table + options: dict + options associate with this table load Returns ------- CTableABC a table in memory + """ ... diff --git a/python/fate/arch/computing/__init__.py b/python/fate/arch/computing/__init__.py index e6fd05b951..20fe23f87c 100644 --- a/python/fate/arch/computing/__init__.py +++ b/python/fate/arch/computing/__init__.py @@ -13,12 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._computing import CTableABC + +from ._profile import enable_profile_remote, profile_ends, profile_start from ._type import ComputingEngine def is_table(v): + from fate.arch.abc import CTableABC + return isinstance(v, CTableABC) -__all__ = ["is_table", "ComputingEngine"] +__all__ = ["is_table", "ComputingEngine", "profile_start", "profile_ends"] diff --git a/python/fate/arch/computing/_address.py b/python/fate/arch/computing/_address.py deleted file mode 100644 index 7cad6dbee5..0000000000 --- a/python/fate/arch/computing/_address.py +++ /dev/null @@ -1,242 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import abc - - -class Address(metaclass=abc.ABCMeta): - ... - - -class StandaloneAddress(Address): - def __init__( - self, - home=None, - name=None, - namespace=None, - storage_type=None, - ): - self.home = home - self.name = name - self.namespace = namespace - self.storage_type = storage_type - - def __hash__(self): - return (self.home, self.name, self.namespace, self.storage_type).__hash__() - - def __str__(self): - return f"StandaloneAddress(name={self.name}, namespace={self.namespace})" - - def __repr__(self): - return self.__str__() - - @property - def connector(self): - return {"home": self.home} - - -class EggRollAddress(Address): - def __init__(self, home=None, name=None, namespace=None): - self.name = name - self.namespace = namespace - self.home = home - - def __hash__(self): - return (self.home, self.name, self.namespace).__hash__() - - def __str__(self): - return f"EggRollAddress(name={self.name}, namespace={self.namespace})" - - def __repr__(self): - return self.__str__() - - @property - def connector(self): - return {"home": self.home} - - -class HDFSAddress(Address): - def __init__(self, name_node=None, path=None): - self.name_node = name_node - self.path = path - - def __hash__(self): - return (self.name_node, self.path).__hash__() - - def __str__(self): - return f"HDFSAddress(name_node={self.name_node}, path={self.path})" - - def __repr__(self): - return self.__str__() - - @property - def connector(self): - return {"name_node": self.name_node} - - -class PathAddress(Address): - def __init__(self, path): - self.path = path - - def __hash__(self): - return self.path.__hash__() - - def __str__(self): - return f"PathAddress(path={self.path})" - - def __repr__(self): - return self.__str__() - - -class ApiAddress(Address): - def __init__(self, method="POST", url=None, header=None, body=None): - self.method = method - self.url = url - self.header = header if header else {} - self.body = body if body else {} - - def __hash__(self): - return (self.method, self.url).__hash__() - - def __str__(self): - return f"ApiAddress(url={self.url})" - - def __repr__(self): - return self.__str__() - - -class MysqlAddress(Address): - def __init__( - self, - user=None, - passwd=None, - host=None, - port=None, - db=None, - name=None, - ): - self.user = user - self.passwd = passwd - self.host = host - self.port = port - self.db = db - self.name = name - - def __hash__(self): - return (self.host, self.port, self.db, self.name).__hash__() - - def __str__(self): - return f"MysqlAddress(db={self.db}, name={self.name})" - - def __repr__(self): - return self.__str__() - - @property - def connector(self): - return { - "user": self.user, - "passwd": self.passwd, - "host": self.host, - "port": self.port, - "db": self.db, - } - - -class HiveAddress(Address): - def __init__( - self, - host=None, - name=None, - port=10000, - username=None, - database="default", - auth_mechanism="PLAIN", - password=None, - ): - self.host = host - self.username = username - self.port = port - self.database = database - self.auth_mechanism = auth_mechanism - self.password = password - self.name = name - - def __hash__(self): - return (self.host, self.port, self.database, self.name).__hash__() - - def __str__(self): - return f"HiveAddress(database={self.database}, name={self.name})" - - def __repr__(self): - return self.__str__() - - @property - def connector(self): - return { - "host": self.host, - "port": self.port, - "username": self.username, - "password": self.password, - "auth_mechanism": self.auth_mechanism, - "database": self.database, - } - - -class LinkisHiveAddress(Address): - def __init__( - self, - host="127.0.0.1", - port=9001, - username="", - database="", - name="", - run_type="hql", - execute_application_name="hive", - source={}, - params={}, - ): - self.host = host - self.port = port - self.username = username - self.database = database if database else f"{username}_ind" - self.name = name - self.run_type = run_type - self.execute_application_name = execute_application_name - self.source = source - self.params = params - - def __hash__(self): - return (self.host, self.port, self.database, self.name).__hash__() - - def __str__(self): - return f"LinkisHiveAddress(database={self.database}, name={self.name})" - - def __repr__(self): - return self.__str__() - - -class LocalFSAddress(Address): - def __init__(self, path): - self.path = path - - def __hash__(self): - return (self.path).__hash__() - - def __str__(self): - return f"LocalFSAddress(path={self.path})" - - def __repr__(self): - return self.__str__() diff --git a/python/fate/arch/computing/_profile.py b/python/fate/arch/computing/_profile.py index 80e28d795c..c426bb4be9 100644 --- a/python/fate/arch/computing/_profile.py +++ b/python/fate/arch/computing/_profile.py @@ -76,7 +76,7 @@ def __init__(self, function_name: str, function_stack_list): self._start = time.time() function_stack = "\n".join(function_stack_list) - self._hash = hashlib.blake2b(function_stack.encode("utf-8"), digest_size=5).hexdigest() + self._hash = hashlib.blake2b(f"{function_name}#{function_stack}".encode("utf-8"), digest_size=5).hexdigest() if self._hash not in self._STATS: self._STATS[self._hash] = _ComputingTimerItem(function_name, function_stack) @@ -145,7 +145,7 @@ def computing_statistics_table(cls, timer_aggregator: _TimerItem = None): if timer_aggregator: timer_aggregator.union(total) - return base_table.get_string(), detailed_base_table.get_string() + return str(base_table), str(detailed_base_table) class _FederationTimer(object): @@ -183,7 +183,7 @@ def federation_statistics_table(cls, timer_aggregator: _TimerItem = None): if timer_aggregator: timer_aggregator.union(total) - return base_table.get_string() + return str(base_table) class _FederationRemoteTimer(_FederationTimer): @@ -313,9 +313,9 @@ def profile_ends(): def _pretty_table_str(v): - from ..computing._computing import CTableABC + from ..computing import is_table - if isinstance(v, CTableABC): + if is_table(v): return f"Table(partition={v.partitions})" else: return f"{type(v).__name__}" diff --git a/python/fate/arch/computing/eggroll/_csession.py b/python/fate/arch/computing/eggroll/_csession.py index b916430a2b..f391969e2a 100644 --- a/python/fate/arch/computing/eggroll/_csession.py +++ b/python/fate/arch/computing/eggroll/_csession.py @@ -15,10 +15,10 @@ # import logging -from typing import Optional -from ...unify import uuid -from .._computing import CSessionABC +from fate.arch.abc import CSessionABC + +from ...unify import URI, uuid from .._profile import computing_profile from ._table import Table @@ -51,35 +51,37 @@ def session_id(self): return self._session_id @computing_profile - def load(self, address, partitions: Optional[int], schema: dict, **kwargs): - - from .._address import EggRollAddress + def load(self, uri: URI, schema: dict, options: dict = None) -> Table: from ._type import EggRollStoreType - if isinstance(address, EggRollAddress): - options = kwargs.get("option", {}) - if partitions is not None: - options["total_partitions"] = partitions - options["store_type"] = kwargs.get("store_type", EggRollStoreType.ROLLPAIR_LMDB) - options["create_if_missing"] = False - rp = self._rpc.load(namespace=address.namespace, name=address.name, options=options) - if rp is None or rp.get_partitions() == 0: - raise RuntimeError(f"no exists: {address.name}, {address.namespace}") - - if options["store_type"] != EggRollStoreType.ROLLPAIR_IN_MEMORY: - rp = rp.save_as( - name=f"{address.name}_{uuid()}", - namespace=self.session_id, - partition=partitions, - options={"store_type": EggRollStoreType.ROLLPAIR_IN_MEMORY}, - ) + if uri.scheme != "eggroll": + raise ValueError(f"uri scheme {uri.scheme} not supported with eggroll backend") + try: + _, namespace, name = uri.path_splits() + except Exception as e: + raise ValueError(f"uri {uri} not valid, demo format: eggroll:///namespace/name") from e + + if options is None: + options = {} + if "store_type" not in options: + options["store_type"] = EggRollStoreType.ROLLPAIR_LMDB + options["create_if_missing"] = False + rp = self._rpc.load(namespace=namespace, name=name, options=options) + if rp is None or rp.get_partitions() == 0: + raise RuntimeError(f"no exists: {name}, {namespace}") + + if options["store_type"] != EggRollStoreType.ROLLPAIR_IN_MEMORY: + rp = rp.save_as( + name=f"{name}_{uuid()}", + namespace=self.session_id, + partition=rp.get_partitions(), + options={"store_type": EggRollStoreType.ROLLPAIR_IN_MEMORY}, + ) table = Table(rp=rp) table.schema = schema return table - raise NotImplementedError(f"address type {type(address)} not supported with eggroll backend") - @computing_profile def parallelize(self, data, partition: int, include_key: bool, **kwargs) -> Table: options = dict() diff --git a/python/fate/arch/computing/eggroll/_table.py b/python/fate/arch/computing/eggroll/_table.py index 8b97c2139e..4745c88cf8 100644 --- a/python/fate/arch/computing/eggroll/_table.py +++ b/python/fate/arch/computing/eggroll/_table.py @@ -18,7 +18,9 @@ import logging import typing -from .._computing import CTableABC +from fate.arch.abc import CTableABC + +from ...unify import URI from .._profile import computing_profile from .._type import ComputingEngine @@ -44,23 +46,31 @@ def copy(self): return Table(self._rp.map_values(lambda x: x)) @computing_profile - def save(self, address, partitions, schema: dict, **kwargs): - options = kwargs.get("options", {}) - from .._address import EggRollAddress - from ._type import EggRollStoreType + def save(self, uri: URI, schema: dict, options: dict = None): + if options is None: + options = {} - if isinstance(address, EggRollAddress): - options["store_type"] = kwargs.get("store_type", EggRollStoreType.ROLLPAIR_LMDB) - self._rp.save_as( - name=address.name, - namespace=address.namespace, - partition=partitions, - options=options, - ) - schema.update(self.schema) - return + from ._type import EggRollStoreType - raise NotImplementedError(f"address type {type(address)} not supported with eggroll backend") + if uri.scheme != "eggroll": + raise ValueError(f"uri scheme {uri.scheme} not supported with eggroll backend") + try: + _, namespace, name = uri.path_splits() + except Exception as e: + raise ValueError(f"uri {uri} not supported with eggroll backend") from e + + if "store_type" not in options: + options["store_type"] = EggRollStoreType.ROLLPAIR_LMDB + + partitions = options.get("partitions", self.partitions) + self._rp.save_as( + name=name, + namespace=namespace, + partition=partitions, + options=options, + ) + schema.update(self.schema) + return @computing_profile def collect(self, **kwargs) -> list: diff --git a/python/fate/arch/computing/spark/_csession.py b/python/fate/arch/computing/spark/_csession.py index 70db4c01af..581aeacd1b 100644 --- a/python/fate/arch/computing/spark/_csession.py +++ b/python/fate/arch/computing/spark/_csession.py @@ -14,11 +14,16 @@ # limitations under the License. # import logging +import typing from typing import Iterable -from .._computing import Address, CSessionABC +from fate.arch.abc import CSessionABC + +from ...unify import URI from ._table import from_hdfs, from_hive, from_localfs, from_rdd +if typing.TYPE_CHECKING: + from ._table import Table LOGGER = logging.getLogger(__name__) @@ -30,43 +35,50 @@ class CSession(CSessionABC): def __init__(self, session_id): self._session_id = session_id - def load(self, address: Address, partitions, schema, **kwargs): - from .._address import HDFSAddress + def load(self, uri: URI, schema, options: dict = None) -> "Table": + if not options: + options = {} + partitions = options.get("partitions", None) - if isinstance(address, HDFSAddress): + if uri.scheme == "hdfs": + in_serialized = (options.get("in_serialized", True),) + id_delimiter = (options.get("id_delimiter", ","),) table = from_hdfs( - paths=f"{address.name_node}/{address.path}", + paths=uri.original_uri, partitions=partitions, - in_serialized=kwargs.get("in_serialized", True), - id_delimiter=kwargs.get("id_delimiter", ","), + in_serialized=in_serialized, + id_delimiter=id_delimiter, ) table.schema = schema return table - from .._address import HiveAddress, LinkisHiveAddress - - if isinstance(address, (HiveAddress, LinkisHiveAddress)): + if uri.scheme == "hive": + try: + (path,) = uri.path_splits() + database_name, table_name = path.split(".") + except Exception as e: + raise ValueError(f"invalid hive uri {uri}, demo uri: hive://localhost:10000/database.table") from e table = from_hive( - tb_name=address.name, - db_name=address.database, + tb_name=table_name, + db_name=database_name, partitions=partitions, ) table.schema = schema return table - from .._address import LocalFSAddress - - if isinstance(address, LocalFSAddress): + if uri.scheme == "file": + in_serialized = (options.get("in_serialized", True),) + id_delimiter = (options.get("id_delimiter", ","),) table = from_localfs( - paths=address.path, + paths=uri.path, partitions=partitions, - in_serialized=kwargs.get("in_serialized", True), - id_delimiter=kwargs.get("id_delimiter", ","), + in_serialized=in_serialized, + id_delimiter=id_delimiter, ) table.schema = schema return table - raise NotImplementedError(f"address type {type(address)} not supported with spark backend") + raise NotImplementedError(f"uri type {uri} not supported with spark backend") def parallelize(self, data: Iterable, partition: int, include_key: bool, **kwargs): # noinspection PyPackageRequirements diff --git a/python/fate/arch/computing/spark/_table.py b/python/fate/arch/computing/spark/_table.py index d8860d03c9..8c831c6811 100644 --- a/python/fate/arch/computing/spark/_table.py +++ b/python/fate/arch/computing/spark/_table.py @@ -21,17 +21,16 @@ from itertools import chain import pyspark +from fate.arch.abc import CTableABC from pyspark.rddsampler import RDDSamplerBase from scipy.stats import hypergeom -from .._computing import CTableABC from .._profile import computing_profile from .._type import ComputingEngine from ._materialize import materialize, unmaterialize LOGGER = logging.getLogger(__name__) - _HDFS_DELIMITER = "\t" @@ -80,33 +79,35 @@ def copy(self): return Table(_map_value(self._rdd, lambda x: x)) @computing_profile - def save(self, address, partitions, schema, **kwargs): - from .._address import HDFSAddress - - if isinstance(address, HDFSAddress): - self._rdd.map(lambda x: hdfs_serialize(x[0], x[1])).repartition(partitions).saveAsTextFile( - f"{address.name_node}/{address.path}" - ) + def save(self, uri, schema: dict, options: dict = None): + if options is None: + options = {} + partitions = options.get("partitions") + if uri.scheme == "hdfs": + table = self._rdd.map(lambda x: hdfs_serialize(x[0], x[1])) + if partitions: + table = table.repartition(partitions) + table.saveAsTextFile(uri.original_uri) schema.update(self.schema) return - from .._address import HiveAddress, LinkisHiveAddress - - if isinstance(address, (HiveAddress, LinkisHiveAddress)): - LOGGER.debug(f"partitions: {partitions}") - _repartition = self._rdd.map(lambda x: hive_to_row(x[0], x[1])).repartition(partitions) - _repartition.toDF().write.saveAsTable(f"{address.database}.{address.name}") + if uri.scheme == "hive": + table = self._rdd.map(lambda x: hive_to_row(x[0], x[1])) + if partitions: + table = table.repartition(partitions) + table.toDF().write.saveAsTable(uri.original_uri) schema.update(self.schema) return - from .._address import LocalFSAddress - - if isinstance(address, LocalFSAddress): - self._rdd.map(lambda x: hdfs_serialize(x[0], x[1])).repartition(partitions).saveAsTextFile(address.path) + if uri.scheme == "file": + table = self._rdd.map(lambda x: hdfs_serialize(x[0], x[1])) + if partitions: + table = table.repartition(partitions) + table.saveAsTextFile(uri.path) schema.update(self.schema) return - raise NotImplementedError(f"address type {type(address)} not supported with spark backend") + raise NotImplementedError(f"uri type {uri} not supported with spark backend") @property def partitions(self): @@ -216,7 +217,10 @@ def from_hdfs(paths: str, partitions, in_serialized=True, id_delimiter=None): sc = SparkContext.getOrCreate() fun = hdfs_deserialize if in_serialized else lambda x: (x.partition(id_delimiter)[0], x.partition(id_delimiter)[2]) - rdd = materialize(sc.textFile(paths, partitions).map(fun).repartition(partitions)) + rdd = sc.textFile(paths, partitions).map(fun) + if partitions is not None: + rdd = rdd.repartition(partitions) + rdd = materialize(rdd) return Table(rdd=rdd) diff --git a/python/fate/arch/computing/standalone/_csession.py b/python/fate/arch/computing/standalone/_csession.py index a6927abf6c..51418013a6 100644 --- a/python/fate/arch/computing/standalone/_csession.py +++ b/python/fate/arch/computing/standalone/_csession.py @@ -17,22 +17,25 @@ from collections.abc import Iterable from typing import Optional +from fate.arch.abc import CSessionABC + from ..._standalone import Session -from ...unify import generate_computing_uuid, uuid -from .._computing import Address, CSessionABC +from ...unify import URI, generate_computing_uuid, uuid from ._table import Table LOGGER = logging.getLogger(__name__) class CSession(CSessionABC): - def __init__(self, session_id: Optional[str] = None, options: Optional[dict] = None): + def __init__( + self, session_id: Optional[str] = None, logger_config: Optional[dict] = None, options: Optional[dict] = None + ): if session_id is None: session_id = generate_computing_uuid() if options is None: options = {} max_workers = options.get("task_cores", None) - self._session = Session(session_id, max_workers=max_workers) + self._session = Session(session_id, max_workers=max_workers, logger_config=logger_config) def get_standalone_session(self): return self._session @@ -41,25 +44,25 @@ def get_standalone_session(self): def session_id(self): return self._session.session_id - def load(self, address: Address, partitions: int, schema: dict, **kwargs): - from .._address import StandaloneAddress - from ._type import StandaloneStoreType - - if isinstance(address, StandaloneAddress): - raw_table = self._session.load(address.name, address.namespace) - if address.storage_type != StandaloneStoreType.ROLLPAIR_IN_MEMORY: - partitions = raw_table.partitions if partitions is None else partitions - raw_table = raw_table.save_as( - name=f"{address.name}_{uuid()}", - namespace=address.namespace, - partition=partitions, - need_cleanup=True, - ) - table = Table(raw_table) - table.schema = schema - return table - - raise NotImplementedError(f"address type {type(address)} not supported with standalone backend") + def load(self, uri: URI, schema: dict, options: dict = None): + if uri.scheme != "standalone": + raise ValueError(f"uri scheme `{uri.scheme}` not supported with standalone backend") + try: + *database, namespace, name = uri.path_splits() + except Exception as e: + raise ValueError(f"uri `{uri}` not valid, demo format: standalone://database_path/namespace/name") from e + + raw_table = self._session.load(name=name, namespace=namespace) + partitions = raw_table.partitions + raw_table = raw_table.save_as( + name=f"{name}_{uuid()}", + namespace=namespace, + partition=partitions, + need_cleanup=True, + ) + table = Table(raw_table) + table.schema = schema + return table def parallelize(self, data: Iterable, partition: int, include_key: bool, **kwargs): table = self._session.parallelize(data=data, partition=partition, include_key=include_key, **kwargs) diff --git a/python/fate/arch/computing/standalone/_table.py b/python/fate/arch/computing/standalone/_table.py index 7302b5b18f..12d2a28425 100644 --- a/python/fate/arch/computing/standalone/_table.py +++ b/python/fate/arch/computing/standalone/_table.py @@ -17,7 +17,9 @@ import logging import typing -from .._computing import CTableABC +from fate.arch.abc import CTableABC + +from ...unify import URI from .._profile import computing_profile from .._type import ComputingEngine @@ -46,20 +48,23 @@ def copy(self): return Table(self._table.mapValues(lambda x: x)) @computing_profile - def save(self, address, partitions, schema, **kwargs): - from .._address import StandaloneAddress - - if isinstance(address, StandaloneAddress): - self._table.save_as( - name=address.name, - namespace=address.namespace, - partition=partitions, - need_cleanup=False, - ) - schema.update(self.schema) - return - - raise NotImplementedError(f"address type {type(address)} not supported with standalone backend") + def save(self, uri: URI, schema, options: dict = None): + if options is None: + options = {} + + if uri.scheme != "standalone": + raise ValueError(f"uri scheme `{uri.scheme}` not supported with standalone backend") + try: + *database, namespace, name = uri.path_splits() + except Exception as e: + raise ValueError(f"uri `{uri}` not supported with standalone backend") from e + self._table.save_as( + name=name, + namespace=namespace, + partition=options.get("partitions", self.partitions), + need_cleanup=False, + ) + schema.update(self.schema) @computing_profile def count(self) -> int: diff --git a/python/fate/arch/context/__init__.py b/python/fate/arch/context/__init__.py index 10f371d63e..9b029362f8 100644 --- a/python/fate/arch/context/__init__.py +++ b/python/fate/arch/context/__init__.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ._context import Context, Namespace +from ._cipher import CipherKit +from ._context import Context -__all__ = ["Context", "Namespace"] +__all__ = ["Context", "CipherKit"] diff --git a/python/fate/arch/context/_cipher.py b/python/fate/arch/context/_cipher.py index 19db601827..01c9267891 100644 --- a/python/fate/arch/context/_cipher.py +++ b/python/fate/arch/context/_cipher.py @@ -12,16 +12,170 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from fate.interface import CipherKit as CipherKitInterface -from ..tensor._phe import PHECipher +import logging +import typing + from ..unify import device +logger = logging.getLogger(__name__) + + +class CipherKit: + def __init__(self, device: device, cipher_mapping: typing.Optional[dict] = None) -> None: + self._device = device + if cipher_mapping is None: + self._cipher_mapping = {} + else: + self._cipher_mapping = cipher_mapping + self._allow_custom_random_seed = False + self._custom_random_seed = 42 -class CipherKit(CipherKitInterface): - def __init__(self, device: device) -> None: - self.device = device + def set_phe(self, device: device, options: typing.Optional[dict]): + if "phe" not in self._cipher_mapping: + self._cipher_mapping["phe"] = {} + self._cipher_mapping["phe"][device] = options + + def _set_default_phe(self): + if "phe" not in self._cipher_mapping: + self._cipher_mapping["phe"] = {} + if self._device not in self._cipher_mapping["phe"]: + if self._device == device.CPU: + self._cipher_mapping["phe"][device.CPU] = {"kind": "paillier", "key_length": 1024} + else: + logger.warning(f"no impl exists for device {self._device}, fallback to CPU") + self._cipher_mapping["phe"][device.CPU] = self._cipher_mapping["phe"].get( + device.CPU, {"kind": "paillier", "key_length": 1024} + ) @property def phe(self): - return PHECipher(self.device) + self._set_default_phe() + if self._device not in self._cipher_mapping["phe"]: + raise ValueError(f"no impl exists for device {self._device}") + return PHECipherBuilder(**self._cipher_mapping["phe"][self._device]) + + @property + def allow_custom_random_seed(self): + return self._allow_custom_random_seed + + def set_allow_custom_random_seed(self, allow_custom_random_seed): + self._allow_custom_random_seed = allow_custom_random_seed + + def set_custom_random_seed(self, custom_random_seed): + self._custom_random_seed = custom_random_seed + + def get_custom_random_seed(self): + return self._custom_random_seed + + +class PHECipherBuilder: + def __init__(self, kind, key_length) -> None: + self.kind = kind + self.key_length = key_length + + def setup(self, options: typing.Optional[dict] = None): + if options is None: + kind = self.kind + key_size = self.key_length + else: + kind = options.get("kind", self.kind) + key_size = options.get("key_length", 1024) + + if kind == "paillier": + from fate.arch.protocol.phe.paillier import evaluator, keygen + from fate.arch.tensor.phe import PHETensorCipher + + sk, pk, coder = keygen(key_size) + tensor_cipher = PHETensorCipher.from_raw_cipher(pk, coder, sk, evaluator) + return PHECipher(kind, key_size, pk, sk, evaluator, coder, tensor_cipher, True, True, True) + + if kind == "ou": + from fate.arch.protocol.phe.ou import evaluator, keygen + from fate.arch.tensor.phe import PHETensorCipher + + sk, pk, coder = keygen(key_size) + tensor_cipher = PHETensorCipher.from_raw_cipher(pk, coder, sk, evaluator) + return PHECipher(kind, key_size, pk, sk, evaluator, coder, tensor_cipher, False, False, True) + + elif kind == "mock": + from fate.arch.protocol.phe.mock import evaluator, keygen + from fate.arch.tensor.phe import PHETensorCipher + + sk, pk, coder = keygen(key_size) + tensor_cipher = PHETensorCipher.from_raw_cipher(pk, coder, sk, evaluator) + return PHECipher(kind, key_size, pk, sk, evaluator, coder, tensor_cipher, True, False, False) + + else: + raise ValueError(f"Unknown PHE keygen kind: {self.kind}") + + +class PHECipher: + def __init__( + self, + kind, + key_size, + pk, + sk, + evaluator, + coder, + tensor_cipher, + can_support_negative_number, + can_support_squeeze, + can_support_pack, + ) -> None: + self._kind = kind + self._key_size = key_size + self._pk = pk + self._sk = sk + self._coder = coder + self._evaluator = evaluator + self._tensor_cipher = tensor_cipher + self._can_support_negative_number = can_support_negative_number + self._can_support_squeeze = can_support_squeeze + self._can_support_pack = can_support_pack + + @property + def kind(self): + return self._kind + + @property + def can_support_negative_number(self): + return self._can_support_negative_number + + @property + def can_support_squeeze(self): + return self._can_support_squeeze + + @property + def can_support_pack(self): + return self._can_support_pack + + @property + def key_size(self): + return self._key_size + + def get_tensor_encryptor(self): + return self._tensor_cipher.pk + + def get_tensor_coder(self): + return self._tensor_cipher.coder + + def get_tensor_decryptor(self): + return self._tensor_cipher.sk + + @property + def pk(self): + return self._pk + + @property + def coder(self): + return self._coder + + @property + def sk(self): + return self._sk + + @property + def evaluator(self): + return self._evaluator diff --git a/python/fate/arch/context/_context.py b/python/fate/arch/context/_context.py index 871799c1e3..4309aaf655 100644 --- a/python/fate/arch/context/_context.py +++ b/python/fate/arch/context/_context.py @@ -13,28 +13,23 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from contextlib import contextmanager -from copy import copy -from typing import Iterator, List, Optional +from typing import Iterable, Literal, Optional, Tuple, TypeVar, overload -from fate.interface import T_ROLE, ComputingEngine -from fate.interface import Context as ContextInterface -from fate.interface import FederationEngine, MetricsHandler, PartyMeta +from fate.arch.abc import CSessionABC, FederationEngine from ..unify import device from ._cipher import CipherKit -from ._federation import GC, Parties, Party -from ._namespace import Namespace -from .io.kit import IOKit -from .metric import MetricsWrap +from ._federation import Parties, Party +from ._metrics import InMemoryMetricsHandler, MetricsWrap +from ._namespace import NS, default_ns logger = logging.getLogger(__name__) +T = TypeVar("T") -class Context(ContextInterface): - """ - implement fate.interface.ContextInterface +class Context: + """ Note: most parameters has default dummy value, which is convenient when used in script. please pass in custom implements as you wish @@ -42,112 +37,181 @@ class Context(ContextInterface): def __init__( self, - context_name: Optional[str] = None, device: device = device.CPU, - computing: Optional[ComputingEngine] = None, - federation: Optional[FederationEngine] = None, - metrics_handler: Optional[MetricsHandler] = None, - namespace: Optional[Namespace] = None, + computing: Optional["CSessionABC"] = None, + federation: Optional["FederationEngine"] = None, + metrics_handler: Optional = None, + namespace: Optional[NS] = None, + cipher: Optional[CipherKit] = None, ) -> None: - self.context_name = context_name - self.metrics = MetricsWrap(metrics_handler) - - if namespace is None: - namespace = Namespace() - self.namespace = namespace - self.super_namespace = Namespace() - - self.cipher: CipherKit = CipherKit(device) - self._io_kit: IOKit = IOKit() - + self._device = device self._computing = computing self._federation = federation - self._role_to_parties = None + self._metrics_handler = metrics_handler + self._namespace = namespace + self._cipher = cipher - self._gc = GC() + if self._namespace is None: + self._namespace = default_ns + if self._cipher is None: + self._cipher: CipherKit = CipherKit(device) + + self._role_to_parties = None self._is_destroyed = False - def with_namespace(self, namespace: Namespace): - context = copy(self) - context.namespace = namespace - return context + @property + def device(self): + return self._device - def into_group_namespace(self, group_name: str, group_id: str): - context = copy(self) - context.metrics = context.metrics.into_group(group_name, group_id) - context.namespace = self.namespace.sub_namespace(f"{group_name}_{group_id}") - return context + @property + def namespace(self): + return self._namespace - def range(self, end): - for i in range(end): - yield i, self.with_namespace(self.namespace.sub_namespace(f"{i}")) + @property + def cipher(self): + return self._cipher - def iter(self, iterable): - for i, it in enumerate(iterable): - yield self.with_namespace(self.namespace.sub_namespace(f"{i}")), it + def set_cipher(self, cipher_mapping): + self._cipher = CipherKit(self._device, {"phe": {self._device: cipher_mapping["phe"]}}) + + def set_metric_handler(self, metrics_handler): + self._metrics_handler = metrics_handler + + @property + def metrics(self): + if self._metrics_handler is None: + self._metrics_handler = InMemoryMetricsHandler() + return MetricsWrap(self._metrics_handler, self._namespace) + + def with_namespace(self, namespace: NS): + return Context( + device=self._device, + computing=self._computing, + federation=self._federation, + metrics_handler=self._metrics_handler, + namespace=namespace, + cipher=self._cipher, + ) @property def computing(self): return self._get_computing() @property - def federation(self) -> FederationEngine: + def federation(self) -> "FederationEngine": return self._get_federation() - @contextmanager - def sub_ctx(self, namespace: str) -> Iterator["Context"]: - try: - yield self.with_namespace(self.namespace.sub_namespace(namespace)) - finally: - ... + def sub_ctx(self, name: str) -> "Context": + return self.with_namespace(self._namespace.sub_ns(name=name)) + + def indexed_ctx(self, index: int) -> "Context": + return self.with_namespace(self._namespace.indexed_ns(index)) + + @property + def on_iterations(self) -> "Context": + return self.sub_ctx("iterations") + + @property + def on_batches(self) -> "Context": + return self.sub_ctx("batches") - def set_federation(self, federation: FederationEngine): + @property + def on_cross_validations(self) -> "Context": + return self.sub_ctx("cross_validations") + + @overload + def ctxs_range(self, end: int) -> Iterable[Tuple[int, "Context"]]: + ... + + @overload + def ctxs_range(self, start: int, end: int) -> Iterable[Tuple[int, "Context"]]: + ... + + def ctxs_range(self, *args, **kwargs) -> Iterable[Tuple[int, "Context"]]: + + """ + create contexes with namespaces indexed from 0 to end(excluded) + """ + + if "start" in kwargs: + start = kwargs["start"] + if "end" not in kwargs: + raise ValueError("End value must be provided") + end = kwargs["end"] + if len(args) > 0: + raise ValueError("Too many arguments") + else: + if "end" in kwargs: + end = kwargs["end"] + if len(args) > 1: + raise ValueError("Too many arguments") + elif len(args) == 0: + raise ValueError("Start value must be provided") + else: + start = args[0] + else: + if len(args) == 1: + start, end = 0, args[0] + elif len(args) == 2: + start, end = args + else: + raise ValueError("Too few arguments") + + for i in range(start, end): + yield i, self.with_namespace(self._namespace.indexed_ns(index=i)) + + def ctxs_zip(self, iterable: Iterable[T]) -> Iterable[Tuple["Context", T]]: + """ + zip contexts with iterable with namespaces indexed from 0 + """ + for i, it in enumerate(iterable): + yield self.with_namespace(self._namespace.indexed_ns(index=i)), it + + def set_federation(self, federation: "FederationEngine"): self._federation = federation @property def guest(self) -> Party: - return Party( - self._get_federation(), - self._get_parties("guest")[0], - self.namespace, - ) + return self._get_parties("guest")[0] @property def hosts(self) -> Parties: - return Parties( - self._get_federation(), - self._get_federation().local_party, - self._get_parties("host"), - self.namespace, - ) + return self._get_parties("host") @property def arbiter(self) -> Party: - return Party( - self._get_federation(), - self._get_parties("arbiter")[0], - self.namespace, - ) + return self._get_parties("arbiter")[0] @property def local(self): - return self._get_federation().local_party + role, party_id = self._get_federation().local_party + for party in self._get_parties(role): + if party.party[1] == party_id: + return party + raise RuntimeError("local party not found") + + @property + def is_on_guest(self): + return self._federation.local_party[0] == "guest" + + @property + def is_on_host(self): + return self._federation.local_party[0] == "host" + + @property + def is_on_arbiter(self): + return self._federation.local_party[0] == "arbiter" @property def parties(self) -> Parties: - return Parties( - self._get_federation(), - self._get_federation().local_party, - self._get_parties(), - self.namespace, - ) + return self._get_parties() - def _get_parties(self, role: Optional[T_ROLE] = None) -> List[PartyMeta]: + def _get_parties(self, role: Optional[Literal["guest", "host", "arbiter"]] = None) -> Parties: # update role to parties mapping if self._role_to_parties is None: self._role_to_parties = {} - for party in self._get_federation().parties: - self._role_to_parties.setdefault(party[0], []).append(party) + for i, party in enumerate(self._get_federation().parties): + self._role_to_parties.setdefault(party[0], []).append((i, party)) parties = [] if role is None: @@ -155,10 +219,16 @@ def _get_parties(self, role: Optional[T_ROLE] = None) -> List[PartyMeta]: parties.extend(role_parties) else: if role not in self._role_to_parties: - raise RuntimeError(f"no {role} party has configurated") + raise RuntimeError(f"no {role} party has configured") else: parties.extend(self._role_to_parties[role]) - return parties + parties.sort(key=lambda x: x[0]) + return Parties( + self, + self._get_federation(), + parties, + self._namespace, + ) def _get_federation(self): if self._federation is None: @@ -170,12 +240,6 @@ def _get_computing(self): raise RuntimeError(f"computing not set") return self._computing - def reader(self, uri, **kwargs): - return self._io_kit.reader(self, uri, **kwargs) - - def writer(self, uri, **kwargs): - return self._io_kit.writer(self, uri, **kwargs) - def destroy(self): if not self._is_destroyed: try: diff --git a/python/fate/arch/context/_federation.py b/python/fate/arch/context/_federation.py index 95f5a71bfc..6708468c9e 100644 --- a/python/fate/arch/context/_federation.py +++ b/python/fate/arch/context/_federation.py @@ -14,19 +14,21 @@ # limitations under the License. import io import pickle -from typing import Any, List, Optional, TypeVar, Union +import struct +import typing +from typing import Any, List, Tuple, TypeVar, Union -from fate.interface import FederationEngine -from fate.interface import Parties as PartiesInterface -from fate.interface import Party as PartyInterface -from fate.interface import PartyMeta +from fate.arch.abc import FederationEngine, PartyMeta from ..computing import is_table from ..federation._gc import IterationGC -from ._namespace import Namespace +from ._namespace import NS T = TypeVar("T") +if typing.TYPE_CHECKING: + from fate.arch.context import Context + class GC: def __init__(self) -> None: @@ -46,30 +48,47 @@ def get_or_set_pull_gc(self, key): class _KeyedParty: def __init__(self, party: Union["Party", "Parties"], key) -> None: - self.party = party - self.key = key + self._party = party + self._key = key def put(self, value): - return self.party.put(self.key, value) + return self._party.put(self._key, value) def get(self): - return self.party.get(self.key) + return self._party.get(self._key) -class Party(PartyInterface): - def __init__(self, federation, party: PartyMeta, namespace, key=None) -> None: +class Party: + def __init__(self, ctx: "Context", federation, party: PartyMeta, rank: int, namespace: NS, key=None) -> None: + self._ctx = ctx + self._party = party self.federation = federation - self.party = party + self.rank = rank self.namespace = namespace - self.key = key def __call__(self, key: str) -> "_KeyedParty": return _KeyedParty(self, key) + @property + def party(self) -> PartyMeta: + return self._party + + @property + def role(self) -> str: + return self.party[0] + + @property + def party_id(self) -> str: + return self.party[1] + + @property + def name(self) -> str: + return f"{self.party[0]}_{self.party[1]}" + def put(self, *args, **kwargs): if args: assert len(args) == 2 and isinstance(args[0], str), "invalid position parameter" - assert not kwargs, "keywords paramters not allowed when position parameter provided" + assert not kwargs, "keywords parameters not allowed when position parameter provided" kvs = [args] else: kvs = kwargs.items() @@ -78,83 +97,117 @@ def put(self, *args, **kwargs): return _push(self.federation, k, self.namespace, [self.party], v) def get(self, name: str): - return _pull(self.federation, name, self.namespace, [self.party])[0] + return _pull(self._ctx, self.federation, name, self.namespace, [self.party])[0] + def get_int(self, name: str): + ... -class Parties(PartiesInterface): + +class Parties: def __init__( self, + ctx: "Context", federation: FederationEngine, - party: PartyMeta, - parties: List[PartyMeta], - namespace: Namespace, + parties: List[Tuple[int, PartyMeta]], + namespace: NS, ) -> None: + self._ctx = ctx self.federation = federation - self.party = party self.parties = parties self.namespace = namespace - def __getitem__(self, key: int) -> Party: - return Party(self.federation, self.parties[key], self.namespace) - - def __call__(self, key: str) -> "_KeyedParty": - return _KeyedParty(self, key) + @property + def ranks(self): + return [p[0] for p in self.parties] - def get_neighbor(self, shift: int, module: bool = False) -> Party: - start_index = self.get_local_index() - if start_index is None: - raise RuntimeError(f"local party `{self.party}` not in `{self.parties}`") - target_index = start_index + shift - if module: - target_index = target_index % module + def __getitem__(self, key: int) -> Party: + rank, party = self.parties[key] + return Party(self._ctx, self.federation, party, rank, self.namespace) - if 0 <= target_index < len(self.parties): - return self(target_index) - else: - raise IndexError(f"target index `{target_index}` out of bound") + def __iter__(self): + return iter([Party(self._ctx, self.federation, party, rank, self.namespace) for rank, party in self.parties]) - def get_neighbors(self) -> "Parties": - parties = [party for party in self.parties if party != self.party] - return Parties(self.federation, self.party, parties, self.namespace) + def __len__(self) -> int: + return len(self.parties) - def get_local_index(self) -> Optional[int]: - if self.party not in self.parties: - return None - else: - return self.parties.index(self.party) + def __call__(self, key: str) -> "_KeyedParty": + return _KeyedParty(self, key) def put(self, *args, **kwargs): if args: assert len(args) == 2 and isinstance(args[0], str), "invalid position parameter" - assert not kwargs, "keywords paramters not allowed when position parameter provided" + assert not kwargs, "keywords parameters not allowed when position parameter provided" kvs = [args] else: kvs = kwargs.items() for k, v in kvs: - return _push(self.federation, k, self.namespace, self.parties, v) + return _push(self.federation, k, self.namespace, [p[1] for p in self.parties], v) def get(self, name: str): - return _pull(self.federation, name, self.namespace, self.parties) + return _pull(self._ctx, self.federation, name, self.namespace, [p[1] for p in self.parties]) def _push( federation: FederationEngine, name: str, - namespace: Namespace, + namespace: NS, parties: List[PartyMeta], value, ): - tag = namespace.fedeation_tag() + tag = namespace.federation_tag _TableRemotePersistentPickler.push(value, federation, name, tag, parties) +class Serde: + @classmethod + def encode_int(cls, value: int) -> bytes: + return struct.pack("!q", value) # '!q' is for long long (8 bytes) + + @classmethod + def decode_int(cls, value: bytes) -> int: + return struct.unpack("!q", value)[0] + + @classmethod + def encode_str(cls, value: str) -> bytes: + utf8_str = value.encode("utf-8") + return struct.pack("!I", len(utf8_str)) + utf8_str # prepend length of string + + @classmethod + def decode_str(cls, value: bytes) -> str: + length = struct.unpack("!I", value[:4])[0] # get length of string + return value[4 : 4 + length].decode("utf-8") # decode string + + @classmethod + def encode_bytes(cls, value: bytes) -> bytes: + return struct.pack("!I", len(value)) + value # prepend length of bytes + + @classmethod + def decode_bytes(cls, value: bytes) -> bytes: + length = struct.unpack("!I", value[:4])[0] # get length of bytes + return value[4 : 4 + length] # extract bytes + + @classmethod + def encode_float(cls, value: float) -> bytes: + return struct.pack("!d", value) + + @classmethod + def decode_float(cls, value: bytes) -> float: + return struct.unpack("!d", value)[0] + + +def _push_int(federation: FederationEngine, name: str, namespace: NS, parties: List[PartyMeta], value: int): + tag = namespace.federation_tag + federation.push(v=Serde.encode_int(value), name=name, tag=tag, parties=parties) + + def _pull( + ctx: "Context", federation: FederationEngine, name: str, - namespace: Namespace, + namespace: NS, parties: List[PartyMeta], ): - tag = namespace.fedeation_tag() + tag = namespace.federation_tag raw_values = federation.pull( name=name, tag=tag, @@ -162,11 +215,16 @@ def _pull( ) values = [] for party, buffers in zip(parties, raw_values): - values.append(_TableRmotePersistentUnpickler.pull(buffers, federation, name, tag, party)) + values.append(_TableRemotePersistentUnpickler.pull(buffers, ctx, federation, name, tag, party)) return values -class _TablePersistantId: +class _TablePersistentId: + def __init__(self, key) -> None: + self.key = key + + +class _ContextPersistentId: def __init__(self, key) -> None: self.key = key @@ -194,11 +252,16 @@ def _get_next_table_key(self): return f"{self._name}__table_persistent_{self._table_index}__" def persistent_id(self, obj: Any) -> Any: + from fate.arch.context import Context + if is_table(obj): key = self._get_next_table_key() self._federation.push(v=obj, name=key, tag=self._tag, parties=self._parties) self._table_index += 1 - return _TablePersistantId(key) + return _TablePersistentId(key) + if isinstance(obj, Context): + key = f"{self._name}__context__" + return _ContextPersistentId(key) @classmethod def push( @@ -215,15 +278,17 @@ def push( federation.push(v=f.getvalue(), name=name, tag=tag, parties=parties) -class _TableRmotePersistentUnpickler(pickle.Unpickler): +class _TableRemotePersistentUnpickler(pickle.Unpickler): def __init__( self, + ctx: "Context", federation: FederationEngine, name: str, tag: str, party: PartyMeta, f, ): + self._ctx = ctx self._federation = federation self._name = name self._tag = tag @@ -231,19 +296,22 @@ def __init__( super().__init__(f) def persistent_load(self, pid: Any) -> Any: - if isinstance(pid, _TablePersistantId): + if isinstance(pid, _TablePersistentId): table = self._federation.pull(pid.key, self._tag, [self._party])[0] return table + if isinstance(pid, _ContextPersistentId): + return self._ctx @classmethod def pull( cls, buffers, + ctx: "Context", federation: FederationEngine, name: str, tag: str, party: PartyMeta, ): with io.BytesIO(buffers) as f: - unpickler = _TableRmotePersistentUnpickler(federation, name, tag, party, f) + unpickler = _TableRemotePersistentUnpickler(ctx, federation, name, tag, party, f) return unpickler.load() diff --git a/python/fate/arch/context/_metrics.py b/python/fate/arch/context/_metrics.py index ae946a49c4..6428db03d7 100644 --- a/python/fate/arch/context/_metrics.py +++ b/python/fate/arch/context/_metrics.py @@ -12,3 +12,191 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import time +import typing +from typing import Dict, List, Optional, Tuple, Union + +import pydantic + +from ._namespace import NS, IndexedNS + + +class BaseMetricsHandler: + def log_metrics(self, metrics: Union["StepMetrics", "OneTimeMetrics"]): + if isinstance(metrics, StepMetrics): + self._log_step_metrics(metrics) + elif isinstance(metrics, OneTimeMetrics): + self._log_one_time_metrics(metrics) + else: + raise ValueError(f"metrics `{metrics}` not allowed") + + def _log_step_metrics(self, metrics: "StepMetrics"): + raise NotImplementedError + + def _log_one_time_metrics(self, metrics: "OneTimeMetrics"): + raise NotImplementedError + + +class InMemoryMetricsHandler(BaseMetricsHandler): + def __init__(self): + self._step_metrics: typing.Dict[typing.Any, "StepMetrics"] = {} + self._one_time_metrics: typing.Dict[typing.Any, "OneTimeMetrics"] = {} + + def _log_step_metrics(self, metrics: "StepMetrics"): + if (metrics.name, tuple(metrics.groups)) in self._step_metrics: + self._step_metrics[(metrics.name, tuple(metrics.groups))] = self._step_metrics[ + (metrics.name, tuple(metrics.groups)) + ].merge(metrics) + else: + self._step_metrics[(metrics.name, tuple(metrics.groups))] = metrics + + def _log_one_time_metrics(self, metrics: "OneTimeMetrics"): + if (metrics.name, tuple(metrics.groups)) in self._one_time_metrics: + raise ValueError(f"duplicated metrics: `{metrics.name}` already exists") + else: + self._one_time_metrics[(metrics.name, tuple(metrics.groups))] = metrics + + def get_metrics(self): + metrics = [] + for k, v in self._step_metrics.items(): + metrics.append(v.dict()) + for k, v in self._one_time_metrics.items(): + metrics.append(v.dict()) + return metrics + + +class MetricsWrap: + def __init__(self, handler, namespace: NS) -> None: + self.handler = handler + self.namespace = namespace + + def log_metrics(self, data, name: str, type: Optional[str] = None): + return self.handler.log_metrics( + OneTimeMetrics( + name=name, + type=type, + groups=[*self.namespace.metric_groups, self.namespace.get_group()], + data=data, + ) + ) + + def log_step(self, value, name: str, type: Optional[str] = None): + if isinstance(self.namespace, IndexedNS): + step = self.namespace.index + else: + raise RuntimeError( + "log step metric only allowed in indexed namespace since the step is inferred from namespace" + ) + timestamp = time.time() + return self.handler.log_metrics( + StepMetrics( + name=name, + type=type, + groups=self.namespace.metric_groups, + step_axis=self.namespace.name, + data=[dict(metric=value, step=step, timestamp=timestamp)], + ) + ) + + def log_scalar(self, name: str, scalar: float): + return self.log_step(value=scalar, name=name, type="scalar") + + def log_loss(self, name: str, loss: float): + return self.log_step(value=loss, name=name, type="loss") + + def log_accuracy(self, name: str, accuracy: float): + return self.log_step(value=accuracy, name=name, type="accuracy") + + def log_auc(self, name: str, auc: float): + return self.log_step(value=auc, name=name, type="auc") + + def log_roc(self, name: str, data: List[Tuple[float, float]]): + return self.log_metrics(data=data, name=name, type="roc") + + +class OneTimeMetrics: + def __init__( + self, name: str, type: Optional[str], groups: List[Tuple[str, Optional[int]]], data: Union[List, Dict] + ) -> None: + self.name = name + self.groups = groups + self.type = type + self.data = data + + def dict(self): + return self.to_record().dict() + + def to_record(self): + return MetricRecord( + name=self.name, + groups=[MetricRecord.Group(name=k, index=v) for k, v in self.groups], + type=self.type, + step_axis=None, + data=self.data, + ) + + def __str__(self): + return str(self.dict()) + + def __repr__(self): + return self.__str__() + + +class StepMetrics: + def __init__( + self, name: str, type: Optional[str], groups: List[Tuple[str, Optional[int]]], step_axis: str, data: List + ) -> None: + self.name = name + self.type = type + self.groups = groups + self.step_axis = step_axis + self.data = data + + def merge(self, metrics: "StepMetrics"): + if ( + isinstance(metrics, StepMetrics) + and metrics.type == self.type + and metrics.name == self.name + and metrics.step_axis == self.step_axis + and metrics.groups == self.groups + ): + return StepMetrics( + name=self.name, + type=self.type, + groups=self.groups, + step_axis=self.step_axis, + data=[*self.data, *metrics.data], + ) + + raise RuntimeError(f"metrics merge not allowed: `{metrics}` with `{self}`") + + def dict(self) -> dict: + return self.to_record().dict() + + def to_record(self) -> "MetricRecord": + return MetricRecord( + name=self.name, + type=self.type, + groups=[MetricRecord.Group(name=g[0], index=g[1]) for g in self.groups], + step_axis=self.step_axis, + data=self.data, + ) + + def __str__(self): + return f"StepMetrics(name={self.name}, type={self.type}, groups={self.groups}, step_axis={self.step_axis}, data={self.data})" + + def __repr__(self): + return self.__str__() + + +class MetricRecord(pydantic.BaseModel): + class Group(pydantic.BaseModel): + name: str + index: Optional[int] + + name: str + type: Optional[str] + groups: List[Group] + step_axis: Optional[str] + data: Union[List, Dict] diff --git a/python/fate/arch/context/_mlmd.py b/python/fate/arch/context/_mlmd.py deleted file mode 100644 index aa05eed8a4..0000000000 --- a/python/fate/arch/context/_mlmd.py +++ /dev/null @@ -1,239 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json - -from ml_metadata import metadata_store -from ml_metadata.proto import metadata_store_pb2 - - -class MachineLearningMetadata: - def __init__(self, backend="sqlite", metadata={}) -> None: - self.store = self.create_store(backend, metadata) - self._task_context_type_id = None # context type - self._task_type_id = None # execution type - self._data_type_id = None # data artifact - self._model_type_id = None # model artifact - self._metric_type_id = None # metric artifact - self._parameter_type_id = None # parameter artifact - - @classmethod - def create_store(cls, backend, metadata): - connection_config = metadata_store_pb2.ConnectionConfig() - if backend == "sqlite": - connection_config.sqlite.filename_uri = metadata["filename_uri"] - connection_config.sqlite.connection_mode = metadata.get("connection_mode", 3) - return metadata_store.MetadataStore(connection_config) - - def get_artifacts(self, taskid): - context_id = self.get_or_create_task_context(taskid).id - artifacts = self.store.get_artifacts_by_context(context_id) - # parameters - parameters = [] - input_data, output_data = [], [] - input_model, output_model = [], [] - input_metric, output_metric = [], [] - - def _to_dict(artifact): - return dict( - uri=artifact.uri, - name=artifact.properties["name"].string_value, - metadata=json.loads(artifact.properties["metadata"].string_value), - ) - - for artifact in artifacts: - if self.parameter_type_id == artifact.type_id: - parameters.append( - dict( - name=artifact.properties["name"].string_value, - value=json.loads(artifact.properties["value"].string_value), - type=artifact.properties["type"].string_value, - ) - ) - if artifact.type_id in {self.data_type_id, self.model_type_id, self.metric_type_id}: - is_input = artifact.properties["is_input"].bool_value - - if self.data_type_id == artifact.type_id: - if is_input: - input_data.append(_to_dict(artifact)) - else: - output_data.append(_to_dict(artifact)) - - if self.model_type_id == artifact.type_id: - if is_input: - input_model.append(_to_dict(artifact)) - else: - output_model.append(_to_dict(artifact)) - - if self.metric_type_id == artifact.type_id: - if is_input: - input_metric.append(_to_dict(artifact)) - else: - output_metric.append(_to_dict(artifact)) - return dict( - parameters=parameters, - input=dict(data=input_data, model=input_model, metric=input_metric), - output=dict(data=output_data, model=output_model, metric=output_metric), - ) - - def get_or_create_task_context(self, taskid): - task_context_run = self.store.get_context_by_type_and_name("TaskContext", taskid) - if task_context_run is None: - task_context_run = metadata_store_pb2.Context() - task_context_run.type_id = self.task_context_type_id - task_context_run.name = taskid - [task_context_run_id] = self.store.put_contexts([task_context_run]) - task_context_run.id = task_context_run_id - return task_context_run - - def put_task_to_task_context(self, taskid): - association = metadata_store_pb2.Association() - association.execution_id = self.get_or_create_task(taskid).id - association.context_id = self.get_or_create_task_context(taskid).id - self.store.put_attributions_and_associations([], [association]) - - def put_artifact_to_task_context(self, taskid, artifact_id): - attribution = metadata_store_pb2.Attribution() - attribution.artifact_id = artifact_id - attribution.context_id = self.get_or_create_task_context(taskid).id - self.store.put_attributions_and_associations([attribution], []) - - def update_task_state(self, taskid, state, exception=None): - task_run = self.get_or_create_task(taskid) - task_run.properties["state"].string_value = state - if exception is not None: - task_run.properties["exception"].string_value = exception - self.store.put_executions([task_run]) - - def get_or_create_task(self, taskid): - task_run = self.store.get_execution_by_type_and_name("Task", taskid) - if task_run is None: - task_run = metadata_store_pb2.Execution() - task_run.type_id = self.task_type_id - task_run.name = taskid - task_run.properties["state"].string_value = "INIT" - task_run.properties["safe_terminate"].bool_value = False - [task_run_id] = self.store.put_executions([task_run]) - task_run.id = task_run_id - return task_run - - def get_task_safe_terminate_flag(self, taskid: str): - task_run = self.get_or_create_task(taskid) - return task_run.properties["safe_terminate"].bool_value - - def set_task_safe_terminate_flag(self, taskid: str): - task_run = self.get_or_create_task(taskid) - task_run.properties["safe_terminate"].bool_value = True - self.store.put_executions([task_run]) - - def record_input_event(self, execution_id, artifact_id): - event = metadata_store_pb2.Event() - event.artifact_id = artifact_id - event.execution_id = execution_id - event.type = metadata_store_pb2.Event.DECLARED_INPUT - self.store.put_events([event]) - - def record_output_event(self, execution_id, artifact_id): - event = metadata_store_pb2.Event() - event.artifact_id = artifact_id - event.execution_id = execution_id - event.type = metadata_store_pb2.Event.DECLARED_OUTPUT - self.store.put_events([event]) - - def add_parameter(self, name: str, value): - artifact = metadata_store_pb2.Artifact() - artifact.properties["name"].string_value = name - artifact.properties["type"].string_value = str(type(value)) - artifact.properties["value"].string_value = json.dumps(value) - artifact.type_id = self.parameter_type_id - [artifact_id] = self.store.put_artifacts([artifact]) - return artifact_id - - def add_data_artifact(self, name: str, uri: str, metadata: dict, is_input): - return self.add_artifact(self.data_type_id, name, uri, metadata, is_input) - - def add_model_artifact(self, name: str, uri: str, metadata: dict, is_input): - return self.add_artifact(self.model_type_id, name, uri, metadata, is_input) - - def add_metric_artifact(self, name: str, uri: str, metadata: dict, is_input): - return self.add_artifact(self.metric_type_id, name, uri, metadata, is_input) - - def add_artifact(self, type_id: int, name: str, uri: str, metadata: dict, is_input): - artifact = metadata_store_pb2.Artifact() - artifact.uri = uri - artifact.properties["name"].string_value = name - artifact.properties["is_input"].bool_value = is_input - artifact.properties["metadata"].string_value = json.dumps(metadata) - artifact.type_id = type_id - [artifact_id] = self.store.put_artifacts([artifact]) - return artifact_id - - @property - def task_context_type_id(self): - if self._task_context_type_id is None: - job_type = metadata_store_pb2.ContextType() - job_type.name = "TaskContext" - job_type.properties["jobid"] = metadata_store_pb2.STRING - self._task_context_type_id = self.store.put_context_type(job_type) - return self._task_context_type_id - - @property - def task_type_id(self): - if self._task_type_id is None: - task_type = metadata_store_pb2.ExecutionType() - task_type.name = "Task" - task_type.properties["state"] = metadata_store_pb2.STRING - task_type.properties["exception"] = metadata_store_pb2.STRING - task_type.properties["safe_terminate"] = metadata_store_pb2.BOOLEAN - self._task_type_id = self.store.put_execution_type(task_type) - return self._task_type_id - - @property - def parameter_type_id(self): - if self._parameter_type_id is None: - artifact_type = metadata_store_pb2.ArtifactType() - artifact_type.name = "Parameter" - artifact_type.properties["name"] = metadata_store_pb2.STRING - artifact_type.properties["type"] = metadata_store_pb2.STRING - artifact_type.properties["value"] = metadata_store_pb2.STRING - self._parameter_type_id = self.store.put_artifact_type(artifact_type) - return self._parameter_type_id - - @property - def data_type_id(self): - if self._data_type_id is None: - self._data_type_id = self.create_artifact_type("Data") - return self._data_type_id - - @property - def model_type_id(self): - if self._model_type_id is None: - self._model_type_id = self.create_artifact_type("Model") - return self._model_type_id - - @property - def metric_type_id(self): - if self._metric_type_id is None: - self._metric_type_id = self.create_artifact_type("Metric") - return self._metric_type_id - - def create_artifact_type(self, name): - artifact_type = metadata_store_pb2.ArtifactType() - artifact_type.name = name - artifact_type.properties["uri"] = metadata_store_pb2.STRING - artifact_type.properties["name"] = metadata_store_pb2.STRING - artifact_type.properties["is_input"] = metadata_store_pb2.BOOLEAN - artifact_type.properties["metadata"] = metadata_store_pb2.STRING - artifact_type_id = self.store.put_artifact_type(artifact_type) - return artifact_type_id diff --git a/python/fate/arch/context/_namespace.py b/python/fate/arch/context/_namespace.py index 802a748523..ab3b5f54a2 100644 --- a/python/fate/arch/context/_namespace.py +++ b/python/fate/arch/context/_namespace.py @@ -13,78 +13,63 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from contextlib import contextmanager -from typing import Generator, overload +from typing import List, Optional, Tuple logger = logging.getLogger(__name__) +_NS_FEDERATION_SPLIT = "." -class Namespace: - """ - Summary, Metrics may be namespace awared: - ``` - namespace = Namespace() - ctx = Context(...summary=XXXSummary(namespace)) - ``` - """ - - def __init__(self, namespaces=None) -> None: - if namespaces is None: - namespaces = [] - self.namespaces = namespaces - - @contextmanager - def into_subnamespace(self, subnamespace: str): - self.namespaces.append(subnamespace) - try: - yield self - finally: - self.namespaces.pop() + +class NS: + def __init__(self, name, deep, parent: Optional["NS"] = None) -> None: + self.name = name + self.deep = deep + self.parent = parent + + if self.parent is None: + self._federation_tag = self.get_name() + self._metric_groups = [] + else: + self._federation_tag = f"{self.parent._federation_tag}{_NS_FEDERATION_SPLIT}{self.get_name()}" + self._metric_groups = [*self.parent._metric_groups, self.parent.get_group()] + + @property + def federation_tag(self): + return self._federation_tag @property - def namespace(self): - return ".".join(self.namespaces) - - def fedeation_tag(self) -> str: - return ".".join(self.namespaces) - - def sub_namespace(self, namespace): - return Namespace([*self.namespaces, namespace]) - - @overload - @contextmanager - def iter_namespaces( - self, start: int, stop: int, *, prefix_name="" - ) -> Generator[Generator["Namespace", None, None], None, None]: - ... - - @overload - @contextmanager - def iter_namespaces( - self, stop: int, *, prefix_name="" - ) -> Generator[Generator["Namespace", None, None], None, None]: - ... - - @contextmanager - def iter_namespaces(self, *args, prefix_name=""): - assert 0 < len(args) <= 2, "position argument should be 1 or 2" - if len(args) == 1: - start, stop = 0, args[0] - if len(args) == 2: - start, stop = args[0], args[1] - - prev_namespace_state = self._namespace_state - - def _state_iterator() -> Generator["Namespace", None, None]: - for i in range(start, stop): - # the tags in the iteration need to be distinguishable - template_formated = f"{prefix_name}iter_{i}" - self._namespace_state = IterationState(prev_namespace_state.sub_namespace(template_formated)) - yield self - - # with context returns iterator of Contexts - # namespaec state inside context is changed alone with iterator comsued - yield _state_iterator() - - # restore namespace state when leaving with context - self._namespace_state = prev_namespace_state + def metric_groups(self) -> List[Tuple[str, Optional[int]]]: + return self._metric_groups + + def get_name(self): + return self.name + + def get_group(self): + return self.name, None + + def __str__(self) -> str: + return f"{self.__class__.__name__}(name={self.name}, deep={self.deep}" + + def indexed_ns(self, index: int): + return IndexedNS(index=index, name=self.name, deep=self.deep, parent=self.parent) + + def sub_ns(self, name: str): + return NS(name=name, deep=self.deep + 1, parent=self) + + +class IndexedNS(NS): + def __init__(self, index, name: str, deep: int, parent: Optional["NS"] = None) -> None: + self.index = index + super().__init__(name=name, deep=deep, parent=parent) + + def get_name(self): + return f"{self.name}-{self.index}" + + def get_group(self): + return self.name, self.index + + def __str__(self) -> str: + return f"{self.__class__.__name__}(index={self.index}, name={self.name}, deep={self.deep})" + + +default_ns = NS(name="default", deep=0) diff --git a/python/fate/arch/context/io/data/csv.py b/python/fate/arch/context/io/data/csv.py deleted file mode 100644 index 68cb6afddb..0000000000 --- a/python/fate/arch/context/io/data/csv.py +++ /dev/null @@ -1,52 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from ....unify import URI -from .df import Dataframe - - -class CSVReader: - def __init__(self, ctx, name: str, uri: URI, metadata: dict) -> None: - self.name = name - self.ctx = ctx - self.uri = uri - self.metadata = metadata - - def read_dataframe(self): - import inspect - - from fate.arch import dataframe - - kwargs = {} - p = inspect.signature(dataframe.CSVReader.__init__).parameters - parameter_keys = p.keys() - for k, v in self.metadata.items(): - if k in parameter_keys: - kwargs[k] = v - - dataframe_reader = dataframe.CSVReader(**kwargs).to_frame(self.ctx, self.uri.path) - # s_df = dataframe.serialize(self.ctx, dataframe_reader) - # dataframe_reader = dataframe.deserialize(self.ctx, s_df) - return Dataframe(dataframe_reader, dataframe_reader.shape[1], dataframe_reader.shape[0]) - - -class CSVWriter: - def __init__(self, ctx, name: str, uri: URI, metadata: dict) -> None: - self.name = name - self.ctx = ctx - self.uri = uri - self.metadata = metadata - - def write_dataframe(self, df): - ... diff --git a/python/fate/arch/context/io/data/df.py b/python/fate/arch/context/io/data/df.py deleted file mode 100644 index 2a86b33990..0000000000 --- a/python/fate/arch/context/io/data/df.py +++ /dev/null @@ -1,25 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -class Dataframe: - def __init__(self, frames, num_features, num_samples) -> None: - self.data = frames - self.num_features = num_features - self.num_samples = num_samples - - def __len__(self): - return self.num_samples - - def to_local(self): - return self.data.to_local() diff --git a/python/fate/arch/context/io/data/eggroll.py b/python/fate/arch/context/io/data/eggroll.py deleted file mode 100644 index df4f2d95bf..0000000000 --- a/python/fate/arch/context/io/data/eggroll.py +++ /dev/null @@ -1,127 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ....unify import EggrollURI - - -class EggrollDataFrameWriter: - def __init__(self, ctx, uri: EggrollURI, metadata: dict) -> None: - self.ctx = ctx - self.uri = EggrollMetaURI(uri) - self.metadata = metadata - - def write_dataframe(self, df): - from fate.arch import dataframe - from fate.arch.computing._address import EggRollAddress - - table = dataframe.serialize(self.ctx, df) - schema = {} - table.save( - address=EggRollAddress(name=self.uri.get_data_name(), namespace=self.uri.get_data_namespace()), - partitions=int(self.metadata.get("num_partitions", table.partitions)), - schema=schema, - **self.metadata, - ) - # save meta - meta_table = self.ctx.computing.parallelize([("schema", schema)], partition=1, include_key=True) - meta_table.save( - address=EggRollAddress(name=self.uri.get_meta_name(), namespace=self.uri.get_meta_namespace()), - partitions=1, - schema={}, - **self.metadata, - ) - - -class EggrollDataFrameReader: - def __init__(self, ctx, uri: EggrollURI, metadata: dict) -> None: - self.ctx = ctx - self.uri = EggrollMetaURI(uri) - self.metadata = metadata - - def read_dataframe(self): - from fate.arch import dataframe - - from .df import Dataframe - - table = load_table(self.ctx, self.uri, self.metadata) - df = dataframe.deserialize(self.ctx, table) - return Dataframe(df, df.shape[1], df.shape[0]) - - -class EggrollRawTableReader: - def __init__(self, ctx, name: str, uri: EggrollURI, metadata: dict) -> None: - self.name = name - self.ctx = ctx - self.uri = EggrollMetaURI(uri) - self.metadata = metadata - - def read_dataframe(self): - import inspect - - from fate.arch import dataframe - - from .df import Dataframe - - table = load_table(self.ctx, self.uri, self.metadata) - - kwargs = {} - p = inspect.signature(dataframe.RawTableReader.__init__).parameters - parameter_keys = p.keys() - for k, v in table.schema.items(): - if k in parameter_keys: - kwargs[k] = v - - dataframe_reader = dataframe.RawTableReader(**kwargs).to_frame(self.ctx, table) - return Dataframe(dataframe_reader, dataframe_reader.shape[1], dataframe_reader.shape[0]) - - -class EggrollMetaURI: - def __init__(self, uri: EggrollURI) -> None: - self.uri = uri - - def get_data_namespace(self): - return self.uri.namespace - - def get_data_name(self): - return self.uri.name - - def get_meta_namespace(self): - return self.uri.namespace - - def get_meta_name(self): - return f"{self.uri.name}.meta" - - -def load_table(ctx, uri: EggrollMetaURI, metadata: dict): - from fate.arch.computing._address import EggRollAddress - - meta_key, meta = list( - ctx.computing.load( - address=EggRollAddress(name=uri.get_meta_name(), namespace=uri.get_meta_namespace()), - partitions=1, - schema={}, - **metadata, - ).collect() - )[0] - assert meta_key == "schema" - num_partitions = metadata.get("num_partitions") - table = ctx.computing.load( - address=EggRollAddress(name=uri.get_data_name(), namespace=uri.get_data_namespace()), - partitions=num_partitions, - schema=meta, - **metadata, - ) - - return table diff --git a/python/fate/arch/context/io/data/file.py b/python/fate/arch/context/io/data/file.py deleted file mode 100644 index 5a3d120917..0000000000 --- a/python/fate/arch/context/io/data/file.py +++ /dev/null @@ -1,70 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from ....unify import FileURI -from .df import Dataframe - - -class FileDataFrameWriter: - def __init__(self, ctx, name: str, uri: FileURI, metadata: dict) -> None: - self.name = name - self.ctx = ctx - self.uri = FileMetaURI(uri) - self.metadata = metadata - - def write_dataframe(self, df): - import json - - from fate.arch import dataframe - - table = dataframe.serialize(self.ctx, df) - with open(self.uri.get_data_path(), "w") as f: - json.dump(list(table.collect()), f) - with open(self.uri.get_meta_path(), "w") as f: - json.dump(table.schema, f) - - -class FileDataFrameReader: - def __init__(self, ctx, name: str, uri: FileURI, metadata: dict) -> None: - self.name = name - self.ctx = ctx - self.uri = FileMetaURI(uri) - self.metadata = metadata - - def read_dataframe(self): - import json - - from fate.arch import dataframe - - with open(self.uri.get_meta_path(), "r") as fin: - schema = json.load(fin) - with open(self.uri.get_data_path(), "r") as fin: - data = json.load(fin) - - table = self.ctx.computing.parallelize(data, include_key=True, partition=1) - table.schema = schema - df = dataframe.deserialize(self.ctx, table) - - return Dataframe(df, df.shape[1], df.shape[0]) - - -class FileMetaURI: - def __init__(self, uri: FileURI) -> None: - self.uri = uri - - def get_data_path(self): - return self.uri.path - - def get_meta_path(self): - return f"{self.uri.path}.meta" diff --git a/python/fate/arch/context/io/kit.py b/python/fate/arch/context/io/kit.py deleted file mode 100644 index 531f5c4059..0000000000 --- a/python/fate/arch/context/io/kit.py +++ /dev/null @@ -1,165 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Protocol - -from fate.components import Artifact, DatasetArtifact, MetricArtifact, ModelArtifact - -from ...unify import URI, EggrollURI - - -class Reader(Protocol): - ... - - -class Writer(Protocol): - ... - - -class IOKit: - @staticmethod - def _parse_args(arg, **kwargs): - name = "" - metadata = {} - if hasattr(arg, "uri"): - uri = arg.uri - name = arg.name - metadata = arg.metadata - elif isinstance(arg[0], str): - uri = arg[0] - else: - raise ValueError(f"invalid arguments: {arg} and {kwargs}") - if "name" in kwargs: - name = kwargs["name"] - if "metadata" in kwargs: - metadata = kwargs["metadata"] - for k, v in kwargs.items(): - if k not in ["name", "metadata"]: - metadata[k] = v - - uri = URI.from_string(uri) - format = metadata.get("format") - return format, name, uri, metadata - - def reader(self, ctx, artifact, **kwargs): - name = artifact.name - metadata = artifact.metadata - if "metadata" in kwargs: - metadata = kwargs["metadata"] - for k, v in kwargs.items(): - if k not in ["name", "metadata"]: - metadata[k] = v - writer_format = metadata.get("format") - if "name" in kwargs: - name = kwargs["name"] - - if isinstance(artifact, MetricArtifact): - uri = URI.from_string(artifact.uri) - if uri.schema == "file": - from .metric.file import FileMetricsReader - - return FileMetricsReader(ctx, name, uri, metadata) - if uri.schema in ["http", "https"]: - from .metric.http import HTTPMetricsReader - - return HTTPMetricsReader(ctx, name, uri, metadata) - - if isinstance(artifact, ModelArtifact): - uri = URI.from_string(artifact.uri) - if uri.schema == "file": - from .model.file import FileModelReader - - return FileModelReader(ctx, name, uri, metadata) - - if uri.schema in ["http", "https"]: - from .model.http import HTTPModelReader - - return HTTPModelReader(ctx, name, uri, metadata) - - if isinstance(artifact, DatasetArtifact): - uri = URI.from_string(artifact.uri) - if uri.schema == "file": - if writer_format == "csv": - from .data.csv import CSVReader - - return CSVReader(ctx, name, uri, metadata) - - elif writer_format == "dataframe": - from .data.file import FileDataFrameReader - - return FileDataFrameReader(ctx, name, uri.to_schema(), {}) - elif uri.schema == "eggroll": - if writer_format == "dataframe": - from .data.eggroll import EggrollDataFrameReader - - return EggrollDataFrameReader(ctx, uri.to_schema(), {}) - elif writer_format == "raw_table": - from .data.eggroll import EggrollRawTableReader - - return EggrollRawTableReader(ctx, name, uri.to_schema(), {}) - - raise NotImplementedError(f"{artifact}") - - def writer(self, ctx, artifact: Artifact, **kwargs) -> "Writer": - name = artifact.name - metadata = artifact.metadata - if "metadata" in kwargs: - metadata = kwargs["metadata"] - for k, v in kwargs.items(): - if k not in ["name", "metadata"]: - metadata[k] = v - writer_format = metadata.get("format") - if "name" in kwargs: - name = kwargs["name"] - - if isinstance(artifact, MetricArtifact): - uri = URI.from_string(artifact.uri) - if uri.schema == "file": - from .metric.file import FileMetricsWriter - - return FileMetricsWriter(ctx, name, uri, metadata) - - if uri.schema in ["http", "https"]: - from .metric.http import HTTPMetricsWriter - - return HTTPMetricsWriter(ctx, name, uri, metadata) - if isinstance(artifact, ModelArtifact): - uri = URI.from_string(artifact.uri) - if uri.schema == "file": - from .model.file import FileModelWriter - - return FileModelWriter(ctx, name, uri) - if uri.schema in ["http", "https"]: - from .model.http import HTTPModelWriter - - return HTTPModelWriter(ctx, name, uri, metadata) - - if isinstance(artifact, DatasetArtifact): - uri = URI.from_string(artifact.uri) - if uri.schema == "file": - if writer_format == "csv": - from .data.csv import CSVWriter - - return CSVWriter(ctx, name, uri, metadata) - - elif writer_format == "dataframe": - from .data.file import FileDataFrameWriter - - return FileDataFrameWriter(ctx, name, uri.to_schema(), {}) - elif uri.schema == "eggroll": - if writer_format == "dataframe": - from .data.eggroll import EggrollDataFrameWriter - - return EggrollDataFrameWriter(ctx, uri.to_schema(), {}) - raise NotImplementedError(f"{artifact}") diff --git a/python/fate/arch/context/io/metric/file.py b/python/fate/arch/context/io/metric/file.py deleted file mode 100644 index a5dd7d6b38..0000000000 --- a/python/fate/arch/context/io/metric/file.py +++ /dev/null @@ -1,55 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -import os -from typing import Union - -from ....unify import URI -from ...metric import InCompleteMetrics, Metrics - - -class FileMetricsWriter: - def __init__(self, ctx, name: str, uri: URI, metadata) -> None: - self.name = name - self.ctx = ctx - self.uri = uri - - def write_metric(self, metrics: Union[Metrics, InCompleteMetrics]): - if isinstance(metrics, Metrics): - with open(self.uri.path, "w") as f: - json.dump(metrics.dict(), f) - else: - # read - if not os.path.exists(self.uri.path): - merged = metrics - else: - with open(self.uri.path, "r") as f: - merged = metrics.from_dict(json.load(f)).merge(metrics) - - with open(self.uri.path, "w") as f: - json.dump(merged.dict(), f) - - -class FileMetricsReader: - def __init__(self, ctx, name: str, uri: URI, metadata: dict) -> None: - self.name = name - self.ctx = ctx - self.uri = uri - self.metadata = metadata - - def read_metric(self): - with open(self.uri.path, "r") as fin: - metric_dict = json.loads(fin.read()) - return metric_dict diff --git a/python/fate/arch/context/io/metric/http.py b/python/fate/arch/context/io/metric/http.py deleted file mode 100644 index a12fd1162a..0000000000 --- a/python/fate/arch/context/io/metric/http.py +++ /dev/null @@ -1,47 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from typing import Union - -import requests - -from ....unify import URI -from ...metric import InCompleteMetrics, Metrics - - -class HTTPMetricsWriter: - def __init__(self, ctx, name: str, uri: URI, metadata) -> None: - self.name = name - self.ctx = ctx - self.entrypoint = f"{uri.schema}://{uri.authority}{uri.path}" - - def write_metric(self, metrics: Union[Metrics, InCompleteMetrics]): - if isinstance(metrics, Metrics): - response = requests.post(url=self.entrypoint, json={"data": metrics.dict(), "incomplte": False}) - else: - response = requests.post(url=self.entrypoint, json={"data": metrics.dict(), "incomplete": True}) - logging.info(response.text) - - -class HTTPMetricsReader: - def __init__(self, ctx, name: str, uri: URI, metadata: dict) -> None: - self.name = name - self.ctx = ctx - self.uri = uri - self.entrypoint = f"{uri.schema}://{uri.authority}{uri.path}" - - def read_metric(self): - metric_dict = requests.get(url=self.entrypoint).json().get("data", {}) - return metric_dict diff --git a/python/fate/arch/context/io/model/file.py b/python/fate/arch/context/io/model/file.py deleted file mode 100644 index 386f236372..0000000000 --- a/python/fate/arch/context/io/model/file.py +++ /dev/null @@ -1,40 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json - -from ....unify import URI - - -class FileModelWriter: - def __init__(self, ctx, name: str, uri: URI) -> None: - self.ctx = ctx - self.name = name - self.path = uri.path - - def write_model(self, model): - with open(self.path, "w") as f: - json.dump(model, f) - - -class FileModelReader: - def __init__(self, ctx, name: str, uri: URI, metadata: dict) -> None: - self.name = name - self.ctx = ctx - self.uri = uri - self.metadata = metadata - - def read_model(self): - with open(self.uri.path, "r") as fin: - return json.loads(fin.read()) diff --git a/python/fate/arch/context/io/model/http.py b/python/fate/arch/context/io/model/http.py deleted file mode 100644 index 903225bd0d..0000000000 --- a/python/fate/arch/context/io/model/http.py +++ /dev/null @@ -1,46 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging - -import requests - -from ....unify import URI - -logger = logging.getLogger(__name__) - - -class HTTPModelWriter: - def __init__(self, ctx, name: str, uri: URI, metadata) -> None: - self.name = name - self.ctx = ctx - self.uri = uri - self.entrypoint = f"{self.uri.schema}://{self.uri.authority}{self.uri.path}" - - def write_model(self, model): - logger.debug(self.entrypoint) - response = requests.post(url=self.entrypoint, json={"data": model}) - logger.debug(response.text) - - -class HTTPModelReader: - def __init__(self, ctx, name: str, uri: URI, metadata: dict) -> None: - self.name = name - self.ctx = ctx - self.uri = uri - self.entrypoint = f"{self.uri.schema}://{self.uri.authority}{self.uri.path}" - self.metadata = metadata - - def read_model(self): - return requests.get(url=self.entrypoint).json().get("data", {}) diff --git a/python/fate/arch/context/metric/_handler.py b/python/fate/arch/context/metric/_handler.py deleted file mode 100644 index 7e59e091a8..0000000000 --- a/python/fate/arch/context/metric/_handler.py +++ /dev/null @@ -1,38 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Union - -from fate.interface import MetricsHandler - -from ._type import InCompleteMetrics, Metrics - - -class NoopMetricsHandler(MetricsHandler): - def __init__(self) -> None: - self._metrics = {} - - def log_metrics(self, metrics: Union[Metrics, InCompleteMetrics]): - if isinstance(metrics, Metrics): - if metrics.name in self._metrics: - raise ValueError(f"duplicated metircs: `{metrics.name}` already exists") - else: - self._metrics[metrics.name] = metrics - elif isinstance(metrics, InCompleteMetrics): - if metrics.name not in self._metrics: - self._metrics[metrics.name] = metrics - else: - self._metrics[metrics.name].merge(metrics) - else: - raise ValueError(f"metrics `{metrics}` not allowed") diff --git a/python/fate/arch/context/metric/_incomplte_metrics.py b/python/fate/arch/context/metric/_incomplte_metrics.py deleted file mode 100644 index 08d2b95607..0000000000 --- a/python/fate/arch/context/metric/_incomplte_metrics.py +++ /dev/null @@ -1,56 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from ._type import InCompleteMetrics - - -class StepMetrics(InCompleteMetrics): - complete = False - - def __init__(self, name, type, data, namespace, groups, metadata) -> None: - self.name = name - self.type = type - self.namespace = namespace - self.groups = groups - self.data = data - self.metadata = metadata - - def merge(self, metrics: InCompleteMetrics): - if not isinstance(metrics, StepMetrics): - raise ValueError(f"can't merge metrics type `{metrics}` with StepMetrics") - if metrics.type != self.type or metrics.nemaspace != self.namespace: - raise ValueError(f"can't merge metrics type `{metrics}` with StepMetrics named `{self.name}`") - # TODO: compare groups - return StepMetrics( - name=self.name, - type=self.type, - namespace=self.namespace, - groups=self.groups, - data=[*self.data, *metrics.data], - metadata=self.metadata, - ) - - def dict(self) -> dict: - return dict( - name=self.name, - namespace=self.nemaspace, - groups=self.groups, - type=self.type, - metadata=self.metadata, - data=self.data, - ) - - @classmethod - def from_dict(cls, d) -> "StepMetrics": - return StepMetrics(**d) diff --git a/python/fate/arch/context/metric/_metric.py b/python/fate/arch/context/metric/_metric.py deleted file mode 100644 index 43c6967c4a..0000000000 --- a/python/fate/arch/context/metric/_metric.py +++ /dev/null @@ -1,55 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from ._type import Metric - - -class ScalarMetric(Metric): - type = "scalar" - - def __init__(self, scalar) -> None: - self.scalar = scalar - - def dict(self): - return self.scalar - - -class LossMetric(Metric): - type = "loss" - - def __init__(self, loss) -> None: - self.loss = loss - - def dict(self) -> dict: - return self.loss - - -class AccuracyMetric(Metric): - type = "accuracy" - - def __init__(self, accuracy) -> None: - self.accuracy = accuracy - - def dict(self) -> dict: - return self.accuracy - - -class AUCMetric(Metric): - type = "auc" - - def __init__(self, auc) -> None: - self.auc = auc - - def dict(self) -> dict: - return self.auc diff --git a/python/fate/arch/context/metric/_type.py b/python/fate/arch/context/metric/_type.py deleted file mode 100644 index 6705d35879..0000000000 --- a/python/fate/arch/context/metric/_type.py +++ /dev/null @@ -1,54 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import abc -from typing import Dict, Optional - - -class Metric(metaclass=abc.ABCMeta): - type: str - - @abc.abstractmethod - def dict(self) -> dict: - ... - - -class Metrics(metaclass=abc.ABCMeta): - name: str - type: str - nemaspace: Optional[str] = None - groups: Dict[str, str] = {} - - @abc.abstractmethod - def dict(self) -> dict: - ... - - -class InCompleteMetrics(metaclass=abc.ABCMeta): - name: str - type: str - nemaspace: Optional[str] = None - groups: Dict[str, str] = {} - - @abc.abstractmethod - def dict(self) -> dict: - ... - - @abc.abstractmethod - def merge(self, metrics: "InCompleteMetrics") -> "InCompleteMetrics": - ... - - @abc.abstractclassmethod - def from_dict(cls, d) -> "InCompleteMetrics": - ... diff --git a/python/fate/arch/context/metric/_wrap.py b/python/fate/arch/context/metric/_wrap.py deleted file mode 100644 index 00d33c9e5e..0000000000 --- a/python/fate/arch/context/metric/_wrap.py +++ /dev/null @@ -1,93 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import List, Optional, Tuple, Union - -from fate.interface import MetricsHandler -from fate.interface import MetricsWrap as MetricsWrapProtocol - -from ._handler import NoopMetricsHandler -from ._incomplte_metrics import StepMetrics -from ._metric import AccuracyMetric, AUCMetric, LossMetric, ScalarMetric -from ._metrics import ROCMetrics -from ._type import InCompleteMetrics, Metric, Metrics - - -class MetricsWrap(MetricsWrapProtocol): - def __init__(self, handler: Optional[MetricsHandler], groups=None) -> None: - if handler is None: - self.handler = NoopMetricsHandler() - else: - self.handler = handler - if groups is None: - self.groups = {} - else: - self.groups = groups - - def into_group(self, group_name: str, group_id: str) -> "MetricsWrap": - if group_name in self.groups: - raise ValueError( - f"can't into group named `{group_name}` since `{group_name}` already in groups `{self.groups}`" - ) - groups = {**self.groups} - groups.update({group_name: group_id}) - return MetricsWrap(self.handler, groups) - - def log_metrics(self, metrics: Union[Metrics, InCompleteMetrics]): - if self.groups: - for group_name, group_id in self.groups.items(): - if group_name in metrics.groups: - if (to_add_group_id := metrics.groups[group_name]) != group_id: - raise ValueError( - f"try to add group named `{group_name}`, but group id `{group_id}` not equals `{to_add_group_id}`" - ) - else: - metrics.groups[group_name] = group_id - return self.handler.log_metrics(metrics) - - def log_meta(self, meta): - return self.log_metrics(meta) - - def log_metric( - self, name: str, metric: Metric, step=None, timestamp=None, namespace=None, groups=None, metadata=None - ): - if groups is None: - groups = {} - if metadata is None: - metadata = {} - return self.log_metrics( - StepMetrics( - name=name, - type=metric.type, - namespace=namespace, - groups=groups, - data=[dict(metric=metric.dict(), step=step, timestamp=timestamp)], - metadata=metadata, - ) - ) - - def log_scalar(self, name: str, metric: float, step=None, timestamp=None): - return self.log_metric(name, ScalarMetric(metric), step, timestamp) - - def log_loss(self, name: str, loss: float, step, timestamp=None): - return self.log_metric(name, LossMetric(loss), step, timestamp) - - def log_accuracy(self, name: str, accuracy: float, step=None, timestamp=None): - return self.log_metric(name, AccuracyMetric(accuracy), step, timestamp) - - def log_auc(self, name: str, auc: float, step=None, timestamp=None): - return self.log_metric(name, AUCMetric(auc), step, timestamp) - - def log_roc(self, name: str, data: List[Tuple[float, float]]): - return self.log_metrics(ROCMetrics(name, data)) diff --git a/python/fate/arch/dataframe/__init__.py b/python/fate/arch/dataframe/__init__.py index 3c766750fd..f0a0b1ed66 100644 --- a/python/fate/arch/dataframe/__init__.py +++ b/python/fate/arch/dataframe/__init__.py @@ -12,24 +12,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from ._dataframe import DataFrame from ._frame_reader import ( CSVReader, ImageReader, PandasReader, - RawTableReader, + TableReader, TorchDataSetReader, ) from .io import build_schema, deserialize, parse_schema, serialize -from .utils import DataLoader +from .utils import DataLoader, BatchEncoding +from .utils import KFold __all__ = [ "PandasReader", "CSVReader", - "RawTableReader", + "TableReader", "ImageReader", "TorchDataSetReader", "parse_schema", "build_schema", "serialize", "deserialize", + "DataFrame", + "KFold", + "DataLoader", + "BatchEncoding" ] diff --git a/python/fate/arch/dataframe/_dataframe.py b/python/fate/arch/dataframe/_dataframe.py index b8d22659b7..162ee728d7 100644 --- a/python/fate/arch/dataframe/_dataframe.py +++ b/python/fate/arch/dataframe/_dataframe.py @@ -12,153 +12,444 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# import copy import operator +import typing +from typing import List, Union + +import numpy as np +import pandas as pd -import torch -from fate.arch.computing import is_table +from fate.arch.tensor import DTensor +from .manager import DataManager, Schema -from .ops import arith_method, stat_method, transform_to_predict_result -from .storage import Index, ValueStore +if typing.TYPE_CHECKING: + from fate.arch.histogram import DistributedHistogram, HistogramBuilder -# TODO: record data type, support multiple data types class DataFrame(object): - def __init__(self, ctx, schema, index=None, match_id=None, values=None, label=None, weight=None): + def __init__(self, ctx, block_table, partition_order_mappings, data_manager: DataManager): self._ctx = ctx - self._index = index - self._match_id = match_id - self._values = values - self._label = label - self._weight = weight - self._schema = Schema(**schema) - - self.__shape = None + self._block_table = block_table + self._partition_order_mappings = partition_order_mappings + self._data_manager = data_manager + + self._sample_id_indexer = None + self._match_id_indexer = None + self._sample_id = None + self._match_id = None + self._label = None + self._weight = None + + self._count = None self._columns = None - self._tensor_label = None + @property + def sample_id(self): + if self._sample_id is None: + self._sample_id = self.__extract_fields( + with_sample_id=True, with_match_id=False, with_label=False, with_weight=False + ) + return self._sample_id @property - def index(self): - return self._index + def match_id(self): + if self._match_id is None: + self._match_id = self.__extract_fields( + with_sample_id=False, with_match_id=True, with_label=False, with_weight=False + ) + + return self._match_id @property def values(self): - return self._values + """ + as values maybe bigger than match_id/sample_id/weight/label, we will not cached them + """ + if not len(self.schema.columns): + return None + + return self.__extract_fields( + with_sample_id=True, with_match_id=True, with_label=False, with_weight=False, columns=self.columns.tolist() + ) @property def label(self): + if not self.schema.label_name: + return None + + if self._label is None: + self._label = self.__extract_fields( + with_sample_id=True, with_match_id=True, with_label=True, with_weight=False + ) + return self._label @property def weight(self): + if not self.schema.weight_name: + return None + + if self._weight is None: + self._weight = self.__extract_fields( + with_sample_id=True, with_match_id=True, with_label=False, with_weight=True + ) + return self._weight @property - def match_id(self): - return self._match_id + def shape(self) -> "tuple": + if self._count is None: + items = 0 + for _, v in self._partition_order_mappings.items(): + items += v["end_index"] - v["start_index"] + 1 + self._count = items + + return self._count, len(self._data_manager.schema.columns) @property - def shape(self): - if self.__shape: - return self.__shape + def schema(self) -> "Schema": + return self._data_manager.schema - if self._values is None: - self.__shape = (self._index.count(), 0) - else: - self.__shape = (self._index.count(), len(self._schema.header)) + @property + def columns(self): + return self.schema.columns - return self.__shape + @property + def block_table(self): + return self._block_table + + @block_table.setter + def block_table(self, block_table): + self._block_table = block_table @property - def schema(self) -> "Schema": - return self._schema + def partition_order_mappings(self): + return self._partition_order_mappings @property - def columns(self) -> "ColumnObject": - if not self._columns: - self._columns = ColumnObject(self._schema.header) - else: - return self._columns + def data_manager(self) -> "DataManager": + return self._data_manager + + @data_manager.setter + def data_manager(self, data_manager): + self._data_manager = data_manager + + @property + def dtypes(self): + return self._data_manager.dtypes + + def as_tensor(self, dtype=None): + """ + df.weight.as_tensor() + df.label.as_tensor() + df.values.as_tensor() + """ + from .ops._transformer import transform_to_tensor + + return transform_to_tensor( + self._block_table, self._data_manager, dtype, partition_order_mappings=self.partition_order_mappings + ) + + def as_pd_df(self) -> "pd.DataFrame": + from .ops._transformer import transform_to_pandas_dataframe - def max(self, *args, **kwargs) -> "DataFrame": - return stat_method(self._values, "max", *args, index=self._schema.header, **kwargs) + return transform_to_pandas_dataframe(self._block_table, self._data_manager) - def min(self, *args, **kwargs) -> "DataFrame": - return stat_method(self._values, "min", *args, index=self._schema.header, **kwargs) + def apply_row(self, func, columns=None, with_label=False, with_weight=False, enable_type_align_checking=False): + from .ops._apply_row import apply_row + + return apply_row( + self, + func, + columns=columns, + with_label=with_label, + with_weight=with_weight, + enable_type_align_checking=enable_type_align_checking, + ) + + def create_frame(self, with_label=False, with_weight=False, columns: Union[list, pd.Index] = None) -> "DataFrame": + if columns is not None and isinstance(columns, pd.Index): + columns = columns.tolist() + + return self.__extract_fields( + with_sample_id=True, with_match_id=True, with_label=with_label, with_weight=with_weight, columns=columns + ) + + def empty_frame(self) -> "DataFrame": + return DataFrame( + self._ctx, + self._ctx.computing.parallelize([], include_key=False, partition=self._block_table.partitions), + partition_order_mappings=dict(), + data_manager=self._data_manager.duplicate(), + ) - def mean(self, *args, **kwargs) -> "DataFrame": - return stat_method(self._values, "mean", *args, index=self._schema.header, **kwargs) + def drop(self, index) -> "DataFrame": + from .ops._dimension_scaling import drop - def sum(self, *args, **kwargs) -> "DataFrame": - return stat_method(self._values, "sum", *args, index=self._schema.header, **kwargs) + return drop(self, index) - def std(self, *args, **kwargs) -> "DataFrame": - return stat_method(self._values, "std", *args, index=self._schema.header, **kwargs) + def fillna(self, value): + from .ops._fillna import fillna + + return fillna(self, value) + + def get_dummies(self, dtype="int32"): + from .ops._encoder import get_dummies + + return get_dummies(self, dtype=dtype) + + def isna(self): + from .ops._missing import isna + + return isna(self) + + def isin(self, values): + from .ops._isin import isin + + return isin(self, values) + + def na_count(self): + return self.isna().sum() + + def max(self) -> "pd.Series": + from .ops._stat import max + + return max(self) + + def min(self, *args, **kwargs) -> "pd.Series": + from .ops._stat import min + + return min(self) + + def mean(self, *args, **kwargs) -> "pd.Series": + from .ops._stat import mean + + return mean(self) + + def sum(self, *args, **kwargs) -> "pd.Series": + from .ops._stat import sum + + return sum(self) + + def std(self, ddof=1, **kwargs) -> "pd.Series": + from .ops._stat import std + + return std(self, ddof=ddof) + + def var(self, ddof=1, **kwargs): + from .ops._stat import var + + return var(self, ddof=ddof) + + def variation(self, ddof=1): + from .ops._stat import variation + + return variation(self, ddof=ddof) + + def skew(self, unbiased=False): + from .ops._stat import skew + + return skew(self, unbiased=unbiased) + + def kurt(self, unbiased=False): + from .ops._stat import kurt + + return kurt(self, unbiased=unbiased) + + def sigmoid(self) -> "DataFrame": + from .ops._activation import sigmoid + + return sigmoid(self) + + def rename( + self, + sample_id_name: str = None, + match_id_name: str = None, + label_name: str = None, + weight_name: str = None, + columns: dict = None, + ): + self._data_manager.rename( + sample_id_name=sample_id_name, + match_id_name=match_id_name, + label_name=label_name, + weight_name=weight_name, + columns=columns, + ) def count(self) -> "int": return self.shape[0] - def __add__(self, other) -> "DataFrame": - return self._arithmetic_operate(operator.add, other) + def describe(self, ddof=1, unbiased=False): + from .ops._stat import describe + + return describe(self, ddof=ddof, unbiased=unbiased) + + def quantile(self, q, relative_error: float = 1e-4): + from .ops._quantile import quantile - def __sub__(self, other) -> "DataFrame": - return self._arithmetic_operate(operator.sub, other) + return quantile(self, q, relative_error) + + def qcut(self, q: int): + from .ops._quantile import qcut + + return qcut(self, q) + + def bucketize(self, boundaries: Union[dict, pd.DataFrame]) -> "DataFrame": + from .ops._encoder import bucketize + + return bucketize(self, boundaries) + + def distributed_hist_stat(self, + histogram_builder: "HistogramBuilder", + position: "DataFrame" = None, + targets: Union[dict, "DataFrame"] = None, + ) -> "DistributedHistogram": + from .ops._histogram import distributed_hist_stat + + if targets is None: + raise ValueError("To use distributed hist stat, targets should not be None") + if position is None: + position = self.create_frame() + position["node_idx"] = 0 + + return distributed_hist_stat(self, histogram_builder, position, targets) + + def replace(self, to_replace=None) -> "DataFrame": + from .ops._replace import replace + + return replace(self, to_replace) + + def __add__(self, other: Union[int, float, list, "np.ndarray", "DataFrame", "pd.Series"]) -> "DataFrame": + return self.__arithmetic_operate(operator.add, other) + + def __radd__(self, other: Union[int, float, list, "np.ndarray", "pd.Series"]) -> "DataFrame": + return self + other + + def __sub__(self, other: Union[int, float, list, "np.ndarray", "pd.Series"]) -> "DataFrame": + return self.__arithmetic_operate(operator.sub, other) + + def __rsub__(self, other: Union[int, float, list, "np.ndarray", "pd.Series"]) -> "DataFrame": + return self * (-1) + other def __mul__(self, other) -> "DataFrame": - return self._arithmetic_operate(operator.mul, other) + return self.__arithmetic_operate(operator.mul, other) + + def __rmul__(self, other) -> "DataFrame": + return self * other def __truediv__(self, other) -> "DataFrame": - return self._arithmetic_operate(operator.truediv, other) + return self.__arithmetic_operate(operator.truediv, other) + + def __pow__(self, power) -> "DataFrame": + return self.__arithmetic_operate(operator.pow, power) + + def __lt__(self, other) -> "DataFrame": + return self.__cmp_operate(operator.lt, other) + + def __le__(self, other) -> "DataFrame": + return self.__cmp_operate(operator.le, other) - def _arithmetic_operate(self, op, other) -> "DataFrame": - ret_value = arith_method(self._values, other, op) - attrs_dict = self._retrieval_attr() - attrs_dict["values"] = ret_value - return DataFrame(**attrs_dict) + def __gt__(self, other) -> "DataFrame": + return self.__cmp_operate(operator.gt, other) - def __getattr__(self, attr): - if attr not in self.schema.header: - raise ValueError(f"DataFrame does not has attribute {attr}") + def __ge__(self, other) -> "DataFrame": + return self.__cmp_operate(operator.ge, other) - if isinstance(self._values, ValueStore): - value = getattr(self._values, attr) + def __eq__(self, other) -> "DataFrame": + return self.__cmp_operate(operator.eq, other) + + def __ne__(self, other) -> "DataFrame": + return self.__cmp_operate(operator.ne, other) + + def __invert__(self): + from .ops._unary_operator import invert + + return invert(self) + + def __arithmetic_operate(self, op, other) -> "DataFrame": + from .ops._arithmetic import arith_operate + + return arith_operate(self, other, op) + + def __cmp_operate(self, op, other) -> "DataFrame": + from .ops._cmp import cmp_operate + + return cmp_operate(self, other, op) + + def __setattr__(self, key, value): + property_attr_mapping = dict(block_table="_block_table", data_manager="_data_manager") + if key not in ["label", "weight"] and key not in property_attr_mapping: + self.__dict__[key] = value + return + + if key in property_attr_mapping: + self.__dict__[property_attr_mapping[key]] = value + return + + if key == "label": + if self._label is not None: + self.__dict__["_label"] = None + from .ops._set_item import set_label_or_weight + + set_label_or_weight(self, value, key_type=key) else: - col_idx = self.schema.header.index(attr) - value = self._values[:, col_idx] + if self._weight is not None: + self.__dict__["_weight"] = None + from .ops._set_item import set_label_or_weight - schema = dict(sid=self.schema.sid, header=[attr]) + set_label_or_weight(self, value, key_type=key) - return DataFrame(self._ctx, schema=schema, values=value) + def __getitem__(self, items) -> "DataFrame": + if isinstance(items, DataFrame): + from .ops._where import where - def __getitem__(self, items): - indexes = self.__get_index_by_column_names(items) - ret_tensor = self._values[:, indexes] + return where(self, items) - header_mapping = dict(zip(self._schema.header, range(len(self._schema.header)))) - new_schema = copy.deepcopy(self._schema) - new_header = items if isinstance(items, list) else [items] - new_anonymous_header = [] + if isinstance(items, DTensor): + from .ops._dimension_scaling import retrieval_row + + return retrieval_row(self, items) + + if isinstance(items, pd.Index): + items = items.tolist() + elif not isinstance(items, list): + items = [items] for item in items: - index = header_mapping[item] - new_anonymous_header.append(self._schema.anonymous_header[index]) + if item not in self._data_manager.schema.columns: + raise ValueError(f"DataFrame does not has attribute {item}") - new_schema["header"] = new_header - new_schema["anonymous__header"] = new_anonymous_header + return self.__extract_fields(with_sample_id=True, with_match_id=True, columns=items) - return DataFrame( - self._ctx, index=self._index, values=ret_tensor, label=self._label, weight=self._weight, schema=new_schema - ) + def __setitem__(self, keys, items): + if isinstance(keys, str): + keys = [keys] + elif isinstance(keys, pd.Series): + keys = keys.tolist() + + state = 0 + column_set = set(self._data_manager.schema.columns) + for key in keys: + if key not in column_set: + state |= 1 + else: + state |= 2 - def __setitem__(self, keys, item): - if not isinstance(item, DataFrame): - raise ValueError("Using syntax df[[col1, col2...]] = rhs, rhs should be a dataframe") + if state == 3: + raise ValueError(f"setitem operation does not support a mix of old and new columns") - indexes = self.__get_index_by_column_names(keys) - self._values[:, indexes] = item._values + from .ops._set_item import set_item - return self + set_item(self, keys, items, state) + + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, state_dict): + self.__dict__.update(state_dict) def __len__(self): return self.count() @@ -187,196 +478,106 @@ def __get_index_by_column_names(self, column_names): return indexes - def loc(self, ids, with_partition_id=True): - # this is very costly, use iloc is better - # TODO: if data is not balance, repartition is need? - if isinstance(ids, int): - ids = [ids] - - indexes = self._index.get_indexer(ids, with_partition_id) - - return self.iloc(indexes) - - def iloc(self, indexes): - # TODO: if data is not balance, repartition is need? - if self.is_local: - if is_table(indexes): - raise ValueError("Local dataframe does not support table indexer") - # indexes = indexes.reduce(lambda l1, l2: l1 + l2) - - weight = self._weight[indexes] if self._weight else None - label = self._label[indexes] if self._label else None - values = self._values[indexes] if self._values else None - match_id = self._match_id[indexes] if self._match_id else None - index = self._index[indexes] - elif isinstance(indexes, (int, list)) or is_table(indexes): - if isinstance(indexes, int): - indexes = [indexes] - - """ - indexer: [(old_partition_id, old_block_index), (new_partition_id, new_block_index)] - note: new_block_index may not be continuous - """ - if isinstance(indexes, list): - indexes = self._index.change_index_list_to_indexer(indexes) - """ - agg_indexer: key=old_partition_id, value=[old_block_index, (new_partition_id, new_block_index)] - """ - agg_indexer = Index.aggregate_indexer(indexes) - - # TODO: use distributed tensor slice api later - def _iloc_tensor(distributed_tensor): - blocks = distributed_tensor.storage.blocks - dtype = blocks.first()[1].dtype.name - - def _retrieval_func(kvs): - ret = dict() - for partition_id_key, (t, mappings) in kvs: - t = t.to_local().data.tolist() - for old_block_index, (new_partition_id, new_block_index) in mappings: - t_value = t[old_block_index] - - if new_partition_id not in ret: - ret[new_partition_id] = [] - ret[new_partition_id].append((new_block_index, t_value)) - - return list(ret.items()) - - blocks = blocks.join(agg_indexer, lambda ten, block_mapping: (ten, block_mapping)) - blocks = blocks.mapReducePartitions(_retrieval_func, lambda l1, l2: l1 + l2) - blocks = blocks.mapValues(lambda block: sorted(block, key=lambda buf: buf[0])) - blocks = blocks.mapValues( - lambda block: torch.tensor([value[1] for value in block], dtype=getattr(torch, dtype)) - ) - blocks = [block for pid, block in sorted(list(blocks.collect()))] - - from fate.arch import tensor - - return tensor.distributed_tensor(self._ctx, blocks, partitions=len(blocks)) - - weight = _iloc_tensor(self._weight) if self._weight else None - label = _iloc_tensor(self._label) if self._label else None - values = _iloc_tensor(self._values) if self._values else None - match_id = _iloc_tensor(self._match_id) if self._match_id else None - index = self._index[indexes] - else: - raise ValueError(f"iloc function dose not support args type={type(indexes)}") + def get_indexer(self, target): + if target not in ["sample_id", "match_id"]: + raise ValueError(f"Target should be sample_id or match_id, but {target} found") - return DataFrame( - self._ctx, self._schema.dict(), index=index, match_id=match_id, label=label, weight=weight, values=values - ) - - - @property - def is_local(self): - if self._values is not None: - return not self._values.is_distributed - if self._weight is not None: - return not self._weight.is_distributed - if self.label is not None: - return not self._label.is_distributed - if self._match_id is not None: - return not self._match_id.is_distributed - - return False - - def transform_to_predict_result( - self, predict_score, data_type="train", task_type="binary", classes=None, threshold=0.5 - ): - """ """ + if self.shape[0] == 0: + return self._ctx.computing.parallelize([], include_key=False, partition=self._block_table.partitions) - ret, header = transform_to_predict_result( - self._ctx, predict_score, data_type=data_type, task_type=task_type, classes=classes, threshold=threshold - ) - - transform_schema = {"header": header, "sid": self._schema.sid} - if self._schema.match_id_name: - transform_schema["match_id_name"] = self._schema.match_id_name - - if self._label: - transform_schema["label_name"] = self.schema.label_name - - return DataFrame( - ctx=self._ctx, - index=self._index, - match_id=self._match_id, - label=self.label, - values=ValueStore(self._ctx, ret, header), - schema=transform_schema, - ) - - -class ColumnObject(object): - def __init__(self, col_names): - self._col_names = col_names - - def __getitem__(self, items): - if isinstance(items, int): - return self._col_names[items] + target_name = getattr(self.schema, f"{target}_name") + indexer = self.__convert_to_table(target_name) + if target == "sample_id": + self._sample_id_indexer = indexer else: - ret_cols = [] - for item in items: - ret_cols.append(self._col_names[item]) - - return ColumnObject(ret_cols) + self._match_id_indexer = indexer - def tolist(self): - return self._col_names + return indexer - def __iter__(self): - return (col_name for col_name in self._col_names) + def loc(self, indexer, target="sample_id", preserve_order=False): + from .ops._indexer import loc + return loc(self, indexer, target=target, preserve_order=preserve_order) -class Schema(object): - def __init__( - self, sid=None, match_id_name=None, weight_name=None, label_name=None, header=None, anonymous_header=None - ): - self._sid = sid - self._match_id_name = match_id_name - self._weight_name = weight_name - self._label_name = label_name - self._header = header - self._anonymous_header = anonymous_header - - @property - def sid(self): - return self._sid - - @property - def match_id_name(self): - return self._match_id_name + def iloc(self, indexer: "DataFrame") -> "DataFrame": + from .ops._dimension_scaling import retrieval_row + return retrieval_row(self, indexer) - @property - def weight_name(self): - return self._weight_name + def loc_with_sample_id_replacement(self, indexer): + """ + indexer: table, + row: (key=random_key, + value=(sample_id, (src_block_id, src_block_offset)) + """ + from .ops._indexer import loc_with_sample_id_replacement - @property - def label_name(self): - return self._label_name + return loc_with_sample_id_replacement(self, indexer) - @property - def header(self): - return self._header + def flatten(self, key_type="block_id", with_sample_id=True): + """ + flatten data_frame + """ + from .ops._indexer import flatten_data + return flatten_data(self, key_type=key_type, with_sample_id=with_sample_id) - @property - def anonymous_header(self): - return self._anonymous_header + def copy(self) -> "DataFrame": + return DataFrame( + self._ctx, + self._block_table.mapValues(lambda v: v), + copy.deepcopy(self.partition_order_mappings), + self._data_manager.duplicate(), + ) - def dict(self): - schema = dict(sid=self._sid) + @classmethod + def from_flatten_data(cls, ctx, flatten_table, data_manager, key_type) -> "DataFrame": + from .ops._indexer import transform_flatten_data_to_df + return transform_flatten_data_to_df(ctx, flatten_table, data_manager, key_type) + + @classmethod + def hstack(cls, stacks: List["DataFrame"]) -> "DataFrame": + from .ops._dimension_scaling import hstack + + return hstack(stacks) + + @classmethod + def vstack(cls, stacks: List["DataFrame"]) -> "DataFrame": + from .ops._dimension_scaling import vstack + + return vstack(stacks) + + def sample(self, n: int = None, frac: float = None, random_state=None) -> "DataFrame": + from .ops._dimension_scaling import sample + + return sample(self, n, frac, random_state) + + def __extract_fields( + self, + with_sample_id=True, + with_match_id=True, + with_label=True, + with_weight=True, + columns: Union[str, list] = None, + ) -> "DataFrame": + from .ops._field_extract import field_extract + + return field_extract( + self, + with_sample_id=with_sample_id, + with_match_id=with_match_id, + with_label=with_label, + with_weight=with_weight, + columns=columns, + ) - if self._header: - schema["header"] = self._header - if self._anonymous_header: - schema["anonymous_header"] = self._anonymous_header + def __convert_to_table(self, target_name): + block_loc = self._data_manager.loc_block(target_name) + assert block_loc[1] == 0, "support only one indexer in current version" - if self._weight_name: - schema["weight_name"] = self._weight_name + from .ops._indexer import transform_to_table - if self._label_name: - schema["label_name"] = self._label_name + return transform_to_table(self._block_table, block_loc[0], self._partition_order_mappings) - if self._match_id_name: - schema["match_id_name"] = self._match_id_name + def data_overview(self, num=100): + from .ops._data_overview import collect_data - return schema + return collect_data(self, num=num) diff --git a/python/fate/arch/dataframe/_frame_reader.py b/python/fate/arch/dataframe/_frame_reader.py index ec70641406..07298e6334 100644 --- a/python/fate/arch/dataframe/_frame_reader.py +++ b/python/fate/arch/dataframe/_frame_reader.py @@ -13,33 +13,63 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -import typing - -import numpy as np import pandas as pd -import torch -from fate.arch import tensor +from typing import Union + +from .conf.default_config import DATAFRAME_BLOCK_ROW_SIZE +from .entity import types from ._dataframe import DataFrame -from .storage import Index +from .manager import DataManager -class RawTableReader(object): +class TableReader(object): def __init__( self, - delimiter: str = ",", - label_name: typing.Union[None, str] = None, + sample_id_name: str =None, + match_id_name: str = None, + match_id_list: list = None, + match_id_range: int = 0, + label_name: Union[None, str] = None, label_type: str = "int", - weight_name: typing.Union[None, str] = None, - dtype: str = "float32", + weight_name: Union[None, str] = None, + weight_type: str = "float32", + header: str = None, + delimiter: str = ",", + dtype: Union[str, dict] = "float32", + anonymous_site_name: str = None, + na_values: Union[str, list, dict] = None, input_format: str = "dense", + tag_with_value: bool = False, + tag_value_delimiter: str = ":", + block_row_size: int = None ): - self._delimiter = delimiter + self._sample_id_name = sample_id_name + self._match_id_name = match_id_name + self._match_id_list = match_id_list + self._match_id_range = match_id_range self._label_name = label_name self._label_type = label_type self._weight_name = weight_name + self._weight_type = weight_type + self._delimiter = delimiter + self._header = header self._dtype = dtype + self._anonymous_site_name = anonymous_site_name + self._na_values = na_values self._input_format = input_format + self._tag_with_value = tag_with_value + self._tag_value_delimiter = tag_value_delimiter + self._block_row_size = block_row_size if block_row_size is not None else DATAFRAME_BLOCK_ROW_SIZE + + self.check_params() + + def check_params(self): + if not self._sample_id_name: + raise ValueError("Please provide sample_id_name") + + if not isinstance(self._block_row_size, int) or self._block_row_size < 0: + raise ValueError("block_row_size should be positive integer") def to_frame(self, ctx, table): if self._input_format != "dense": @@ -48,57 +78,32 @@ def to_frame(self, ctx, table): return self._dense_format_to_frame(ctx, table) def _dense_format_to_frame(self, ctx, table): - schema = dict() - schema["sid"] = table.schema["sid"] - header = table.schema["header"].split(self._delimiter, -1) - + data_manager = DataManager(block_row_size=self._block_row_size) + columns = self._header.split(self._delimiter, -1) + columns.remove(self._sample_id_name) + retrieval_index_dict = data_manager.init_from_local_file( + sample_id_name=self._sample_id_name, columns=columns, match_id_list=self._match_id_list, + match_id_name=self._match_id_name, label_name=self._label_name, weight_name=self._weight_name, + label_type=self._label_type, weight_type=self._weight_type, + dtype=self._dtype, default_type=types.DEFAULT_DATA_TYPE) + + from .ops._indexer import get_partition_order_by_raw_table + partition_order_mappings = get_partition_order_by_raw_table(table, data_manager.block_row_size) + # partition_order_mappings = _get_partition_order(table) table = table.mapValues(lambda value: value.split(self._delimiter, -1)) - header_indexes = list(range(len(header))) - index_table, _block_partition_mapping, _global_ranks = _convert_to_order_indexes(table) - - data_dict = {} - if self._label_name: - if self._label_name not in header: - raise ValueError("Label name does not exist in header, please have a check") - label_idx = header.index(self._label_name) - header.remove(self._label_name) - header_indexes.remove(label_idx) - label_type = getattr(np, self._label_type) - label_table = table.mapValues(lambda value: [label_type(value[label_idx])]) - data_dict["label"] = _convert_to_tensor( - ctx, - label_table, - block_partition_mapping=_block_partition_mapping, - dtype=getattr(torch, self._label_type), - ) - schema["label_name"] = self._label_name - - if self._weight_name: - if self._weight_name not in header: - raise ValueError("Weight name does not exist in header, please have a check") - - weight_idx = header.index(self._weight_name) - header.remove(self._weight_name) - header_indexes.remove(weight_idx) - weight_table = table.mapValues(lambda value: [value[weight_idx]]) - data_dict["weight"] = _convert_to_tensor( - ctx, weight_table, block_partition_mapping=_block_partition_mapping, dtype=getattr(torch, "float64") - ) - - schema["weight_name"] = self._weight_name - - if header_indexes: - value_table = table.mapValues(lambda value: np.array(value)[header_indexes].astype(self._dtype).tolist()) - data_dict["values"] = _convert_to_tensor( - ctx, value_table, block_partition_mapping=_block_partition_mapping, dtype=getattr(torch, self._dtype) - ) - schema["header"] = header - - data_dict["index"] = _convert_to_index( - ctx, index_table, block_partition_mapping=_block_partition_mapping, global_ranks=_global_ranks + to_block_func = functools.partial(_to_blocks, + data_manager=data_manager, + retrieval_index_dict=retrieval_index_dict, + partition_order_mappings=partition_order_mappings) + block_table = table.mapPartitions( + to_block_func, + use_previous_behavior=False ) - return DataFrame(ctx=ctx, schema=schema, **data_dict) + return DataFrame(ctx=ctx, + block_table=block_table, + partition_order_mappings=partition_order_mappings, + data_manager=data_manager) class ImageReader(object): @@ -115,35 +120,48 @@ def __init__( class CSVReader(object): # TODO: fast data read - # TODO: a. support match_id, b. more id type def __init__( self, - id_name: typing.Union[None, str] = None, + sample_id_name: Union[None, str] = None, + match_id_list: Union[None, list] = None, + match_id_name: Union[None, str] = None, delimiter: str = ",", - label_name: typing.Union[None, str] = None, + label_name: Union[None, str] = None, label_type: str = "int", - weight_name: typing.Union[None, str] = None, + weight_name: Union[None, str] = None, + weight_type: str = "float32", dtype: str = "float32", + na_values: Union[None, str, list, dict] = None, partition: int = 4, + block_row_size: int = None ): - self._id_name = id_name + self._sample_id_name = sample_id_name + self._match_id_list = match_id_list + self._match_id_name = match_id_name self._delimiter = delimiter self._label_name = label_name self._label_type = label_type self._weight_name = weight_name + self._weight_type = weight_type self._dtype = dtype + self._na_values = na_values self._partition = partition + self._block_row_size = block_row_size if block_row_size is not None else DATAFRAME_BLOCK_ROW_SIZE def to_frame(self, ctx, path): # TODO: use table put data instead of read all data - df = pd.read_csv(path, delimiter=self._delimiter) + df = pd.read_csv(path, delimiter=self._delimiter, na_values=self._na_values) return PandasReader( - id_name=self._id_name, + sample_id_name=self._sample_id_name, + match_id_list=self._match_id_list, + match_id_name=self._match_id_name, label_name=self._label_name, label_type=self._label_type, weight_name=self._weight_name, + dtype=self._dtype, partition=self._partition, + block_row_size=self._block_row_size ).to_frame(ctx, df) @@ -173,142 +191,136 @@ def to_frame(self, ctx, dataset): class PandasReader(object): def __init__( self, - id_name: typing.Union[None, str] = None, + sample_id_name: Union[None, str] = None, + match_id_list: Union[None, list] = None, + match_id_name: Union[None, str] = None, label_name: str = None, - label_type: str = "int", - weight_name: typing.Union[None, str] = None, + label_type: str = "int32", + weight_name: Union[None, str] = None, + weight_type: str = "float32", dtype: str = "float32", partition: int = 4, + block_row_size: int = None, ): - self._id_name = id_name + self._sample_id_name = sample_id_name + self._match_id_list = match_id_list + self._match_id_name = match_id_name self._label_name = label_name self._label_type = label_type self._weight_name = weight_name + self._weight_type = weight_type self._dtype = dtype self._partition = partition + self._block_row_size = block_row_size if block_row_size is not None else DATAFRAME_BLOCK_ROW_SIZE - def to_frame(self, ctx, df: "pd.DataFrame"): - schema = dict() - if not self._id_name: - self._id_name = df.columns[0] - df = df.set_index(self._id_name) - - # TODO: need to ensure id's type is str? - df.index = df.index.astype("str") - - id_list = df.index.tolist() + if self._sample_id_name and not self._match_id_name: + raise ValueError(f"As sample_id {self._sample_id_name} is given, match_id should be given too") - index_table = ctx.computing.parallelize( - zip(id_list, range(df.shape[0])), include_key=True, partition=self._partition + def to_frame(self, ctx, df: "pd.DataFrame"): + if not self._sample_id_name: + self._sample_id_name = types.DEFAULT_SID_NAME + df.index.name = self._sample_id_name + else: + df = df.set_index(self._sample_id_name) + + data_manager = DataManager(block_row_size=self._block_row_size) + retrieval_index_dict = data_manager.init_from_local_file( + sample_id_name=self._sample_id_name, columns=df.columns.tolist(), match_id_list=self._match_id_list, + match_id_name=self._match_id_name, label_name=self._label_name, weight_name=self._weight_name, + label_type=self._label_type, weight_type=self._weight_type, + dtype=self._dtype, default_type=types.DEFAULT_DATA_TYPE) + + site_name = ctx.local.name + local_role = ctx.local.party[0] + + if local_role != "local": + data_manager.fill_anonymous_site_name(site_name=site_name) + + buf = zip(df.index.tolist(), df.values.tolist()) + table = ctx.computing.parallelize( + buf, include_key=True, partition=self._partition ) - index_table, _block_partition_mapping, _global_ranks = _convert_to_order_indexes(index_table) - - data_dict = {} - if self._label_name: - label_list = [[label] for label in df[self._label_name].tolist()] - label_table = ctx.computing.parallelize( - zip(id_list, label_list), include_key=True, partition=self._partition - ) - data_dict["label"] = _convert_to_tensor( - ctx, - label_table, - block_partition_mapping=_block_partition_mapping, - dtype=getattr(torch, self._label_type), - ) - df = df.drop(columns=self._label_name) - schema["label_name"] = self._label_name - - if self._weight_name: - weight_list = df[self._weight_name].tolist() - weight_table = ctx.computing.parallelize( - zip(id_list, weight_list), include_key=True, partition=self._partition - ) - data_dict["weight"] = _convert_to_tensor( - ctx, weight_table, block_partition_mapping=_block_partition_mapping, dtype=getattr(torch, "float64") - ) - - df = df.drop(columns=self._weight_name) - schema["weight_name"] = self._weight_name - - if df.shape[1]: - value_table = ctx.computing.parallelize( - zip(id_list, df.values), include_key=True, partition=self._partition - ) - data_dict["values"] = _convert_to_tensor( - ctx, value_table, block_partition_mapping=_block_partition_mapping, dtype=getattr(torch, self._dtype) - ) - schema["header"] = df.columns.to_list() - - data_dict["index"] = _convert_to_index( - ctx, index_table, block_partition_mapping=_block_partition_mapping, global_ranks=_global_ranks + from .ops._indexer import get_partition_order_by_raw_table + partition_order_mappings = get_partition_order_by_raw_table(table, data_manager.block_row_size) + # partition_order_mappings = _get_partition_order(table) + to_block_func = functools.partial(_to_blocks, + data_manager=data_manager, + retrieval_index_dict=retrieval_index_dict, + partition_order_mappings=partition_order_mappings) + + block_table = table.mapPartitions( + to_block_func, + use_previous_behavior=False ) - schema["sid"] = self._id_name - - return DataFrame(ctx=ctx, schema=schema, **data_dict) - - -def _convert_to_order_indexes(table): - def _get_block_summary(kvs): - key = next(kvs)[0] - block_size = 1 + sum(1 for kv in kvs) - return {key: block_size} - - def _order_indexes(kvs, rank_dict: dict = None): - bid = None - order_indexes = [] - for idx, (k, v) in enumerate(kvs): - if bid is None: - bid = rank_dict[k]["block_id"] - - order_indexes.append((k, (bid, idx))) - - return order_indexes + return DataFrame(ctx=ctx, + block_table=block_table, + partition_order_mappings=partition_order_mappings, + data_manager=data_manager) - block_summary = table.mapPartitions(_get_block_summary).reduce(lambda blk1, blk2: {**blk1, **blk2}) - start_index, block_id = 0, 0 - block_partition_mapping = dict() - global_ranks = [] - for blk_key, blk_size in block_summary.items(): - block_partition_mapping[blk_key] = dict( - start_index=start_index, end_index=start_index + blk_size - 1, block_id=block_id - ) - global_ranks.append(block_partition_mapping[blk_key]) +def _to_blocks(kvs, + data_manager=None, + retrieval_index_dict=None, + partition_order_mappings=None, + na_values=None): + """ + sample_id/match_id,label(maybe missing),weight(maybe missing),X + """ + block_id = None - start_index += blk_size - block_id += 1 + schema = data_manager.schema - order_func = functools.partial(_order_indexes, rank_dict=block_partition_mapping) - order_table = table.mapPartitions(order_func, use_previous_behavior=False) + splits = [[] for _ in range(data_manager.block_num)] + sample_id_block = data_manager.loc_block(schema.sample_id_name, with_offset=False) if schema.sample_id_name else None - return order_table, block_partition_mapping, global_ranks + match_id_block = data_manager.loc_block(schema.match_id_name, with_offset=False)if schema.match_id_name else None + match_id_column_index = retrieval_index_dict["match_id_index"] + label_block = data_manager.loc_block(schema.label_name, with_offset=False) if schema.label_name else None + label_column_index = retrieval_index_dict["label_index"] -def _convert_to_index(ctx, table, block_partition_mapping, global_ranks): - return Index(ctx, table, block_partition_mapping=block_partition_mapping, global_ranks=global_ranks) + weight_block = data_manager.loc_block(schema.weight_name, with_offset=False) if schema.weight_name else None + weight_column_index = retrieval_index_dict["weight_index"] + column_indexes = retrieval_index_dict["column_indexes"] -def _convert_to_tensor(ctx, table, block_partition_mapping, dtype): - # TODO: in mini-demo stage, distributed tensor only accept list, in future, replace this with distributed table. - convert_func = functools.partial(_convert_block, block_partition_mapping=block_partition_mapping, dtype=dtype) - blocks_with_id = list(table.mapPartitions(convert_func, use_previous_behavior=False).collect()) - blocks = [block_with_id[1] for block_with_id in sorted(blocks_with_id)] + columns = schema.columns + column_blocks_mapping = dict() + for col_id, col_name in zip(column_indexes, columns): + bid = data_manager.loc_block(col_name, with_offset=False) + if bid not in column_blocks_mapping: + column_blocks_mapping[bid] = [] - return tensor.distributed_tensor(ctx, blocks, partitions=len(blocks)) + column_blocks_mapping[bid].append(col_id) + block_row_size = data_manager.block_row_size -def _convert_block(kvs, block_partition_mapping, dtype, convert_type="tensor"): - ret = [] - block_id = None + lid = 0 for key, value in kvs: if block_id is None: - block_id = block_partition_mapping[key]["block_id"] - - ret.append(value) - - if convert_type == "tensor": - return [(block_id, torch.tensor(ret, dtype=dtype))] - else: - return [(block_id, pd.Index(ret, dtype=dtype))] + block_id = partition_order_mappings[key]["start_block_id"] + lid += 1 + + # columns = value.split(",", -1) + splits[sample_id_block].append(key) + if match_id_block: + splits[match_id_block].append(value[match_id_column_index]) + if label_block: + splits[label_block].append([value[label_column_index]]) + if weight_block: + splits[weight_block].append([value[weight_column_index]]) + + for bid, col_id_list in column_blocks_mapping.items(): + splits[bid].append([value[col_id] for col_id in col_id_list]) + + if lid % block_row_size == 0: + converted_blocks = data_manager.convert_to_blocks(splits) + yield block_id, converted_blocks + block_id += 1 + splits = [[] for _ in range(data_manager.block_num)] + + if lid % block_row_size: + converted_blocks = data_manager.convert_to_blocks(splits) + yield block_id, converted_blocks diff --git a/rust/tensor/rust_paillier/rust_paillier/par/__init__.py b/python/fate/arch/dataframe/conf/__init__.py similarity index 100% rename from rust/tensor/rust_paillier/rust_paillier/par/__init__.py rename to python/fate/arch/dataframe/conf/__init__.py diff --git a/python/fate/arch/dataframe/conf/default_config.py b/python/fate/arch/dataframe/conf/default_config.py new file mode 100644 index 0000000000..5b7aa0404b --- /dev/null +++ b/python/fate/arch/dataframe/conf/default_config.py @@ -0,0 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +DATAFRAME_BLOCK_ROW_SIZE = 2**7 +BLOCK_COMPRESS_THRESHOLD = 5 diff --git a/rust/tensor/rust_paillier/.projectile b/python/fate/arch/dataframe/entity/__init__.py similarity index 100% rename from rust/tensor/rust_paillier/.projectile rename to python/fate/arch/dataframe/entity/__init__.py diff --git a/python/fate/arch/dataframe/entity/types.py b/python/fate/arch/dataframe/entity/types.py new file mode 100644 index 0000000000..5d40dca731 --- /dev/null +++ b/python/fate/arch/dataframe/entity/types.py @@ -0,0 +1,2 @@ +DEFAULT_DATA_TYPE = "float32" +DEFAULT_SID_NAME = "sample_id" \ No newline at end of file diff --git a/python/fate/arch/dataframe/io/_json_schema.py b/python/fate/arch/dataframe/io/_json_schema.py index b65c35601b..7e5bb514cf 100644 --- a/python/fate/arch/dataframe/io/_json_schema.py +++ b/python/fate/arch/dataframe/io/_json_schema.py @@ -12,90 +12,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import pandas as pd - -from ..storage import ValueStore - -FRAME_SCHEME = "fate.dataframe" +FRAME_SCHEME = "fate.arch.dataframe" def build_schema(data): - fields = [] - schema = data.schema - """ - index, match_id, label, weight, values - """ - fields.append(dict(type="str", name=schema.sid, property="index")) - - if schema.match_id_name is not None: - fields.append(dict(type="str", name=schema.match_id_name, property="match_id")) - - if schema.label_name is not None: - label = data.label - fields.append(dict(type=label.dtype.name, name=schema.label_name, property="label")) - - if schema.weight_name is not None: - weight = data.weight - fields.append(dict(type=weight.dtype.name, name=schema["weight_name"], property="weight")) - - if schema.header is not None: - values = data.values - columns = schema.header - if isinstance(values, ValueStore): - dtypes = values.dtypes - for col_name in columns: - fields.append( - dict( - type=dtypes[col_name].name, - name=col_name, - property="value", - source="fate.dataframe.value_store", - ) - ) - - else: - for col_name in columns: - fields.append(dict(type=values.dtype.name, name=col_name, property="value", source="fate.arch.tensor")) + meta = data.data_manager.serialize() built_schema = dict() - built_schema["fields"] = fields - built_schema["global_ranks"] = data.index.global_ranks - built_schema["block_partition_mapping"] = data.index.block_partition_mapping + built_schema["schema_meta"] = meta + built_schema["partition_order_mappings"] = data.partition_order_mappings built_schema["type"] = FRAME_SCHEME + return built_schema def parse_schema(schema): - if "type" not in schema or schema["type"] != FRAME_SCHEME: + if schema.get("type") != FRAME_SCHEME: raise ValueError(f"deserialize data error, schema type is not {FRAME_SCHEME}") - recovery_schema = dict() - column_info = dict() - fields = schema["fields"] - - for idx, field in enumerate(fields): - if field["property"] == "index": - recovery_schema["sid"] = field["name"] - column_info["index"] = dict(start_idx=idx, end_idx=idx, type=field["type"]) - - elif field["property"] == "match_id": - recovery_schema["match_id_name"] = field["name"] - column_info["match_id"] = dict(start_idx=idx, end_idx=idx, type=field["type"]) - - elif field["property"] == "label": - recovery_schema["label_name"] = field["name"] - column_info["label"] = dict(start_idx=idx, end_idx=idx, type=field["type"]) - - elif field["property"] == "weight": - recovery_schema["weight_name"] = field["name"] - column_info["weight"] = dict(start_idx=idx, end_idx=idx, type=field["type"]) - - elif field["property"] == "value": - header = [field["name"] for field in fields[idx:]] - recovery_schema["header"] = header - column_info["values"] = dict( - start_idx=idx, end_idx=idx + len(header) - 1, type=field["type"], source=field["source"] - ) - break + schema_meta = schema["schema_meta"] + partition_order_mappings = schema["partition_order_mappings"] - return recovery_schema, schema["global_ranks"], schema["block_partition_mapping"], column_info + return schema_meta, partition_order_mappings diff --git a/python/fate/arch/dataframe/io/_json_serialization.py b/python/fate/arch/dataframe/io/_json_serialization.py index 77ba44ce60..461f3ca5b8 100644 --- a/python/fate/arch/dataframe/io/_json_serialization.py +++ b/python/fate/arch/dataframe/io/_json_serialization.py @@ -12,152 +12,39 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools - -import numpy as np -import pandas as pd -import torch -from fate.arch import tensor -from fate.arch.context.io.data import df from .._dataframe import DataFrame -from ..storage import Index, ValueStore +from ..manager import DataManager from ._json_schema import build_schema, parse_schema -def _serialize_distributed(ctx, data): +def _serialize(ctx, data): """ index, match_id, label, weight, values """ - # TODO: tensor does not provide method to get raw values directly, so we use .storages.blocks first schema = build_schema(data) - tensors = [data.label, data.weight] - tensor_concat = None - for t in tensors: - if not t: - continue - - """ - distributed tensor - """ - t = t.storage.blocks - if tensor_concat is None: - tensor_concat = t - else: - tensor_concat = tensor_concat.join(tensor, lambda t1, t2: torch.concat([t1, t2], -1)) - - if data.values is not None: - if isinstance(data.values, ValueStore): - value_concat = data.values.values - if tensor_concat is not None: - value_concat = tensor_concat.join( - value_concat, lambda t1, t2: np.concatenate([t1.to_local().data.numpy(), t2.to_numpy()], axis=-1) - ) - else: - value_concat = data.values.storage.blocks.mapValues(lambda t: t.to_local().data) - if tensor_concat is not None: - value_concat = tensor_concat.join( - value_concat, lambda t1, t2: np.concatenate([t1.to_local().data.numpy(), t2.numpy()], axis=-1) - ) - - else: - value_concat = tensor_concat - if value_concat is not None: - value_concat = value_concat.mapValues(lambda t: t.to_local.data.numpy()) - - tensor_concat = value_concat - - index = Index.aggregate(data.index.values) - if tensor_concat is None: - """ - data only has index - """ - serialize_data = index - else: - - def _flatten(index: list, t): - flatten_ret = [] - for (_id, block_index), _t in zip(index, t): - flatten_ret.append([_id] + _t.tolist()) - - return flatten_ret - - serialize_data = index.join(tensor_concat, _flatten) + from ..ops._transformer import transform_block_table_to_list + serialize_data = transform_block_table_to_list(data.block_table, data.data_manager) serialize_data.schema = schema return serialize_data def serialize(ctx, data): - if isinstance(data, df.Dataframe): - data = data.data - - return _serialize_distributed(ctx, data) + return _serialize(ctx, data) def deserialize(ctx, data): - recovery_schema, global_ranks, block_partition_mapping, column_info = parse_schema(data.schema) - - def _recovery_index(kvs): - """ - TODO: index should provider deserialize method, implement it here for convenient - """ - start_index = column_info["index"]["start_idx"] - indexes = [] - for key, values in kvs: - for offset, v in enumerate(values): - indexes.append((v[start_index], (key, offset))) - - return indexes - - def _recovery_tensor(value, tensor_info=None): - start_index = tensor_info["start_idx"] - end_index = tensor_info["end_idx"] - dtype = tensor_info["type"] - - ret_tensor = [] - for v in value: - ret_tensor.append(v[start_index : end_index + 1]) - - return torch.tensor(ret_tensor, dtype=getattr(torch, dtype)) - - def _recovery_distributed_value_store(value, value_info, header): - start_index = value_info["start_idx"] - end_index = value_info["end_idx"] - - filter_value = [] - for v in value: - filter_value.append(v[start_index : end_index + 1]) - - df = pd.DataFrame(filter_value, columns=header) - - return df + schema_meta, partition_order_mappings = parse_schema(data.schema) - def _to_distributed_tensor(tensor_list): - return tensor.distributed_tensor(ctx, tensor_list, partitions=len(tensor_list)) + data_manager = DataManager.deserialize(schema_meta) - ret_dict = dict() - ret_dict["index"] = Index( - ctx=ctx, - distributed_index=data.mapPartitions(_recovery_index, use_previous_behavior=False), - block_partition_mapping=block_partition_mapping, - global_ranks=global_ranks, - ) + site_name = ctx.local.name + data_manager.fill_anonymous_site_name(site_name) - tensor_keywords = ["weight", "label", "values"] - for keyword in tensor_keywords: - if keyword in column_info: - if keyword == "values" and column_info["values"]["source"] == "fate.dataframe.value_store": - continue - _recovery_func = functools.partial(_recovery_tensor, tensor_info=column_info[keyword]) - tensors = [tensor for key, tensor in sorted(list(data.mapValues(_recovery_func).collect()))] - ret_dict[keyword] = _to_distributed_tensor(tensors) + from ..ops._transformer import transform_list_to_block_table - if "values" in column_info and column_info["values"]["source"] == "fate.dataframe.value_store": - _recovery_df_func = functools.partial( - _recovery_distributed_value_store, value_info=column_info["values"], header=recovery_schema["header"] - ) - ret_dict["values"] = ValueStore(ctx, data.mapValues(_recovery_df_func), recovery_schema["header"]) + block_table = transform_list_to_block_table(data, data_manager) - return DataFrame(ctx, recovery_schema, **ret_dict) + return DataFrame(ctx, block_table, partition_order_mappings, data_manager) diff --git a/python/fate/arch/dataframe/manager/__init__.py b/python/fate/arch/dataframe/manager/__init__.py new file mode 100644 index 0000000000..1f829be469 --- /dev/null +++ b/python/fate/arch/dataframe/manager/__init__.py @@ -0,0 +1,3 @@ +from .schema_manager import Schema +from .data_manager import DataManager +from .block_manager import BlockType, Block diff --git a/python/fate/arch/dataframe/manager/block_manager.py b/python/fate/arch/dataframe/manager/block_manager.py new file mode 100644 index 0000000000..f10be7ec3d --- /dev/null +++ b/python/fate/arch/dataframe/manager/block_manager.py @@ -0,0 +1,802 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import bisect +import copy +import json +from enum import Enum +from typing import Dict, List, Tuple, Union + +import numpy as np +import pandas as pd +import torch +from fate.arch.tensor.phe._tensor import PHETensor + +from .schema_manager import SchemaManager + + +class BlockType(str, Enum): + int32 = "int32" + int64 = "int64" + float32 = "float32" + float64 = "float64" + bool = "bool" + index = "index" + phe_tensor = "phe_tensor" + np_object = "np_object" + + @staticmethod + def promote_types(l_type: "BlockType", r_type: "BlockType"): + if l_type < r_type: + return r_type + else: + return l_type + + def __lt__(self, other): + if self == other: + return False + + if self == BlockType.bool: + return other != BlockType.bool + + if self == BlockType.index: + raise ValueError("Can not compare index types") + + if self == BlockType.np_object: + return False + + if other == BlockType.np_object: + return True + + if self == BlockType.int32: + return other not in [BlockType.bool, BlockType.int32, BlockType] + + if self == BlockType.int64: + return other not in [BlockType.bool, BlockType.int32, BlockType.int64] + + if self == BlockType.float32: + return other in [BlockType.float64, BlockType.phe_tensor, BlockType.np_object] + + if self == BlockType.float64: + return other in [BlockType.phe_tensor, BlockType.np_object] + + return False + + def __gt__(self, other): + if self == other: + return False + + return other < self + + @staticmethod + def get_block_type(data_type): + if isinstance(data_type, PHETensor) or type(data_type) == PHETensor: + return BlockType.phe_tensor + if hasattr(data_type, "dtype"): + data_type = data_type.dtype + if hasattr(data_type, "name"): + data_type = data_type.name + if isinstance(data_type, str): + try: + data_type = BlockType(data_type) + except ValueError: + data_type = "np_object" + return BlockType(data_type) + elif isinstance(data_type, (bool, np.bool_)) or data_type == torch.bool: + return BlockType.bool + elif isinstance(data_type, np.int64) or data_type == torch.int64: + return BlockType.int64 + elif isinstance(data_type, (int, np.int32)) or data_type == torch.int32: + return BlockType.int32 + elif isinstance(data_type, np.float64) or data_type == torch.float64: + return BlockType.float64 + elif isinstance(data_type, (float, np.float32)) or data_type == torch.float32: + return BlockType.float32 + else: + return BlockType.np_object + + @staticmethod + def is_tensor(block_type): + return block_type in [BlockType.bool, BlockType.int32, BlockType.int64, BlockType.float32, BlockType.float64] + + @staticmethod + def is_float(block_type): + return block_type in [BlockType.float32, BlockType.float64] + + @staticmethod + def is_integer(block_type): + return block_type in [BlockType.int32, BlockType.int64] + + @staticmethod + def is_arr(block_value): + if isinstance(block_value, (torch.Tensor, np.ndarray)) and block_value.shape: + return True + return isinstance(block_value, list) + + +class Block(object): + def __init__(self, field_indexes, block_type=None, should_compress=True): + self._block_type = block_type + self._field_indexes = field_indexes + self._should_compress = should_compress + + self._field_index_mapping = dict(zip(field_indexes, range(len(field_indexes)))) + + @property + def block_type(self): + return self._block_type + + @property + def field_indexes(self): + return self._field_indexes + + @field_indexes.setter + def field_indexes(self, field_indexes: Union[list, set]): + self._field_indexes = field_indexes + self._field_index_mapping = dict(zip(field_indexes, range(len(field_indexes)))) + + @property + def should_compress(self): + return self._should_compress + + @should_compress.setter + def should_compress(self, should_compress): + self._should_compress = should_compress + + @property + def is_single(self): + return len(self._field_indexes) == 1 + + def get_field_offset(self, idx): + return self._field_index_mapping[idx] + + def reset_field_indexes(self, dst_field_indexes): + field_indexes = [dst_field_indexes[src_field_index] for src_field_index in self._field_indexes] + self._field_index_mapping = dict(zip(field_indexes, range(len(field_indexes)))) + self._field_indexes = field_indexes + + def derive_block(self, field_indexes) -> Tuple["Block", bool, list]: + """ + assume that sub field indexes always in self._field_indexes + + return: BlockObject, RetrievalIndexInOriBlock: list + """ + src_field_indexes, dst_field_indexes = [], [] + field_indexes = sorted(field_indexes, key=lambda v: v[1]) + for src_field_index, dst_field_index in field_indexes: + src_field_indexes.append(src_field_index) + dst_field_indexes.append(dst_field_index) + + new_block = copy.deepcopy(self) + new_block.field_indexes = dst_field_indexes + # new_block = type(self)(dst_field_indexes) + new_block.should_compress = self._should_compress + + # TODO: can be optimize as sub_field_indexes is ordered, but this is not a bottle neck + changed = True + if len(src_field_indexes) == len(self._field_indexes): + is_monotonous = True + for i in range(1, len(src_field_indexes)): + if src_field_indexes[i] < src_field_indexes[i - 1]: + is_monotonous = False + + if is_monotonous: + retrieval_indexes = [i for i in range(len(self._field_indexes))] + changed = False + else: + retrieval_indexes = [bisect.bisect_left(self._field_indexes, col) for col in src_field_indexes] + else: + retrieval_indexes = [bisect.bisect_left(self._field_indexes, col) for col in src_field_indexes] + + return new_block, changed, retrieval_indexes + + def __str__(self): + field_indexes_format = ",".join(map(str, self._field_indexes)) + return f"block_type:{self._block_type}, fields=={field_indexes_format}" + + def is_numeric(self): + return self._block_type in {BlockType.int32, BlockType.int64, BlockType.float32, BlockType.float64} + + def is_phe_tensor(self): + return self._block_type == BlockType.phe_tensor + + def to_dict(self): + return dict( + block_type=json.dumps(self._block_type), + field_indexes=self._field_indexes, + should_compress=self._should_compress, + ) + + @staticmethod + def from_dict(s_dict): + block_type = json.loads(s_dict["block_type"]) + field_indexes = s_dict["field_indexes"] + should_compress = s_dict["should_compress"] + block = Block.get_block_by_type(block_type) + return block(field_indexes, should_compress=should_compress) + + @staticmethod + def get_block_by_type(block_type): + if not isinstance(block_type, BlockType): + block_type = BlockType.get_block_type(block_type) + + if block_type == block_type.int32: + return Int32Block + elif block_type == block_type.int64: + return Int64Block + elif block_type == block_type.float32: + return Float32Block + elif block_type == block_type.float64: + return Float64Block + elif block_type == block_type.bool: + return BoolBlock + elif block_type == block_type.index: + return IndexBlock + elif block_type == block_type.phe_tensor: + return PHETensorBlock + else: + return NPObjectBlock + + @staticmethod + def convert_block(block): + raise NotImplemented + + def convert_block_type(self, block_type): + converted_block = self.get_block_by_type(block_type)(self._field_indexes, block_type, self._should_compress) + + return converted_block + + # @classmethod + # def retrieval_row(cls, block, indexes): + # if isinstance(block, CiphertextVector): + # return block.slice_indexes(indexes) + # elif isinstance(block, pd.Index): + # if isinstance(indexes, list): + # return block[indexes] + # else: + # return pd.Index(block[indexes]) + # else: + # return block[indexes] + + @classmethod + def transform_block_to_list(cls, block): + return block.tolist() + + # @classmethod + # def transform_row_to_raw(cls, block, index): + # if isinstance(block, pd.Index): + # return block[index] + # elif isinstance(block, CiphertextVector): + # return block.slice_indexes([index]) + # else: + # return block[index].tolist() + + @classmethod + def vstack(cls, blocks): + ret = blocks[0] + if isinstance(ret, pd.Index): + for block in blocks[1:]: + ret = ret.append(block) + elif isinstance(ret, torch.Tensor): + ret = torch.vstack(blocks) + elif isinstance(ret, np.ndarray): + ret = np.vstack(blocks) + else: + raise ValueError(f"Not implemented block vstack for type {type(ret)}") + + return ret + + +class Int32Block(Block): + def __init__(self, *args, **kwargs): + super(Int32Block, self).__init__(*args, **kwargs) + self._block_type = BlockType.int32 + + @staticmethod + def convert_block(block): + if isinstance(block, torch.Tensor): + if block.dtype == torch.int32: + return block + else: + return block.to(torch.int32) + try: + return torch.tensor(block, dtype=torch.int32) + except ValueError: + return torch.tensor(np.array(block, dtype="int32"), dtype=torch.int32) + + @property + def dtype(self): + return torch.int32 + + +class Int64Block(Block): + def __init__(self, *args, **kwargs): + super(Int64Block, self).__init__(*args, **kwargs) + self._block_type = BlockType.int64 + + @staticmethod + def convert_block(block): + if isinstance(block, torch.Tensor): + if block.dtype == torch.int64: + return block + else: + return block.to(torch.int64) + try: + return torch.tensor(block, dtype=torch.int64) + except ValueError: + return torch.tensor(np.array(block, dtype="int64"), dtype=torch.int64) + + @property + def dtype(self): + return torch.int64 + + +class Float32Block(Block): + def __init__(self, *args, **kwargs): + super(Float32Block, self).__init__(*args, **kwargs) + self._block_type = BlockType.float32 + + @staticmethod + def convert_block(block): + if isinstance(block, torch.Tensor): + if block.dtype == torch.float32: + return block + else: + return block.to(torch.float32) + try: + return torch.tensor(block, dtype=torch.float32) + except ValueError: + return torch.tensor(np.array(block, dtype="float32"), dtype=torch.float32) + + @property + def dtype(self): + return torch.float32 + + +class Float64Block(Block): + def __init__(self, *args, **kwargs): + super(Float64Block, self).__init__(*args, **kwargs) + self._block_type = BlockType.float64 + + @staticmethod + def convert_block(block): + if isinstance(block, torch.Tensor): + if block.dtype == torch.float64: + return block + else: + return block.to(torch.float64) + try: + return torch.tensor(block, dtype=torch.float64) + except ValueError: + return torch.tensor(np.array(block, dtype="float64"), dtype=torch.float64) + + @property + def dtype(self): + return torch.float64 + + +class BoolBlock(Block): + def __init__(self, *args, **kwargs): + super(BoolBlock, self).__init__(*args, **kwargs) + self._block_type = BlockType.bool + + @staticmethod + def convert_block(block): + if isinstance(block, torch.Tensor): + if block.dtype == torch.bool: + return block + else: + return block.to(torch.bool) + try: + return torch.tensor(block, dtype=torch.bool) + except ValueError: + return torch.tensor(np.array(block, dtype="bool"), dtype=torch.bool) + + @property + def dtype(self): + return torch.bool + + +class IndexBlock(Block): + def __init__(self, *args, **kwargs): + super(IndexBlock, self).__init__(*args, **kwargs) + self._block_type = BlockType.index + + @staticmethod + def convert_block(block): + return pd.Index(block, dtype=str) + + @property + def dtype(self): + return np.dtype("O") + + +class PHETensorBlock(Block): + def __init__(self, *args, **kwargs): + kwargs["should_compress"] = False + + super(PHETensorBlock, self).__init__(*args, **kwargs) + self._block_type = BlockType.phe_tensor + self._pk = None + self._evaluator = None + self._coder = None + self._dtype = None + self._device = None + + def set_extra_kwargs(self, pk, evaluator, coder, dtype, device): + self._pk = pk + self._evaluator = evaluator + self._coder = coder + self._dtype = dtype + self._device = device + + def convert_block(self, block): + if isinstance(block, list): + block = self._evaluator.cat(block) + return block + + def convert_to_phe_tensor(self, block, shape): + if isinstance(block, PHETensor): + return block + + if isinstance(block, list): + block = block[0].cat(block[1:]) + + return PHETensor( + pk=self._pk, + evaluator=self._evaluator, + coder=self._coder, + shape=shape, + data=block, + device=self._device, + dtype=self._dtype, + ) + + @property + def device(self): + return self._device + + @property + def dtype(self): + return self._dtype + + +class NPObjectBlock(Block): + def __init__(self, *args, **kwargs): + super(NPObjectBlock, self).__init__(*args, **kwargs) + self._block_type = BlockType.np_object + + @staticmethod + def convert_block(block): + return np.array(block, dtype=object) + + @property + def dtype(self): + return np.dtype("O") + + +class BlockManager(object): + def __init__(self): + """ + block manager managers the block structure of each partition, distributed always + Please note that we only compress numeric or bool or object type, not compress index type yet + + _blocks: list of Blocks, each element contains the attrs: axis_set + """ + self._blocks = [] + self._field_block_mapping = dict() + + def initialize_blocks(self, schema_manager: SchemaManager): + """ + sample_id + match_id + label + weight + fields + """ + schema = schema_manager.schema + sample_id_type = schema_manager.get_field_types(name=schema.sample_id_name) + + self._blocks.append( + Block.get_block_by_type(sample_id_type)( + [schema_manager.get_field_offset(schema.sample_id_name)], should_compress=False + ) + ) + + if schema.match_id_name: + dtype = schema_manager.get_field_types(name=schema.match_id_name) + self._blocks.append( + Block.get_block_by_type(dtype)( + [schema_manager.get_field_offset(schema.match_id_name)], should_compress=False + ) + ) + + if schema.label_name: + dtype = schema_manager.get_field_types(name=schema.label_name) + self._blocks.append( + Block.get_block_by_type(dtype)( + [schema_manager.get_field_offset(schema.label_name)], should_compress=False + ) + ) + + if schema.weight_name: + dtype = schema_manager.get_field_types(name=schema.weight_name) + self._blocks.append( + Block.get_block_by_type(dtype)( + [schema_manager.get_field_offset(schema.weight_name)], should_compress=False + ) + ) + + for column_name in schema.columns: + dtype = schema_manager.get_field_types(name=column_name) + self._blocks.append( + Block.get_block_by_type(dtype)([schema_manager.get_field_offset(column_name)], should_compress=True) + ) + + new_blocks, _1, _2 = self.compress() + + self.reset_blocks(new_blocks) + + def append_fields(self, field_indexes, block_types, should_compress=True): + block_num = len(self._blocks) + block_ids = [] + if isinstance(block_types, list): + for offset, (field_index, block_type) in enumerate(zip(field_indexes, block_types)): + block = Block.get_block_by_type(block_type) + self._blocks.append(block(field_indexes=[field_index], should_compress=should_compress)) + self._field_block_mapping[field_index] = (block_num + offset, 0) + block_ids.append(block_num + offset) + else: + block = Block.get_block_by_type(block_types) + self._blocks.append(block(field_indexes=field_indexes, should_compress=should_compress)) + block_ids.append(block_num) + for offset, field_index in enumerate(field_indexes): + self._field_block_mapping[field_index] = (block_num, offset) + + return block_ids + + def pop_blocks(self, block_indexes: List[int]): + block_index_set = set(block_indexes) + blocks = [] + field_block_mapping = dict() + + for bid, block in enumerate(self._blocks): + if bid not in block_index_set: + blocks.append(block) + + self._blocks = blocks + + def split_fields(self, field_indexes, block_types): + field_sets = set(field_indexes) + block_field_maps = dict() + for idx, field_index in enumerate(field_indexes): + block_id, offset = self.loc_block(field_index, with_offset=True) + if block_id not in block_field_maps: + block_field_maps[block_id] = [] + + block_type = block_types[idx].value if isinstance(block_types, list) else block_types.value + block_field_maps[block_id].append([field_index, offset, block_type]) + + cur_block_num = len(self._blocks) + narrow_blocks = [] + for block_id, field_with_offset_list in block_field_maps.items(): + if len(self._blocks[block_id].field_indexes) == len(field_with_offset_list): + if len(field_with_offset_list) == 1: + self._blocks[block_id] = Block.get_block_by_type(block_type)( + self._blocks[block_id].field_indexes, should_compress=self._blocks[block_id].should_compress + ) + else: + should_compress = self._blocks[block_id].should_compress + for idx, (field, offset, block_type) in enumerate(field_with_offset_list): + if not idx: + self._blocks[block_id] = Block.get_block_by_type(block_type)( + [field], should_compress=should_compress + ) + self._field_block_mapping[field] = (block_id, 0) + else: + self._blocks.append( + Block.get_block_by_type(block_type)([field], should_compress=should_compress) + ) + self._field_block_mapping[field] = (cur_block_num, 0) + cur_block_num += 1 + else: + narrow_field_indexes = [] + narrow_field_offsets = [] + for offset, field in enumerate(self._blocks[block_id].field_indexes): + if field not in field_sets: + narrow_field_indexes.append(field) + narrow_field_offsets.append(offset) + + narrow_blocks.append((block_id, narrow_field_offsets)) + + self._blocks[block_id] = Block.get_block_by_type(self._blocks[block_id].block_type)( + narrow_field_indexes, should_compress=self._blocks[block_id].should_compress + ) + for offset, narrow_field in enumerate(narrow_field_indexes): + self._field_block_mapping[narrow_field] = (block_id, offset) + + for field, offset, block_type in field_with_offset_list: + self._blocks.append( + Block.get_block_by_type(block_type)( + [field], should_compress=self._blocks[block_id].should_compress + ) + ) + self._field_block_mapping[field] = (cur_block_num, 0) + cur_block_num += 1 + + dst_blocks = [self._field_block_mapping[field][0] for field in field_indexes] + + return narrow_blocks, dst_blocks + + @property + def blocks(self): + return self._blocks + + @blocks.setter + def blocks(self, blocks): + self._blocks = blocks + + @property + def field_block_mapping(self): + return self._field_block_mapping + + @field_block_mapping.setter + def field_block_mapping(self, field_block_mapping): + self._field_block_mapping = field_block_mapping + + def reset_block_field_indexes(self, field_index_changes: Dict[int, int]): + field_block_mapping = dict() + for bid in range(len(self._blocks)): + self._blocks[bid].reset_field_indexes(field_index_changes) + for offset, field_index in enumerate(self._blocks[bid].field_indexes): + field_block_mapping[field_index] = (bid, offset) + + self._field_block_mapping = field_block_mapping + + def duplicate(self): + dup_block_manager = BlockManager() + dup_block_manager.blocks = copy.deepcopy(self._blocks) + dup_block_manager.field_block_mapping = copy.deepcopy(self._field_block_mapping) + + return dup_block_manager + + def get_numeric_block(self): + numeric_blocks = [] + for _blk in self._blocks: + if _blk.is_numeric: + numeric_blocks.append(_blk) + + return numeric_blocks + + def loc_block(self, field_index, with_offset=True) -> Union[Tuple[int, int], int]: + if with_offset: + return self._field_block_mapping[field_index] + else: + return self._field_block_mapping[field_index][0] + + def compress(self): + compressible_blocks = dict() + + has_compressed = False + for block_id, block in enumerate(self._blocks): + if block.should_compress: + has_compressed = True + + if block.block_type not in compressible_blocks: + compressible_blocks[block.block_type] = [] + compressible_blocks[block.block_type].append((block_id, block)) + + if not has_compressed: + return self._blocks, [], [] + + new_blocks, to_compress_block_loc = [], [] + non_compressed_block_changes = dict() + for block_type, block_list in compressible_blocks.items(): + _blocks = [] + _block_ids = [] + for block_id, block in block_list: + if block.should_compress: + _blocks.append(block) + _block_ids.append(block_id) + else: + non_compressed_block_changes[block_id] = len(new_blocks) + new_blocks.append(block) + + if len(_blocks) > 1: + dst_block_id = len(new_blocks) + block_loc = [] + """ + merge all field_indexes, use set instead of merge sort to avoid O(n * len(_blocks)) + """ + field_indexes_set = set() + for block in _blocks: + field_indexes_set |= set(block.field_indexes) + new_blocks.append( + Block.get_block_by_type(block_type)(sorted(list(field_indexes_set)), should_compress=True) + ) + + dst_loc_mappings = dict(zip(new_blocks[-1].field_indexes, range(len(new_blocks[-1].field_indexes)))) + for block_id, block in zip(_block_ids, _blocks): + new_field_indexes = [dst_loc_mappings[bid] for bid in block.field_indexes] + block_loc.append((block_id, new_field_indexes)) + + to_compress_block_loc.append((dst_block_id, block_loc)) + + elif _blocks: + non_compressed_block_changes[_block_ids[0]] = len(new_blocks) + new_blocks.append(_blocks[0]) + + return new_blocks, to_compress_block_loc, non_compressed_block_changes + + def reset_blocks(self, blocks): + self._blocks = blocks + self._field_block_mapping = dict() + for bid, blocks in enumerate(self._blocks): + for idx in blocks.field_indexes: + self._field_block_mapping[idx] = (bid, blocks.get_field_offset(idx)) + + def apply( + self, + ): + """ + make some fields to some other type + + """ + + def to_dict(self): + """ + deserialize + """ + return dict(blocks=[blk.to_dict() for blk in self._blocks]) + + @staticmethod + def from_dict(s_dict): + blocks = [Block.from_dict(block_s_dict) for block_s_dict in s_dict["blocks"]] + bm = BlockManager() + bm.blocks = blocks + + return blocks + + def derive_new_block_manager(self, indexes: list) -> Tuple["BlockManager", List[Tuple[int, int, bool, List]]]: + """ + derive a new block manager filter by indexes + + return: list, each element in order: + (src_block, src_block_dst_block, block_changed, old_block_indexes) + """ + block_manager = BlockManager() + block_index_mapping = dict() + + indexes = sorted(indexes) + for src_field_index, dst_field_index in indexes: + bid = self.loc_block(src_field_index, with_offset=False) + if bid not in block_index_mapping: + block_index_mapping[bid] = [] + block_index_mapping[bid].append((src_field_index, dst_field_index)) + + derived_blocks = [] + blocks_loc = [] + + new_block_id = 0 + for bid, field_indexes in block_index_mapping.items(): + block, block_changed, retrieval_indexes = self._blocks[bid].derive_block(field_indexes) + derived_blocks.append(block) + blocks_loc.append((bid, new_block_id, block_changed, retrieval_indexes)) + new_block_id += 1 + + block_manager.reset_blocks(derived_blocks) + + return block_manager, blocks_loc diff --git a/python/fate/arch/dataframe/manager/data_manager.py b/python/fate/arch/dataframe/manager/data_manager.py new file mode 100644 index 0000000000..7a01a03b08 --- /dev/null +++ b/python/fate/arch/dataframe/manager/data_manager.py @@ -0,0 +1,273 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np +from .schema_manager import SchemaManager +from .block_manager import BlockManager +from .block_manager import BlockType +import pandas as pd +from ..entity import types +from typing import Union, List, Tuple +from ..conf.default_config import DATAFRAME_BLOCK_ROW_SIZE + + +class DataManager(object): + def __init__( + self, + schema_manager: SchemaManager = None, + block_manager: BlockManager = None, + block_row_size: int = DATAFRAME_BLOCK_ROW_SIZE + ): + self._schema_manager = schema_manager + self._block_manager = block_manager + self._block_row_size = block_row_size + + @property + def blocks(self): + return self._block_manager.blocks + + @property + def block_num(self): + return len(self._block_manager.blocks) + + @property + def block_row_size(self): + return self._block_row_size + + @property + def schema(self): + return self._schema_manager.schema + + @property + def dtypes(self): + field_name_list = self.get_field_name_list() + dtype_dict = dict() + for name in field_name_list: + block_id = self.loc_block(name, with_offset=False) + dtype_dict[name] = self._block_manager.blocks[block_id].dtype + + return pd.Series(dtype_dict) + + def add_label_or_weight(self, key_type, name, block_type): + field_index, field_index_changes = self._schema_manager.add_label_or_weight(key_type, name, block_type) + self._block_manager.reset_block_field_indexes(field_index_changes) + self._block_manager.append_fields([field_index], block_type, should_compress=False) + + def append_columns(self, columns: List[str], block_types: Union["BlockType", List["BlockType"]]) -> List[int]: + field_indexes = self._schema_manager.append_columns(columns, block_types) + block_indexes = self._block_manager.append_fields(field_indexes, block_types) + + return block_indexes + + def pop_blocks(self, block_indexes: List[int]): + field_indexes = [] + for block_index in block_indexes: + field_indexes.extend(self._block_manager.blocks[block_index].field_indexes) + + field_index_changes = self._schema_manager.pop_fields(field_indexes) + self._block_manager.pop_blocks(block_indexes) + self._block_manager.reset_block_field_indexes(field_index_changes) + + def split_columns(self, columns: List[str], block_types: Union["BlockType", List["BlockType"]]): + field_indexes = self._schema_manager.split_columns(columns, block_types) + narrow_blocks, dst_blocks = self._block_manager.split_fields(field_indexes, block_types) + + return narrow_blocks, dst_blocks + + def duplicate(self) -> "DataManager": + return DataManager( + self._schema_manager.duplicate(), + self._block_manager.duplicate() + ) + + def init_from_local_file(self, sample_id_name, columns, match_id_list, match_id_name, label_name, weight_name, + label_type, weight_type, dtype, default_type=types.DEFAULT_DATA_TYPE, + anonymous_site_name=None): + schema_manager = SchemaManager() + retrieval_index_dict = schema_manager.parse_local_file_schema(sample_id_name, + columns, + match_id_list, + match_id_name, + label_name, + weight_name, + anonymous_site_name=anonymous_site_name) + schema_manager.init_field_types(label_type, weight_type, dtype, + default_type=default_type) + block_manager = BlockManager() + block_manager.initialize_blocks(schema_manager) + + self._schema_manager = schema_manager + self._block_manager = block_manager + + return retrieval_index_dict + + def convert_to_blocks(self, splits): + converted_blocks = [] + for bid, block in enumerate(self._block_manager.blocks): + converted_blocks.append(block.convert_block(splits[bid])) + + return converted_blocks + + def derive_new_data_manager(self, with_sample_id, with_match_id, with_label, with_weight, columns) \ + -> Tuple["DataManager", List[Tuple[int, int, bool, List]]]: + schema_manager, derive_indexes = self._schema_manager.derive_new_schema_manager(with_sample_id=with_sample_id, + with_match_id=with_match_id, + with_label=with_label, + with_weight=with_weight, + columns=columns) + block_manager, blocks_loc = self._block_manager.derive_new_block_manager(derive_indexes) + + return DataManager( + schema_manager=schema_manager, + block_manager=block_manager + ), blocks_loc + + def loc_block(self, name: Union[str, List[str]], with_offset=True): + if isinstance(name, str): + field_index = self._schema_manager.get_field_offset(name) + return self._block_manager.loc_block(field_index, with_offset) + else: + loc_ret = [] + for _name in name: + field_index = self._schema_manager.get_field_offset(_name) + loc_ret.append(self._block_manager.loc_block(field_index, with_offset)) + + return loc_ret + + def fill_anonymous_site_name(self, site_name): + self._schema_manager.fill_anonymous_site_name(site_name) + + def get_fields_loc(self, with_sample_id=True, with_match_id=True, with_label=True, with_weight=True): + field_block_mapping = self._block_manager.field_block_mapping + fields_loc = [[]] * len(field_block_mapping) + for col_id, _block_id_tuple in field_block_mapping.items(): + fields_loc[col_id] = _block_id_tuple + + exclude_indexes = set() + if not with_sample_id and self.schema.sample_id_name: + exclude_indexes.add(self._schema_manager.get_field_offset(self.schema.sample_id_name)) + + if not with_match_id and self.schema.match_id_name: + exclude_indexes.add(self._schema_manager.get_field_offset(self.schema.match_id_name)) + + if not with_label and self.schema.label_name: + exclude_indexes.add(self._schema_manager.get_field_offset(self.schema.label_name)) + + if not with_weight and self.schema.weight_name: + exclude_indexes.add(self._schema_manager.get_field_offset(self.schema.weight_name)) + + if not exclude_indexes: + return fields_loc + + ret_fields_loc = [] + for field_index, field_loc in enumerate(fields_loc): + if field_index not in exclude_indexes: + ret_fields_loc.append(field_loc) + + return ret_fields_loc + + def get_field_name(self, field_index): + return self._schema_manager.get_field_name(field_index) + + def get_field_name_list(self, with_sample_id=True, with_match_id=True, with_label=True, with_weight=True): + return self._schema_manager.get_field_name_list(with_sample_id=with_sample_id, + with_match_id=with_match_id, + with_label=with_label, + with_weight=with_weight) + + def get_field_type_by_name(self, name): + return self._schema_manager.get_field_types(name) + + def get_field_offset(self, name): + return self._schema_manager.get_field_offset(name) + + def get_block(self, block_id): + return self._block_manager.blocks[block_id] + + def infer_operable_blocks(self): + operable_field_offsets = self._schema_manager.infer_operable_filed_offsets() + block_index_set = set(self._block_manager.loc_block(offset, with_offset=False) for offset in operable_field_offsets) + return sorted(list(block_index_set)) + + def infer_operable_field_names(self): + return self._schema_manager.infer_operable_field_names() + + def infer_non_operable_blocks(self): + non_operable_field_offsets = self._schema_manager.infer_non_operable_field_offsets() + block_index_set = set(self._block_manager.loc_block(offset, with_offset=False) for offset in non_operable_field_offsets) + return sorted(list(block_index_set)) + + def try_to_promote_types(self, + block_indexes: List[int], + block_type: Union[bool, list, int, float, np.dtype, BlockType]) -> List[Tuple[int, BlockType]]: + promote_types = [] + if isinstance(block_type, (bool, int, float, np.dtype)): + block_type = BlockType.get_block_type(block_type) + + if isinstance(block_type, BlockType): + for idx, bid in enumerate(block_indexes): + if self.get_block(bid).block_type < block_type: + promote_types.append( + (bid, block_type) + ) + else: + for idx, (bid, r_type) in enumerate(zip(block_indexes, block_type)): + block_type = BlockType.get_block_type(r_type) + if self.get_block(bid).block_type < block_type: + promote_types.append( + (bid, block_type) + ) + + return promote_types + + def promote_types(self, to_promote_blocks: list): + for bid, block_type in to_promote_blocks: + self._block_manager.blocks[bid] = self._block_manager.blocks[bid].convert_block_type(block_type) + for field_index in self._block_manager.blocks[bid].field_indexes: + self._schema_manager.set_field_type_by_offset(field_index, block_type.value) + + def compress_blocks(self): + new_blocks, to_compress_block_loc, non_compress_block_changes = self._block_manager.compress() + if to_compress_block_loc: + self._block_manager.reset_blocks(new_blocks) + + return to_compress_block_loc, non_compress_block_changes + + def rename(self, sample_id_name=None, match_id_name=None, label_name=None, weight_name=None, columns: dict = None): + self._schema_manager.rename(sample_id_name=sample_id_name, + match_id_name=match_id_name, + label_name=label_name, + weight_name=weight_name, + columns=columns) + + def serialize(self): + schema_serialization = self._schema_manager.serialize() + fields = schema_serialization["fields"] + for col_id, field in enumerate(fields): + block_id = self._block_manager.loc_block(col_id, with_offset=False) + should_compress = self._block_manager.blocks[block_id].should_compress + field["should_compress"] = should_compress + + schema_serialization["fields"] = fields + return schema_serialization + + @classmethod + def deserialize(cls, schema_meta): + data_manager = DataManager() + data_manager._schema_manager = SchemaManager.deserialize(schema_meta) + data_manager._block_manager = BlockManager() + data_manager._block_manager.initialize_blocks(data_manager._schema_manager) + + return data_manager diff --git a/python/fate/arch/dataframe/manager/schema_manager.py b/python/fate/arch/dataframe/manager/schema_manager.py new file mode 100644 index 0000000000..bdb763e01f --- /dev/null +++ b/python/fate/arch/dataframe/manager/schema_manager.py @@ -0,0 +1,657 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import copy +from typing import List, Union +import pandas as pd +from .utils import AnonymousGenerator + + +DEFAULT_LABEL_NAME = "label" +DEFAULT_WEIGHT_NAME = "weight" + + +class Schema(object): + def __init__( + self, + sample_id_name=None, + match_id_name=None, + weight_name=None, + label_name=None, + columns: Union[list, pd.Index] = None, + anonymous_label_name=None, + anonymous_weight_name=None, + anonymous_columns: Union[list, pd.Index] = None, + anonymous_summary: dict = None + ): + self._sample_id_name = sample_id_name + self._match_id_name = match_id_name + self._weight_name = weight_name + self._label_name = label_name + self._columns = pd.Index(columns) if columns else pd.Index([]) + self._anonymous_label_name = anonymous_label_name + self._anonymous_weight_name = anonymous_weight_name + self._anonymous_columns = pd.Index(anonymous_columns) if anonymous_columns else pd.Index([]) + self._anonymous_summary = anonymous_summary if anonymous_summary else dict() + + @property + def sample_id_name(self): + return self._sample_id_name + + @sample_id_name.setter + def sample_id_name(self, sample_id_name: str): + self._sample_id_name = sample_id_name + + @property + def match_id_name(self): + return self._match_id_name + + @match_id_name.setter + def match_id_name(self, match_id_name: str): + self._match_id_name = match_id_name + + @property + def weight_name(self): + return self._weight_name + + @weight_name.setter + def weight_name(self, weight_name: str): + self._weight_name = weight_name + + if self.anonymous_weight_name is None: + anonymous_generator = AnonymousGenerator(site_name=self._anonymous_summary["site_name"]) + + self._anonymous_weight_name = anonymous_generator.add_anonymous_weight() + + @property + def anonymous_weight_name(self): + return self._anonymous_weight_name + + @anonymous_weight_name.setter + def anonymous_weight_name(self, anonymous_weight_name: str): + self._anonymous_weight_name = anonymous_weight_name + + @property + def label_name(self): + return self._label_name + + @label_name.setter + def label_name(self, label_name: str): + self._label_name = label_name + + if self._anonymous_label_name is None: + anonymous_generator = AnonymousGenerator(site_name=self._anonymous_summary["site_name"]) + self._anonymous_label_name = anonymous_generator.add_anonymous_label() + + @property + def anonymous_label_name(self): + return self._anonymous_label_name + + @anonymous_label_name.setter + def anonymous_label_name(self, anonymous_label_name): + self._anonymous_label_name = anonymous_label_name + + @property + def columns(self) -> pd.Index: + return self._columns + + @columns.setter + def columns(self, columns: pd.Index): + self._columns = columns + + @property + def anonymous_columns(self) -> pd.Index: + return self._anonymous_columns + + @anonymous_columns.setter + def anonymous_columns(self, anonymous_columns: pd.Index): + self._anonymous_columns = anonymous_columns + + @property + def anonymous_summary(self) -> dict: + return self._anonymous_summary + + @anonymous_summary.setter + def anonymous_summary(self, anonymous_summary): + self._anonymous_summary = anonymous_summary + + def append_columns(self, names): + self._columns = self._columns.append(pd.Index(names)) + # TODO: extend anonymous column + anonymous_generator = AnonymousGenerator(site_name=self._anonymous_summary["site_name"]) + + anonymous_columns, anonymous_summary = anonymous_generator.add_anonymous_columns(names, self._anonymous_summary) + self._anonymous_columns = self._anonymous_columns.append(pd.Index(anonymous_columns)) + self._anonymous_summary = anonymous_summary + + def init_anonymous_names(self, anonymous_site_name): + anonymous_generator = AnonymousGenerator(anonymous_site_name) + anonymous_ret_dict = anonymous_generator.generate_anonymous_names(self) + self._set_anonymous_info_by_dict(anonymous_ret_dict) + + def fill_anonymous_site_name(self, anonymous_site_name): + anonymous_generator = AnonymousGenerator(anonymous_site_name) + anonymous_ret_dict = anonymous_generator.fill_anonymous_site_name( + anonymous_label_name=self.anonymous_label_name, + anonymous_weight_name=self._anonymous_weight_name, + anonymous_columns=self._anonymous_columns, + anonymous_summary=self._anonymous_summary + ) + + self._set_anonymous_info_by_dict(anonymous_ret_dict) + + def _set_anonymous_info_by_dict(self, anonymous_ret_dict): + if self._label_name: + self._anonymous_label_name = anonymous_ret_dict["anonymous_label_name"] + if self._weight_name: + self._anonymous_weight_name = anonymous_ret_dict["anonymous_weight_name"] + if self._columns is not None: + self._anonymous_columns = anonymous_ret_dict["anonymous_columns"] + + self._anonymous_summary = anonymous_ret_dict["anonymous_summary"] + + def pop_columns(self, names): + names = set(names) + if self._label_name in names: + names.remove(self._label_name) + self._label_name = None + if self._weight_name in names: + names.remove(self._weight_name) + self._weight_name = None + + columns = [] + for name in self._columns: + if name not in names: + columns.append(name) + self._columns = pd.Index(columns) + + # TODO: pop anonymous columns + + def __eq__(self, other: "Schema"): + return self.label_name == other.label_name and self.weight_name == other.weight_name \ + and self.sample_id_name == other.sample_id_name and self.match_id_name == other.match_id_name \ + and self.columns.tolist() == other.columns.tolist() + + def serialize(self): + s_obj = list() + s_obj.append( + dict(name=self._sample_id_name, + property="sample_id") + ) + + if self._match_id_name: + s_obj.append( + dict(name=self._match_id_name, + property="match_id") + ) + + if self._label_name: + s_obj.append( + dict(name=self._label_name, + anonymous_name=self._anonymous_label_name, + property="label") + ) + if self._weight_name: + s_obj.append( + dict(name=self._weight_name, + anonymous_name=self._anonymous_weight_name, + property="weight") + ) + + if len(self._columns): + for name, anonymous_name in zip(self._columns, self._anonymous_columns): + s_obj.append( + dict(name=name, + anonymous_name=anonymous_name, + property="column") + ) + + return dict(fields=s_obj, + anonymous_summary=self._anonymous_summary) + + +class SchemaManager(object): + def __init__(self): + self._schema = None + self._type_mapping = dict() + self._name_offset_mapping = dict() + self._offset_name_mapping = dict() + + @property + def schema(self): + return self._schema + + @schema.setter + def schema(self, schema): + self._schema = schema + + def rename(self, sample_id_name=None, match_id_name=None, label_name=None, weight_name=None, columns: dict = None): + attr_dict = { + "sample_id_name": sample_id_name, + "match_id_name": match_id_name, + "label_name": label_name, + "weight_name": weight_name + } + + for attr, value in attr_dict.items(): + if not value: + continue + o_name = getattr(self._schema, attr) + setattr(self._schema, attr, value) + self._rename_single_column(o_name, value) + + if columns: + for o_name, n_name in columns.items(): + self._rename_single_column(o_name, n_name) + + o_columns = self._schema.columns.tolist() + n_columns = [o_name if o_name not in columns else columns[o_name] for o_name in o_columns] + self._schema.columns = pd.Index(n_columns) + + def _rename_single_column(self, src, dst): + if src == dst: + return + + self._type_mapping[dst] = self._type_mapping[src] + self._type_mapping.pop(src) + + self._name_offset_mapping[dst] = self._name_offset_mapping[src] + offset = self._name_offset_mapping.pop(src) + + self._offset_name_mapping[offset] = dst + + def add_label_or_weight(self, key_type, name, block_type): + self._type_mapping[name] = block_type.value + + src_field_names = self.get_field_name_list() + if key_type == "label": + self._schema.label_name = name + else: + self._schema.weight_name = name + + dst_field_names = self.get_field_name_list() + + name_offset_mapping = dict() + offset_name_mapping = dict() + field_index_changes = dict() + + for offset, field_name in enumerate(dst_field_names): + name_offset_mapping[field_name] = offset + offset_name_mapping[offset] = field_name + + for field_name in src_field_names: + src_offset = self._name_offset_mapping[field_name] + dst_offset = name_offset_mapping[field_name] + field_index_changes[src_offset] = dst_offset + + self._name_offset_mapping = name_offset_mapping + self._offset_name_mapping = offset_name_mapping + + return self._name_offset_mapping[name], field_index_changes + + def append_columns(self, names, block_types): + field_index = len(self._name_offset_mapping) + for offset, name in enumerate(names): + if isinstance(block_types, list): + dtype = block_types[offset].value + else: + dtype = block_types.value + + self._type_mapping[name] = dtype + self._name_offset_mapping[name] = field_index + offset + self._offset_name_mapping[field_index + offset] = name + + self.schema.append_columns(names) + + return [field_index + offset for offset in range(len(names))] + + def pop_fields(self, field_indexes): + field_names = [self._offset_name_mapping[field_id] for field_id in field_indexes] + self._schema = copy.deepcopy(self._schema) + self._schema.pop_columns(field_names) + + field_index_set = set(field_indexes) + left_field_indexes = [] + for i in range(len(self._offset_name_mapping)): + if i not in field_index_set: + left_field_indexes.append(i) + + name_offset_mapping = dict() + offset_name_mapping = dict() + field_index_changes = dict() + for dst_field_id, src_field_id in enumerate(left_field_indexes): + name = self._offset_name_mapping[src_field_id] + name_offset_mapping[name] = dst_field_id + offset_name_mapping[dst_field_id] = name + field_index_changes[src_field_id] = dst_field_id + + self._name_offset_mapping = name_offset_mapping + self._offset_name_mapping = offset_name_mapping + + return field_index_changes + + def split_columns(self, names, block_types): + field_indexes = [self._name_offset_mapping[name] for name in names] + for offset, name in enumerate(names): + if isinstance(block_types, list): + self._type_mapping[name] = block_types[offset].value + else: + self._type_mapping[name] = block_types.value + + return field_indexes + + def duplicate(self): + dup_schema_manager = SchemaManager() + dup_schema_manager.schema = copy.deepcopy(self._schema) + dup_schema_manager._name_offset_mapping = copy.deepcopy(self._name_offset_mapping) + dup_schema_manager._type_mapping = copy.deepcopy(self._type_mapping) + dup_schema_manager._offset_name_mapping = copy.deepcopy(self._offset_name_mapping) + + return dup_schema_manager + + def get_all_keys(self): + return list(self._name_offset_mapping.keys()) + + def parse_local_file_schema(self, sample_id_name, columns, match_id_list, match_id_name, label_name, weight_name, + anonymous_site_name=None): + column_indexes = list(range(len(columns))) + + match_id_index, label_index, weight_index = None, None, None + if match_id_list: + if match_id_name and match_id_name not in match_id_list: + raise ValueError(f"{match_id_name} not exist match_id_list={match_id_list}") + if not match_id_name and len(match_id_list) > 1: + raise ValueError(f"Multi match id exists, specify one to be used") + + match_id_name = match_id_list[0] + elif match_id_name: + match_id_list = [match_id_name] + + if match_id_name: + match_id_index = self.extract_column_index_by_name(columns, column_indexes, match_id_name) + match_id_list.pop(match_id_list.index(match_id_name)) + if label_name: + label_index = self.extract_column_index_by_name(columns, column_indexes, label_name) + if weight_name: + weight_index = self.extract_column_index_by_name(columns, column_indexes, weight_name) + + for id_name in match_id_list: + idx = columns.index(id_name) + columns.pop(idx) + column_indexes.pop(idx) + + self._schema = Schema( + sample_id_name=sample_id_name, + match_id_name=match_id_name, + weight_name=weight_name, + label_name=label_name, + columns=columns + ) + + self._schema.init_anonymous_names(anonymous_site_name) + self.init_name_mapping() + + return dict( + match_id_index=match_id_index, + label_index=label_index, + weight_index=weight_index, + column_indexes=column_indexes + ) + + def fill_anonymous_site_name(self, anonymous_site_name): + self._schema.fill_anonymous_site_name(anonymous_site_name) + + @staticmethod + def extract_column_index_by_name(columns, column_indexes, name, drop=True): + try: + idx = columns.index(name) + ret = column_indexes[idx] + if drop: + columns.pop(idx) + column_indexes.pop(idx) + + return ret + except ValueError: + raise ValueError(f"{name} does not exist in {columns}") + + def init_field_types(self, label_type="float32", weight_type="float32", dtype="float32", + default_type="float32", match_id_type="index", sample_id_type="index"): + self._type_mapping[self._schema.sample_id_name] = "index" + + if self._schema.match_id_name: + self._type_mapping[self._schema.match_id_name] = "index" + + if self._schema.label_name: + self._type_mapping[self._schema.label_name] = label_type + + if self._schema.weight_name: + self._type_mapping[self._schema.weight_name] = weight_type + + if isinstance(dtype, str): + for column in self._schema.columns: + self._type_mapping[column] = dtype + elif isinstance(dtype, dict): + for column in self._schema.columns: + self._type_mapping[column] = dtype.get(column, default_type) + + def init_name_mapping(self): + offset = 0 + + if self._schema.sample_id_name: + offset = 1 + self._name_offset_mapping[self._schema.sample_id_name] = 0 + + if self._schema.match_id_name: + self._name_offset_mapping[self._schema.match_id_name] = offset + offset += 1 + + if self._schema.label_name: + self._name_offset_mapping[self._schema.label_name] = offset + offset += 1 + + if self._schema.weight_name: + self._name_offset_mapping[self._schema.weight_name] = offset + offset += 1 + + if len(self._schema.columns): + for idx, column_name in enumerate(self._schema.columns): + self._name_offset_mapping[column_name] = offset + idx + + for column_name, idx in self._name_offset_mapping.items(): + self._offset_name_mapping[idx] = column_name + + def get_field_offset(self, name): + if name not in self._name_offset_mapping: + raise ValueError(f"{name} does not exist in schema") + + return self._name_offset_mapping[name] + + def get_field_name(self, offset): + if offset >= len(self._offset_name_mapping): + raise ValueError(f"Offset={offset} is out out bound") + + return self._offset_name_mapping[offset] + + def get_field_name_list(self, with_sample_id=True, with_match_id=True, with_label=True, with_weight=True): + field_names = [] + if with_sample_id and self._schema.sample_id_name: + field_names.append(self._schema.sample_id_name) + + if with_match_id and self._schema.match_id_name: + field_names.append(self._schema.match_id_name) + + if with_label and self._schema.label_name: + field_names.append(self._schema.label_name) + + if with_weight and self._schema.weight_name: + field_names.append(self._schema.weight_name) + + field_names += self._schema.columns.tolist() + + return field_names + + def get_field_types(self, name=None, flatten=False): + if not name: + if not flatten: + return self._type_mapping + else: + types = [None] * len(self._type_mapping) + for idx, name in self._offset_name_mapping: + types[idx] = self._type_mapping[name] + return types + else: + return self._type_mapping[name] + + def set_field_type_by_offset(self, field_index, field_type): + name = self._offset_name_mapping[field_index] + self._type_mapping[name] = field_type + + def derive_new_schema_manager(self, with_sample_id=True, with_match_id=True, + with_label=True, with_weight=True, columns: Union[str, list] = None): + derived_schema_manager = SchemaManager() + derived_schema = Schema() + + indexes = [] + + derived_schema.anonymous_summary = self._schema.anonymous_summary + if with_sample_id: + derived_schema.sample_id_name = self._schema.sample_id_name + + if with_match_id: + derived_schema.match_id_name = self._schema.match_id_name + + if with_label: + derived_schema.label_name = self._schema.label_name + derived_schema.anonymous_label_name = self._schema.anonymous_label_name + + if with_weight: + derived_schema.weight_name = self._schema.weight_name + derived_schema.anonymous_weight_name = self._schema.anonymous_weight_name + + if columns: + if isinstance(columns, str): + columns = [columns] + + derived_columns = [] + derived_anonymous_columns = [] + + anonymous_mappings = dict(zip(self._schema.columns, self._schema.anonymous_columns)) + for column in columns: + anonymous_column = anonymous_mappings[column] + derived_columns.append(column) + derived_anonymous_columns.append(anonymous_column) + + derived_schema.columns = pd.Index(derived_columns) + derived_schema.anonymous_columns = pd.Index(derived_anonymous_columns) + + derived_schema_manager.schema = derived_schema + derived_schema_manager.init_name_mapping() + + for name in derived_schema_manager.get_all_keys(): + indexes.append((self.get_field_offset(name), derived_schema_manager.get_field_offset(name))) + derived_schema_manager._type_mapping[name] = self.get_field_types(name) + + return derived_schema_manager, sorted(indexes) + + def infer_operable_field_names(self) -> List[str]: + if len(self._schema.columns): + return self._schema.columns.tolist() + elif self._schema.weight_name: + return [self._schema.weight_name] + else: + return [self._schema.label_name] + + def infer_operable_filed_offsets(self) -> List[int]: + operable_field_names = self.infer_operable_field_names() + operable_field_offsets = [self._name_offset_mapping[field_name] for field_name in operable_field_names] + + return operable_field_offsets + + def infer_non_operable_field_names(self) -> List[str]: + operable_field_name_set = set(self.infer_operable_field_names()) + non_operable_field_names = [] + if self._schema.sample_id_name: + non_operable_field_names.append(self._schema.sample_id_name) + + if self._schema.match_id_name: + non_operable_field_names.append(self._schema.match_id_name) + + if self._schema.label_name and self._schema.label_name not in operable_field_name_set: + non_operable_field_names.append(self._schema.label_name) + + if self._schema.weight_name and self._schema.weight_name not in operable_field_name_set: + non_operable_field_names.append(self._schema.weight_name) + + return non_operable_field_names + + def infer_non_operable_field_offsets(self) -> List[int]: + non_operable_field_names = self.infer_non_operable_field_names() + non_operable_field_offsets = [self._name_offset_mapping[field_name] for field_name in non_operable_field_names] + + return non_operable_field_offsets + + def serialize(self): + """ + fields: list of field, + field: + name: column_name + anonymous_name: anonymous_column_name + dtype: data_type + property: sample_id/match_id/label/weight/column + """ + schema_serialization = self._schema.serialize() + + schema_serialization_with_type = [] + for field in schema_serialization["fields"]: + field["dtype"] = self._type_mapping[field["name"]] + schema_serialization_with_type.append(field) + + schema_serialization["fields"] = schema_serialization_with_type + return schema_serialization + + @classmethod + def deserialize(cls, schema_meta: dict): + schema_manager = SchemaManager() + + schema_dict = dict( + columns=[], + anonymous_columns=[], + anonymous_summary=schema_meta["anonymous_summary"] + ) + type_dict = {} + for field in schema_meta["fields"]: + p = field["property"] + name = field["name"] + if field["property"] in ["sample_id", "match_id", "weight", "label"]: + schema_dict[f"{p}_name"] = name + type_dict[f"{p}_type"] = field["dtype"] + if p in ["label", "weight"]: + schema_dict[f"anonymous_{p}_name"] = field["anonymous_name"] + else: + schema_dict["columns"].append(name) + + if "dtype" not in type_dict: + type_dict["dtype"] = dict() + type_dict["dtype"][name] = field["dtype"] + + if "anonymous_name" in field: + schema_dict["anonymous_columns"].append(field["anonymous_name"]) + + schema = Schema(**schema_dict) + schema_manager.schema = schema + schema_manager.init_name_mapping() + schema_manager.init_field_types(**type_dict) + + return schema_manager diff --git a/python/fate/arch/dataframe/storage/__init__.py b/python/fate/arch/dataframe/manager/utils/__init__.py similarity index 90% rename from python/fate/arch/dataframe/storage/__init__.py rename to python/fate/arch/dataframe/manager/utils/__init__.py index 8b1973877d..340504900d 100644 --- a/python/fate/arch/dataframe/storage/__init__.py +++ b/python/fate/arch/dataframe/manager/utils/__init__.py @@ -12,5 +12,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ._index import Index -from ._value_store import ValueStore +# +from ._anonymous_generator import AnonymousGenerator diff --git a/python/fate/arch/dataframe/manager/utils/_anonymous_generator.py b/python/fate/arch/dataframe/manager/utils/_anonymous_generator.py new file mode 100644 index 0000000000..0df5a88b34 --- /dev/null +++ b/python/fate/arch/dataframe/manager/utils/_anonymous_generator.py @@ -0,0 +1,125 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import copy +import pandas as pd + + +ANONYMOUS_COLUMN_PREFIX = "x" +ANONYMOUS_LABEL = "y" +ANONYMOUS_WEIGHT = "weight" +SPLICES = "_" +DEFAULT_SITE_NAME = "AnonymousRole_AnonymousPartyId" + + +class AnonymousGenerator(object): + def __init__(self, site_name=None): + self._site_name = site_name + + def _generate_anonymous_column(self, suf): + if self._site_name: + return SPLICES.join([self._site_name, suf]) + else: + return SPLICES.join([DEFAULT_SITE_NAME, suf]) + + def generate_anonymous_names(self, schema): + column_len = len(schema.columns.tolist()) + anonymous_label_name = None + anonymous_weight_name = None + + anonymous_columns = [self._generate_anonymous_column( + ANONYMOUS_COLUMN_PREFIX + str(i)) for i in range(column_len)] + + if schema.label_name: + anonymous_label_name = self._generate_anonymous_column(ANONYMOUS_LABEL) + + if schema.weight_name: + anonymous_weight_name = self._generate_anonymous_column(ANONYMOUS_WEIGHT) + + return dict( + anonymous_label_name=anonymous_label_name, + anonymous_weight_name=anonymous_weight_name, + anonymous_columns=anonymous_columns, + anonymous_summary=dict(column_len=column_len, + site_name=self._site_name + ) + ) + + def _check_site_name_consistency(self, anonymous_summary): + anonymous_site_name = anonymous_summary["site_name"] + + if anonymous_site_name and self._site_name is not None and anonymous_site_name != self._site_name: + raise ValueError(f"previous_site_name={anonymous_site_name} != current_site_name={self._site_name}") + + def add_anonymous_label(self): + return self._generate_anonymous_column(ANONYMOUS_LABEL) + + def add_anonymous_weight(self): + return self._generate_anonymous_column(ANONYMOUS_WEIGHT) + + def add_anonymous_columns(self, columns, anonymous_summary: dict): + self._check_site_name_consistency(anonymous_summary) + anonymous_summary = copy.deepcopy(anonymous_summary) + + column_len = anonymous_summary["column_len"] + anonymous_columns = [self._generate_anonymous_column(ANONYMOUS_COLUMN_PREFIX + str(i + column_len)) + for i in range(len(columns))] + + anonymous_summary["column_len"] = column_len + len(columns) + return anonymous_columns, anonymous_summary + + def fill_anonymous_site_name(self, anonymous_label_name, anonymous_weight_name, + anonymous_columns, anonymous_summary): + anonymous_summary = copy.deepcopy(anonymous_summary) + + self._check_site_name_consistency(anonymous_summary) + + if anonymous_summary["site_name"] is None: + anonymous_label_name = self._fill_site_name(anonymous_label_name) + anonymous_weight_name = self._fill_site_name(anonymous_weight_name) + anonymous_columns = self._fill_site_name(anonymous_columns) + anonymous_summary["site_name"] = self._site_name + + return dict( + anonymous_label_name=anonymous_label_name, + anonymous_weight_name=anonymous_weight_name, + anonymous_columns=anonymous_columns, + anonymous_summary=anonymous_summary + ) + + def _fill_site_name(self, name): + if name is None: + return name + + if isinstance(name, str): + site_name_pre, site_name_suf, suf = name.split(SPLICES, 2) + site_name = SPLICES.join([site_name_pre, site_name_suf]) + + if site_name != DEFAULT_SITE_NAME: + raise ValueError(f"To fill anonymous names with site_name, it shouldn't be fill before") + return self._generate_anonymous_column(suf) + else: + name = list(name) + ret = [] + for _name in name: + site_name_pre, site_name_suf, suf = _name.split(SPLICES, 2) + site_name = SPLICES.join([site_name_pre, site_name_suf]) + + if site_name != DEFAULT_SITE_NAME: + raise ValueError(f"To fill anonymous names with site_name, it shouldn't be fill before") + + ret.append(self._generate_anonymous_column(suf)) + + return pd.Index(ret) diff --git a/python/fate/arch/dataframe/ops/__init__.py b/python/fate/arch/dataframe/ops/__init__.py index 063d19301e..ae946a49c4 100644 --- a/python/fate/arch/dataframe/ops/__init__.py +++ b/python/fate/arch/dataframe/ops/__init__.py @@ -12,8 +12,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ._arithmetic import arith_method -from ._predict_result_transformaton import transform_to_predict_result -from ._stat import stat_method - -__all__ = ["arith_method", "transform_to_predict_result", "stat_method"] diff --git a/python/fate/arch/dataframe/ops/_activation.py b/python/fate/arch/dataframe/ops/_activation.py new file mode 100644 index 0000000000..02a75b61cb --- /dev/null +++ b/python/fate/arch/dataframe/ops/_activation.py @@ -0,0 +1,51 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools + +from .._dataframe import DataFrame +from ..manager import DataManager +from ..manager.block_manager import BlockType + + +def sigmoid(df: "DataFrame") -> "DataFrame": + data_manager = df.data_manager.duplicate() + operable_blocks = data_manager.infer_operable_blocks() + non_operable_blocks = data_manager.infer_non_operable_blocks() + for block_id in operable_blocks: + if not data_manager.blocks[block_id].is_numeric(): + raise ValueError("Sigmoid support only operates on numeric columns") + if data_manager.blocks[block_id].block_type in [BlockType.int32, BlockType.int64]: + data_manager.blocks[block_id] = data_manager.blocks[block_id].convert_block_type(BlockType.float32) + + def _sigmoid(blocks, op_blocks=None, reserved_blocks=None): + ret_blocks = [[] for i in range(len(op_blocks) + len(reserved_blocks))] + for bid in reserved_blocks: + ret_blocks[bid] = blocks[bid] + + for bid in op_blocks: + ret_blocks[bid] = blocks[bid].sigmoid() + + return ret_blocks + + _sigmoid_func = functools.partial(_sigmoid, op_blocks=operable_blocks, reserved_blocks=non_operable_blocks) + + block_table = df.block_table.mapValues(_sigmoid_func) + + return DataFrame( + df._ctx, + block_table, + df.partition_order_mappings, + data_manager + ) diff --git a/python/fate/arch/dataframe/ops/_apply_row.py b/python/fate/arch/dataframe/ops/_apply_row.py new file mode 100644 index 0000000000..503a91b2f3 --- /dev/null +++ b/python/fate/arch/dataframe/ops/_apply_row.py @@ -0,0 +1,161 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +import pandas as pd +import torch + +from collections.abc import Iterable + +from .._dataframe import DataFrame +from ..manager.block_manager import Block, BlockType +from ..manager.data_manager import DataManager +from ..utils._auto_column_name_generated import generated_default_column_names + + +def apply_row(df: "DataFrame", func, + columns: list=None, with_label=False, with_weight=False, + enable_type_align_checking=True) -> "DataFrame": + """ + In current version, assume that the apply_row results' lengths are equal + """ + data_manager = df.data_manager + dst_data_manager, _ = data_manager.derive_new_data_manager(with_sample_id=True, + with_match_id=True, + with_label=not with_weight, + with_weight=not with_weight, + columns=None) + + non_operable_field_names = dst_data_manager.get_field_name_list() + non_operable_blocks = [data_manager.loc_block(field_name, + with_offset=False) for field_name in non_operable_field_names] + operable_blocks = data_manager.infer_operable_blocks() + is_numeric = True + for bid in operable_blocks: + if not data_manager.get_block(bid).is_numeric(): + is_numeric = False + break + block_column_in_orders = list() + if is_numeric: + for bid in operable_blocks: + field_indexes = data_manager.get_block(bid).field_indexes + block_column_in_orders.extend([data_manager.get_field_name(field_index) for field_index in field_indexes]) + + fields_loc = data_manager.get_fields_loc(with_sample_id=False, with_match_id=False, + with_label=with_label, with_weight=with_weight) + + fields_name = data_manager.get_field_name_list(with_sample_id=False, + with_match_id=False, + with_label=with_label, + with_weight=with_weight) + + _apply_func = functools.partial(_apply, func=func, src_operable_blocks=operable_blocks, src_field_names=fields_name, + src_fields_loc=fields_loc, src_non_operable_blocks=non_operable_blocks, + ret_columns=columns, dst_dm=dst_data_manager, is_numeric=is_numeric, + need_shuffle=True if block_column_in_orders == fields_name else False, + block_column_in_orders=block_column_in_orders, + enable_type_align_checking=enable_type_align_checking) + + dst_block_table_with_dm = df.block_table.mapValues(_apply_func) + + dst_data_manager = dst_block_table_with_dm.first()[1][1] + dst_block_table = dst_block_table_with_dm.mapValues(lambda blocks_with_dm: blocks_with_dm[0]) + + return DataFrame( + df._ctx, + dst_block_table, + df.partition_order_mappings, + dst_data_manager + ) + + +def _apply(blocks, func=None, src_operable_blocks=None, src_field_names=None, + src_fields_loc=None, src_non_operable_blocks=None, ret_columns=None, + dst_dm: "DataManager"=None, is_numeric=True, + block_column_in_orders=None, + need_shuffle=False, enable_type_align_checking=True): + dm = dst_dm.duplicate() + apply_blocks = [] + + if is_numeric: + apply_data = [] + for bid in src_operable_blocks: + apply_data.append(blocks[bid]) + apply_data = torch.hstack(apply_data) + apply_data = pd.DataFrame(apply_data, columns=block_column_in_orders) + if need_shuffle: + apply_data = apply_data[src_field_names] + else: + lines = len(blocks[0]) + flat_blocks = [Block.transform_block_to_list(block) for block in blocks] + apply_data = [[] for _ in range(lines)] + for bid, offset in src_fields_loc: + for lid in range(lines): + apply_data[lid].append(flat_blocks[bid][lid][offset]) + + apply_data = pd.DataFrame(apply_data, columns=src_field_names) + + apply_ret = apply_data.apply(lambda row: func(row), axis=1).values.tolist() + + if isinstance(apply_ret[0], Iterable): + first_row = list(apply_ret[0]) + ret_column_len = len(first_row) + block_types = [BlockType.np_object if BlockType.is_arr(value) else BlockType.get_block_type(value) for value in first_row] + apply_blocks = [[] for _ in range(ret_column_len)] + for ret in apply_ret: + for idx, value in enumerate(ret): + apply_blocks[idx].append([value]) + + if enable_type_align_checking: + block_type = BlockType.np_object if BlockType.is_arr(value) else BlockType.get_block_type(value) + if block_types[idx] < block_type: + block_types[idx] = block_type + else: + block_types = [BlockType.np_object if BlockType.is_arr(apply_ret[0]) else BlockType.get_block_type(apply_ret[0])] + apply_blocks.append([[ret] for ret in apply_ret]) + ret_column_len = 1 + + if enable_type_align_checking: + for ret in apply_ret: + block_type = BlockType.np_object if BlockType.is_arr(ret) else BlockType.get_block_type(ret) + if block_types[0] < block_type: + block_types[0] = block_type + + + if not ret_columns: + ret_columns = generated_default_column_names(ret_column_len) + + block_indexes = dm.append_columns( + ret_columns, block_types + ) + + ret_blocks = [[] for _ in range(len(src_non_operable_blocks) + ret_column_len)] + for idx, bid in enumerate(src_non_operable_blocks): + ret_blocks[idx] = blocks[bid] + + for idx, bid in enumerate(block_indexes): + if dm.blocks[bid].is_phe_tensor(): + single_value = apply_blocks[idx][0][0] + dm.blocks[bid].set_extra_kwargs(pk=single_value.pk, + evaluator=single_value.evaluator, + coder=single_value.coder, + dtype=single_value.dtype, + device=single_value.device) + ret = [v[0]._data for v in apply_blocks[idx]] + ret_blocks[bid] = dm.blocks[bid].convert_block(ret) + # ret_blocks[bid] = dm.blocks[bid].convert_to_phe_tensor(ret, shape=(len(ret), 1)) + else: + ret_blocks[bid] = dm.blocks[bid].convert_block(apply_blocks[idx]) + + return ret_blocks, dm diff --git a/python/fate/arch/dataframe/ops/_arithmetic.py b/python/fate/arch/dataframe/ops/_arithmetic.py index 6476621ef3..23c66e211f 100644 --- a/python/fate/arch/dataframe/ops/_arithmetic.py +++ b/python/fate/arch/dataframe/ops/_arithmetic.py @@ -14,18 +14,96 @@ # limitations under the License. import numpy as np import pandas as pd -import torch -from fate.arch import tensor +from fate.arch.computing import is_table +from .._dataframe import DataFrame +from ._promote_types import promote_types +from .utils.series_align import series_to_ndarray +from .utils.operators import binary_operate -def arith_method(lhs, rhs, op): - if isinstance(rhs, pd.Series): - rhs = tensor.tensor(torch.tensor(rhs.tolist(), dtype=getattr(torch, str(rhs.dtype)))) - elif isinstance(rhs, (int, float, np.int, np.int32, np.int64, np.float, np.float32, np.float64)): - pass - elif hasattr(rhs, "values"): - rhs = rhs.values +def arith_operate(lhs: DataFrame, rhs, op) -> "DataFrame": + data_manager = lhs.data_manager.duplicate() + block_indexes = data_manager.infer_operable_blocks() + column_names = data_manager.infer_operable_field_names() + + if isinstance(rhs, DataFrame): + rhs_column_names = rhs.data_manager.infer_operable_field_names() + if len(column_names) != len(rhs_column_names) or len(column_names) > 1: + raise ValueError(f"Operation={op} of two dataframe should have same column length=1") + + rhs_block_id = rhs.data_manager.infer_operable_blocks()[0] + block_table = _operate(lhs.block_table, rhs.block_table, op, block_indexes, rhs_block_id) + to_promote_blocks = data_manager.try_to_promote_types(block_indexes, + rhs.data_manager.get_block(rhs_block_id).block_type) + elif isinstance(rhs, (np.ndarray, list, pd.Series)): + if isinstance(rhs, pd.Series): + rhs = series_to_ndarray(rhs, column_names) + if isinstance(rhs, list): + rhs = np.array(rhs) + if len(rhs.shape) > 2: + raise ValueError("NdArray's Dimension should <= 2") + if len(column_names) != rhs.size: + raise ValueError(f"Size of List/NDArray should = {len(lhs.schema.columns)}") + rhs = rhs.reshape(-1) + field_indexes = [data_manager.get_field_offset(name) for name in column_names] + field_indexes_mappings = dict(zip(field_indexes, range(len(field_indexes)))) + rhs_blocks = [np.array([]) for i in range(data_manager.block_num)] + rhs_types = [] + for bid in block_indexes: + indexer = [field_indexes_mappings[field] for field in data_manager.get_block(bid).field_indexes] + rhs_blocks[bid] = rhs[indexer] + rhs_types.append(rhs_blocks[bid].dtype) + + block_table = binary_operate(lhs.block_table, rhs_blocks, op, block_indexes) + to_promote_blocks = data_manager.try_to_promote_types(block_indexes, rhs_types) + + elif isinstance(rhs, (bool, int, float, np.int32, np.float32, np.int64, np.float64, np.bool_)): + block_table = binary_operate(lhs.block_table, rhs, op, block_indexes) + to_promote_blocks = data_manager.try_to_promote_types(block_indexes, rhs) + else: + raise ValueError(f"Operation={op} between dataframe and {type(rhs)} is not implemented") + + if to_promote_blocks: + block_table, data_manager = promote_types(block_table, data_manager, to_promote_blocks) + + return type(lhs) ( + lhs._ctx, + block_table, + lhs.partition_order_mappings, + data_manager + ) + + +def _operate(lhs, rhs, op, block_indexes, rhs_block_id=None): + block_index_set = set(block_indexes) + if isinstance(rhs, list): + op_ret = lhs.mapValues( + lambda blocks: + [ + op(blocks[bid], rhs[bid]) if bid in block_index_set + else blocks[bid] + for bid in range(len(blocks)) + ] + ) + elif isinstance(rhs, (bool, int, float, np.int32, np.float32, np.int64, np.float64, np.bool_)): + op_ret = lhs.mapValues( + lambda blocks: + [ + op(blocks[bid], rhs) if bid in block_index_set + else blocks[bid] + for bid in range(len(blocks)) + ] + ) + elif is_table(rhs): + op_ret = lhs.join(rhs, + lambda blocks1, blocks2: + [ + op(blocks1[bid], blocks2[rhs_block_id]) if bid in block_index_set + else blocks1[bid] + for bid in range(len(blocks1)) + ] + ) else: - raise ValueError(f"{op.__name__} between DataFrame and {type(rhs)} is not supported") + raise ValueError(f"Not implement type between dataframe nad {type(rhs)}") - return op(lhs, rhs) + return op_ret diff --git a/python/fate/arch/dataframe/ops/_cmp.py b/python/fate/arch/dataframe/ops/_cmp.py new file mode 100644 index 0000000000..58450c200b --- /dev/null +++ b/python/fate/arch/dataframe/ops/_cmp.py @@ -0,0 +1,125 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pandas as pd +import torch +from .._dataframe import DataFrame +from .._dataframe import DataManager +from ..manager import BlockType +from .utils.operators import binary_operate +from .utils.series_align import series_to_ndarray +from ._compress_block import compress_blocks + + +def cmp_operate(lhs: DataFrame, rhs, op) -> "DataFrame": + data_manager = lhs.data_manager + block_indexes = data_manager.infer_operable_blocks() + column_names = data_manager.infer_operable_field_names() + + if isinstance(rhs, (bool, int, float, np.int32, np.float32, np.int64, np.float64, np.bool_)): + block_table = binary_operate(lhs.block_table, rhs, op, block_indexes) + + elif isinstance(rhs, (np.ndarray, list, pd.Series)): + if isinstance(rhs, pd.Series): + rhs = series_to_ndarray(rhs, column_names) + if isinstance(rhs, list): + rhs = np.array(rhs) + if len(rhs.shape) > 2: + raise ValueError("NdArray's Dimension should <= 2") + if len(column_names) != rhs.size: + raise ValueError(f"Size of List/NDArray/Series should = {len(lhs.schema.columns)}") + rhs = rhs.reshape(-1) + field_indexes = [data_manager.get_field_offset(name) for name in column_names] + field_indexes_mappings = dict(zip(field_indexes, range(len(field_indexes)))) + rhs_blocks = [np.array([]) for _ in range(data_manager.block_num)] + for bid in block_indexes: + indexer = [field_indexes_mappings[field] for field in data_manager.get_block(bid).field_indexes] + if BlockType.is_tensor(data_manager.get_block(bid).block_type): + rhs_blocks[bid] = torch.Tensor(rhs[indexer]) + else: + rhs_blocks[bid] = rhs[indexer] + + block_table = binary_operate(lhs.block_table, rhs_blocks, op, block_indexes) + + elif isinstance(rhs, DataFrame): + other_data_manager = rhs.data_manager + other_column_names = other_data_manager.infer_operable_field_names() + if set(column_names) != set(other_column_names): + raise ValueError("Comparison of two DataFrame should be identically-labeled") + lhs_block_loc = [data_manager.loc_block(name) for name in column_names] + rhs_block_loc = [other_data_manager.loc_block(name) for name in column_names] + field_indexes = [data_manager.get_field_offset(name) for name in column_names] + field_indexes_mappings = dict(zip(field_indexes, range(len(field_indexes)))) + indexers = [ + [field_indexes_mappings[field] for field in data_manager.get_block(bid).field_indexes] + for bid in block_indexes + ] + + block_table = _cmp_dfs(lhs.block_table, rhs.block_table, op, lhs_block_loc, rhs_block_loc, + block_indexes, indexers) + else: + raise ValueError(f"Not implement comparison of rhs type={type(rhs)}") + + block_table, data_manager = _merge_bool_blocks(block_table, data_manager, block_indexes) + return type(lhs)( + lhs._ctx, + block_table, + lhs.partition_order_mappings, + data_manager + ) + + +def _merge_bool_blocks(block_table, data_manager: DataManager, block_indexes): + """ + all blocks are bool type, they should be merge into one blocks + """ + dst_data_manager = data_manager.duplicate() + to_promote_types = [] + for bid in block_indexes: + to_promote_types.append((bid, BlockType.bool)) + + dst_data_manager.promote_types(to_promote_types) + dst_block_table, dst_data_manager = compress_blocks(block_table, dst_data_manager) + + return dst_block_table, dst_data_manager + + +def _cmp_dfs(lhs_block_table, rhs_block_table, op, + lhs_block_loc, rhs_block_loc, + block_indexes, indexers): + + block_index_set = set(block_indexes) + + def _cmp_partition(l_blocks, r_blocks): + ret_blocks = [[] for i in range(l_blocks)] + for bid in range(len(l_blocks)): + if bid not in block_index_set: + ret_blocks[bid] = l_blocks[bid] + + for bid, indexer in zip(block_indexes, indexers): + cmp_ret = torch.empty(l_blocks[bid].shape) + for idx in indexer: + _, l_offset = lhs_block_loc[idx] + r_bid, r_offset = rhs_block_loc[idx] + cmp_ret[:, l_offset] = op(l_blocks[bid][l_offset], r_blocks[r_bid][r_offset]) + + ret_blocks[bid] = cmp_ret + + return ret_blocks + + block_table = lhs_block_table.join(rhs_block_table, + _cmp_partition) + + return block_table diff --git a/python/fate/arch/dataframe/ops/_compress_block.py b/python/fate/arch/dataframe/ops/_compress_block.py new file mode 100644 index 0000000000..c1f955c3eb --- /dev/null +++ b/python/fate/arch/dataframe/ops/_compress_block.py @@ -0,0 +1,60 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np +import torch +from ..manager import BlockType +from ..manager import DataManager +from ..conf.default_config import BLOCK_COMPRESS_THRESHOLD + + +def compress_blocks(block_table, data_manager: DataManager, force_compress=False): + compressed_data_manager = data_manager.duplicate() + to_compress_block_loc, non_compress_block_changes = compressed_data_manager.compress_blocks() + + compress_block_size = 0 + for _, block_loc in to_compress_block_loc: + compress_block_size += len(block_loc) + + if not to_compress_block_loc or (not force_compress and compress_block_size <= BLOCK_COMPRESS_THRESHOLD): + return block_table, data_manager + + def _compress(blocks): + ret_blocks = [[] for _ in range(compressed_data_manager.block_num)] + for src_bid, dst_bid in non_compress_block_changes.items(): + ret_blocks[dst_bid] = blocks[src_bid] + + lines = len(blocks[0]) + for dst_bid, block_loc in to_compress_block_loc: + block = compressed_data_manager.get_block(dst_bid) + field_len = len(block.field_indexes) + # TODO: empty block create logic should move to block_manager later, + # we pull it here as block_manager has more type like phe_tensor/pd.Index, which should not be considered in compressing + if BlockType.is_tensor(block.block_type): + block_buf = np.empty((lines, field_len), dtype=getattr(np, block.block_type.value)) + else: + block_buf = np.empty((lines, field_len), dtype=object) + + for src_bid, field_indexes in block_loc: + block_buf[:, field_indexes] = blocks[src_bid] + + if isinstance(block_buf, np.ndarray): + ret_blocks[dst_bid] = torch.from_numpy(block_buf) + + return ret_blocks + + block_table = block_table.mapValues(_compress) + + return block_table, compressed_data_manager diff --git a/python/fate/arch/tensor/types/_lstorage.py b/python/fate/arch/dataframe/ops/_data_overview.py similarity index 53% rename from python/fate/arch/tensor/types/_lstorage.py rename to python/fate/arch/dataframe/ops/_data_overview.py index 64446ec599..351511bfbd 100644 --- a/python/fate/arch/tensor/types/_lstorage.py +++ b/python/fate/arch/dataframe/ops/_data_overview.py @@ -12,24 +12,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Protocol - -from fate.arch.unify import device +# +from .._dataframe import DataFrame +from ..ops._transformer import transform_block_to_list -from ._dtype import dtype -from ._shape import Shape +def collect_data(df: DataFrame, num=100): + data = [] + fields_loc = df.data_manager.get_fields_loc() + for block_id, blocks in df.block_table.collect(): + data_list = transform_block_to_list(blocks, fields_loc) -class LStorage(Protocol): - device: device - dtype: dtype - shape: Shape + if len(data) + len(data_list) <= num: + data.extend(data_list) + else: + data.extend(data_list[:num]) - def tolist(self): - ... + num -= len(data_list) - def transpose(self) -> "LStorage": - ... + if num <= 0: + break - def to_local(self) -> "LStorage": - ... + return data diff --git a/python/fate/arch/dataframe/ops/_dimension_scaling.py b/python/fate/arch/dataframe/ops/_dimension_scaling.py new file mode 100644 index 0000000000..43dcbabe87 --- /dev/null +++ b/python/fate/arch/dataframe/ops/_dimension_scaling.py @@ -0,0 +1,395 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import functools +from typing import List, Union +import torch +from sklearn.utils import resample +from .._dataframe import DataFrame +from ..manager.data_manager import DataManager +from ..manager.block_manager import Block +from ._compress_block import compress_blocks +from ._indexer import get_partition_order_by_raw_table, get_partition_order_mappings_by_block_table +from ._promote_types import promote_partial_block_types +from ._set_item import set_item +from fate.arch.tensor import DTensor + + +def hstack(data_frames: List["DataFrame"]) -> "DataFrame": + if len(data_frames) == 1: + return data_frames[0] + + l_df = DataFrame(data_frames[0]._ctx, + data_frames[0].block_table, + data_frames[0].partition_order_mappings, + data_frames[0].data_manager.duplicate()) + + column_set = set(l_df.schema.columns) + for r_df in data_frames[1:]: + other_column_set = set(r_df.schema.columns) + if column_set & other_column_set: + raise ValueError("Hstack does not support duplicate columns") + + set_item(l_df, r_df.schema.columns.tolist(), r_df, 1) + + column_set |= other_column_set + + data_manager = l_df.data_manager + block_table = l_df.block_table + + block_table, data_manager = compress_blocks(block_table, data_manager) + + l_df.block_table = block_table + l_df.data_manager = data_manager + + return l_df + + +def vstack(data_frames: List["DataFrame"]) -> "DataFrame": + frame_0 = data_frames[0] + data_frames = list(filter(lambda df: df.shape[0], data_frames)) + if len(data_frames) <= 1: + return frame_0 if not data_frames else data_frames[0] + + def _align_blocks(blocks, align_fields_loc=None, full_block_migrate_set=None, dst_dm: DataManager = None): + ret_blocks, lines = [], None + for dst_bid, block in enumerate(dst_dm.blocks): + _field_indexes = block.field_indexes + _src_bid = align_fields_loc[_field_indexes[0]][0] + if _src_bid in full_block_migrate_set: + ret_blocks.append(blocks[_src_bid]) + else: + _align_block = [] + lines = len(blocks[0]) if lines is None else lines + for lid in range(lines): + row = [] + for _field_index in _field_indexes: + _src_bid, _offset = align_fields_loc[_field_index] + row.append(blocks[_src_bid][lid][_offset].item() if isinstance(blocks[_src_bid], torch.Tensor) + else blocks[_src_bid][lid][_offset]) + + _align_block.append(row) + + ret_blocks.append(dst_dm.blocks[dst_bid].convert_block(_align_block)) + + return ret_blocks + + l_df = data_frames[0] + data_manager = l_df.data_manager.duplicate() + l_fields_loc = data_manager.get_fields_loc() + l_field_names = data_manager.get_field_name_list() + l_field_types = [data_manager.get_block(_bid).block_type for _bid, _ in l_fields_loc] + l_block_table = l_df.block_table + type_change = False + for r_df in data_frames[1:]: + if set(l_df.schema.columns) != set(r_df.schema.columns): + raise ValueError("vstack of dataframes should have same schemas") + + for idx, field_name in enumerate(l_field_names): + block_type = r_df.data_manager.get_block( + r_df.data_manager.loc_block(field_name, with_offset=False)).block_type + if block_type > l_field_types[idx]: + l_field_types[idx] = block_type + type_change = True + + if type_change: + changed_fields, changed_block_types, changed_fields_loc = [], [], [] + changed_block_types = [] + for idx in range(len(l_field_names)): + field_name, block_type, (bid, offset) = l_field_names[idx], l_field_types[idx], l_fields_loc[idx] + if block_type != data_manager.get_block(bid).block_type: + changed_fields.append(field_name) + changed_block_types.append(block_type) + changed_fields_loc.append((bid, offset)) + + narrow_blocks, dst_blocks = data_manager.split_columns(changed_fields, changed_block_types) + l_block_table = promote_partial_block_types(l_block_table, narrow_blocks=narrow_blocks, dst_blocks=dst_blocks, + data_manager=data_manager, dst_fields_loc=changed_fields_loc) + + for r_df in data_frames[1:]: + r_field_names = r_df.data_manager.get_field_name_list() + r_fields_loc = r_df.data_manager.get_fields_loc() + r_field_types = [data_manager.get_block(_bid).block_type for _bid, _ in r_fields_loc] + r_type_change = False if l_field_types != r_field_types else True + r_block_table = r_df.block_table + if l_field_names != r_field_names or r_type_change: + shuffle_r_fields_loc, full_migrate_set = [() for _ in range(len(r_field_names))], set() + for field_name, loc in zip(r_field_names, r_fields_loc): + l_offset = data_manager.get_field_offset(field_name) + shuffle_r_fields_loc[l_offset] = loc + + for bid in range(r_df.data_manager.block_num): + r_field_indexes = r_df.data_manager.get_block(bid).field_indexes + field_indexes = [data_manager.get_field_offset(r_field_names[idx]) for idx in r_field_indexes] + l_bid = data_manager.loc_block(r_field_names[r_field_indexes[0]], with_offset=False) + if field_indexes == data_manager.get_block(l_bid).field_indexes: + full_migrate_set.add(bid) + + _align_func = functools.partial(_align_blocks, align_fields_loc=shuffle_r_fields_loc, + full_block_migrate_set=full_migrate_set, dst_dm=data_manager) + r_block_table = r_block_table.mapValues(_align_func) + + l_block_table = l_block_table.union( + r_block_table, + lambda l_blocks, r_blocks: [ + Block.vstack([l_block, r_block]) for l_block, r_block in zip(l_blocks, r_blocks) + ] + ) + + partition_order_mappings = get_partition_order_mappings_by_block_table(l_block_table, data_manager.block_row_size) + _balance_block_func = functools.partial(_balance_blocks, + partition_order_mappings=partition_order_mappings, + block_row_size=data_manager.block_row_size) + l_block_table = l_block_table.mapPartitions(_balance_block_func, + use_previous_behavior=False) + l_block_table, data_manager = compress_blocks(l_block_table, data_manager) + + return DataFrame( + l_df._ctx, + l_block_table, + partition_order_mappings, + data_manager + ) + + +def drop(df: "DataFrame", index: "DataFrame" = None) -> "DataFrame": + if index.shape[0] == 0: + return DataFrame( + df._ctx, + block_table=df.block_table, + partition_order_mappings=copy.deepcopy(df.partition_order_mappings), + data_manager=df.data_manager.duplicate() + ) + + if index.shape[0] == df.shape[0]: + return df.empty_frame() + + data_manager = df.data_manager.duplicate() + l_flatten_func = functools.partial( + _flatten_partition, + block_num=data_manager.block_num + ) + l_flatten_table = df.block_table.mapPartitions(l_flatten_func, use_previous_behavior=False) + + r_flatten_func = functools.partial(_flatten_partition_without_value) + r_flatten_table = index.block_table.mapPartitions(r_flatten_func, use_previous_behavior=False) + + drop_flatten = l_flatten_table.subtractByKey(r_flatten_table) + partition_order_mappings = get_partition_order_by_raw_table( + drop_flatten, data_manager.block_row_size + ) if drop_flatten.count() else dict() + + _convert_to_block_func = functools.partial(to_blocks, + dm=data_manager, + partition_mappings=partition_order_mappings) + + block_table = drop_flatten.mapPartitions(_convert_to_block_func, + use_previous_behavior=False) + + return DataFrame( + df._ctx, + block_table, + partition_order_mappings, + data_manager + ) + + +def sample(df: "DataFrame", n=None, frac: float =None, random_state=None) -> "DataFrame": + """ + only support down sample, n should <= df.shape, or fact = 1 + """ + + if n is not None and frac is not None: + raise ValueError("sample's parameters n and frac should not be set in the same time.") + + if frac is not None: + if frac > 1: + raise ValueError(f"sample's parameter frac={frac} should <= 1.0") + n = max(1, int(df.shape[0] * frac)) + + if n > df.shape[0]: + raise ValueError(f"sample's parameter n={n} > data size={df.shape[0]}") + + if n == 0: + raise ValueError(f"sample's parameter n={n} should >= 1") + + indexer = list(df.get_indexer(target="sample_id").collect()) + sample_indexer = resample(indexer, replace=False, n_samples=n, random_state=random_state) + + sample_indexer = df._ctx.computing.parallelize(sample_indexer, + include_key=True, + partition=df.block_table.partitions) + + sample_frame = df.loc(sample_indexer) + + return sample_frame + + +def retrieval_row(df: "DataFrame", indexer: Union["DTensor", "DataFrame"]): + if isinstance(indexer, DTensor) and indexer.shape[1] != 1: + raise ValueError("Row indexing by DTensor should have only one column filling with True/False") + elif isinstance(indexer, DataFrame): + operable_field_len = len(indexer.data_manager.infer_operable_field_names()) + if operable_field_len != 1: + raise ValueError("Row indexing by DataFrame should have only one column filling with True/False") + + data_manager = df.data_manager.duplicate() + + def _block_counter(kvs, value_type="tensor"): + size = 0 + first_block_id = None + for k, value in kvs: + if first_block_id is None: + first_block_id = k + + if value_type == "tensor": + size += value.sum().item() + else: + size += len(value[0]) + + return first_block_id, size + + if isinstance(indexer, DataFrame): + _block_counter_func = functools.partial(_block_counter, value_type="dataframe") + block_info = sorted([summary[1] for summary in indexer.block_table.applyPartitions(_block_counter_func).collect()]) + else: + block_info = sorted([summary[1] for summary in indexer.shardings._data.applyPartitions(_block_counter).collect()]) + block_order_mappings = dict() + start_index = 0 + acc_block_num = 0 + for block_id, block_size in block_info: + block_num = (block_size + data_manager.block_row_size - 1) // data_manager.block_row_size + block_order_mappings[block_id] = dict( + start_index=start_index, + end_index=start_index + block_size - 1, + start_block_id=acc_block_num, + end_block_id=acc_block_num + block_num - 1 + ) + start_index += block_size + acc_block_num += block_num + + if start_index == 0: + return df.empty_frame() + + if isinstance(indexer, DataFrame): + bid = indexer.data_manager.infer_operable_blocks()[0] + block_table = df.block_table.join(indexer.block_table, lambda v1, v2: (v1, v2[bid])) + else: + block_table = df.block_table.join(indexer.shardings._data, lambda v1, v2: (v1, v2)) + + _balance_block_func = functools.partial(_balance_blocks_with_index, + partition_order_mappings=block_order_mappings, + data_manager=data_manager) + block_table = block_table.mapPartitions(_balance_block_func, + use_previous_behavior=False) + block_table, data_manager = compress_blocks(block_table, data_manager) + partition_order_mappings = get_partition_order_mappings_by_block_table(block_table, data_manager.block_row_size) + + return DataFrame( + df._ctx, + block_table, + partition_order_mappings, + data_manager + ) + + +def _flatten_partition_without_value(kvs): + for block_id, blocks in kvs: + for sample_id in blocks[0]: + yield sample_id, [] + + +def _flatten_partition(kvs, block_num=0): + for block_id, blocks in kvs: + flat_blocks = [Block.transform_block_to_list(block) for block in blocks] + lines = len(flat_blocks[0]) + for i in range(lines): + yield flat_blocks[0][i], [flat_blocks[j][i] for j in range(1, block_num)] + + +def _balance_blocks(kvs, partition_order_mappings: dict=None, block_row_size: int=None): + block_id = None + previous_blocks = list() + for _, blocks in kvs: + if block_id is None and len(blocks[0]): + sample_id = blocks[0][0] + block_id = partition_order_mappings[sample_id]["start_block_id"] + + if previous_blocks: + blocks = [Block.vstack([pre_block, block]) for pre_block, block in zip(previous_blocks, blocks)] + previous_blocks = list() + + row_size = len(blocks[0]) + for i in range(0, row_size, block_row_size): + if row_size - i < block_row_size: + previous_blocks = [block[i: row_size] for block in blocks] + else: + yield block_id, [block[i: i + block_row_size] for block in blocks] + block_id += 1 + + if previous_blocks: + yield block_id, previous_blocks + + if block_id is None: + return [] + + +def _balance_blocks_with_index(kvs, partition_order_mappings: dict=None, data_manager: DataManager=None): + block_id = None + block_num = data_manager.block_num + ret_blocks = [[] for _ in range(block_num)] + block_size = 0 + for _, (blocks, t) in kvs: + if block_id is None: + block_id = partition_order_mappings[_]["start_block_id"] + + flat_blocks = [Block.transform_block_to_list(block) for block in blocks] + for i, v in enumerate(t): + v = v.item() + if not v: + continue + + block_size += 1 + for j in range(block_num): + ret_blocks[j].append(flat_blocks[j][i]) + + if block_size == data_manager.block_row_size: + yield block_id, data_manager.convert_to_blocks(ret_blocks) + block_size = 0 + block_id += 1 + ret_blocks = [[] for _ in range(block_num)] + + if block_size: + yield block_id, data_manager.convert_to_blocks(ret_blocks) + + +def to_blocks(kvs, dm: DataManager = None, partition_mappings: dict = None): + ret_blocks = [[] for _ in range(dm.block_num)] + + block_id = None + for lid, (sample_id, value) in enumerate(kvs): + if block_id is None: + block_id = partition_mappings[sample_id]["start_block_id"] + ret_blocks[0].append(sample_id) + for bid, buf in enumerate(value): + ret_blocks[bid + 1].append(buf) + + if (lid + 1) % dm.block_row_size == 0: + yield block_id, dm.convert_to_blocks(ret_blocks) + ret_blocks = [[] for _ in range(dm.block_num)] + block_id += 1 + + if ret_blocks[0]: + yield block_id, dm.convert_to_blocks(ret_blocks) diff --git a/python/fate/arch/dataframe/ops/_encoder.py b/python/fate/arch/dataframe/ops/_encoder.py new file mode 100644 index 0000000000..044c2163b7 --- /dev/null +++ b/python/fate/arch/dataframe/ops/_encoder.py @@ -0,0 +1,247 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import functools + +import pandas as pd +import numpy as np +import torch +from sklearn.preprocessing import OneHotEncoder +from typing import Union +from ._compress_block import compress_blocks +from .._dataframe import DataFrame +from ..manager import BlockType, DataManager + + +BUCKETIZE_RESULT_TYPE = "int32" + + +def get_dummies(df: "DataFrame", dtype="int32"): + data_manager = df.data_manager + block_indexes = data_manager.infer_operable_blocks() + field_names = data_manager.infer_operable_field_names() + + if len(field_names) != 1: + raise ValueError(f"get_dummies only support single column, but {len(field_names)} columns are found.") + + categories = _get_categories(df.block_table, block_indexes)[0][0] + dst_field_names = ["_".join(map(str, [field_names[0], c])) for c in categories] + dst_data_manager = data_manager.duplicate() + dst_data_manager.pop_blocks(block_indexes) + dst_data_manager.append_columns(dst_field_names, block_types=BlockType.get_block_type(dtype)) + + block_table = _one_hot_encode(df.block_table, block_indexes, dst_data_manager, [[categories]], dtype=dtype) + + return DataFrame( + df._ctx, + block_table, + partition_order_mappings=df.partition_order_mappings, + data_manager=dst_data_manager + ) + + +def _get_categories(block_table, block_indexes): + block_index_set = set(block_indexes) + + def _mapper(blocks): + categories_ = [] + for bid, block in enumerate(blocks): + if bid not in block_index_set: + continue + + enc = OneHotEncoder() + cate_block = enc.fit(block).categories_ + categories_.append([set(cate) for cate in cate_block]) + + return categories_ + + def _reducer(categories1_, categories2_): + categories_ = [] + for cate_block1, cate_block2 in zip(categories1_, categories2_): + cate_block = [cate1 | cate2 for cate1, cate2 in zip(cate_block1, cate_block2)] + categories_.append(cate_block) + + return categories_ + + categories = block_table.mapValues(_mapper).reduce(_reducer) + + categories = [[sorted(cate) for cate in cate_block] for cate_block in categories] + + return categories + + +def _one_hot_encode(block_table, block_indexes, data_manager, categories, dtype): + categories = [np.array(category) for category in categories] + block_index_set = set(block_indexes) + + def _encode(blocks): + ret_blocks = [] + enc_blocks = [] + idx = 0 + for bid, block in enumerate(blocks): + if bid not in block_index_set: + ret_blocks.append(block) + continue + + enc = OneHotEncoder(dtype=dtype) + enc.fit([[1]]) # one hot encoder need to fit first. + enc.categories_ = categories[idx] + idx += 1 + enc_blocks.append(enc.transform(block).toarray()) + + ret_blocks.append(data_manager.blocks[-1].convert_block(np.hstack(enc_blocks))) + + return ret_blocks + + return block_table.mapValues(_encode) + + +def bucketize(df: DataFrame, boundaries: Union[pd.DataFrame, dict]): + if isinstance(boundaries, pd.DataFrame): + boundaries = dict([(_name, boundaries[_name].tolist()) for _name in boundaries]) + elif not isinstance(boundaries, dict): + raise ValueError("boundaries should be pd.DataFrame or dict") + + data_manager = df.data_manager.duplicate() + field_names = list(filter(lambda field_name: field_name in boundaries, data_manager.infer_operable_field_names())) + blocks_loc = data_manager.loc_block(field_names) + + _boundaries_list = [] + for name, (_bid, _) in zip(field_names, blocks_loc): + if BlockType.is_tensor(data_manager.blocks[_bid].block_type): + _boundary = torch.tensor(boundaries[name]) + _boundary[-1] = torch.inf + else: + _boundary = np.array(boundaries[name]) + _boundary[-1] = np.inf + + _boundaries_list.append((_bid, _, _boundary)) + + narrow_blocks, dst_blocks = data_manager.split_columns( + field_names, BlockType.get_block_type(BUCKETIZE_RESULT_TYPE) + ) + + def _mapper( + blocks, boundaries_list: list = None, narrow_loc: list = None, dst_bids: list = None, dm: DataManager = None + ): + ret_blocks = [] + for block in blocks: + if isinstance(block, torch.Tensor): + ret_blocks.append(block.clone()) + elif isinstance(block, np.ndarray): + ret_blocks.append(block.copy()) + else: + ret_blocks.append(block) + + for i in range(len(ret_blocks), dm.block_num): + ret_blocks.append([]) + + for bid, offsets in narrow_loc: + ret_blocks[bid] = ret_blocks[bid][:, offsets] + + for dst_bid, (src_bid, src_offset, boundary) in zip(dst_bids, boundaries_list): + if isinstance(blocks[src_bid], torch.Tensor): + ret = torch.bucketize(blocks[src_bid][:, [src_offset]], boundary, out_int32=False) + else: + ret = torch.bucketize(blocks[src_bid][:, [src_offset]], boundary) + + ret_blocks[dst_bid] = dm.blocks[dst_bid].convert_block(ret) + + return ret_blocks + + bucketize_mapper = functools.partial( + _mapper, boundaries_list=_boundaries_list, narrow_loc=narrow_blocks, dst_bids=dst_blocks, dm=data_manager + ) + + block_table = df.block_table.mapValues(bucketize_mapper) + + block_indexes = data_manager.infer_operable_blocks() + if len(block_indexes) > 1: + to_promote_types = [] + for bid in block_indexes: + to_promote_types.append((bid, BlockType.get_block_type(BUCKETIZE_RESULT_TYPE))) + + data_manager.promote_types(to_promote_types) + block_table, data_manager = compress_blocks(block_table, data_manager) + + return DataFrame( + df._ctx, block_table, partition_order_mappings=df.partition_order_mappings, data_manager=data_manager + ) + + +def bucketize(df: DataFrame, boundaries: Union[pd.DataFrame, dict]): + if isinstance(boundaries, pd.DataFrame): + boundaries = dict([(_name, boundaries[_name].tolist()) for _name in boundaries]) + elif not isinstance(boundaries, dict): + raise ValueError("boundaries should be pd.DataFrame or dict") + + data_manager = df.data_manager.duplicate() + field_names = list(filter(lambda field_name: field_name in boundaries, data_manager.infer_operable_field_names())) + blocks_loc = data_manager.loc_block(field_names) + + _boundaries_list = [] + for name, (_bid, _) in zip(field_names, blocks_loc): + if BlockType.is_tensor(data_manager.blocks[_bid].block_type): + _boundary = torch.tensor(boundaries[name], dtype=torch.float64) + _boundary[-1] = torch.inf + else: + _boundary = np.array(boundaries[name], dtype=np.float64) + _boundary[-1] = np.inf + + _boundaries_list.append((_bid, _, _boundary)) + + narrow_blocks, dst_blocks = data_manager.split_columns( + field_names, BlockType.get_block_type(BUCKETIZE_RESULT_TYPE) + ) + + def _mapper( + blocks, boundaries_list: list = None, narrow_loc: list = None, dst_bids: list = None, dm: DataManager = None + ): + ret_blocks = [block for block in blocks] + + for i in range(len(ret_blocks), dm.block_num): + ret_blocks.append([]) + + for bid, offsets in narrow_loc: + ret_blocks[bid] = ret_blocks[bid][:, offsets] + + for dst_bid, (src_bid, src_offset, boundary) in zip(dst_bids, boundaries_list): + if isinstance(blocks[src_bid], torch.Tensor): + ret = torch.bucketize(blocks[src_bid][:, [src_offset]], boundary, out_int32=False) + else: + ret = np.digitize(blocks[src_bid][:, [src_offset]], boundary) + + ret_blocks[dst_bid] = dm.blocks[dst_bid].convert_block(ret) + + return ret_blocks + + bucketize_mapper = functools.partial( + _mapper, boundaries_list=_boundaries_list, narrow_loc=narrow_blocks, dst_bids=dst_blocks, dm=data_manager + ) + + block_table = df.block_table.mapValues(bucketize_mapper) + + block_indexes = data_manager.infer_operable_blocks() + if len(block_indexes) > 1: + to_promote_types = [] + for _bid in block_indexes: + to_promote_types.append((_bid, data_manager.get_block(_bid).block_type)) + + data_manager.promote_types(to_promote_types) + block_table, data_manager = compress_blocks(block_table, data_manager) + + return DataFrame( + df._ctx, block_table, partition_order_mappings=df.partition_order_mappings, data_manager=data_manager + ) diff --git a/python/fate/arch/dataframe/ops/_field_extract.py b/python/fate/arch/dataframe/ops/_field_extract.py new file mode 100644 index 0000000000..af83c6a4fa --- /dev/null +++ b/python/fate/arch/dataframe/ops/_field_extract.py @@ -0,0 +1,49 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .._dataframe import DataFrame + + +def field_extract(df: "DataFrame", with_sample_id=True, with_match_id=True, with_weight=True, + with_label=True, columns=None): + """ + blocks_loc: list, each element: (src_block_id, dst_block_id, changed=True/False, block_indexes) + """ + def _extract_columns(src_blocks): + extract_blocks = [None] * len(blocks_loc) + + for src_block_id, dst_block_id, is_changed, block_column_indexes in blocks_loc: + block = src_blocks[src_block_id] + if is_changed: + extract_blocks[dst_block_id] = block[:, block_column_indexes] + else: + extract_blocks[dst_block_id] = block + + return extract_blocks + + data_manager, blocks_loc = df.data_manager.derive_new_data_manager( + with_sample_id=with_sample_id, + with_match_id=with_match_id, + with_label=with_label, + with_weight=with_weight, + columns=columns + ) + extract_table = df.block_table.mapValues(_extract_columns) + + return DataFrame( + df._ctx, + extract_table, + partition_order_mappings=df.partition_order_mappings, + data_manager=data_manager + ) diff --git a/python/fate/arch/dataframe/ops/_fillna.py b/python/fate/arch/dataframe/ops/_fillna.py new file mode 100644 index 0000000000..af17490dd4 --- /dev/null +++ b/python/fate/arch/dataframe/ops/_fillna.py @@ -0,0 +1,90 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np +import pandas as pd +import torch +from .._dataframe import DataFrame + + +def fillna(df: DataFrame, value, downcast=None): + data_manager = df.data_manager + block_indexes = data_manager.infer_operable_blocks() + if isinstance(value, (int, float, np.int32, np.int64, np.float32, np.float64)): + block_table = _fillna(df.block_table, value, block_indexes) + elif isinstance(value, (list, pd.Series, dict)): + if isinstance(value, list): + column_names = data_manager.infer_operable_field_names() + if len(value) != column_names: + raise ValueError("fillna's list length should have identical column shape") + value = dict(zip(column_names, value)) + elif isinstance(value, pd.Series): + value = value.to_dict() + value_indexers = dict() + for column_name, fill_value in value.items(): + bid, offset = data_manager.loc_block(column_name) + if bid not in value_indexers: + value_indexers[bid] = dict() + value_indexers[bid][offset] = fill_value + + block_table = _fillna(df.block_table, value_indexers, block_indexes) + + else: + raise ValueError(f"Not support value type={type(value)}") + + return DataFrame( + df._ctx, + block_table, + df.partition_order_mappings, + data_manager.duplicate() + ) + + +def _fillna(block_table, value, block_indexes): + block_index_set = set(block_indexes) + if isinstance(value, (int, float, np.int32, np.int64, np.float32, np.float64)): + def _fill(blocks): + ret_blocks = [] + for bid, block in enumerate(blocks): + if bid not in block_index_set: + ret_blocks.append(block) + elif isinstance(block, torch.Tensor): + ret_blocks.append(torch.nan_to_num(block, value)) + elif isinstance(block, np.ndarray): + ret_blocks.append(np.nan_to_num(block, value)) + + return ret_blocks + + return block_table.mapValues(_fill) + else: + def _fill_with_dict(blocks): + ret_blocks = [] + for bid, block in enumerate(blocks): + if bid not in block_index_set: + ret_blocks.append(block) + elif isinstance(block, torch.Tensor): + block = block.clone() + for offset, fill_value in value.get(bid, {}).items(): + block[:, offset] = torch.nan_to_num(block[:, offset], fill_value) + ret_blocks.append(block) + elif isinstance(block, np.ndarray): + block = block.copy() + for offset, fill_value in value.get(bid, {}).items(): + block[:, offset] = np.nan_to_num(block[:, offset], fill_value) + ret_blocks.append(block) + + return ret_blocks + + return block_table.mapValues(_fill_with_dict) diff --git a/python/fate/arch/dataframe/ops/_histogram.py b/python/fate/arch/dataframe/ops/_histogram.py new file mode 100644 index 0000000000..2793d7e46e --- /dev/null +++ b/python/fate/arch/dataframe/ops/_histogram.py @@ -0,0 +1,103 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import functools +import typing +from typing import Union + + +from .._dataframe import DataFrame +from ..manager import BlockType, DataManager +from ._compress_block import compress_blocks + +if typing.TYPE_CHECKING: + from fate.arch.histogram import DistributedHistogram, HistogramBuilder + + +def distributed_hist_stat(df: DataFrame, histogram_builder: "HistogramBuilder", position: DataFrame, targets: Union[dict, DataFrame]) -> "DistributedHistogram": + block_table, data_manager = _try_to_compress_table(df.block_table, df.data_manager, force_compress=True) + data_block_id = data_manager.infer_operable_blocks()[0] + position_block_id = position.data_manager.infer_operable_blocks()[0] + + if isinstance(targets, dict): + def _pack_data_with_position(l_blocks, r_blocks, l_block_id=None, r_block_id=None): + return l_blocks[l_block_id], r_blocks[r_block_id], dict() + + def _pack_with_target(l_values, r_value, target_name): + l_values[2][target_name] = r_value + + return l_values + + _pack_func = functools.partial(_pack_data_with_position, + l_block_id=data_block_id, + r_block_id=position_block_id) + + data_with_position = block_table.join(position.block_table, _pack_func) + + for name, target in targets.items(): + _pack_with_target_func = functools.partial(_pack_with_target, target_name=name) + data_with_position = data_with_position.join(target.shardings._data, _pack_with_target_func) + else: + data_with_position = block_table.join( + position.block_table, + lambda l_blocks, r_blocks: (l_blocks[data_block_id], r_blocks[position_block_id]) + ) + + target_data_manager = targets.data_manager + target_field_names = target_data_manager.infer_operable_field_names() + fields_loc = target_data_manager.loc_block(target_field_names, with_offset=True) + + def _pack_with_targets(l_blocks, r_blocks): + target_blocks = dict() + for field_name, (block_id, offset) in zip(target_field_names, fields_loc): + if (block := target_data_manager.get_block(block_id)).is_phe_tensor(): + target_blocks[field_name] = block.convert_to_phe_tensor( + r_blocks[block_id], + shape=(len(r_blocks[0]), 1) + ) + else: + target_blocks[field_name] = r_blocks[block_id][:, [offset]] + + return l_blocks[0], l_blocks[1], target_blocks + + data_with_position = data_with_position.join(targets.block_table, _pack_with_targets) + + return histogram_builder.statistic(data_with_position) + + +def _try_to_compress_table(block_table, data_manager: DataManager, force_compress=False): + block_indexes = data_manager.infer_operable_blocks() + if len(block_indexes) == 1: + return block_table, data_manager + + block_type = None + for block_id in block_indexes: + _type = data_manager.get_block(block_id).block_type + if not BlockType.is_integer(_type): + raise ValueError("To use hist interface, indexes type should be integer >= 0") + + if not block_type: + block_type = _type + elif block_type < _type: + block_type = _type + + to_promote_types = [] + for bid in block_indexes: + to_promote_types.append((bid, block_type)) + + data_manager.promote_types(to_promote_types) + block_table, data_manager = compress_blocks(block_table, data_manager, force_compress=force_compress) + + return block_table, data_manager diff --git a/python/fate/arch/dataframe/ops/_indexer.py b/python/fate/arch/dataframe/ops/_indexer.py new file mode 100644 index 0000000000..bc41d2a1ba --- /dev/null +++ b/python/fate/arch/dataframe/ops/_indexer.py @@ -0,0 +1,390 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import functools +import uuid + +from ..manager import Block, DataManager +from .._dataframe import DataFrame + + +def transform_to_table(block_table, block_index, partition_order_mappings): + def _convert_to_order_index(kvs): + for block_id, blocks in kvs: + for _idx, _id in enumerate(blocks[block_index]): + yield _id, (block_id, _idx) + + return block_table.mapPartitions(_convert_to_order_index, + use_previous_behavior=False) + + +def get_partition_order_mappings_by_block_table(block_table, block_row_size): + def _block_counter(kvs): + partition_key = None + size = 0 + first_block_id = '' + + for k, v in kvs: + if size == 0 and len(v[0]): + partition_key = v[0][0] + + if first_block_id == '': + first_block_id = k + + size += len(v[0]) + + if size == 0: + partition_key = str(first_block_id) + "_" + str(uuid.uuid1()) + + return partition_key, size + + block_info = sorted([summary[1] for summary in block_table.applyPartitions(_block_counter).collect()]) + block_order_mappings = dict() + start_index = 0 + acc_block_num = 0 + for block_key, block_size in block_info: + block_num = (block_size + block_row_size - 1) // block_row_size + block_order_mappings[block_key] = dict( + start_index=start_index, + end_index=start_index + block_size - 1, + start_block_id=acc_block_num, + end_block_id=acc_block_num + block_num - 1 + ) + start_index += block_size + acc_block_num += block_num + + return block_order_mappings + + +def get_partition_order_by_raw_table(table, block_row_size, key_type="sample_id"): + def _get_block_summary(kvs): + try: + if key_type == "sample_id": + key = next(kvs)[0] + else: + key = next(kvs)[1][0] + + block_size = 1 + sum(1 for kv in kvs) + except StopIteration: + key, block_size = 0, 0 + + return {key: block_size} + + block_summary = table.applyPartitions(_get_block_summary).reduce(lambda blk1, blk2: {**blk1, **blk2}) + + start_index, acc_block_num = 0, 0 + block_order_mappings = dict() + + if not block_summary: + return block_order_mappings + + for blk_key, blk_size in sorted(block_summary.items()): + block_num = (blk_size + block_row_size - 1) // block_row_size + block_order_mappings[blk_key] = dict( + start_index=start_index, + end_index=start_index + blk_size - 1, + start_block_id=acc_block_num, + end_block_id=acc_block_num + block_num - 1 + ) + + start_index += blk_size + acc_block_num += block_num + + return block_order_mappings + + +def regenerated_sample_id(block_table, regenerated_sample_id_table, data_manager): + """ + regenerated_sample_id_table: (sample_id, ([new_id_list])) + """ + from ._dimension_scaling import _flatten_partition + + _flatten_func = functools.partial(_flatten_partition, block_num=data_manager.block_num) + raw_table = block_table.mapPartitions(_flatten_func, use_previous_behavior=False) + regenerated_table = raw_table.join(regenerated_sample_id_table, lambda lhs, rhs: (lhs, rhs)) + + def _flat_id(key, value): + content, id_list = value + flat_ret = [] + for _id in id_list: + flat_ret.append((_id, content)) + + return flat_ret + + _flat_id_func = functools.partial(_flat_id) + regenerated_table = regenerated_table.flatMap(_flat_id_func) + + return regenerated_table + + +def _merge_list(lhs, rhs): + if not lhs: + return rhs + if not rhs: + return lhs + + l_len = len(lhs) + r_len = len(rhs) + ret = [None] * (l_len + r_len) + i, j, k = 0, 0, 0 + while i < l_len and j < r_len: + if lhs[i][0] < rhs[j][0]: + ret[k] = lhs[i] + i += 1 + else: + ret[k] = rhs[j] + j += 1 + + k += 1 + + while i < l_len: + ret[k] = lhs[i] + i += 1 + k += 1 + + while j < r_len: + ret[k] = rhs[j] + j += 1 + k += 1 + + return ret + + +def flatten_data(df: DataFrame, key_type="block_id", with_sample_id=True): + """ + key_type="block_id": + key=(block_id, block_offset), value=data_row + key_type="sample_id": + key=sample_id, value=data_row + """ + sample_id_index = df.data_manager.loc_block( + df.data_manager.schema.sample_id_name, with_offset=False + ) if (with_sample_id or key_type == "sample_id") else None + + def _flatten(kvs): + for block_id, blocks in kvs: + flat_blocks = [Block.transform_block_to_list(block) for block in blocks] + block_num = len(flat_blocks) + if key_type == "block_id": + for row_id in range(len(blocks[0])): + if with_sample_id: + yield (block_id, row_id), ( + flat_blocks[sample_id_index][row_id], + [flat_blocks[i][row_id] for i in range(block_num)] + ) + else: + yield (block_id, row_id), [flat_blocks[i][row_id] for i in range(block_num)] + else: + for row_id in range(len(blocks[0])): + yield flat_blocks[sample_id_index][row_id], [flat_blocks[i][row_id] for i in range(block_num)] + + if key_type in ["block_id", "sample_id"]: + return df.block_table.mapPartitions(_flatten, use_previous_behavior=False) + else: + raise ValueError(f"Not Implement key_type={key_type} of flatten_data.") + + +def transform_flatten_data_to_df(ctx, flatten_table, data_manager: DataManager, key_type, value_with_sample_id=True): + partition_order_mappings = get_partition_order_by_raw_table(flatten_table, + data_manager.block_row_size, + key_type=key_type) + block_num = data_manager.block_num + + def _convert_to_blocks(kvs): + bid = None + ret_blocks = [[] for _ in range(block_num)] + + lid = 0 + for _, value in kvs: + if value_with_sample_id: + data = value[1] + else: + data = value + lid += 1 + if bid is None: + sample_id = data[0] + bid = partition_order_mappings[sample_id]["start_block_id"] + + for i in range(block_num): + ret_blocks[i].append(data[i]) + + if lid % data_manager.block_row_size == 0: + ret_blocks = [data_manager.blocks[i].convert_block(block) for i, block in enumerate(ret_blocks)] + yield bid, ret_blocks + bid += 1 + ret_blocks = [[] for _ in range(block_num)] + + if lid % data_manager.block_row_size: + ret_blocks = [data_manager.blocks[i].convert_block(block) for i, block in enumerate(ret_blocks)] + yield bid, ret_blocks + + block_table = flatten_table.mapPartitions(_convert_to_blocks, use_previous_behavior=False) + + return DataFrame( + ctx=ctx, + block_table=block_table, + partition_order_mappings=partition_order_mappings, + data_manager=data_manager.duplicate() + ) + + +def loc(df: DataFrame, indexer, target="sample_id", preserve_order=False): + """ + indexer: table, key=sample_id, value=(block_id, block_offset) + """ + if target != "sample_id": + raise ValueError(f"Only target=sample_id is supported, but target={target} is found") + flatten_table = flatten_data(df, key_type="sample_id") + if not preserve_order: + flatten_table = flatten_table.join(indexer, lambda v1, v2: v1) + if not flatten_table.count(): + return df.empty_frame() + return transform_flatten_data_to_df(df._ctx, flatten_table, df.data_manager, + key_type="sample_id", value_with_sample_id=False) + else: + flatten_table_with_dst_indexer = flatten_table.join(indexer, lambda v1, v2: (v2[0], (v2[1], v1))) + if not flatten_table_with_dst_indexer.count(): + return df.empty_frame() + + def _aggregate(kvs): + values = [value for key, value in kvs] + values.sort() + i = 0 + l = len(values) + while i < l: + j = i + 1 + while j < l and values[j][0] == values[i][0]: + j += 1 + + yield values[i][0], [values[k][1] for k in range(i, j)] + + i = j + + data_manager = df.data_manager.duplicate() + block_num = data_manager.block_num + + def _to_blocks(values): + block_size = len(values) + ret_blocks = [[None] * block_size for _ in range(block_num)] + + for row_id, row_data in values: + for j in range(block_num): + ret_blocks[j][row_id] = row_data[j] + + for idx, block_schema in enumerate(data_manager.blocks): + ret_blocks[idx] = block_schema.convert_block(ret_blocks[idx]) + + return ret_blocks + + agg_data = flatten_table_with_dst_indexer.mapReducePartitions(_aggregate, lambda v1, v2: v1 + v2) + block_table = agg_data.mapValues(_to_blocks) + + partition_order_mappings = get_partition_order_mappings_by_block_table( + block_table, + block_row_size=data_manager.block_row_size + ) + + return DataFrame( + df._ctx, + block_table=block_table, + partition_order_mappings=partition_order_mappings, + data_manager=data_manager.duplicate() + ) + + +def loc_with_sample_id_replacement(df: DataFrame, indexer): + """ + indexer: table, + row: (key=random_key, + value=(sample_id, (src_block_id, src_offset)) + """ + if indexer.count() == 0: + return df.empty_frame() + + data_manager = df.data_manager + partition_order_mappings = get_partition_order_by_raw_table(indexer, + data_manager.block_row_size, + key_type="block_id") + + def _aggregate(kvs): + bid, offset = None, 0 + flat_ret = [] + for k, values in kvs: + sample_id, (src_block_id, src_offset) = values + if bid is None: + bid = partition_order_mappings[sample_id]["start_block_id"] + + flat_ret.append((src_block_id, src_offset, sample_id, bid, offset)) + + offset += 1 + if offset == data_manager.block_row_size: + offset = 0 + bid += 1 + + flat_ret.sort() + i = 0 + l = len(flat_ret) + while i < l: + j = i + while j < l and flat_ret[i][0] == flat_ret[j][0]: + j += 1 + + agg_ret = [flat_ret[k][1:] for k in range(i, j)] + yield flat_ret[i][0], agg_ret + + i = j + + sample_id_index = data_manager.loc_block(data_manager.schema.sample_id_name, with_offset=False) + block_num = data_manager.block_num + + def _convert_to_row(kvs): + ret_dict = {} + for block_id, (blocks, block_indexer) in kvs: + flat_blocks = [Block.transform_block_to_list(block) for block in blocks] + for src_row_id, sample_id, dst_block_id, dst_row_id in block_indexer: + if dst_block_id not in ret_dict: + ret_dict[dst_block_id] = [] + + row_data = [flat_blocks[i][src_row_id] for i in range(block_num)] + row_data[sample_id_index] = sample_id + + ret_dict[dst_block_id].append( + (dst_row_id, row_data) + ) + + for dst_block_id, value_list in ret_dict.items(): + yield dst_block_id, sorted(value_list) + + agg_indexer = indexer.mapReducePartitions(_aggregate, lambda l1, l2: l1 + l2) + block_table = df.block_table.join(agg_indexer, lambda v1, v2: (v1, v2)) + block_table = block_table.mapReducePartitions(_convert_to_row, _merge_list) + + def _convert_to_frame_block(blocks, data_manager): + convert_blocks = [] + for idx, block_schema in enumerate(data_manager.blocks): + block_content = [row_data[1][idx] for row_data in blocks] + convert_blocks.append(block_schema.convert_block(block_content)) + + return convert_blocks + + _convert_to_frame_block_func = functools.partial(_convert_to_frame_block, data_manager=data_manager) + block_table = block_table.mapValues(_convert_to_frame_block_func) + + return DataFrame( + ctx=df._ctx, + block_table=block_table, + partition_order_mappings=partition_order_mappings, + data_manager=data_manager + ) diff --git a/python/fate/arch/dataframe/ops/_isin.py b/python/fate/arch/dataframe/ops/_isin.py new file mode 100644 index 0000000000..5e9cdae2d9 --- /dev/null +++ b/python/fate/arch/dataframe/ops/_isin.py @@ -0,0 +1,115 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np +import pandas as pd +import torch +from ._compress_block import compress_blocks +from .._dataframe import DataFrame + + +def isin(df: DataFrame, values): + """ + support types are: scalar、list、series、dict + note: torch.isin and np.isin does not support nan, so use torch.isnan and np.isnan. + value type of set and list are not same, e.g.: {1.0}/[1.0] may lead to different result, + so does ont change to set now + """ + if isinstance(values, (list, dict, pd.Series)): + data_manager = df.data_manager + block_indexes = data_manager.infer_operable_blocks() + if isinstance(values, pd.Series): + values = values.to_dict() + if isinstance(values, dict): + value_indexers = dict() + for column_name, in_value in values.items(): + bid, offset = data_manager.loc_block(column_name) + if bid not in value_indexers: + value_indexers[bid] = dict() + value_indexers[bid][offset] = in_value + block_table = _isin(df.block_table, value_indexers, block_indexes) + else: + block_table = _isin(df.block_table, values, block_indexes) + else: + raise ValueError(f"isin only support type in [list, dict, pandas.Series], but {type(values)} was found") + + dst_data_manager = data_manager.duplicate() + to_promote_types = [] + for bid in block_indexes: + to_promote_types.append((bid, torch.bool)) + + dst_data_manager.promote_types(to_promote_types) + dst_block_table, dst_data_manager = compress_blocks(block_table, dst_data_manager) + + return type(df) ( + df._ctx, + dst_block_table, + df.partition_order_mappings, + dst_data_manager + ) + + +def _isin(block_table, values, block_indexes): + block_index_set = set(block_indexes) + def _has_nan_value(v_list): + for v in v_list: + if np.isnan(v): + return True + + return False + + if isinstance(values, list): + def _is_in_list(blocks): + ret_blocks = [] + for bid, block in enumerate(blocks): + if bid not in block_index_set: + ret_blocks.append(block) + elif isinstance(block, torch.Tensor): + ret_block = torch.isin(block, torch.Tensor(values)) + if _has_nan_value(values): + ret_block |= torch.isnan(block) + ret_blocks.append(ret_block) + elif isinstance(block, np.ndarray): + ret_block = np.isin(block, values) + if _has_nan_value(values): + ret_block |= np.isnan(block) + ret_blocks.append(torch.tensor(ret_block, dtype=torch.bool)) + + block_table = block_table.mapValues(_is_in_list) + else: + def _is_in_dict(blocks): + ret_blocks = [] + for bid, block in enumerate(blocks): + if bid not in block_index_set: + ret_blocks.append(block) + continue + elif isinstance(block, torch.Tensor): + ret_block = torch.zeros(block.shape, dtype=torch.bool) + for offset, in_values in values.get(bid, {}).items(): + ret_block[:, offset] = torch.isin(block[:, offset], torch.Tensor(in_values)) + if _has_nan_value(in_values): + ret_block[:, offset] |= torch.isnan(block[:, offset]) + ret_blocks.append(ret_block) + elif isinstance(block, np.ndarray): + ret_block = np.zeros(block.shape, dtype=np.bool_) + for offset, in_values in values.get(bid, {}).items(): + ret_block[:, offset] = np.isin(block[:, offset], in_values) + if _has_nan_value(in_values): + ret_block[:, offset] = np.isnan(block[:, offset]) + ret_blocks.append(torch.tensor(ret_block, dtype=torch.bool)) + + block_table = block_table.mapValues(_is_in_dict) + + return block_table \ No newline at end of file diff --git a/python/fate/arch/dataframe/ops/_missing.py b/python/fate/arch/dataframe/ops/_missing.py new file mode 100644 index 0000000000..c9a23c6cf3 --- /dev/null +++ b/python/fate/arch/dataframe/ops/_missing.py @@ -0,0 +1,58 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import torch +from ._compress_block import compress_blocks +from .._dataframe import DataFrame +from ..manager import BlockType + + +def isna(df: "DataFrame"): + data_manager = df.data_manager + block_indexes = data_manager.infer_operable_blocks() + + block_table = _isna(df.block_table, block_indexes) + dst_data_manager = data_manager.duplicate() + to_promote_types = [] + for bid in block_indexes: + to_promote_types.append((bid, BlockType.get_block_type(torch.bool))) + + dst_data_manager.promote_types(to_promote_types) + dst_block_table, dst_data_manager = compress_blocks(block_table, dst_data_manager) + + return DataFrame( + df._ctx, + dst_block_table, + df.partition_order_mappings, + dst_data_manager + ) + + +def _isna(block_table, block_indexes): + block_index_set = set(block_indexes) + + def _isna_judgement(blocks): + ret_blocks = [] + for bid, block in enumerate(blocks): + if bid not in block_index_set: + ret_blocks.append(block) + else: + ret_blocks.append(torch.isnan(block) if isinstance(block, torch.Tensor) else np.isnan(block)) + + return ret_blocks + + return block_table.mapValues( + _isna_judgement + ) diff --git a/python/fate/arch/dataframe/ops/_predict_result_transformaton.py b/python/fate/arch/dataframe/ops/_predict_result_transformaton.py deleted file mode 100644 index a70f40ca6c..0000000000 --- a/python/fate/arch/dataframe/ops/_predict_result_transformaton.py +++ /dev/null @@ -1,72 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import functools - -import pandas as pd -import torch - - -def transform_to_predict_result( - ctx, predict_score, data_type="train", task_type="binary", classes=None, threshold=0.5 -): - """ """ - transform_header = _predict_header_transform(task_type) - if task_type == "regression": - ... - elif task_type == "binary": - if predict_score.is_distributed: - predict_score = predict_score.storage.blocks.mapValues(lambda t: t.to_local().data) - else: - predict_score_local = predict_score.storage.data - predict_score = ctx.computing.parallelize([predict_score_local], include_key=False, partition=1) - - to_predict_result_func = functools.partial( - _predict_score_to_binary_result, - header=transform_header, - threshold=threshold, - classes=classes, - data_type=data_type, - ) - predict_result = predict_score.mapValues(to_predict_result_func) - - return predict_result, transform_header - - elif task_type == "multi": - ... - - -def _predict_header_transform(task_type): - if task_type in ["regression", "binary", "multi"]: - return ["predict_result", "predict_score", "predict_detail", "type"] - elif task_type == "cluster": - ... - else: - ... - - -def _predict_score_to_binary_result(score_block, header, threshold=0.5, classes=None, data_type="train"): - df = pd.DataFrame(score_block.tolist()) - if classes is None: - classes = [0, 1] - - def _convert(score_series): - score = score_series[0] - result = 1 if score > threshold else 0 - return classes[result], score, {classes[result]: score, classes[1 - result]: 1 - score}, data_type - - df = df.apply(_convert, axis=1, result_type="expand") - df.columns = header - - return df diff --git a/python/fate/arch/dataframe/ops/_promote_types.py b/python/fate/arch/dataframe/ops/_promote_types.py new file mode 100644 index 0000000000..e2e8bc433b --- /dev/null +++ b/python/fate/arch/dataframe/ops/_promote_types.py @@ -0,0 +1,66 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import functools +import torch +from ..manager import DataManager +from ..manager.block_manager import Block +from typing import List, Tuple + + +def promote_types(block_table, data_manager: DataManager, to_promote_blocks): + data_manager.promote_types(to_promote_blocks) + to_promote_block_dict = dict((bid, block_type) for bid, block_type in to_promote_blocks) + block_table = block_table.mapValues( + lambda blocks: [ + blocks[bid] if bid not in to_promote_block_dict + else Block.get_block_by_type(to_promote_block_dict[bid]).convert_block(blocks[bid].tolist()) + for bid in range(len(blocks)) + ] + ) + + return block_table, data_manager + + +def promote_partial_block_types(block_table, narrow_blocks, dst_blocks, dst_fields_loc, + data_manager: DataManager, inplace=True): + def _mapper(blocks, narrow_loc: list = None, dst_bids: list = None, + dst_loc: List[Tuple[str, str]] = None, dm: DataManager = None, inp: bool = True): + ret_blocks = [] + for block in blocks: + if inp: + if isinstance(block, torch.Tensor): + ret_blocks.append(block.clone()) + else: + ret_blocks.append(block.copy()) + else: + ret_blocks.append(block) + + for i in range(len(ret_blocks), dm.block_num): + ret_blocks.append([]) + + for bid, offsets in narrow_loc: + ret_blocks[bid] = ret_blocks[bid][:, offsets] + + for dst_bid, (src_bid, src_offset) in zip(dst_bids, dst_loc): + block_values = blocks[src_bid][:, [src_offset]] + ret_blocks[dst_bid] = dm.blocks[dst_bid].convert_block(block_values) + + return ret_blocks + + _mapper_func = functools.partial(_mapper, narrow_loc=narrow_blocks, dst_bids=dst_blocks, + dst_loc=dst_fields_loc, dm=data_manager, inp=inplace) + + return block_table.mapValues(_mapper_func) diff --git a/python/fate/arch/dataframe/ops/_quantile.py b/python/fate/arch/dataframe/ops/_quantile.py new file mode 100644 index 0000000000..6226284e3d --- /dev/null +++ b/python/fate/arch/dataframe/ops/_quantile.py @@ -0,0 +1,74 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import functools +import pandas as pd +from .._dataframe import DataFrame +from fate.arch.tensor.inside import GKSummary + + +def quantile(df: DataFrame, q, relative_error: float): + if isinstance(q, float): + q = [q] + elif not isinstance(q, list): + q = list(q) + + data_manager = df.data_manager + column_names = data_manager.infer_operable_field_names() + blocks_loc = [data_manager.loc_block(name) for name in column_names] + + def _mapper(blocks, columns_loc=None, error=None): + column_size = len(columns_loc) + gk_summary_obj_list = [GKSummary(error) for _ in range(column_size)] + + for idx, (bid, offset) in enumerate(columns_loc): + gk_summary_obj_list[idx] += blocks[bid][:, offset] + + return gk_summary_obj_list + + def _reducer(l_gk_summary_obj_list, r_gk_summary_obj_list): + rets = [] + for l_gk_summary_obj, r_gk_summary_obj in zip(l_gk_summary_obj_list, r_gk_summary_obj_list): + rets.append(l_gk_summary_obj + r_gk_summary_obj) + + return rets + + gk_summary_func = functools.partial(_mapper, columns_loc=blocks_loc, error=relative_error) + ret_gk_summary_obj_list = df.block_table.mapValues(gk_summary_func).reduce(_reducer) + + quantile_rets = dict() + for column_name, gk_summary_obj in zip(column_names, ret_gk_summary_obj_list): + query_ret = gk_summary_obj.queries(q) + quantile_rets[column_name] = query_ret + + quantile_df = pd.DataFrame(quantile_rets, index=q) + + return quantile_df + + +def qcut(df: DataFrame, q: int): + assert isinstance(q, int), f"to use qcut, {q} should be positive integer" + max_ret = df.max() + min_ret = df.min() + + dist = (max_ret - min_ret) / q + + cut_ret = [] + for i in range(1, q): + cut_ret.append(min_ret + i * dist) + + cut_ret.append(max_ret) + + return pd.DataFrame(cut_ret, index=range(1, q + 1, 1)) diff --git a/python/fate/arch/dataframe/ops/_replace.py b/python/fate/arch/dataframe/ops/_replace.py new file mode 100644 index 0000000000..f279db71fb --- /dev/null +++ b/python/fate/arch/dataframe/ops/_replace.py @@ -0,0 +1,99 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import functools +import numpy as np +import torch +from ._compress_block import compress_blocks +from .._dataframe import DataFrame +from ..manager import BlockType, DataManager + + +def replace(df: "DataFrame", to_replace: dict): + data_manager = df.data_manager.duplicate() + field_names = list(filter(lambda field_name: field_name in to_replace, data_manager.infer_operable_field_names())) + blocks_loc = data_manager.loc_block(field_names) + + dst_block_types = [] + _to_replace_list = [] + for name, (_bid, _) in zip(field_names, blocks_loc): + block_type = data_manager.get_block(_bid).block_type + for k, v in to_replace[name].items(): + v_type = BlockType.get_block_type(v) + + if block_type < v_type: + block_type = v_type + + dst_block_types.append(block_type) + _to_replace_list.append((_bid, _, to_replace[name])) + + narrow_blocks, dst_blocks = data_manager.split_columns(field_names, dst_block_types) + + def _mapper(blocks, to_replace_list: list = None, narrow_loc: list = None, + dst_bids: list = None, dm: DataManager = None): + ret_blocks = [] + for block in blocks: + if isinstance(block, torch.Tensor): + ret_blocks.append(block.clone()) + elif isinstance(block, np.ndarray): + ret_blocks.append(block.copy()) + else: + ret_blocks.append(block) + + for i in range(len(ret_blocks), dm.block_num): + ret_blocks.append([]) + + for bid, offsets in narrow_loc: + ret_blocks[bid] = ret_blocks[bid][:, offsets] + + for dst_bid, (src_bid, src_offset, _to_replace_dict) in zip(dst_bids, to_replace_list): + row_values = blocks[src_bid][:, src_offset] + replace_ret = [] + is_torch = torch.is_tensor(row_values) + for idx, value in enumerate(row_values): + if is_torch: + value = value.item() + if value not in _to_replace_dict: + replace_ret.append([value]) + else: + replace_ret.append([_to_replace_dict[value]]) + + ret_blocks[dst_bid] = dm.blocks[dst_bid].convert_block(replace_ret) + + return ret_blocks + + replace_mapper = functools.partial(_mapper, + to_replace_list=_to_replace_list, + narrow_loc=narrow_blocks, + dst_bids=dst_blocks, + dm=data_manager) + + block_table = df.block_table.mapValues(replace_mapper) + + block_indexes = data_manager.infer_operable_blocks() + if len(block_indexes) > 1: + to_promote_types = [] + for _bid in block_indexes: + to_promote_types.append((_bid, data_manager.get_block(_bid).block_type)) + + data_manager.promote_types(to_promote_types) + block_table, data_manager = compress_blocks(block_table, data_manager) + + return DataFrame( + df._ctx, + block_table, + partition_order_mappings=df.partition_order_mappings, + data_manager=data_manager + ) diff --git a/python/fate/arch/dataframe/ops/_set_item.py b/python/fate/arch/dataframe/ops/_set_item.py new file mode 100644 index 0000000000..f7091c4010 --- /dev/null +++ b/python/fate/arch/dataframe/ops/_set_item.py @@ -0,0 +1,277 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +import numpy as np +from .._dataframe import DataFrame +from ..manager.block_manager import BlockType +from ..manager.data_manager import DataManager +from fate.arch.tensor import DTensor +from fate.arch.tensor.phe._tensor import PHETensor + + +def set_item(df: "DataFrame", keys, items, state): + """ + state: 1 - keys are all new + 2 - keys are all old + """ + if state == 1: + _set_new_item(df, keys, items) + else: + _set_old_item(df, keys, items) + + +def set_label_or_weight(df: "DataFrame", item: "DataFrame", key_type="label"): + if not isinstance(item, DataFrame): + raise ValueError(f"To set label or weight, make sure rhs type={type(df)}") + + data_manager = df.data_manager + other_data_manager = item.data_manager + other_field_names = other_data_manager.infer_operable_field_names() + other_block_id = other_data_manager.loc_block(other_field_names[0], with_offset=False) + + if len(other_field_names) > 1: + raise ValueError(f"Too many columns of rhs, only one is supported") + + other_block_type = other_data_manager.blocks[other_block_id].block_type + if (name := getattr(df.schema, f"{key_type}_name")) is not None: + block_id = data_manager.loc_block(name, with_offset=False) + block_table = df.block_table.join(item.block_table, + lambda blocks1, blocks2: + [block if bid != block_id else blocks2[other_block_id] + for bid, block in enumerate(blocks1)] + ) + if data_manager.blocks[block_id].block_type < other_block_type: + data_manager.blocks[block_id].convert_block_type(other_block_type) + else: + data_manager.add_label_or_weight(key_type=key_type, + name=other_field_names[0], + block_type=other_block_type) + + block_table = df.block_table.join(item.block_table, + lambda blocks1, blocks2: blocks1 + [blocks2[other_block_id]]) + + df.block_table = block_table + df.data_manager = data_manager + + +def _set_new_item(df: "DataFrame", keys, items): + def _append_single(blocks, item, col_len, bid=None, dm: DataManager=None): + lines = len(blocks[0]) + ret_blocks = [block for block in blocks] + ret_blocks.append(dm.blocks[bid].convert_block([[item for _ in range(col_len)] for idx in range(lines)])) + + return ret_blocks + + def _append_multi(blocks, item_list, bid_list=None, dm: DataManager=None): + lines = len(blocks[0]) + ret_blocks = [block for block in blocks] + for bid, item in zip(bid_list, item_list): + ret_blocks.append(dm.blocks[bid].convert_block([[item] for _ in range(lines)])) + + return ret_blocks + + def _append_df(l_blocks, r_blocks, r_blocks_loc=None, dm=None): + ret_blocks = [block for block in l_blocks] + l_bid = len(ret_blocks) + for bid, offset in r_blocks_loc: + if dm.blocks[l_bid].is_phe_tensor(): + ret_blocks.append(r_blocks[bid]) + elif r_blocks[bid].shape[1] == 1: + ret_blocks.append(r_blocks[bid]) + else: + ret_blocks.append(r_blocks[bid][:, [offset]]) + l_bid += 1 + + return ret_blocks + + def _append_tensor(l_blocks, r_tensor, bid_list=None, dm: DataManager = None): + ret_blocks = [block for block in l_blocks] + for offset, bid in enumerate(bid_list): + ret_blocks.append(dm.blocks[bid].convert_block(r_tensor[:, offset: offset+1])) + + return ret_blocks + + def _append_phe_tensor(l_blocks, r_tensor): + ret_blocks = [block for block in l_blocks] + ret_blocks.append(r_tensor._data) + + return ret_blocks + + data_manager = df.data_manager.duplicate() + if isinstance(items, (bool, int, float, str, np.int32, np.float32, np.int64, np.float64, np.bool_)): + bids = data_manager.append_columns(keys, BlockType.get_block_type(items)) + _append_func = functools.partial(_append_single, item=items, col_len=len(keys), bid=bids[0], dm=data_manager) + block_table = df.block_table.mapValues(_append_func) + + elif isinstance(items, list): + if len(keys) != len(items): + if len(keys) > 1: + raise ValueError("Must have equal len keys and value when setting with an iterable") + bids = data_manager.append_columns(keys, BlockType.get_block_type("object")) + _append_func = functools.partial(_append_single, item=items, col_len=len(keys), + bid=bids[0], dm=data_manager) + else: + bids = data_manager.append_columns(keys, [BlockType.get_block_type(items[i]) for i in range(len(keys))]) + _append_func = functools.partial(_append_multi, item_list=items, bid_list=bids, dm=data_manager) + block_table = df.block_table.mapValues(_append_func) + elif isinstance(items, DataFrame): + other_dm = items.data_manager + operable_fields = other_dm.infer_operable_field_names() + operable_blocks_loc = other_dm.loc_block(operable_fields) + block_types = [other_dm.blocks[bid].block_type for bid, _ in operable_blocks_loc] + if len(keys) != len(operable_fields): + raise ValueError("Setitem with rhs=DataFrame must have equal len keys") + data_manager.append_columns(keys, block_types) + + l = len(keys) + for idx, (other_block_id, _) in enumerate(operable_blocks_loc): + if data_manager.blocks[-l + idx].is_phe_tensor(): + other_block = other_dm.blocks[other_block_id] + data_manager.blocks[-l + idx].set_extra_kwargs(pk=other_block._pk, + evaluator=other_block._evaluator, + coder=other_block._coder, + dtype=other_block._dtype, + device=other_block._device) + + _append_func = functools.partial(_append_df, r_blocks_loc=operable_blocks_loc, dm=data_manager) + block_table = df.block_table.join(items.block_table, _append_func) + elif isinstance(items, DTensor): + meta_data = items.shardings._data.mapValues( + lambda v: (v.pk, v.evaluator, v.coder, v.dtype) if isinstance(v, PHETensor) else None + ).first()[1] + + if isinstance(meta_data, tuple): + block_type = BlockType.phe_tensor + if len(keys) != 1: + raise ValueError("to set item of PHETensor, lhs should has only one columns.") + data_manager.append_columns(keys, block_type) + data_manager.blocks[-1].set_extra_kwargs(pk=meta_data[0], evaluator=meta_data[1], coder=meta_data[2], + dtype=meta_data[3], device=items.device) + _append_func = functools.partial(_append_phe_tensor) + block_table = df.block_table.join(items.shardings._data, _append_func) + else: + block_type = BlockType.get_block_type(items.dtype) + if len(keys) != items.shape[1]: + raise ValueError("Setitem with rhs=DTensor must have equal len keys") + bids = data_manager.append_columns(keys, block_type) + _append_func = functools.partial(_append_tensor, bid_list=bids, dm=data_manager) + block_table = df.block_table.join(items.shardings._data, _append_func) + else: + raise ValueError(f"Seiitem with rhs_type={type(items)} is not supported") + + df.block_table = block_table + df.data_manager = data_manager + + +def _set_old_item(df: "DataFrame", keys, items): + def _replace_single(blocks, item=None, narrow_loc=None, dst_bids=None, dm: DataManager=None): + ret_blocks = [block for block in blocks] + for i in range(len(ret_blocks), dm.block_num): + ret_blocks.append([]) + + for bid, offsets in narrow_loc: + ret_blocks[bid] = ret_blocks[bid][:, offsets] + + lines = len(blocks[0]) + for dst_bid in dst_bids: + ret_blocks[dst_bid] = dm.blocks[dst_bid].convert_block([[item] for idx in range(lines)]) + + return ret_blocks + + def _replace_multi(blocks, item_list=None, narrow_loc=None, dst_bids=None, dm: DataManager = None): + ret_blocks = [block for block in blocks] + for i in range(len(ret_blocks), dm.block_num): + ret_blocks.append([]) + + for bid, offsets in narrow_loc: + ret_blocks[bid] = ret_blocks[bid][:, offsets] + + lines = len(blocks[0]) + for dst_bid, item in zip(dst_bids, item_list): + ret_blocks[dst_bid] = dm.blocks[dst_bid].convert_block([[item] for idx in range(lines)]) + + return ret_blocks + + def _replace_df(l_blocks, r_blocks, narrow_loc=None, dst_bids=None, r_blocks_loc=None, dm: DataManager=None): + ret_blocks = [block for block in l_blocks] + for i in range(len(ret_blocks), dm.block_num): + ret_blocks.append([]) + + for bid, offsets in narrow_loc: + ret_blocks[bid] = ret_blocks[bid][:, offsets] + + for dst_bid, (r_bid, offset) in zip(dst_bids, r_blocks_loc): + ret_blocks[dst_bid] = r_blocks[r_bid][:, [offset]] + + return ret_blocks + + def _replace_tensor(blocks, r_tensor, narrow_loc=None, dst_bids=None, dm: DataManager = None): + ret_blocks = [block for block in blocks] + for i in range(len(ret_blocks), dm.block_num): + ret_blocks.append([]) + + for bid, offsets in narrow_loc: + ret_blocks[bid] = ret_blocks[bid][:, offsets] + + for offset, dst_bid in enumerate(dst_bids): + ret_blocks[dst_bid] = dm.blocks[dst_bid].convert_block(r_tensor[:, offset : offset + 1]) + + return ret_blocks + + data_manager = df.data_manager.duplicate() + if isinstance(items, (bool, int, float, str, np.int32, np.float32, np.int64, np.float64, np.bool_)): + narrow_blocks, dst_blocks = data_manager.split_columns(keys, BlockType.get_block_type(items)) + replace_func = functools.partial(_replace_single, item=items, narrow_loc=narrow_blocks, + dst_bids=dst_blocks, dm=data_manager) + block_table = df.block_table.mapValues(replace_func) + elif isinstance(items, list): + if len(keys) != len(items): + if len(keys) > 1: + raise ValueError("Must have equal len keys and value when setting with an iterable") + narrow_blocks, dst_blocks = data_manager.split_columns(keys, BlockType.get_block_type("object")) + replace_func = functools.partial(_replace_single, item=items[0], narrow_loc=narrow_blocks, + dst_bids=dst_blocks, dm=data_manager) + else: + narrow_blocks, dst_blocks = data_manager.split_columns(keys, + [BlockType.get_block_type(item) for item in items]) + replace_func = functools.partial(_replace_multi, item_list=items, narrow_loc=narrow_blocks, + dst_bids=dst_blocks, dm=data_manager) + + block_table = df.block_table.mapValues(replace_func) + elif isinstance(items, DataFrame): + other_dm = items.data_manager + operable_fields = other_dm.infer_operable_field_names() + operable_blocks_loc = other_dm.loc_block(operable_fields) + block_types = [other_dm.blocks[bid].block_type for bid, _ in operable_blocks_loc] + if len(keys) != len(operable_fields): + raise ValueError("Setitem with rhs=DataFrame must have equal len keys") + narrow_blocks, dst_blocks = data_manager.split_columns(keys, block_types) + replace_func = functools.partial(_replace_df, narrow_loc=narrow_blocks, dst_bids=dst_blocks, + r_blocks_loc=operable_blocks_loc, dm=data_manager) + block_table = df.block_table.join(items.block_table, replace_func) + elif isinstance(items, DTensor): + if len(keys) != items.shape[1]: + raise ValueError("Setitem with rhs=DTensor must have equal len keys") + block_type = BlockType.get_block_type(items.dtype) + narrow_blocks, dst_blocks = data_manager.split_columns(keys, block_type) + replace_func = functools.partial(_replace_tensor, narrow_loc=narrow_blocks, + dst_bids=dst_blocks, dm=data_manager) + block_table = df.block_table.join(items.shardings._data, replace_func) + + else: + raise ValueError(f"Seiitem with rhs_type={type(items)} is not supported") + + df.block_table = block_table + df.data_manager = data_manager diff --git a/python/fate/arch/dataframe/ops/_stat.py b/python/fate/arch/dataframe/ops/_stat.py index 0615519766..474f8c196a 100644 --- a/python/fate/arch/dataframe/ops/_stat.py +++ b/python/fate/arch/dataframe/ops/_stat.py @@ -12,19 +12,251 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools + +import numpy as np import pandas as pd +import torch +from .._dataframe import DataFrame +from ..manager import DataManager + + +FLOATING_POINT_ZERO = 1e-14 + + +def min(df: "DataFrame") -> "pd.Series": + data_manager = df.data_manager + operable_blocks = data_manager.infer_operable_blocks() + + def _mapper(blocks, op_bids): + ret = [] + for bid in op_bids: + if isinstance(blocks[bid], torch.Tensor): + ret.append(blocks[bid].min(axis=0).values) + else: + ret.append(blocks[bid].min(axis=0)) + + return ret + + def _reducer(blocks1, blocks2): + ret = [] + for block1, block2 in zip(blocks1, blocks2): + if isinstance(block1, torch.Tensor): + ret.append(torch.minimum(block1, block2)) + else: + ret.append(np.minimum(block1, block2)) + + return ret + + mapper_func = functools.partial( + _mapper, + op_bids=operable_blocks + ) + + reduce_ret = df.block_table.mapValues(mapper_func).reduce(_reducer) + return _post_process(reduce_ret, operable_blocks, data_manager) + + +def max(df: "DataFrame") -> "pd.Series": + data_manager = df.data_manager + operable_blocks = data_manager.infer_operable_blocks() + + def _mapper(blocks, op_bids): + ret = [] + for bid in op_bids: + if isinstance(blocks[bid], torch.Tensor): + ret.append(blocks[bid].max(axis=0).values) + else: + ret.append(blocks[bid].max(axis=0)) + + return ret + + def _reducer(blocks1, blocks2): + ret = [] + for block1, block2 in zip(blocks1, blocks2): + if isinstance(block1, torch.Tensor): + ret.append(torch.maximum(block1, block2)) + else: + ret.append(np.maximum(block1, block2)) + + return ret + + mapper_func = functools.partial( + _mapper, + op_bids=operable_blocks + ) + + reduce_ret = df.block_table.mapValues(mapper_func).reduce(_reducer) + return _post_process(reduce_ret, operable_blocks, data_manager) + + +def sum(df: DataFrame) -> "pd.Series": + data_manager = df.data_manager + operable_blocks = data_manager.infer_operable_blocks() + + def _mapper(blocks, op_bids): + ret = [] + for bid in op_bids: + ret.append(blocks[bid].sum(axis=0)) + + return ret + + def _reducer(blocks1, blocks2): + return [block1 + block2 for block1, block2 in zip(blocks1, blocks2)] + + mapper_func = functools.partial( + _mapper, + op_bids=operable_blocks + ) + + reduce_ret = df.block_table.mapValues(mapper_func).reduce(_reducer) + return _post_process(reduce_ret, operable_blocks, data_manager) + + +def mean(df: "DataFrame") -> "pd.Series": + return sum(df) / df.shape[0] + + +def var(df: "DataFrame", ddof=1) -> "pd.Series": + data_manager = df.data_manager + operable_blocks = data_manager.infer_operable_blocks() + n = df.shape[0] + def _mapper(blocks, op_bids): + ret = [] + for bid in op_bids: + block = blocks[bid] + if isinstance(block, torch.Tensor): + ret.append( + ( + torch.sum(torch.square(block), dim=0, keepdim=True), + torch.sum(block, dim=0, keepdim=True), + ) + ) + else: + ret.append( + ( + np.sum(np.square(block), axis=0), + np.sum(block, axis=0) + ) + ) -def stat_method(df, stat_func, *args, index=None, **kwargs) -> pd.Series: - if "axis" not in kwargs: - kwargs["axis"] = 0 - stat_ret = getattr(df, stat_func)(*args, **kwargs) - dtype = str(stat_ret.dtype.to_torch_dtype()).split(".", -1)[-1] - stat_ret = stat_ret.tolist() - if not kwargs.get("axis", 0): - if index: - return pd.Series(stat_ret, index=index, dtype=dtype) + return ret + + def _reducer(blocks1, block2): + ret = [] + for block1, block2 in zip(blocks1, block2): + if isinstance(block1, torch.Tensor): + ret.append((torch.add(block1[0], block2[0]), torch.add(block1[1], block2[1]))) + else: + ret.append((np.add(block1[0], block2[0]), np.add(block1[1], block2[1]))) + + return ret + + mapper_func = functools.partial( + _mapper, + op_bids=operable_blocks + ) + reduce_ret = df.block_table.mapValues(mapper_func).reduce(_reducer) + + ret_blocks = [] + for (lhs, rhs) in reduce_ret: + if isinstance(lhs, torch.Tensor): + rhs = torch.mul(torch.square(torch.div(rhs, n)), n) + ret_blocks.append(torch.div(torch.sub(lhs, rhs), n - ddof)) else: - return pd.Series(stat_ret, dtype=dtype) + rhs = np.mul(np.square(np.div(rhs, n)), n) + ret_blocks.append(np.div(np.sub(lhs, rhs), n - ddof)) + + return _post_process(ret_blocks, operable_blocks, data_manager) + + +def std(df: "DataFrame", ddof=1) -> "pd.Series": + return var(df, ddof=ddof) ** 0.5 + + +def skew(df: "DataFrame", unbiased=False): + data_manager = df.data_manager + n = df.shape[0] + + if unbiased and n < 3: + field_names = data_manager.infer_operable_field_names() + return pd.Series([np.nan for _ in range(len(field_names))], index=field_names) + + _mean = mean(df) + m1 = df - _mean + m2 = (m1 ** 2).mean() + m3 = (m1 ** 3).mean() + + """ + if abs(value) in m2 < eps=1e-14, we regard it as 0, but eps=1e-14 should be global instead of this file. + """ + non_zero_mask = abs(m2) >= FLOATING_POINT_ZERO + m3[~non_zero_mask] = 0 + m2[~non_zero_mask] = 1 + + if unbiased: + return (n * (n - 1)) ** 0.5 / (n - 2) * (m3 / m2 ** 1.5) else: - return pd.Series(stat_ret, dtype=dtype) + return m3 / m2 ** 1.5 + + +def kurt(df: "DataFrame", unbiased=False): + data_manager = df.data_manager + n = df.shape[0] + if unbiased and n < 4: + field_names = data_manager.infer_operable_field_names() + return pd.Series([np.nan for _ in range(len(field_names))], index=field_names) + + _mean = mean(df) + m1 = df - _mean + m2 = m1 ** 2 + m4 = m2 ** 2 + m2 = m2.mean() + m4 = m4.mean() + + non_zero_mask = abs(m2) >= FLOATING_POINT_ZERO + m4[~non_zero_mask] = 0 + m2[~non_zero_mask] = 1 + + if unbiased: + return (n - 1) / ((n - 2) * (n - 3)) * ((n + 1) * m4 / m2**2 - 3 * (n - 1)) + else: + return m4 / m2 ** 4 - 3 + + +def variation(df: "DataFrame", ddof=1): + return std(df, ddof=ddof) / mean(df) + + +def describe(df: "DataFrame", ddof=1, unbiased=False): + stat_metrics = dict() + stat_metrics["sum"] = sum(df) + stat_metrics["min"] = min(df) + stat_metrics["max"] = max(df) + stat_metrics["mean"] = mean(df) + stat_metrics["std"] = std(df, ddof=ddof) + stat_metrics["var"] = var(df, ddof=ddof) + stat_metrics["variation"] = variation(df, ddof=ddof) + stat_metrics["skew"] = skew(df, unbiased=unbiased) + stat_metrics["kurt"] = kurt(df, unbiased=unbiased) + stat_metrics["na_count"] = df.isna().sum() + + return pd.DataFrame(stat_metrics) + + +def _post_process(reduce_ret, operable_blocks, data_manager: "DataManager") -> "pd.Series": + field_names = data_manager.infer_operable_field_names() + field_indexes = [data_manager.get_field_offset(name) for name in field_names] + field_indexes_loc = dict(zip(field_indexes, range(len(field_indexes)))) + ret = [[] for i in range(len(field_indexes))] + + reduce_ret = [r.reshape(-1).tolist() for r in reduce_ret] + for idx, bid in enumerate(operable_blocks): + field_indexes = data_manager.blocks[bid].field_indexes + for offset, field_index in enumerate(field_indexes): + loc = field_indexes_loc[field_index] + ret[loc] = reduce_ret[idx][offset] + + return pd.Series(ret, index=field_names) + diff --git a/python/fate/arch/dataframe/ops/_transformer.py b/python/fate/arch/dataframe/ops/_transformer.py new file mode 100644 index 0000000000..4414454e72 --- /dev/null +++ b/python/fate/arch/dataframe/ops/_transformer.py @@ -0,0 +1,232 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools + +import pandas as pd +from typing import List, Tuple +import torch +from fate.arch import tensor +import numpy as np +from ..manager.data_manager import DataManager + + +def transform_to_tensor(block_table, + data_manager: "DataManager", + dtype=None, + partition_order_mappings=None): + def _merge_blocks(src_blocks, bids=None, fields=None): + if len(bids) == 1: + bid = bids[0] + t = src_blocks[bid] + else: + i = 0 + tensors = [] + while i < len(fields): + bid = fields[i][0] + indexes = [fields[i][1]] + j = i + 1 + while j < len(fields) and bid == fields[j][0] and indexes[-1] + 1 == fields[j][1]: + indexes.append(fields[j][1]) + j += 1 + + tensors.append(src_blocks[bid][:, indexes]) + i = j + + t = torch.hstack(tensors) + + if dtype: + t = t.type(getattr(torch, t)) + + return t + + def _convert_to_phe_tensor(blocks, bid: int = None, dm: DataManager = None): + phe_block = dm.get_block(bid) + return phe_block.convert_to_phe_tensor(blocks[bid], shape=(len(blocks[0]), 1)) + + block_indexes = data_manager.infer_operable_blocks() + is_phe_tensor = False + for block_id in block_indexes: + if data_manager.get_block(block_id).is_phe_tensor(): + is_phe_tensor = True + break + if not data_manager.get_block(block_id).is_numeric: + raise ValueError("Transform to distributed tensor should ensure every field is numeric") + + if is_phe_tensor: + if len(block_indexes) > 1: + raise ValueError("To use as_tensor of phe_tensor type, it should be only single column") + block_id = block_indexes[0] + _convert_to_phe_tensor_func = functools.partial(_convert_to_phe_tensor, + bid=block_id, + dm=data_manager) + phe_table = block_table.mapValues(_convert_to_phe_tensor_func) + + shape_table = block_table.mapValues(lambda blocks: (len(blocks[0]), 1)) + shapes = [shape_obj for k, shape_obj in sorted(shape_table.collect())] + + return tensor.DTensor.from_sharding_table(phe_table, + shapes=shapes, + dtype=data_manager.get_block(block_id).dtype, + device=data_manager.get_block(block_id).device) + else: + field_names = data_manager.infer_operable_field_names() + fields_loc = data_manager.loc_block(field_names) + + _merged_func = functools.partial( + _merge_blocks, + bids=block_indexes, + fields=fields_loc + ) + merged_table = block_table.mapValues(_merged_func) + + shape_table = merged_table.mapValues(lambda v: v.shape) + shapes = [shape_obj for k, shape_obj in sorted(shape_table.collect())] + + return tensor.DTensor.from_sharding_table(merged_table, + shapes=shapes) + + +def transform_block_table_to_list(block_table, data_manager): + fields_loc = data_manager.get_fields_loc() + transform_block_to_list_func = functools.partial( + transform_block_to_list, + fields_loc=fields_loc + ) + + return block_table.mapValues(transform_block_to_list_func) + + +def transform_block_to_list(blocks, fields_loc): + + if blocks[0].shape[0] == 0: + return [] + + i = 0 + dst_list = None + lines = 0 + while i < len(fields_loc): + bid = fields_loc[i][0] + if isinstance(blocks[bid], pd.Index): + if not dst_list: + lines = len(blocks[bid]) + dst_list = [[] for i in range(lines)] + + for j in range(lines): + dst_list[j].append(blocks[bid][j]) + + i += 1 + else: + """ + pd.values or tensor + """ + indexes = [fields_loc[i][1]] + j = i + 1 + while j < len(fields_loc) and fields_loc[j] == fields_loc[j - 1]: + indexes.append(fields_loc[j][1]) + j += 1 + + if isinstance(blocks[bid], np.ndarray): + for line_id, row_value in enumerate(blocks[bid][:, indexes]): + dst_list[line_id].extend(row_value.tolist()) + else: + try: + for line_id, row_value in enumerate(blocks[bid][:, indexes].tolist()): + dst_list[line_id].extend(row_value) + except Exception as e: + assert 1 == 2, (e, type(blocks[bid]), indexes) + + i = j + + return dst_list + + +def transform_list_to_block_table(table, data_manager): + from ..manager.block_manager import BlockType + + def _to_block(values): + convert_blocks = [] + + lines = len(values) + for block_schema in data_manager.blocks: + if block_schema.block_type == BlockType.index and len(block_schema.field_indexes) == 1: + col_idx = block_schema.field_indexes[0] + block_content = [values[i][col_idx] for i in range(lines)] + else: + block_content = [] + for i in range(lines): + buf = [] + for col_idx in block_schema.field_indexes: + buf.append(values[i][col_idx]) + block_content.append(buf) + + convert_blocks.append(block_schema.convert_block(block_content)) + + return convert_blocks + + return table.mapValues(_to_block) + + +def transform_list_block_to_frame_block(block_table, data_manager): + def _to_frame_block(blocks): + convert_blocks = [] + for idx, block_schema in enumerate(data_manager.blocks): + block_content = [block[idx] for block in blocks] + convert_blocks.append(block_schema.convert_block(block_content)) + + return convert_blocks + + return block_table.mapValues(_to_frame_block) + + +def transform_to_pandas_dataframe(block_table, data_manager): + fields_loc = data_manager.get_fields_loc() + + def _flatten(blocks): + flatten_ret = [] + lines = len(blocks[0]) + + for lid in range(lines): + row = [[] for i in range(len(fields_loc))] + for field_id, (bid, offset) in enumerate(fields_loc): + if isinstance(blocks[bid], np.ndarray): + row[field_id] = blocks[bid][lid][offset] + elif isinstance(blocks[bid], torch.Tensor): + row[field_id] = blocks[bid][lid][offset].item() + else: + row[field_id] = blocks[bid][lid] + + flatten_ret.append(row) + + return flatten_ret + + flatten_table = block_table.mapValues(_flatten) + + flatten_obj = [] + for k, v in flatten_table.collect(): + if not flatten_obj: + flatten_obj = v + else: + flatten_obj.extend(v) + + fields = [data_manager.get_field_name(idx) for idx in range(len(fields_loc))] + pd_df = pd.DataFrame(flatten_obj, columns=fields, dtype=object) + pd_df.set_index(data_manager.schema.sample_id_name) + + for name in fields[1:]: + dtype = data_manager.get_field_type_by_name(name) + if dtype in ["int32", "float32", "int64", "float64"]: + pd_df[name] = pd_df[name].astype(dtype) + + return pd_df diff --git a/python/fate/arch/dataframe/ops/_unary_operator.py b/python/fate/arch/dataframe/ops/_unary_operator.py new file mode 100644 index 0000000000..5083399998 --- /dev/null +++ b/python/fate/arch/dataframe/ops/_unary_operator.py @@ -0,0 +1,35 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import operator +from .._dataframe import DataFrame +from ..manager import BlockType +from .utils.operators import unary_operate + + +def invert(df: DataFrame): + data_manager = df.data_manager + block_indexes = data_manager.infer_operable_blocks() + for bid in block_indexes: + if data_manager.blocks[bid] != BlockType.bool: + raise ValueError("to use ~df syntax, data types should be bool") + + block_table = unary_operate(df.block_table, operator.invert, block_indexes) + return type(df)( + df.ctx, + block_table, + df.partition_order_mappings, + data_manager.duplicate() + ) diff --git a/python/fate/arch/dataframe/ops/_where.py b/python/fate/arch/dataframe/ops/_where.py new file mode 100644 index 0000000000..e3e3308c51 --- /dev/null +++ b/python/fate/arch/dataframe/ops/_where.py @@ -0,0 +1,125 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np +import torch +from typing import List +from .._dataframe import DataFrame +from ..manager import BlockType +from ..manager import DataManager + + +def where(df: DataFrame, other: DataFrame): + if df.shape[0] != other.shape[0]: + raise ValueError("Row numbers should be identical.") + + data_manager = df.data_manager + other_data_manager = other.data_manager + + column_names = data_manager.infer_operable_field_names() + other_column_names = other_data_manager.infer_operable_field_names() + + if (set(column_names) & set(other_column_names)) != set(column_names): + raise ValueError("To use df[mask], mask's columns should contains all df's columns") + + if column_names != other_column_names: + other = other[column_names] + other_column_names = column_names + + false_column_set = _get_false_columns(other) + if not false_column_set: + return df + + """ + need type promote ? + """ + need_promoted = False + for name in column_names: + bid = data_manager.loc_block(name, with_offset=False) + if not BlockType.is_float(data_manager.get_block(bid).block_type): + need_promoted = True + break + + if not need_promoted: + block_table = _where_float_type(df.block_table, other.block_table, + data_manager, other.data_manager, column_names) + return DataFrame( + df._ctx, + block_table, + df.partition_order_mappings, + data_manager.duplicate() + ) + + +def _get_false_columns(df: DataFrame): + block_table = df.block_table + data_manager = df.data_manager + block_index_set = set(data_manager.infer_operable_blocks()) + + false_table = block_table.mapValues( + lambda blocks: [ + block.all(axis=0) if bid in block_index_set else [] + for bid, block in enumerate(blocks) + ] + ) + + false_values = false_table.reduce( + lambda blocks1, blocks2: + [ + block1 & block2 if bid in block_index_set else [] + for bid, (block1, block2) in enumerate(zip(blocks1, blocks2)) + ] + ) + + false_columns = set() + column_names = data_manager.infer_operable_field_names() + for name in column_names: + _bid, _offset = data_manager.loc_block(name) + if isinstance(false_values[_bid], torch.Tensor): + if not false_values[_bid][_offset].item(): + false_columns.add(name) + elif isinstance(false_values[_bid], np.ndarray): + if not false_values[_bid][_offset]: + false_columns.add(name) + + return false_columns + + +def _where_float_type(l_block_table, r_block_table, + l_data_manager: "DataManager", + r_data_manager: "DataManager", + column_names: List[str]): + l_loc_info = [l_data_manager.loc_block(name) for name in column_names] + r_loc_info = [r_data_manager.loc_block(name) for name in column_names] + + def __convert_na(l_blocks, r_blocks): + ret_blocks = [] + for block in l_blocks: + if isinstance(block, torch.Tensor): + ret_blocks.append(block.clone()) + elif isinstance(block, np.ndarray): + ret_blocks.append(np.copy(block)) + else: + ret_blocks.append(block) + + for (l_bid, l_offset), (r_bid, r_offset) in zip(l_loc_info, r_loc_info): + if isinstance(ret_blocks[l_bid], torch.Tensor): + ret_blocks[l_bid][:, l_offset][~r_blocks[r_bid][:, r_offset]] = torch.nan + else: + ret_blocks[l_bid][:, l_offset][~r_blocks[r_bid][:, r_offset]] = np.nan + + return ret_blocks + + return l_block_table.join(r_block_table, __convert_na) diff --git a/python/fate/arch/dataframe/ops/utils/__init__.py b/python/fate/arch/dataframe/ops/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/fate/arch/dataframe/ops/utils/operators.py b/python/fate/arch/dataframe/ops/utils/operators.py new file mode 100644 index 0000000000..0fffd99268 --- /dev/null +++ b/python/fate/arch/dataframe/ops/utils/operators.py @@ -0,0 +1,48 @@ +import numpy as np +from fate.arch.computing import is_table + + +def binary_operate(lhs, rhs, op, block_indexes, rhs_block_id=None): + block_index_set = set(block_indexes) + if isinstance(rhs, list): + op_ret = lhs.mapValues( + lambda blocks: + [ + op(blocks[bid], rhs[bid]) if bid in block_index_set else blocks[bid] + for bid in range(len(blocks)) + ] + ) + elif isinstance(rhs, (bool, int, float, np.int32, np.float32, np.int64, np.float64, np.bool_)): + op_ret = lhs.mapValues( + lambda blocks: + [ + op(blocks[bid], rhs) if bid in block_index_set else blocks[bid] + for bid in range(len(blocks)) + ] + ) + elif is_table(rhs): + op_ret = lhs.join(rhs, + lambda blocks1, blocks2: + [ + op(blocks1[bid], blocks2[rhs_block_id]) if bid in block_index_set else blocks1[bid] + for bid in range(len(blocks1)) + ] + ) + else: + raise ValueError(f"Not implement type between dataframe nad {type(rhs)}") + + return op_ret + + +def unary_operate(lhs, op, block_indexes): + block_index_set = set(block_indexes) + op_ret = lhs.mapValues( + lambda blocks: + [ + op(blocks[bid]) if bid in block_index_set + else blocks[bid] + for bid in range(len(blocks)) + ] + ) + + return op_ret diff --git a/python/fate/arch/dataframe/ops/utils/series_align.py b/python/fate/arch/dataframe/ops/utils/series_align.py new file mode 100644 index 0000000000..3ffe71b499 --- /dev/null +++ b/python/fate/arch/dataframe/ops/utils/series_align.py @@ -0,0 +1,32 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pandas as pd +from typing import List + + +def series_to_ndarray(series_obj: "pd.Series", fields_to_align: List[str]=None): + if isinstance(series_obj.index, pd.RangeIndex) or not fields_to_align: + return series_obj.values + else: + if len(series_obj) != len(fields_to_align): + raise ValueError(f"Can't not align fields, src={fields_to_align}, dst={series_obj}") + + indexer = series_obj.index.get_indexer(fields_to_align) + + return series_obj[indexer].values + + +def series_to_list(series_obj: "pd.Series", fields_to_align: List[str]=None): + return series_to_ndarray(series_obj, fields_to_align).tolist() diff --git a/python/fate/arch/dataframe/storage/_index.py b/python/fate/arch/dataframe/storage/_index.py deleted file mode 100644 index 088b8e84f1..0000000000 --- a/python/fate/arch/dataframe/storage/_index.py +++ /dev/null @@ -1,221 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import copy -import functools - -import pandas as pd -from fate.arch.computing import is_table - - -class Index(object): - def __init__(self, ctx, distributed_index, block_partition_mapping, global_ranks): - self._ctx = ctx - self._index_table = distributed_index - self._block_partition_mapping = block_partition_mapping - self._global_ranks = global_ranks - self._count = None - - @property - def global_ranks(self): - return self._global_ranks - - @property - def block_partition_mapping(self): - return self._block_partition_mapping - - @property - def values(self): - return self._index_table - - def count(self): - if self._count is not None: - return self._count - - self._count = self._index_table.count() - - return self._count - - def __len__(self): - return self.count() - - def tolist(self): - indexes_with_partition_id = sorted(self._index_table.collect(), key=lambda kv: kv[1]) - id_list = [k for k, v in indexes_with_partition_id] - - return id_list - - def to_local(self): - """ - index_table: id, (partition_id, block_index) - """ - index_table = self._index_table.mapValues( - lambda order_tuple: (0, self._global_ranks[order_tuple[0]]["start_index"] + order_tuple[1]) - ) - - global_ranks = [dict(start_index=0, end_index=self.count(), block_id=0)] - block_partition_mapping = copy.deepcopy(self._block_partition_mapping) - for block_id in self._block_partition_mapping: - if block_id != 0: - block_partition_mapping.pop(block_id) - - return Index(self._ctx, index_table, block_partition_mapping, global_ranks) - - def __getitem__(self, items): - if isinstance(items, int): - items = [items] - - # NOTE: it will not repartition automatically, user should call it in DataFrame if need - # TODO: make sure that items is non-overlapped in this version - if isinstance(items, list): - items_set = set(items) - index_table = self._index_table.filter(lambda kv: kv[0] in items_set) - - elif is_table(items): - index_table = self._index_table.join(items, lambda v1, v2: v2[1]) - else: - raise ValueError(f"get item does not support {type(items)}") - - agg_table = self.aggregate(index_table) - global_ranks = self.regenerate_global_ranks(agg_table, self._global_ranks) - block_partition_mapping = self.regenerate_block_partition_mapping(agg_table, global_ranks) - - def _flat_partition(k, values): - _flat_ret = [] - for idx, (_id, block_index) in enumerate(values): - _flat_ret.append((_id, (k, idx))) - - return _flat_ret - - _flat_func = functools.partial(_flat_partition) - index_table = agg_table.flatMap(_flat_func) - - return Index(self._ctx, index_table, block_partition_mapping, global_ranks) - - def get_indexer(self, ids, with_partition_id=True): - if isinstance(ids, list): - - def _filter_id(key, value, ids_set=None): - return key in ids_set - - filter_func = functools.partial(_filter_id, ids_set=set(ids)) - indexer = self._index_table.filter(filter_func) - indexer = indexer.mapValues(lambda v: [v, v]) - - elif is_table(ids): - """ """ - if with_partition_id: - indexer = self._index_table.join(ids, lambda v1, v2: [v1, v2]) - else: - indexer = self._index_table.join(ids, lambda v1, v2: v1) - indexer = indexer.mapValues(lambda v: [v, v]) - else: - raise ValueError(f"get_indexer's args type is {type(ids)}, is not supported") - - return indexer - - def change_indexes_to_indexer(self, indexes): - def _filter(k, v, index_set=None, global_ranks=None): - partition_id, block_index = v - return global_ranks[partition_id]["start_index"] + block_index in index_set - - filter_func = functools.partial(_filter, index_set=set(indexes), global_ranks=self._global_ranks) - indexer = self._index_table.filter(filter_func, use_previous_behavior=False) - indexer = indexer.mapValues(lambda v: [v, v]) - return indexer - - @classmethod - def aggregate(cls, table): - """ - agg_table: key=partition_id, value=(id, block_index), block_index may be not continuous - """ - - def _aggregate_ids(kvs): - aggregate_ret = dict() - - for k, v in kvs: - partition_id, block_index = v - if partition_id not in aggregate_ret: - aggregate_ret[partition_id] = [] - - aggregate_ret[partition_id].append((k, block_index)) - - return list(aggregate_ret.items()) - - agg_table = table.mapReducePartitions(_aggregate_ids, lambda l1, l2: l1 + l2) - agg_table = agg_table.mapValues( - lambda id_list: sorted(id_list, key=lambda block_index_with_key: block_index_with_key[1]) - ) - - return agg_table - - @classmethod - def aggregate_indexer(cls, indexer): - """ - key=id, value=[(old_partition_id, old_block_index), (new_partition_id, new_block_index)] - => - key=old_partition_id, value=[old_block_index, (new_partition_id, new_block_index)] - """ - - def _aggregate(kvs): - aggregate_ret = dict() - for k, values in kvs: - old_msg, new_msg = values - if old_msg[0] not in aggregate_ret: - aggregate_ret[old_msg[0]] = [] - - aggregate_ret[old_msg[0]].append([old_msg[1], new_msg]) - - return list(aggregate_ret.items()) - - agg_indexer = indexer.mapReducePartitions(_aggregate, lambda l1, l2: l1 + l2) - agg_indexer = agg_indexer.mapValues(lambda v: sorted(v)) - return agg_indexer - - @classmethod - def regenerate_global_ranks(cls, agg_table, old_global_ranks): - """ - input should be agg_table instead of index_table - """ - block_counts = sorted(list(agg_table.mapValues(lambda v: len(v)).collect())) - global_ranks = [] - - idx = 0 - for block_id, block_count in block_counts: - if global_ranks and global_ranks[-1]["block_id"] + 1 != block_id: - last_bid = global_ranks[-1]["block_id"] - for bid in range(last_bid + 1, block_id): - global_ranks.append(dict(start_index=idx, end_index=idx - 1, block_id=bid)) - - global_ranks.append(dict(start_index=idx, end_index=idx + block_count - 1, block_id=block_id)) - idx += block_count - - if len(global_ranks) < len(old_global_ranks): - last_bid = len(global_ranks) - for bid in range(last_bid, len(old_global_ranks)): - global_ranks.append(dict(start_index=idx, end_index=idx - 1, block_id=bid)) - - return global_ranks - - @classmethod - def regenerate_block_partition_mapping(cls, agg_table, global_ranks): - """ - input should be agg_table instead of index_table - """ - blocks = agg_table.mapValues(lambda v: v[0]).collect() - block_partition_mapping = dict() - for partition_id, (key, block_index) in blocks: - block_partition_mapping[key] = global_ranks[partition_id] - - return block_partition_mapping diff --git a/python/fate/arch/dataframe/storage/_value_store.py b/python/fate/arch/dataframe/storage/_value_store.py deleted file mode 100644 index dfbbb5d513..0000000000 --- a/python/fate/arch/dataframe/storage/_value_store.py +++ /dev/null @@ -1,59 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import functools - -import pandas as pd - - -class ValueStore(object): - def __init__(self, ctx, distributed_table, header): - self._ctx = ctx - self._header = header - self._data = distributed_table - self._dtypes = None - - def to_local(self, keep_table=False): - if self._data.partitions == 1 and keep_table: - return self - - frames = [frame for partition_id, frame in sorted(self._data.collect())] - concat_frame = pd.concat(frames) - - if not keep_table: - return concat_frame - else: - table = self._ctx.computing.parallelize([(0, concat_frame)], include_key=True, partition=1) - - return ValueStore(self._ctx, table, self._header) - - def __getattr__(self, attr): - if attr not in self._header: - raise ValueError(f"ValueStore does not has attribute: {attr}") - - return ValueStore(self._ctx, self._data.mapValues(lambda df: df[attr]), [attr]) - - def tolist(self): - return self.to_local().tolist() - - @property - def dtypes(self): - if self._dtypes is None: - self._dtypes = self._data.first()[1].dtypes - - return self._dtypes - - @property - def values(self): - return self._data diff --git a/python/fate/arch/dataframe/utils/__init__.py b/python/fate/arch/dataframe/utils/__init__.py index 9b670c6d5b..8c6565ac8f 100644 --- a/python/fate/arch/dataframe/utils/__init__.py +++ b/python/fate/arch/dataframe/utils/__init__.py @@ -13,3 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from ._dataloader import DataLoader +from ._dataloader import BatchEncoding +from ._k_fold import KFold +from ._sample import federated_sample +from ._sample import local_sample diff --git a/python/fate/components/spec/artifact.py b/python/fate/arch/dataframe/utils/_auto_column_name_generated.py similarity index 78% rename from python/fate/components/spec/artifact.py rename to python/fate/arch/dataframe/utils/_auto_column_name_generated.py index 681c63915c..c8e45b8b6e 100644 --- a/python/fate/components/spec/artifact.py +++ b/python/fate/arch/dataframe/utils/_auto_column_name_generated.py @@ -12,12 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +DEFAULT_COL_NAME_PREFIX = "default_col_" -import pydantic - - -class ArtifactSpec(pydantic.BaseModel): - name: str - uri: str - metadata: Optional[dict] = None +def generated_default_column_names(column_length): + return [DEFAULT_COL_NAME_PREFIX + str(i) for i in range(column_length)] diff --git a/python/fate/arch/dataframe/utils/_dataloader.py b/python/fate/arch/dataframe/utils/_dataloader.py index 1ba7bd40cd..940727b9dc 100644 --- a/python/fate/arch/dataframe/utils/_dataloader.py +++ b/python/fate/arch/dataframe/utils/_dataloader.py @@ -14,12 +14,6 @@ # limitations under the License. import random -import numpy as np -import pandas as pd -import torch -from fate.arch.context.io.data import df -from fate.arch.dataframe import PandasReader, TorchDataSetReader - class DataLoader(object): def __init__( @@ -33,16 +27,19 @@ def __init__( batch_size=-1, shuffle=False, batch_strategy="full", - random_seed=None, + random_state=None, ): self._ctx = ctx self._dataset = dataset self._batch_size = batch_size if dataset: - self._batch_size = min(batch_size, len(dataset)) + if batch_size is None: + self._batch_size = len(dataset) + else: + self._batch_size = min(batch_size, len(dataset)) self._shuffle = shuffle self._batch_strategy = batch_strategy - self._random_seed = random_seed + self._random_state = random_state self._need_align = need_align self._mode = mode self._role = role @@ -51,19 +48,6 @@ def __init__( self._init_settings() def _init_settings(self): - if isinstance(self._dataset, df.Dataframe): - self._dataset = self._dataset.data - """ - if isinstance(self._dataset, pd.DataFrame): - self._dataset = PandasReader().to_frame(self._ctx, self._dataset) - elif isinstance(self._dataset, (np.ndarray, list)): - self._dataset = pd.DataFrame(np.ndarray) - self._dataset = PandasReader().to_frame(self._ctx, self._dataset) - elif isinstance(self._dataset, torch.utils.data.Dataset): - # Note: torch dataset items' order should be X, y - self._dataset = TorchDataSetReader().to_frame(self._ctx, self._dataset) - """ - if self._batch_strategy == "full": self._batch_generator = FullBatchDataLoader( self._dataset, @@ -72,50 +56,37 @@ def _init_settings(self): role=self._role, batch_size=self._batch_size, shuffle=self._shuffle, - random_seed=self._random_seed, + random_state=self._random_state, need_align=self._need_align, sync_arbiter=self._sync_arbiter, ) else: raise ValueError(f"batch strategy {self._batch_strategy} is not support") - def next_batch(self, with_index=True): - batch = next(self._batch_generator) - if with_index: - return batch - else: - return batch[1:] - @staticmethod def batch_num(self): return self._batch_generator.batch_num def __next__(self): for batch in self._batch_generator: - if len(batch[1:]) == 1: - yield batch[1] - else: - yield batch[1:] + yield batch def __iter__(self): for batch in self._batch_generator: - if len(batch[1:]) == 1: - yield batch[1] - else: - yield batch[1:] + yield batch class FullBatchDataLoader(object): - def __init__(self, dataset, ctx, mode, role, batch_size, shuffle, random_seed, need_align, sync_arbiter): + def __init__(self, dataset, ctx, mode, role, batch_size, shuffle, random_state, need_align, sync_arbiter): self._dataset = dataset self._ctx = ctx self._mode = mode self._role = role self._batch_size = batch_size - if self._batch_size < 0 and self._role != "arbiter": + if self._batch_size is None and self._role != "arbiter": self._batch_size = len(self._dataset) self._shuffle = shuffle - self._random_seed = random_seed + self._random_state = random_state self._need_align = need_align self._sync_arbiter = sync_arbiter @@ -135,7 +106,7 @@ def _prepare(self): elif self._mode == "local": self._batch_num = (len(self._dataset) + self._batch_size - 1) // self._batch_size elif self._mode == "hetero": - # TODO: index should be align first + # NOTE: index should be align first, using after doing psi if self._role != "arbiter": self._batch_num = (len(self._dataset) + self._batch_size - 1) // self._batch_size if self._role == "guest" and self._sync_arbiter: @@ -147,58 +118,75 @@ def _prepare(self): return if self._batch_size == len(self._dataset): - self._batch_splits.append(self._dataset) + self._batch_splits.append(BatchEncoding(self._dataset, batch_id=0)) else: if self._mode in ["homo", "local"] or self._role == "guest": - indexes = self._dataset.index.tolist() - + indexer = sorted(list(self._dataset.get_indexer(target="sample_id").collect())) if self._shuffle: - random.seed = self._random_seed - random.shuffle(indexes) + random.seed = self._random_state + random.shuffle(indexer) + + for i, iter_ctx in self._ctx.sub_ctx("dataloader_batch").ctxs_range(self._batch_num): + batch_indexer = indexer[self._batch_size * i: self._batch_size * (i + 1)] + batch_indexer = self._ctx.computing.parallelize(batch_indexer, + include_key=True, + partition=self._dataset.block_table.partitions) - for i, iter_ctx in self._ctx.range(self._batch_num): - batch_indexes = indexes[self._batch_size * i : self._batch_size * (i + 1)] + sub_frame = self._dataset.loc(batch_indexer, preserve_order=False) - sub_frame = self._dataset.loc(batch_indexes) + if self._mode == "hetero" and self._role == "guest": + iter_ctx.hosts.put("batch_indexes", sub_frame.get_indexer(target="sample_id")) - if self._role == "guest": - iter_ctx.hosts.put("batch_indexes", batch_indexes) + self._batch_splits.append(BatchEncoding(sub_frame, batch_id=i)) - self._batch_splits.append(sub_frame) elif self._mode == "hetero" and self._role == "host": - for i, iter_ctx in self._ctx.range(self._batch_num): + for i, iter_ctx in self._ctx.sub_ctx("dataloader_batch").ctxs_range(self._batch_num): batch_indexes = iter_ctx.guest.get("batch_indexes") - sub_frame = self._dataset.loc(batch_indexes) - self._batch_splits.append(sub_frame) + sub_frame = self._dataset.loc(batch_indexes, preserve_order=True) + self._batch_splits.append(BatchEncoding(sub_frame, batch_id=i)) def __next__(self): if self._role == "arbiter": for batch_id in range(self._batch_num): - yield batch_id, batch_id + yield BatchEncoding(batch_id=batch_id) return for batch in self._batch_splits: - if batch.label and batch.weight: - yield batch.index, batch.values, batch.label, batch.weight - elif batch.label: - yield batch.index, batch.values, batch.label - else: - yield batch.index, batch.values + yield batch def __iter__(self): - if self._role == "arbiter": - for batch_id in range(self._batch_num): - yield batch_id, batch_id - return - - for batch in self._batch_splits: - if batch.label and batch.weight: - yield batch.index, batch.values, batch.label, batch.weight - elif batch.label: - yield batch.index, batch.values, batch.label - else: - yield batch.index, batch.values + return self.__next__() @property def batch_num(self): return self._batch_num + + +class BatchEncoding(object): + def __init__(self, batch_df=None, batch_id=None): + if batch_df: + self._x = batch_df.values.as_tensor() + self._label = batch_df.label.as_tensor() if batch_df.label else None + self._weight = batch_df.weight.as_tensor() if batch_df.weight else None + else: + self._x = None + self._label = None + self._weight = None + + self._batch_id = batch_id + + @property + def x(self): + return self._x + + @property + def label(self): + return self._label + + @property + def weight(self): + return self._weight + + @property + def batch_id(self): + return self._batch_id diff --git a/python/fate/interface/_gc.py b/python/fate/arch/dataframe/utils/_id_generator.py similarity index 68% rename from python/fate/interface/_gc.py rename to python/fate/arch/dataframe/utils/_id_generator.py index ee64bef1ee..4eac65377f 100644 --- a/python/fate/interface/_gc.py +++ b/python/fate/arch/dataframe/utils/_id_generator.py @@ -12,12 +12,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Protocol +# +import time +import random +import hashlib + +def generate_sample_id(n, prefix): + return [hashlib.sha256(bytes(prefix + str(i), encoding='utf-8')).hexdigest() for i in range(n)] -class GarbageCollector(Protocol): - def register_clean_action(self, name: str, tag: str, obj, method: str, kwargs): - ... - def clean(self, name: str, tag: str): - ... +def generate_sample_id_prefix(): + return str(time.time()) + str(random.randint(1000000, 9999999)) diff --git a/python/fate/arch/dataframe/utils/_k_fold.py b/python/fate/arch/dataframe/utils/_k_fold.py new file mode 100644 index 0000000000..9d0b58f741 --- /dev/null +++ b/python/fate/arch/dataframe/utils/_k_fold.py @@ -0,0 +1,87 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .._dataframe import DataFrame +from sklearn.model_selection import KFold as sk_KFold + + +class KFold(object): + def __init__(self, + ctx, + mode="hetero", + role="guest", + n_splits=5, + shuffle=False, + random_state=None): + self._ctx = ctx + self._mode = mode + self._role = role + self._n_splits = n_splits + self._shuffle = shuffle + self._random_state = random_state + + self._check_param() + + def split(self, df: DataFrame): + if self._mode == "hetero": + return self._hetero_split(df) + else: + return self._homo_split(df, return_indexer=False) + + def _hetero_split(self, df: DataFrame): + if self._role == "guest": + homo_splits = self._homo_split(df, return_indexer=True) + for _, iter_ctx in self._ctx.sub_ctx("KFold").ctxs_range(self._n_splits): + train_frame, test_frame, train_indexer, test_indexer = next(homo_splits) + + iter_ctx.hosts.put("fold_indexes", (train_indexer, test_indexer)) + + yield train_frame, test_frame + else: + for _, iter_ctx in self._ctx.sub_ctx("KFold").ctxs_range(self._n_splits): + train_indexer, test_indexer = iter_ctx.guest.get("fold_indexes") + train_frame = df.loc(train_indexer, preserve_order=True) + test_frame = df.loc(test_indexer, preserve_order=True) + + yield train_frame, test_frame + + def _homo_split(self, df: DataFrame, return_indexer): + kf = sk_KFold(n_splits=self._n_splits, shuffle=self._shuffle, random_state=self._random_state) + indexer = list(df.get_indexer(target="sample_id").collect()) + + for train, test in kf.split(indexer): + train_indexer = [indexer[idx] for idx in train] + test_indexer = [indexer[idx] for idx in test] + + train_indexer = self._ctx.computing.parallelize(train_indexer, + include_key=True, + partition=df.block_table.partitions) + + test_indexer = self._ctx.computing.parallelize(test_indexer, + include_key=True, + partition=df.block_table.partitions) + + train_frame = df.loc(train_indexer) + test_frame = df.loc(test_indexer) + + if return_indexer: + yield train_frame, test_frame, \ + train_frame.get_indexer(target="sample_id"), test_frame.get_indexer(target="sample_id") + else: + yield train_frame, test_frame + + def _check_param(self): + if not isinstance(self._n_splits, int) or self._n_splits < 2: + raise ValueError("n_splits should be positive integer >= 2") diff --git a/python/fate/arch/dataframe/utils/_sample.py b/python/fate/arch/dataframe/utils/_sample.py new file mode 100644 index 0000000000..ce67ea00b2 --- /dev/null +++ b/python/fate/arch/dataframe/utils/_sample.py @@ -0,0 +1,229 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import functools +from typing import Union, Dict, Any + +from sklearn.utils import resample + +from ._id_generator import generate_sample_id, generate_sample_id_prefix +from .._dataframe import DataFrame + +REGENERATED_TAG = "regenerated_index" +SAMPLE_INDEX_TAG = "sample_index" +REGENERATED_IDS = "regenerated_ids" + + +def local_sample( + ctx, + df: DataFrame, + n: int=None, + frac: Union[float, Dict[Any, float]] = None, + replace: bool = True, + random_state=None +): + return _sample_guest(ctx, df, n, frac, replace, random_state, sync=False) + + +def federated_sample( + ctx, + df: DataFrame, + n: int = None, + frac: Union[float, Dict[Any, float]] = None, + replace: bool = True, + random_state=None, + role: str = "guest"): + if role == "guest": + return _sample_guest(ctx, df, n, frac, replace, random_state, sync=True) + else: + return _federated_sample_host(ctx, df) + + +def _sample_guest( + ctx, + df: DataFrame, + n: int = None, + frac: Union[float, Dict[Any, float]] = None, + replace: bool = True, + random_state=None, + sync=True, +): + if n is not None and frac is not None: + raise ValueError("sample's parameters n and frac should not be set in the same time.") + + if frac is not None: + if isinstance(frac, float): + if frac > 1: + raise ValueError(f"sample's parameter frac={frac} should <= 1.0") + n = max(1, int(frac * df.shape[0])) + else: + for k, f in frac.items(): + if f > 1 and replace is False: + raise ValueError(f"sample's parameter frac's label={k}, fraction={f} " + f"should <= 1.0 if replace=False") + + if n is not None: + if n > df.shape[0] and replace is False: + raise ValueError(f"sample's parameter n={n} should <= data_size={df.shape[0]} if replace=False") + + if replace: + choices = resample(list(range(df.shape[0])), replace=True, n_samples=n, random_state=random_state) + indexer = list(df.get_indexer(target="sample_id").collect()) + regenerated_sample_id_prefix = generate_sample_id_prefix() + regenerated_ids = generate_sample_id(n, regenerated_sample_id_prefix) + choice_with_regenerated_ids = _agg_choices(ctx, + indexer, + choices, + regenerated_ids, + df.block_table.partitions) + + if sync: + ctx.hosts.put(REGENERATED_TAG, True) + ctx.hosts.put(REGENERATED_IDS, choice_with_regenerated_ids) + + regenerated_raw_table = _regenerated_sample_ids(df, choice_with_regenerated_ids) + sample_df = _convert_raw_table_to_df(df._ctx, regenerated_raw_table, df.data_manager) + if sync: + sample_indexer = sample_df.get_indexer(target="sample_id") + ctx.hosts.put(SAMPLE_INDEX_TAG, sample_indexer) + + else: + sample_df = df.sample(n=n, random_state=random_state) + if sync: + sample_indexer = sample_df.get_indexer(target="sample_id") + ctx.hosts.put(REGENERATED_TAG, False) + ctx.hosts.put(SAMPLE_INDEX_TAG, sample_indexer) + else: + up_sample = False + for label, f in frac.items(): + if f > 1.0: + up_sample = True + + if up_sample: + regenerated_sample_id_prefix = generate_sample_id_prefix() + choice_with_regenerated_ids = None + for label, f in frac.items(): + label_df = df.iloc(df.label == label) + label_n = max(1, int(label_df.shape[0] * f)) + choices = resample(list(range(label_df.shape[0])), replace=True, + n_samples=label_n, random_state=random_state) + label_indexer = list(label_df.get_indexer(target="sample_id").collect()) + regenerated_ids = generate_sample_id(label_n, regenerated_sample_id_prefix) + label_choice_with_regenerated_ids = _agg_choices(ctx, label_indexer, choices, + regenerated_ids, df.block_table.partitions) + if choice_with_regenerated_ids is None: + choice_with_regenerated_ids = label_choice_with_regenerated_ids + else: + choice_with_regenerated_ids = choice_with_regenerated_ids.union(label_choice_with_regenerated_ids) + + if sync: + ctx.hosts.put(REGENERATED_TAG, True) + ctx.hosts.put(REGENERATED_IDS, choice_with_regenerated_ids) + regenerated_raw_table = _regenerated_sample_ids(df, choice_with_regenerated_ids) + sample_df = _convert_raw_table_to_df(df._ctx, regenerated_raw_table, df.data_manager) + if sync: + sample_indexer = sample_df.get_indexer(target="sample_id") + ctx.hosts.put(SAMPLE_INDEX_TAG, sample_indexer) + else: + sample_df = None + for label, f in frac.items(): + label_df = df.iloc(df.label == label) + label_n = max(1, int(label_df.shape[0] * f)) + sample_label_df = label_df.sample(n=label_n, random_state=random_state) + + if sample_df is None: + sample_df = sample_label_df + else: + sample_df = DataFrame.vstack([sample_df, sample_label_df]) + + if sync: + sample_indexer = sample_df.get_indexer(target="sample_id") + ctx.hosts.put(REGENERATED_TAG, False) + ctx.hosts.put(SAMPLE_INDEX_TAG, sample_indexer) + + return sample_df + + +def _federated_sample_host( + ctx, + df: DataFrame +): + regenerated_tag = ctx.guest.get(REGENERATED_TAG) + if regenerated_tag is False: + sample_indexer = ctx.guest.get(SAMPLE_INDEX_TAG) + sample_df = df.loc(sample_indexer, preserve_order=True) + else: + regenerated_ids = ctx.guest.get(REGENERATED_IDS) + regenerated_raw_table = _regenerated_sample_ids(df, regenerated_ids) + sample_df = _convert_raw_table_to_df(df._ctx, regenerated_raw_table, df.data_manager) + + sample_indexer = ctx.guest.get(SAMPLE_INDEX_TAG) + sample_df = sample_df.loc(sample_indexer, preserve_order=True) + + return sample_df + + +def _regenerated_sample_ids(df, regenerated_ids): + from ..ops._indexer import regenerated_sample_id + regenerated_raw_table = regenerated_sample_id(df.block_table, regenerated_ids, df.data_manager) + + return regenerated_raw_table + + +def _convert_raw_table_to_df( + ctx, + table, + data_manager +): + from ..ops._indexer import get_partition_order_by_raw_table + from ..ops._dimension_scaling import to_blocks + partition_order_mapping = get_partition_order_by_raw_table(table, data_manager.block_row_size) + to_block_func = functools.partial(to_blocks, dm=data_manager, partition_mappings=partition_order_mapping) + block_table = table.mapPartitions(to_block_func, + use_previous_behavior=False) + + return DataFrame( + ctx, + block_table, + partition_order_mapping, + data_manager + ) + + +def _agg_choices(ctx, + indexer, + choices, + regenerated_ids, + partition): + """ + indexer: (sample_id, (partition_id, block_offset)) + """ + choice_dict = dict() + choice_indexer = [] + for idx, choice in enumerate(choices): + if choice not in choice_dict: + current_l = len(choice_dict) + choice_dict[choice] = current_l + choice_indexer.append([]) + + choice_indexer[choice_dict[choice]].append(regenerated_ids[idx]) + + for choice, idx in choice_dict.items(): + choice_regenerated_sample_ids = choice_indexer[idx] + choice_indexer[idx] = (indexer[choice][0], choice_regenerated_sample_ids) + + return ctx.computing.parallelize(choice_indexer, + include_key=True, + partition=partition) diff --git a/python/fate/arch/federation/__init__.py b/python/fate/arch/federation/__init__.py index a20d642676..8e924d12a0 100644 --- a/python/fate/arch/federation/__init__.py +++ b/python/fate/arch/federation/__init__.py @@ -12,6 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ._type import FederationDataType, FederationEngine +from ._type import FederationDataType -__all__ = ["FederationEngine", "FederationDataType"] +__all__ = ["FederationDataType"] diff --git a/python/fate/arch/federation/_federation.py b/python/fate/arch/federation/_federation.py index 94ec2acb16..354c2b1bb8 100644 --- a/python/fate/arch/federation/_federation.py +++ b/python/fate/arch/federation/_federation.py @@ -22,9 +22,8 @@ from pickle import dumps as p_dumps from pickle import loads as p_loads -from fate.interface import FederationEngine, PartyMeta +from fate.arch.abc import CTableABC, FederationEngine, PartyMeta -from ..computing import CTableABC from ..federation import FederationDataType from ..federation._datastream import Datastream from ._gc import GarbageCollector diff --git a/python/fate/arch/federation/_gc.py b/python/fate/arch/federation/_gc.py index df47ee528b..27a47e42ab 100644 --- a/python/fate/arch/federation/_gc.py +++ b/python/fate/arch/federation/_gc.py @@ -18,7 +18,7 @@ import typing from collections import deque -from fate.interface import GarbageCollector as GarbageCollectorInterface +from fate.arch.abc import GarbageCollector as GarbageCollectorInterface LOGGER = logging.getLogger(__name__) diff --git a/python/fate/arch/federation/_type.py b/python/fate/arch/federation/_type.py index d4e7239195..fec2f95f3f 100644 --- a/python/fate/arch/federation/_type.py +++ b/python/fate/arch/federation/_type.py @@ -15,13 +15,6 @@ # -class FederationEngine(object): - EGGROLL = "EGGROLL" - RABBITMQ = "RABBITMQ" - STANDALONE = "STANDALONE" - PULSAR = "PULSAR" - - class FederationDataType(object): OBJECT = "obj" TABLE = "Table" diff --git a/python/fate/arch/federation/eggroll/_federation.py b/python/fate/arch/federation/eggroll/_federation.py index 30ccb05103..d210d0de6a 100644 --- a/python/fate/arch/federation/eggroll/_federation.py +++ b/python/fate/arch/federation/eggroll/_federation.py @@ -23,7 +23,7 @@ from eggroll.roll_pair.roll_pair import RollPair from eggroll.roll_site.roll_site import RollSiteContext -from fate.interface import FederationEngine, PartyMeta +from fate.arch.abc import FederationEngine, PartyMeta from ...computing.eggroll import Table from .._gc import GarbageCollector diff --git a/python/fate/arch/federation/osx/_federation.py b/python/fate/arch/federation/osx/_federation.py index 3969496a18..02b85e597d 100644 --- a/python/fate/arch/federation/osx/_federation.py +++ b/python/fate/arch/federation/osx/_federation.py @@ -17,8 +17,8 @@ import typing from logging import getLogger +from fate.arch.abc import PartyMeta from fate.arch.federation.osx import osx_pb2 -from fate.interface import PartyMeta from .._federation import FederationBase from .._nretry import nretry @@ -146,7 +146,7 @@ def _query_receive_topic(self, channel_info): LOGGER.debug(f"_query_receive_topic, channel_info={channel_info}") topic = channel_info._receive_topic if topic not in self._topic_ip_map: - LOGGER.info("query topic miss cache") + LOGGER.info(f"query topic {topic} miss cache ") response = channel_info.query() if response.code == "0": topic_info = osx_pb2.TopicInfo() diff --git a/python/fate/arch/federation/osx/_mq_channel.py b/python/fate/arch/federation/osx/_mq_channel.py index cacef6d8b3..b1a90cfcca 100644 --- a/python/fate/arch/federation/osx/_mq_channel.py +++ b/python/fate/arch/federation/osx/_mq_channel.py @@ -35,6 +35,7 @@ def __init__( self._namespace = namespace self._send_topic = send_topic self._receive_topic = receive_topic + self._index = 1 self._src_party_id = src_party_id self._src_role = src_role self._dst_party_id = dst_party_id @@ -51,7 +52,7 @@ def consume(self, offset=-1): LOGGER.debug(f"consume, offset={offset}, mq={self}") self._get_or_create_channel() meta = dict( - MessageTopic=self._send_topic, + MessageTopic=self._receive_topic, TechProviderCode="FATE", SourceNodeID=self._src_party_id, TargetNodeID=self._dst_party_id, @@ -64,9 +65,8 @@ def consume(self, offset=-1): inbound = osx_pb2.Inbound(metadata=meta) LOGGER.debug(f"consume, inbound={inbound}, mq={self}") result = self._stub.invoke(inbound) - LOGGER.debug(f"consume, result={result}, mq={self}") - print(result) - print(result.code) + LOGGER.debug(f"consume, result={result.code}, mq={self}") + return result @nretry @@ -91,10 +91,10 @@ def query(self): @nretry def produce(self, body, properties): - LOGGER.debug(f"produce body={body}, properties={properties}, mq={self}") + # LOGGER.debug(f"produce body={body}, properties={properties}, mq={self}") self._get_or_create_channel() meta = dict( - MessageTopic=self._receive_topic, + MessageTopic=self._send_topic, TechProviderCode="FATE", SourceNodeID=self._src_party_id, TargetNodeID=self._dst_party_id, @@ -105,9 +105,13 @@ def produce(self, body, properties): ) msg = osx_pb2.Message(head=bytes(json.dumps(properties), encoding="utf-8"), body=body) inbound = osx_pb2.Inbound(metadata=meta, payload=msg.SerializeToString()) - LOGGER.debug(f"produce inbound={inbound}, mq={self}") + # LOGGER.debug(f"produce inbound={inbound}, mq={self}") result = self._stub.invoke(inbound) - LOGGER.debug(f"produce result={result}, mq={self}") + + LOGGER.debug(f"produce {self._receive_topic} index {self._index} result={result.code}, mq={self}") + if result.code!="0": + raise RuntimeError(f"produce msg error ,code : {result.code} msg : {result.message}") + self._index+=1 return result @nretry @@ -115,7 +119,7 @@ def ack(self, offset): LOGGER.debug(f"ack offset={offset}, mq={self}") self._get_or_create_channel() meta = dict( - MessageTopic=self._send_topic, + MessageTopic=self._receive_topic, TechProviderCode="FATE", SourceNodeID=self._src_party_id, TargetNodeID=self._dst_party_id, @@ -126,7 +130,7 @@ def ack(self, offset): MessageOffSet=offset, ) inbound = osx_pb2.Inbound(metadata=meta) - LOGGER.debug(f"ack inbound={inbound}, mq={self}") + # LOGGER.debug(f"ack inbound={inbound}, mq={self}") result = self._stub.invoke(inbound) LOGGER.debug(f"ack result={result}, mq={self}") return result diff --git a/python/fate/arch/federation/pulsar/_federation.py b/python/fate/arch/federation/pulsar/_federation.py index 6f2bd76f6b..3db33204a7 100644 --- a/python/fate/arch/federation/pulsar/_federation.py +++ b/python/fate/arch/federation/pulsar/_federation.py @@ -17,7 +17,7 @@ import logging from typing import List, Optional -from fate.interface import PartyMeta +from fate.arch.abc import PartyMeta from .._federation import FederationBase from ._mq_channel import ( diff --git a/python/fate/arch/federation/rabbitmq/_federation.py b/python/fate/arch/federation/rabbitmq/_federation.py index c371908b7e..5222b8ac56 100644 --- a/python/fate/arch/federation/rabbitmq/_federation.py +++ b/python/fate/arch/federation/rabbitmq/_federation.py @@ -18,7 +18,7 @@ from logging import getLogger from typing import List, Optional -from fate.interface import PartyMeta +from fate.arch.abc import PartyMeta from .._federation import FederationBase from .._parties import Party diff --git a/python/fate/arch/federation/standalone/_federation.py b/python/fate/arch/federation/standalone/_federation.py index 266c2136d4..454195f732 100644 --- a/python/fate/arch/federation/standalone/_federation.py +++ b/python/fate/arch/federation/standalone/_federation.py @@ -15,7 +15,7 @@ import logging from typing import List, Tuple -from fate.interface import FederationEngine, PartyMeta +from fate.arch.abc import FederationEngine, PartyMeta from ..._standalone import Federation as RawFederation from ..._standalone import Table as RawTable diff --git a/python/fate/arch/histogram/__init__.py b/python/fate/arch/histogram/__init__.py new file mode 100644 index 0000000000..b299e04935 --- /dev/null +++ b/python/fate/arch/histogram/__init__.py @@ -0,0 +1,3 @@ +from ._histogram_distributed import DistributedHistogram +from ._histogram_local import Histogram +from ._histogram_sbt import HistogramBuilder diff --git a/python/fate/arch/histogram/_histogram_distributed.py b/python/fate/arch/histogram/_histogram_distributed.py new file mode 100644 index 0000000000..46fccdd603 --- /dev/null +++ b/python/fate/arch/histogram/_histogram_distributed.py @@ -0,0 +1,273 @@ +import logging +import typing +from typing import List, MutableMapping, Tuple, Optional + +import torch + +from fate.arch.abc import CTableABC +from ._histogram_local import Histogram +from ._histogram_splits import HistogramSplits +from .indexer import HistogramIndexer, Shuffler + +logger = logging.getLogger(__name__) + + +def _decrypt_func(sk_map, coder_map, squeezed, unpacker_map): + def _decrypt(split: HistogramSplits): + split = split.decrypt(sk_map) + if unpacker_map is not None: + split.i_unpack_decode(unpacker_map, squeezed) + return split + else: + split.i_decode(coder_map) + return split + + return _decrypt + + +class DistributedHistogram: + def __init__( + self, + splits: CTableABC[int, HistogramSplits], + k, + node_size, + node_data_size, + global_seed, + seed=None, + squeezed=False, + shuffled=False, + ): + self._splits = splits + self._k = k + self._node_size = node_size + self._node_data_size = node_data_size + self._squeezed = squeezed + self._shuffled = shuffled + self._seed = seed + self._global_seed = global_seed + + def __getstate__(self): + """ + Get the state for pickle. + + remove sensitive data before sending to other parties, such as: + - global_seed + """ + return self._splits, self._k, self._node_size, self._node_data_size, self._squeezed, self._shuffled + + def __setstate__(self, state): + self._splits, self._k, self._node_size, self._node_data_size, self._squeezed, self._shuffled = state + + def i_squeeze(self, squeeze_map: MutableMapping[str, typing.Tuple[int, int]]): + """ + Squeeze the histogram values. + + Args: + squeeze_map: name -> (pack_num, offset_bit) + """ + self._splits = self._splits.mapValues(lambda split: split.i_squeeze(squeeze_map)) + self._squeezed = True + + def i_shuffle_splits(self): + """ + Shuffle the histogram splits values. + """ + seed = self._seed + if seed is None: + return + self._splits = self._splits.mapValues(lambda split: split.i_shuffle(seed, False)) + self._shuffled = True + + def shuffle_splits(self): + """ + Shuffle the histogram splits values, return a new DistributedHistogram. + """ + seed = self._seed + if seed is None: + return self + splits = self._splits.mapValues(lambda split: split.shuffle(seed, False)) + return DistributedHistogram( + splits, self._k, self._node_size, self._node_data_size, self._global_seed, self._seed, self._squeezed, True + ) + + def compute_child(self, weak_child: "DistributedHistogram", mapping: List[Tuple[int, int, int, int]]): + """ + Compute the child histogram. + + Args: + weak_child: the splits of one child + mapping: the mapping from parent to child, + the mapping is a list of (parent_pos, weak_child_pos, target_weak_child_pos, target_strong_child_pos) + which means in logic: + output[target_weak_child_pos] = weak_child[weak_child_pos] + output[target_strong_child_pos] = self[parent_pos] - weak_child[weak_child_pos] + Examples: + # tree structure: + # -1 + # 0 1 + # 2 3 4 5 <-- parent nodes, node #4 is leaf node + # 6 *7 *8 9 *10 11 <-- child nodes, node #7, #8, #10 are weak child nodes + # + # pos parent_node weak_child_node output_node + # 0 #2 #7 #6 + # 1 #3 #8 #7 + # 2 #4 #10 #8 + # 3 #5 #9 + # 4 #10 + # 5 #11 + >>> parent = DistributedHistogram(...) # data for nodes stored in order [#2, #3, #4, #5] + >>> weak_child = DistributedHistogram(...) # data for nodes stored in order [#7, #8, #10] + >>> mapping = [ + >>> (0, 0, 1, 0), # pos for (#2, #7, #7, #6) + >>> (1, 1, 2, 3), # pos for (#3, #8, #8, #9) + >>> (3, 2, 4, 5) # pos for (#5, #10, #10, #11) + >>> ] + >>> child = parent.compute_child(weak_child, mapping) # data for nodes stored in order [#6, #7, #8, #9, #10, #11] + """ + # assert self._node_size == weak_child._node_size, 'node size not match, {} != {}'.format( + # self._node_size, weak_child._node_size + # ) + assert self._node_data_size == weak_child._node_data_size + splits = self._splits.join(weak_child._splits, lambda x, y: x.compute_child_splits(y, mapping)) + return DistributedHistogram( + splits, + weak_child._k, + len(mapping) * 2, + weak_child._node_data_size, + weak_child._global_seed, + weak_child._seed, + ) + + def i_sub_on_key(self, from_key: str, to_key: str): + """ + Subtract the histogram splits values on the given key. + + Args: + from_key: the start key + to_key: the end key + """ + self._splits = self._splits.mapValues(lambda split: split.i_sub_on_key(from_key, to_key)) + + def recover_feature_bins( + self, feature_bin_sizes, split_points: typing.Dict[int, int] + ) -> typing.Dict[int, typing.Tuple[int, int]]: + """ + Recover the feature bins from the split points. + + Args: + feature_bin_sizes: the feature bin sizes + split_points: nid -> split data index + + Returns: + nid -> (fid, bid) + """ + + if self._shuffled: + split_points = self._recover_from_split_shuffle(split_points) + + split_points = self._recover_from_global_shuffle(split_points) + return self._recover_histogram_position(split_points, feature_bin_sizes) + + def _recover_histogram_position( + self, split_points: typing.Dict[int, int], feature_bin_sizes + ) -> typing.Dict[int, typing.Tuple[int, int]]: + fid_bid = {} + indexer = HistogramIndexer(self._node_size, feature_bin_sizes) + for nid, index in split_points.items(): + _, fid, bid = indexer.get_reverse_position(index) + fid_bid[nid] = (fid, bid) + return fid_bid + + def _recover_from_global_shuffle(self, split_points: MutableMapping[int, int]): + if self._global_seed is None: + return split_points + shuffler = Shuffler(self._node_size, self._node_data_size, self._global_seed) + points = list(split_points.items()) + real_indexes = shuffler.get_reverse_indexes(step=1, indexes=[p[1] for p in points]) + out = {} + for (nid, _), index in zip(points, real_indexes): + out[nid] = index + return out + + def _recover_from_split_shuffle(self, split_points): + splits_info = list(self._splits_into_k(self._node_data_size, self._k)) + _size_mapping = {} + out = {} + for nid in split_points: + index = split_points[nid] + for split_info in splits_info: + if split_info[0] <= index < split_info[1]: + if split_info[1] - split_info[0] not in _size_mapping: + _size_mapping[split_info[1] - split_info[0]] = [(split_info, nid, index - split_info[0])] + else: + _size_mapping[split_info[1] - split_info[0]].append((split_info, nid, index - split_info[0])) + for size in _size_mapping: + shuffler = Shuffler(self._node_size, size, self._seed) + fixed_size_splits = _size_mapping[size] + for (split_info, nid, _), i in zip( + fixed_size_splits, shuffler.get_reverse_indexes(step=1, indexes=[p[2] for p in fixed_size_splits]) + ): + out[nid] = split_info[0] + i + return out + + @staticmethod + def _splits_into_k(n, k: int): + d, r = divmod(n, k) + start = 0 + for _ in range(k): + end = start + d + (r > 0) + yield start, end + start = end + r -= 1 + + def decrypt( + self, + sk_map: MutableMapping[str, typing.Any], + coder_map: MutableMapping[str, typing.Tuple[typing.Any, torch.dtype]], + unpacker_map: Optional[MutableMapping[str, typing.Tuple[typing.Any, int, int, int, int, int]]] = None, + ): + out = list(self._splits.mapValues(_decrypt_func(sk_map, coder_map, self._squeezed, unpacker_map)).collect()) + out.sort(key=lambda x: x[0]) + data = HistogramSplits.cat([split for _, split in out]) + return Histogram(HistogramIndexer(self._node_size, [self._node_data_size]), data) + + # def decrypt_(self, sk_map: MutableMapping[str, typing.Any]): + # """ + # Decrypt the histogram values. + # + # Args: + # sk_map: name -> sk + # """ + # table = self._table.mapValues(lambda split: split.i_decrypt(sk_map)) + # return DistributedHistogram(table, self._node_size, self._node_data_size, self._squeezed) + # + # def unpack_decode(self, coder_map: MutableMapping[str, typing.Tuple[typing.Any, int, int, int, int]]): + # """ + # Unpack and decode the histogram values. + # + # Args: + # coder_map: name -> (coder, pack_num, offset_bit, precision, total_num) + # """ + # table = self._table.mapValues(lambda split: split.i_unpack_decode(coder_map, self._squeezed)) + # return DistributedHistogram(table, self._node_size, self._node_data_size) + # + # def decode(self, coder_map: MutableMapping[str, typing.Tuple[typing.Any, torch.dtype]]): + # """ + # Decode the histogram values. + # + # Args: + # coder_map: name -> (coder, dtype) + # """ + # table = self._table.mapValues(lambda split: split.i_decode(coder_map)) + # return DistributedHistogram(table, self._node_size, self._node_data_size) + # + # def union(self) -> Histogram: + # """ + # Union the splits into one histogram. + # """ + # out = list(self._table.collect()) + # out.sort(key=lambda x: x[0]) + # return self.cat([split for _, split in out]) + # def cat(self, hists: typing.List["HistogramSplits"]) -> "Histogram": + # data = HistogramSplits.cat(hists) + # return Histogram(HistogramIndexer(self._node_size, [self._node_data_size]), data) diff --git a/python/fate/arch/histogram/_histogram_local.py b/python/fate/arch/histogram/_histogram_local.py new file mode 100644 index 0000000000..930a3ec81e --- /dev/null +++ b/python/fate/arch/histogram/_histogram_local.py @@ -0,0 +1,99 @@ +import logging +import typing + +from ._histogram_splits import HistogramSplits +from .indexer import HistogramIndexer, Shuffler +from .values import HistogramValuesContainer + +logger = logging.getLogger(__name__) + + +class Histogram: + def __init__(self, indexer: "HistogramIndexer", values: HistogramValuesContainer): + self._indexer = indexer + self._data = values + + def __str__(self): + return self._data.show(self._indexer) + + def to_dict(self, feature_names: typing.List[str] = None): + """ + Convert the histogram to a dict. + + the dict is structured as: + { + node_id: { + name: { + feature_id: { + bid: value + } + } + } + } + """ + histogram_dict = self._data.to_structured_dict(self._indexer) + if feature_names is not None: + histogram_dict_with_names = {} + for nid, node_data in histogram_dict.items(): + histogram_dict_with_names[nid] = {} + for name, feature_data in node_data.items(): + histogram_dict_with_names[nid][name] = {} + for fid, bid_data in feature_data.items(): + histogram_dict_with_names[nid][name][feature_names[fid]] = bid_data + return histogram_dict_with_names + else: + return histogram_dict + + @classmethod + def create(cls, num_node, feature_bin_sizes, values_schema: dict): + indexer = HistogramIndexer(num_node, feature_bin_sizes) + size = indexer.total_data_size() + return cls(indexer, HistogramValuesContainer.create(values_schema, size)) + + def i_update(self, fids, nids, targets, node_mapping): + if node_mapping is None: + positions = self._indexer.get_positions( + nids.flatten().detach().numpy().tolist(), fids.detach().numpy().tolist() + ) + if len(positions) == 0: + return self + self._data.i_update(targets, positions) + else: + positions, masks = self._indexer.get_positions_with_node_mapping( + nids.flatten().detach().numpy().tolist(), fids.detach().numpy().tolist(), node_mapping + ) + if len(positions) == 0: + return self + self._data.i_update_with_masks(targets, positions, masks) + return self + + def iadd(self, hist: "Histogram"): + self._data.iadd(hist._data) + return self + + def decrypt(self, sk_map: dict): + return Histogram(self._indexer, self._data.decrypt(sk_map)) + + def decode(self, coder_map: dict): + return Histogram(self._indexer, self._data.decode(coder_map)) + + def i_shuffle(self, seed, reverse=False): + shuffler = Shuffler(self._indexer.get_node_size(), self._indexer.get_node_axis_stride(), seed) + self._data.i_shuffle(shuffler, reverse=reverse) + return self + + def i_cumsum_bins(self): + self._data.i_cumsum_bins(self._indexer.global_flatten_bin_sizes()) + return self + + def reshape(self, feature_bin_sizes): + indexer = self._indexer.reshape(feature_bin_sizes) + return Histogram(indexer, self._data) + + def extract_data(self): + return self._data.extract_data(self._indexer) + + def to_splits(self, k) -> typing.Iterator[typing.Tuple[(int, "HistogramSplits")]]: + for pid, (start, end), indexes in self._indexer.splits_into_k(k): + data = self._data.intervals_slice(indexes) + yield pid, HistogramSplits(pid, self._indexer.node_size, start, end, data) diff --git a/python/fate/arch/histogram/_histogram_sbt.py b/python/fate/arch/histogram/_histogram_sbt.py new file mode 100644 index 0000000000..e528ff3654 --- /dev/null +++ b/python/fate/arch/histogram/_histogram_sbt.py @@ -0,0 +1,71 @@ +from ._histogram_distributed import DistributedHistogram +from ._histogram_local import Histogram + + +class HistogramBuilder: + def __init__( + self, + num_node, + feature_bin_sizes, + value_schemas, + global_seed=None, + seed=None, + node_mapping=None, + k=None, + enable_cumsum=True, + ): + self._num_node = num_node + self._feature_bin_sizes = feature_bin_sizes + self._node_data_size = sum(feature_bin_sizes) + self._value_schemas = value_schemas + self._global_seed = global_seed + self._seed = seed + self._node_mapping = node_mapping + self._enable_cumsum = enable_cumsum + self._k = k + + def __str__(self): + return f"<{self.__class__.__name__} node_size={self._num_node}, feature_bin_sizes={self._feature_bin_sizes}, node_data_size={self._node_data_size}, seed={self._global_seed}>" + + def statistic(self, data) -> "DistributedHistogram": + """ + Update the histogram with the data. + Args: + data: table with the following schema: + Returns: + ShuffledHistogram, the shuffled histogram + """ + if self._k is None: + self._k = data.partitions**2 + mapper = get_partition_hist_build_mapper( + self._num_node, + self._feature_bin_sizes, + self._value_schemas, + self._global_seed, + self._k, + self._node_mapping, + self._enable_cumsum, + ) + table = data.mapReducePartitions(mapper, lambda x, y: x.iadd(y)) + data = DistributedHistogram( + table, self._k, self._num_node, self._node_data_size, global_seed=self._global_seed, seed=self._seed + ) + return data + + +def get_partition_hist_build_mapper( + num_node, feature_bin_sizes, value_schemas, global_seed, k, node_mapping, enable_cumsum +): + def _partition_hist_build_mapper(part): + hist = Histogram.create(num_node, feature_bin_sizes, value_schemas) + for _, raw in part: + feature_ids, node_ids, targets = raw + hist.i_update(feature_ids, node_ids, targets, node_mapping) + if enable_cumsum: + hist.i_cumsum_bins() + if global_seed is not None: + hist.i_shuffle(global_seed) + splits = hist.to_splits(k) + return splits + + return _partition_hist_build_mapper diff --git a/python/fate/arch/histogram/_histogram_splits.py b/python/fate/arch/histogram/_histogram_splits.py new file mode 100644 index 0000000000..a458015db8 --- /dev/null +++ b/python/fate/arch/histogram/_histogram_splits.py @@ -0,0 +1,114 @@ +import logging +import typing +from typing import List, Tuple + +from .values import HistogramValuesContainer +from .indexer import Shuffler + +logger = logging.getLogger(__name__) + + +class HistogramSplits: + def __init__(self, sid, num_node, start, end, data): + self.sid = sid + self.num_node = num_node + self.start = start + self.end = end + self._data: HistogramValuesContainer = data + + def __str__(self): + result = f"{self.__class__.__name__}(start={self.start}, end={self.end}):\n" + result += str(self._data) + return result + + def __repr__(self): + return self.__str__() + + def iadd(self, other: "HistogramSplits"): + self._data.iadd(other._data) + return self + + def i_sub_on_key(self, from_key, to_key): + self._data.i_sub_on_key(from_key, to_key) + return self + + def compute_child_splits( + self: "HistogramSplits", weak_child_splits: "HistogramSplits", mapping: List[Tuple[int, int, int, int]] + ): + assert len(mapping) == weak_child_splits.num_node + assert self.end == weak_child_splits.end + assert self.start == weak_child_splits.start + assert self.sid == weak_child_splits.sid + size = self.end - self.start + positions = [] + for parent_pos, weak_child_pos, target_weak_child_pos, target_strong_child_pos in mapping: + target_weak_child_start = target_weak_child_pos * size + target_weak_child_end = (target_weak_child_pos + 1) * size + target_strong_child_start = target_strong_child_pos * size + target_strong_child_end = (target_strong_child_pos + 1) * size + parent_data_start = parent_pos * size + parent_data_end = (parent_pos + 1) * size + weak_child_data_start = weak_child_pos * size + weak_child_data_end = (weak_child_pos + 1) * size + positions.append( + ( + target_weak_child_start, + target_weak_child_end, + target_strong_child_start, + target_strong_child_end, + parent_data_start, + parent_data_end, + weak_child_data_start, + weak_child_data_end, + ) + ) + data = self._data.compute_child(weak_child_splits._data, positions, size * len(mapping) * 2) + return HistogramSplits(self.sid, 2 * weak_child_splits.num_node, self.start, self.end, data) + + def i_decrypt(self, sk_map): + self._data = self._data.decrypt(sk_map) + return self + + def decrypt(self, sk_map): + data = self._data.decrypt(sk_map) + return HistogramSplits(self.sid, self.num_node, self.start, self.end, data) + + def i_decode(self, coder_map): + self._data = self._data.decode(coder_map) + return self + + def i_unpack_decode(self, coder_map, squeezed): + unpacker_map = {} + for name, (coder, gh_pack_num, offset_bit, precision, squeeze_num) in coder_map.items(): + if squeezed: + pack_num = gh_pack_num * squeeze_num + else: + pack_num = gh_pack_num + total_num = (self.end - self.start) * self.num_node * gh_pack_num + unpacker_map[name] = (coder, pack_num, offset_bit, precision, total_num, gh_pack_num) + self._data = self._data.unpack_decode(unpacker_map) + return self + + def i_squeeze(self, squeeze_map): + self._data.i_squeeze(squeeze_map) + return self + + def i_shuffle(self, seed, reverse=False): + shuffler = Shuffler(self.num_node, self.end - self.start, seed) + self._data.i_shuffle(shuffler, reverse=reverse) + return self + + def shuffle(self, seed, reverse=False): + shuffler = Shuffler(self.num_node, self.end - self.start, seed) + data = self._data.shuffle(shuffler, reverse=reverse) + return HistogramSplits(self.sid, self.num_node, self.start, self.end, data) + + @classmethod + def cat(cls, splits: typing.List["HistogramSplits"]) -> "HistogramValuesContainer": + chunks_info = [] + chunks_values = [] + for split in splits: + chunks_info.append((split.num_node, split.end - split.start)) + chunks_values.append(split._data) + data = HistogramValuesContainer.cat(chunks_info, chunks_values) + return data diff --git a/python/fate/arch/histogram/indexer/__init__.py b/python/fate/arch/histogram/indexer/__init__.py new file mode 100644 index 0000000000..2f23500878 --- /dev/null +++ b/python/fate/arch/histogram/indexer/__init__.py @@ -0,0 +1,8 @@ +INDEXER_USE_PYTHON = True + +if INDEXER_USE_PYTHON: + from ._indexer import HistogramIndexer, Shuffler +# else: +# from fate_utils.histogram import HistogramIndexer, Shuffler + +__all__ = ["HistogramIndexer", "Shuffler"] diff --git a/python/fate/arch/histogram/indexer/_indexer.py b/python/fate/arch/histogram/indexer/_indexer.py new file mode 100644 index 0000000000..c7a6f2799d --- /dev/null +++ b/python/fate/arch/histogram/indexer/_indexer.py @@ -0,0 +1,226 @@ +from typing import List, Tuple, Dict + +import numpy as np +import torch + + +class Shuffler: + """ + Shuffler is used to shuffle the data. + + data is stored in a flatten array with the shape is (num_node * node_size * step) and the + shuffle is applied to the `node_size` dimension. + """ + + def __init__(self, num_node: int, node_size: int, seed: int): + self.seed = seed + self.num_node = num_node + self.node_size = node_size + + def _get_perm_indexes(self): + return [ + torch.randperm(self.node_size, generator=torch.Generator().manual_seed(self.seed)) + for _ in range(self.num_node) + ] + + def _get_global_perm_index(self): + index = torch.hstack([index + (nid * self.node_size) for nid, index in enumerate(self._get_perm_indexes())]) + return index + + def get_shuffle_index(self, step, reverse=False): + """ + get chunk shuffle index + """ + stepped = torch.arange(0, self.num_node * self.node_size * step).reshape(self.num_node * self.node_size, step) + indexes = stepped[self._get_global_perm_index(), :].flatten() + if reverse: + indexes = torch.argsort(indexes) + return indexes.detach().cpu().tolist() + + def get_reverse_indexes(self, step, indexes): + mapping = self.get_shuffle_index(step, reverse=False) + return [mapping[i] for i in indexes] + + +class HistogramIndexer: + """ + HistogramIndexer is used to index the `VectorizedHistogram` data. + + The data is stored in a flatten array with the shape is (num_node * sum(feature_bin_sizes)) and the + data is indexed by (node_id, feature_id, bin_id) in the flatten array. + At logical level, the data is indexed by (node_id, feature_id, bin_id) likes: + + node_id_0: + feature_id_0: + bin_id_0: data + bin_id_1: data + ... + feature_id_1: + bin_id_0: data + bin_id_1: data + ... + ... + node_id_1: + feature_id_0: + bin_id_0: data + bin_id_1: data + ... + feature_id_1: + bin_id_0: data + bin_id_1: data + ... + ... + ... + notice that the data is stored in the flatten array, so the index is calculated by: + position = node_id * sum(feature_bin_sizes) + feature_bin_sizes[feature_id] + bin_id + + which means the bin_size of each feature is not necessary to be the same but the sum of all feature_bin_sizes + should be the same. + + Notes: This class will be rewritten in rust in the future. + """ + + def __init__(self, node_size: int, feature_bin_sizes: List[int]): + self.node_size = node_size + self.feature_bin_sizes = feature_bin_sizes + self.feature_size = len(feature_bin_sizes) + self.feature_axis_stride = np.cumsum([0] + [feature_bin_sizes[i] for i in range(self.feature_size)]) + self.node_axis_stride = sum(feature_bin_sizes) + + def get_node_size(self): + return self.node_size + + def get_node_axis_stride(self): + return self.node_axis_stride + + def get_position(self, nid: int, fid: int, bid: int): + """ + get data position by node_id, feature_id, bin_id + Args: + nid: node id + fid: feature id + bid: bin id + + Returns: data position + """ + return nid * self.node_axis_stride + self.feature_axis_stride[fid] + bid + + def get_positions_with_node_mapping(self, nids: List[int], bids: List[List[int]], node_mapping: Dict[int, int]): + """ + get data positions by node_ids and bin_ids + Args: + nids: node ids + bids: bin ids + node_mapping: node mapping + + Returns: data positions + """ + assert len(nids) == len(bids), f"nids length {len(nids)} is not equal to bids length {len(bids)}" + positions = [] + masks = [] + for nid, bids in zip(nids, bids): + if nid in node_mapping: + positions.append([self.get_position(node_mapping[nid], fid, bid) for fid, bid in enumerate(bids)]) + masks.append(True) + else: + masks.append(False) + return positions, masks + + def get_positions(self, nids: List[int], bids: List[List[int]]): + """ + get data positions by node_ids and bin_ids + Args: + nids: node ids + bids: bin ids + + Returns: data positions + """ + positions = [] + assert len(nids) == len(bids), f"nids length {len(nids)} is not equal to bids length {len(bids)}" + for nid, bids in zip(nids, bids): + positions.append([self.get_position(nid, fid, bid) for fid, bid in enumerate(bids)]) + return positions + + def get_reverse_position(self, position) -> Tuple[int, int, int]: + """ + get node_id, feature_id, bin_id by data position + Args: + position: data position + + Returns: node_id, feature_id, bin_id + """ + nid = position // self.node_axis_stride + bid = position % self.node_axis_stride + for fid in range(self.feature_size): + if bid < self.feature_axis_stride[fid + 1]: + return nid, fid, bid - self.feature_axis_stride[fid] + + def get_node_intervals(self): + intervals = [] + for nid in range(self.node_size): + intervals.append((nid * self.node_axis_stride, (nid + 1) * self.node_axis_stride)) + return intervals + + def get_feature_position_ranges(self): + """ + get feature position and size + Returns: list of (feature_position, feature_bin_size) + """ + intervals = [] + for nid in range(self.node_size): + node_stride = nid * self.node_axis_stride + for fid in range(self.feature_size): + intervals.append( + (node_stride + self.feature_axis_stride[fid], node_stride + self.feature_axis_stride[fid + 1]) + ) + return intervals + + def total_data_size(self): + return self.node_size * self.node_axis_stride + + def one_node_data_size(self): + return self.node_axis_stride + + def global_flatten_bin_sizes(self): + return self.feature_bin_sizes * self.node_size + + def flatten_in_node(self): + return HistogramIndexer(self.node_size, [self.one_node_data_size()]) + + def squeeze_bins(self): + return HistogramIndexer(self.node_size, [1] * self.feature_size) + + def reshape(self, feature_bin_sizes): + return HistogramIndexer(self.node_size, feature_bin_sizes) + + def unflatten_indexes(self): + indexes = {} + for nid in range(self.node_size): + indexes[nid] = {} + for fid in range(self.feature_size): + indexes[nid][fid] = [] + for bid in range(self.feature_bin_sizes[fid]): + indexes[nid][fid].append(self.get_position(nid, fid, bid)) + return indexes + + def splits_into_k(self, k: int): + for pid, (start, end) in enumerate(self._splits_into_k(self.node_axis_stride, k)): + shift = self.node_axis_stride + yield pid, (start, end), [(start + nid * shift, end + nid * shift) for nid in range(self.node_size)] + + @staticmethod + def _splits_into_k(n, k: int): + d, r = divmod(n, k) + start = 0 + for _ in range(k): + end = start + d + (r > 0) + yield start, end + start = end + r -= 1 + + @staticmethod + def _find_split(n, k: int, i): + d, r = divmod(n, k) + if i < (d + 1) * r: + return i // (d + 1) + return r + (i - (d + 1) * r) // d diff --git a/python/fate/arch/histogram/values/__init__.py b/python/fate/arch/histogram/values/__init__.py new file mode 100644 index 0000000000..60f8fbb75a --- /dev/null +++ b/python/fate/arch/histogram/values/__init__.py @@ -0,0 +1,3 @@ +from ._values import HistogramValuesContainer + +__all__ = ["HistogramValuesContainer"] diff --git a/python/fate/arch/histogram/values/_cipher.py b/python/fate/arch/histogram/values/_cipher.py new file mode 100644 index 0000000000..9675af1ec0 --- /dev/null +++ b/python/fate/arch/histogram/values/_cipher.py @@ -0,0 +1,139 @@ +import logging +import typing +from typing import List, Tuple + +import torch + +from ._encoded import HistogramEncodedValues +from ._value import HistogramValues +from ..indexer import Shuffler + +logger = logging.getLogger(__name__) + + +class HistogramEncryptedValues(HistogramValues): + def __init__(self, pk, evaluator, data, coder, dtype: torch.dtype, size: int, stride: int): + self.stride = stride + self.data = data + self.pk = pk + self.coder = coder + self.dtype = dtype + self.size = size + self.evaluator = evaluator + + @classmethod + def zeros(cls, pk, evaluator, size: int, coder, dtype, stride: int): + return cls(pk, evaluator, evaluator.zeros(size * stride, dtype), coder, dtype, size, stride) + + def i_update(self, value, positions): + from fate.arch.tensor.phe import PHETensor + + if isinstance(value, PHETensor): + value = value.data + + return self.evaluator.i_update(self.pk, self.data, value, positions, self.stride) + + def i_update_with_masks(self, value, positions, masks): + from fate.arch.tensor.phe import PHETensor + + if isinstance(value, PHETensor): + value = value.data + + return self.evaluator.i_update_with_masks(self.pk, self.data, value, positions, masks, self.stride) + + def iadd(self, other): + self.evaluator.i_add(self.pk, self.data, other.data) + return self + + def slice(self, start, end): + return HistogramEncryptedValues( + self.pk, + self.evaluator, + self.evaluator.slice(self.data, start * self.stride, (end - start) * self.stride), + self.coder, + self.dtype, + self.size, + self.stride, + ) + + def intervals_slice(self, intervals: typing.List[typing.Tuple[int, int]]) -> "HistogramEncryptedValues": + intervals = [(start * self.stride, end * self.stride) for start, end in intervals] + data = self.evaluator.intervals_slice(self.data, intervals) + return HistogramEncryptedValues(self.pk, self.evaluator, data, self.coder, self.dtype, self.size, self.stride) + + def i_shuffle(self, shuffler: "Shuffler", reverse=False): + indices = shuffler.get_shuffle_index(step=self.stride, reverse=reverse) + self.evaluator.i_shuffle(self.pk, self.data, indices) + return self + + def shuffle(self, shuffler: "Shuffler", reverse=False): + indices = shuffler.get_shuffle_index(step=self.stride, reverse=reverse) + data = self.evaluator.shuffle(self.pk, self.data, indices) + return HistogramEncryptedValues(self.pk, self.evaluator, data, self.coder, self.dtype, self.size, self.stride) + + def chunking_sum(self, intervals: typing.List[typing.Tuple[int, int]]): + """ + sum bins in the given logical intervals + """ + intervals = [(start * self.stride, end * self.stride) for start, end in intervals] + data = self.evaluator.intervals_sum_with_step(self.pk, self.data, intervals, self.stride) + return HistogramEncryptedValues(self.pk, self.evaluator, data, self.coder, self.dtype, self.size, self.stride) + + def compute_child( + self, + weak_child: "HistogramEncryptedValues", + positions: List[Tuple[int, int, int, int, int, int, int, int]], + size: int, + ): + data = self.evaluator.zeros(size * self.stride, self.dtype) + for ( + target_weak_child_start, + target_weak_child_end, + target_strong_child_start, + target_strong_child_end, + parent_data_start, + parent_data_end, + weak_child_data_start, + weak_child_data_end, + ) in positions: + s = (parent_data_end - parent_data_start) * self.stride + self.evaluator.i_add( + self.pk, + data, + weak_child.data, + target_weak_child_start * self.stride, + weak_child_data_start * self.stride, + s, + ) + self.evaluator.i_add( + self.pk, data, self.data, target_strong_child_start * self.stride, parent_data_start * self.stride, s + ) + self.evaluator.i_sub( + self.pk, + data, + weak_child.data, + target_strong_child_start * self.stride, + weak_child_data_start * self.stride, + s, + ) + + return HistogramEncryptedValues(self.pk, self.evaluator, data, self.coder, self.dtype, self.size, self.stride) + + def decrypt(self, sk): + data = sk.decrypt_to_encoded(self.data) + return HistogramEncodedValues(data, self.size, self.dtype, self.stride) + + def squeeze(self, pack_num, offset_bit): + data = self.evaluator.pack_squeeze(self.data, pack_num, offset_bit, self.pk) + return HistogramEncryptedValues(self.pk, self.evaluator, data, self.coder, self.dtype, self.size, self.stride) + + def i_chunking_cumsum(self, chunk_sizes: typing.List[int]): + chunk_sizes = [num * self.stride for num in chunk_sizes] + self.evaluator.chunking_cumsum_with_step(self.pk, self.data, chunk_sizes, self.stride) + return self + + def __str__(self): + return f"" + + def extract_node_data(self, node_data_size, node_size): + raise NotImplementedError diff --git a/python/fate/arch/histogram/values/_encoded.py b/python/fate/arch/histogram/values/_encoded.py new file mode 100644 index 0000000000..fbc7a7586f --- /dev/null +++ b/python/fate/arch/histogram/values/_encoded.py @@ -0,0 +1,52 @@ +import logging + +import torch + +from ._plain import HistogramPlainValues +from ._value import HistogramValues + +logger = logging.getLogger(__name__) + + +class HistogramEncodedValues(HistogramValues): + def __init__(self, data, size: int, dtype: torch.dtype, stride: int): + self.data = data + self.size = size + self.dtype = dtype + self.stride = stride + + def decode_f64(self, coder): + return HistogramPlainValues(coder.decode_f64_vec(self.data), self.dtype, self.size, self.stride) + + def decode_i64(self, coder): + return HistogramPlainValues(coder.decode_i64_vec(self.data), self.dtype, self.size, self.stride) + + def decode_f32(self, coder): + return HistogramPlainValues(coder.decode_f32_vec(self.data), self.dtype, self.size, self.stride) + + def decode_i32(self, coder): + return HistogramPlainValues(coder.decode_i32_vec(self.data), self.dtype, self.size, self.stride) + + def decode(self, coder, dtype): + if dtype is None: + dtype = self.dtype + if dtype == torch.float64: + return self.decode_f64(coder) + elif dtype == torch.float32: + return self.decode_f32(coder) + elif dtype == torch.int64: + return self.decode_i64(coder) + elif dtype == torch.int32: + return self.decode_i32(coder) + else: + raise NotImplementedError + + def unpack(self, coder, pack_num, offset_bit, precision, total_num, stride): + data = coder.unpack_floats(self.data, offset_bit, pack_num, precision, total_num) + return HistogramPlainValues(data, self.dtype, self.size, stride) + + def slice(self, start, end): + if hasattr(self.data, "slice"): + return self.data.slice(start * self.stride, end * self.stride) + else: + return "" diff --git a/python/fate/arch/histogram/values/_plain.py b/python/fate/arch/histogram/values/_plain.py new file mode 100644 index 0000000000..c3eb7cf27a --- /dev/null +++ b/python/fate/arch/histogram/values/_plain.py @@ -0,0 +1,158 @@ +import logging +import typing +from typing import List, Tuple + +import torch + +from ._value import HistogramValues +from ..indexer import Shuffler + +logger = logging.getLogger(__name__) + + +class HistogramPlainValues(HistogramValues): + def __init__(self, data, dtype: torch.dtype, size: int, stride: int): + self.data = data + self.dtype = dtype + self.size = size + self.stride = stride + + def __str__(self): + return f"" + + def __repr__(self): + return str(self) + + @classmethod + def zeros(cls, size, stride, dtype=torch.float64): + return cls(torch.zeros(size * stride, dtype=dtype), dtype, size, stride) + + def intervals_slice(self, intervals: typing.List[typing.Tuple[int, int]]): + size = sum(e - s for s, e in intervals) + result = torch.zeros(size * self.stride, dtype=self.data.dtype) + start = 0 + for s, e in intervals: + end = start + (e - s) * self.stride + result[start:end] = self.data[s * self.stride : e * self.stride] + start = end + return HistogramPlainValues(result, self.dtype, size, self.stride) + + def iadd_slice(self, value, sa, sb, size): + size = size * self.stride + value = value.view(-1) + self.data[sa : sa + size] += value[sb : sb + size] + + def slice(self, start, end): + return HistogramPlainValues( + self.data[start * self.stride : end * self.stride], self.dtype, end - start, self.stride + ) + + def iadd(self, other): + self.data += other.data + + def i_update(self, value, positions): + if self.stride == 1: + index = torch.LongTensor(positions) + value = value.view(-1, 1).expand(-1, index.shape[1]).flatten() + index = index.flatten() + data = self.data + else: + index = torch.LongTensor(positions) + data = self.data.view(-1, self.stride) + value = ( + value.view(-1, self.stride) + .unsqueeze(1) + .expand(-1, index.shape[1], self.stride) + .reshape(-1, self.stride) + ) + index = index.flatten().unsqueeze(1).expand(-1, self.stride) + if self.data.dtype != value.dtype: + logger.warning(f"update value dtype {value.dtype} is not equal to data dtype {self.data.dtype}") + value = value.to(data.dtype) + data.scatter_add_(0, index, value) + + def i_update_with_masks(self, value, positions, masks): + if self.stride == 1: + value = value[masks] + index = torch.LongTensor(positions) + value = value.view(-1, 1).expand(-1, index.shape[1]).flatten() + index = index.flatten() + data = self.data + else: + index = torch.LongTensor(positions) + data = self.data.view(-1, self.stride) + value = value.view(-1, self.stride)[masks] + value = value.unsqueeze(1).expand(-1, index.shape[1], self.stride).reshape(-1, self.stride) + index = index.flatten().unsqueeze(1).expand(-1, self.stride) + if self.data.dtype != value.dtype: + logger.warning(f"update value dtype {value.dtype} is not equal to data dtype {self.data.dtype}") + value = value.to(data.dtype) + data.scatter_add_(0, index, value) + + def i_shuffle(self, shuffler: "Shuffler", reverse=False): + indices = shuffler.get_shuffle_index(step=self.stride, reverse=reverse) + self.data = self.data[indices] + + def shuffle(self, shuffler: "Shuffler", reverse=False): + indices = shuffler.get_shuffle_index(step=self.stride, reverse=reverse) + data = self.data[indices] + return HistogramPlainValues(data, self.dtype, self.size, self.stride) + + def i_chunking_cumsum(self, chunk_sizes: typing.List[int]): + data_view = self.data.view(-1, self.stride) + start = 0 + for num in chunk_sizes: + data_view[start : start + num, :] = data_view[start : start + num, :].cumsum(dim=0) + start += num + + def chunking_sum(self, intervals: typing.List[typing.Tuple[int, int]]): + size = len(intervals) + result = torch.zeros(size * self.stride, dtype=self.data.dtype) + data_view = self.data.view(-1, self.stride) + for i, (start, end) in enumerate(intervals): + result[i * self.stride : (i + 1) * self.stride] = data_view[start:end, :].sum(dim=0) + return HistogramPlainValues(result, self.dtype, size, self.stride) + + def compute_child( + self, weak_child: "HistogramPlainValues", positions: List[Tuple[int, int, int, int, int, int, int, int]], size + ): + data = torch.zeros(size * self.stride, dtype=self.data.dtype) + data_view = data.view(-1, self.stride) + + parent_data_view = self.data.view(-1, self.stride) + weak_child_data_view = weak_child.data.view(-1, self.stride) + + for ( + target_weak_child_start, + target_weak_child_end, + target_strong_child_start, + target_strong_child_end, + parent_data_start, + parent_data_end, + weak_child_data_start, + weak_child_data_end, + ) in positions: + # copy data from weak child to correct position + data_view[target_weak_child_start:target_weak_child_end] = weak_child_data_view[ + weak_child_data_start:weak_child_data_end + ] + # compute strong child data + data_view[target_strong_child_start:target_strong_child_end] = ( + parent_data_view[parent_data_start:parent_data_end] + - weak_child_data_view[weak_child_data_start:weak_child_data_end] + ) + return HistogramPlainValues(data, self.dtype, size, self.stride) + + @classmethod + def cat(cls, chunks_info: List[Tuple[int, int]], values: List["HistogramPlainValues"]): + data = [] + for (num_chunk, chunk_size), value in zip(chunks_info, values): + data.append(value.data.reshape(num_chunk, chunk_size, value.stride)) + data = torch.cat(data, dim=1) + size = data.shape[0] + dtype = data.dtype + data = data.flatten() + return cls(data, dtype, size, values[0].stride) + + def extract_node_data(self, node_data_size, node_size): + return list(self.data.reshape(node_size, node_data_size, self.stride)) diff --git a/python/fate/arch/histogram/values/_value.py b/python/fate/arch/histogram/values/_value.py new file mode 100644 index 0000000000..5713807a9f --- /dev/null +++ b/python/fate/arch/histogram/values/_value.py @@ -0,0 +1,60 @@ +import logging +import typing +from typing import List, Tuple + +from ..indexer import Shuffler + +logger = logging.getLogger(__name__) + + +class HistogramValues: + def iadd_slice(self, value, sa, sb, size): + raise NotImplementedError + + def i_update(self, value, positions): + raise NotImplementedError + + def i_update_with_masks(self, value, positions, masks): + raise NotImplementedError + + def iadd(self, other): + raise NotImplementedError + + def chunking_sum(self, intervals: typing.List[typing.Tuple[int, int]]): + raise NotImplementedError + + def compute_child(self, weak_child: "HistogramValues", positions: List[Tuple[int, int, int, int, int, int]], size): + ... + + def intervals_slice(self, intervals: typing.List[typing.Tuple[int, int]]): + raise NotImplementedError + + def i_shuffle(self, shuffler: "Shuffler", reverse=False): + raise NotImplementedError + + def shuffle(self, shuffler: "Shuffler", reverse=False): + raise NotImplementedError + + def slice(self, start, end): + raise NotImplementedError + + def decrypt(self, sk): + raise NotImplementedError + + def squeeze(self, pack_num, offset_bit): + raise NotImplementedError + + def unpack(self, coder, pack_num, offset_bit, precision, total_num, stride): + raise NotImplementedError + + def i_chunking_cumsum(self, chunk_sizes: typing.List[int]): + raise NotImplementedError + + def decode(self, coder, dtype): + raise NotImplementedError + + def cat(self, chunks_info, values): + raise NotImplementedError(f"{self.__class__.__name__}.cat") + + def extract_node_data(self, node_data_size, node_size): + raise NotImplementedError diff --git a/python/fate/arch/histogram/values/_values.py b/python/fate/arch/histogram/values/_values.py new file mode 100644 index 0000000000..2c2322a73e --- /dev/null +++ b/python/fate/arch/histogram/values/_values.py @@ -0,0 +1,210 @@ +import typing +from typing import MutableMapping +from ._value import HistogramValues +from ._plain import HistogramPlainValues +from ._cipher import HistogramEncryptedValues + +if typing.TYPE_CHECKING: + from ..indexer import HistogramIndexer + + +class HistogramValuesContainer(object): + def __init__(self, data: MutableMapping[str, HistogramValues]): + self._data = data + + @classmethod + def create(cls, values_schema: dict, size): + values_mapping = {} + for name, items in values_schema.items(): + stride = items.get("stride", 1) + if items["type"] == "ciphertext": + pk = items["pk"] + evaluator = items["evaluator"] + coder = items.get("coder") + dtype = items.get("dtype") + values_mapping[name] = HistogramEncryptedValues.zeros(pk, evaluator, size, coder, dtype, stride) + elif items["type"] == "plaintext": + import torch + + dtype = items.get("dtype", torch.float64) + values_mapping[name] = HistogramPlainValues.zeros(size, stride=stride, dtype=dtype) + else: + raise NotImplementedError + return HistogramValuesContainer(values_mapping) + + def __str__(self): + result = "" + for name, value in self._data.items(): + result += f"{name}: {value}\n" + return result + + def i_update(self, targets, positions): + for name, value in targets.items(): + self._data[name].i_update(value, positions) + return self + + def i_update_with_masks(self, targets, positions, masks): + for name, value in targets.items(): + self._data[name].i_update_with_masks(value, positions, masks) + return self + + def iadd(self, other: "HistogramValuesContainer"): + for name, values in other._data.items(): + if name in self._data: + self._data[name].iadd(values) + else: + self._data[name] = values + return self + + def i_sub_on_key(self, from_key, to_key): + left_value = self._data[from_key] + right_value = self._data[to_key] + if isinstance(left_value, HistogramEncryptedValues): + if isinstance(right_value, HistogramEncryptedValues): + assert left_value.stride == right_value.stride + left_value.data = left_value.evaluator.sub(left_value.pk, left_value.data, right_value.data) + elif isinstance(right_value, HistogramPlainValues): + assert left_value.stride == right_value.stride + if left_value.coder is None: + raise ValueError(f"coder is None, please set coder for i_sub_on_key({from_key}, {to_key})") + left_value.data = left_value.evaluator.sub_plain( + left_value.data, right_value.data, left_value.pk, left_value.coder + ) + else: + raise NotImplementedError + elif isinstance(left_value, HistogramPlainValues): + if isinstance(right_value, HistogramEncryptedValues): + assert left_value.stride == right_value.stride + if right_value.coder is None: + raise ValueError(f"coder is None, please set coder for i_sub_on_key({from_key}, {to_key})") + data = right_value.evaluator.rsub_plain( + right_value.data, left_value.data, right_value.pk, right_value.coder + ) + self._data[from_key] = HistogramEncryptedValues( + right_value.pk, + right_value.evaluator, + data, + right_value.coder, + right_value.dtype, + right_value.size, + right_value.stride, + ) + elif isinstance(right_value, HistogramPlainValues): + assert left_value.stride == right_value.stride + left_value.data = left_value.data - right_value.data + else: + raise NotImplementedError + else: + raise NotImplementedError + + def decrypt(self, sk_map: dict): + values_mapping = {} + for name, values in self._data.items(): + if name in sk_map: + values_mapping[name] = values.decrypt(sk_map[name]) + else: + values_mapping[name] = values + return HistogramValuesContainer(values_mapping) + + def decode(self, coder_map: dict): + values_mapping = {} + for name, values in self._data.items(): + if name in coder_map: + coder, dtype = coder_map[name] + values_mapping[name] = values.decode(coder, dtype) + else: + values_mapping[name] = values + return HistogramValuesContainer(values_mapping) + + def unpack_decode(self, unpacker_map: dict): + values_mapping = {} + for name, values in self._data.items(): + if name in unpacker_map: + unpacker, pack_num, offset_bit, precision, total_num, stride = unpacker_map[name] + values_mapping[name] = values.unpack(unpacker, pack_num, offset_bit, precision, total_num, stride) + else: + values_mapping[name] = values + return HistogramValuesContainer(values_mapping) + + def i_squeeze(self, squeeze_map): + for name, value in self._data.items(): + if name in squeeze_map: + pack_num, offset_bit = squeeze_map[name] + self._data[name] = value.squeeze(pack_num, offset_bit) + + def i_shuffle(self, shuffler, reverse=False): + for name, values in self._data.items(): + values.i_shuffle(shuffler, reverse=reverse) + + def shuffle(self, shuffler, reverse=False): + data = {} + for name, values in self._data.items(): + data[name] = values.shuffle(shuffler, reverse=reverse) + return HistogramValuesContainer(data) + + def i_cumsum_bins(self, intervals: list): + for name, values in self._data.items(): + values.i_chunking_cumsum(intervals) + + def intervals_slice(self, intervals: list): + result = {} + for name, values in self._data.items(): + result[name] = values.intervals_slice(intervals) + return HistogramValuesContainer(result) + + def extract_data(self, indexer: "HistogramIndexer"): + data = {} + for name, value_container in self._data.items(): + node_data_list = value_container.extract_node_data(indexer.node_axis_stride, indexer.node_size) + for nid, node_data in enumerate(node_data_list): + if nid not in data: + data[nid] = {} + data[nid][name] = node_data + return data + + def compute_child(self, weak_child: "HistogramValuesContainer", positions: list, size): + result = {} + for name, values in self._data.items(): + result[name] = values.compute_child(weak_child._data[name], positions, size) + return HistogramValuesContainer(result) + + @classmethod + def cat(cls, chunks_info, chunks_values: typing.List["HistogramValuesContainer"]) -> "HistogramValuesContainer": + data = {} + for chunk_values in chunks_values: + for name, value in chunk_values._data.items(): + if name not in data: + data[name] = [value] + else: + data[name].append(value) + for name, values in data.items(): + data[name] = values[0].cat(chunks_info, values) + return HistogramValuesContainer(data) + + def show(self, indexer): + result = "" + indexes = indexer.unflatten_indexes() + for nid, fids in indexes.items(): + result += f"node-{nid}:\n" + for fid, bids in fids.items(): + result += f"\tfeature-{fid}:\n" + for start in bids: + for name, value_container in self._data.items(): + values = value_container.slice(start, start + 1) + result += f"\t\t{name}: {values}" + result += "\n" + return result + + def to_structured_dict(self, indexer): + indexes = indexer.unflatten_indexes() + result = {} + for nid, fids in indexes.items(): + result[nid] = {} + for name, value_container in self._data.items(): + result[nid][name] = {} + for fid, bids in fids.items(): + result[nid][name][fid] = {} + for bid, start in enumerate(bids): + values = value_container.slice(start, start + 1) + result[nid][name][fid][bid] = values + return result diff --git a/python/fate/arch/protocol/__init__.py b/python/fate/arch/protocol/__init__.py new file mode 100644 index 0000000000..8cee78061e --- /dev/null +++ b/python/fate/arch/protocol/__init__.py @@ -0,0 +1,3 @@ +from .secure_aggregation import SecureAggregatorClient, SecureAggregatorServer + +__all__ = ["SecureAggregatorClient", "SecureAggregatorServer"] diff --git a/python/fate/arch/protocol/diffie_hellman/__init__.py b/python/fate/arch/protocol/diffie_hellman/__init__.py new file mode 100644 index 0000000000..42e7ee3b8c --- /dev/null +++ b/python/fate/arch/protocol/diffie_hellman/__init__.py @@ -0,0 +1,3 @@ +from fate_utils.secure_aggregation_helper import DiffieHellman + +__all__ = ["DiffieHellman"] diff --git a/python/fate/arch/protocol/phe/__init__.py b/python/fate/arch/protocol/phe/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/fate/arch/protocol/phe/mock.py b/python/fate/arch/protocol/phe/mock.py new file mode 100644 index 0000000000..2297f92382 --- /dev/null +++ b/python/fate/arch/protocol/phe/mock.py @@ -0,0 +1,443 @@ +from typing import List, Optional, Tuple + +import torch + +from .type import TensorEvaluator + +V = torch.Tensor + + +class EV: + def __init__(self, data): + self.data = data + + def __str__(self): + return f"" + + def __repr__(self): + return str(self) + + def tolist(self): + return [EV(x.clone().detach()) for x in self.data] + + +class FV: + def __init__(self, data): + self.data = data + + +class SK: + def __init__(self): + ... + + def decrypt_to_encoded(self, vec: EV) -> FV: + return FV(vec.data) + + +class PK: + def __init__(self): + ... + + def encrypt_encoded(self, vec: FV, obfuscate: bool) -> EV: + return EV(vec.data) + + def encrypt_encoded_scalar(self, val, obfuscate) -> EV: + return EV(val) + + +class Coder: + def __init__(self): + ... + + def pack_floats(self, float_tensor: V, offset_bit: int, pack_num: int, precision: int) -> FV: + return float_tensor + + def unpack_floats(self, packed: FV, offset_bit: int, pack_num: int, precision: int, total_num: int) -> V: + return torch.tensor(self.coder.unpack_floats(packed, offset_bit, pack_num, precision, total_num)) + + def pack_vec(self, vec: torch.LongTensor, num_shift_bit, num_elem_each_pack) -> FV: + return self.coder.pack_u64_vec(vec.detach().tolist(), num_shift_bit, num_elem_each_pack) + + def unpack_vec(self, vec: FV, num_shift_bit, num_elem_each_pack, total_num) -> torch.LongTensor: + return torch.LongTensor(self.coder.unpack_u64_vec(vec, num_shift_bit, num_elem_each_pack, total_num)) + + def encode_tensor(self, tensor: V, dtype: torch.dtype = None) -> FV: + if dtype is None: + dtype = tensor.dtype + return self.encode_vec(tensor.flatten(), dtype=dtype) + + def decode_tensor(self, tensor: FV, dtype: torch.dtype, shape: torch.Size = None, device=None) -> V: + data = self.decode_vec(tensor, dtype) + if shape is not None: + data = data.reshape(shape) + if device is not None: + data = data.to(device.to_torch_device()) + return data + + def encode_vec(self, vec: V, dtype: torch.dtype = None) -> FV: + if dtype is None: + dtype = vec.dtype + else: + if dtype != vec.dtype: + vec = vec.to(dtype=dtype) + if dtype == torch.float64: + return self.encode_f64_vec(vec) + if dtype == torch.float32: + return self.encode_f32_vec(vec) + if dtype == torch.int64: + return self.encode_i64_vec(vec) + if dtype == torch.int32: + return self.encode_i32_vec(vec) + raise NotImplementedError(f"{vec.dtype} not supported") + + def decode_vec(self, vec: FV, dtype: torch.dtype) -> V: + if dtype == torch.float64: + return self.decode_f64_vec(vec) + if dtype == torch.float32: + return self.decode_f32_vec(vec) + if dtype == torch.int64: + return self.decode_i64_vec(vec) + if dtype == torch.int32: + return self.decode_i32_vec(vec) + raise NotImplementedError(f"{dtype} not supported") + + def encode(self, val, dtype=None) -> FV: + if not isinstance(val, torch.Tensor): + val = torch.tensor(val) + assert val.ndim == 0, "only scalar supported" + if dtype is None: + dtype = val.dtype + val = val.item() + if dtype == torch.float64: + return self.encode_f64(val) + if dtype == torch.float32: + return self.encode_f32(val) + if dtype == torch.int64: + return self.encode_i64(val) + if dtype == torch.int32: + return self.encode_i32(val) + raise NotImplementedError(f"{dtype} not supported, val={val}, type={type(val)}") + + def encode_f64(self, val: float) -> FV: + return torch.tensor(val, dtype=torch.float64) + + def decode_f64(self, val): + return float(val.item()) + + def encode_i64(self, val: int): + return torch.tensor(val, dtype=torch.int64) + + def decode_i64(self, val): + return int(val.item()) + + def encode_f32(self, val: float): + return torch.tensor(val, dtype=torch.float32) + + def decode_f32(self, val): + return float(val.item()) + + def encode_i32(self, val: int): + return torch.tensor(val, dtype=torch.int32) + + def decode_i32(self, val): + return int(val.item()) + + def encode_f64_vec(self, vec: torch.Tensor): + return FV(vec.detach().flatten()) + + def decode_f64_vec(self, vec): + return vec.data.type(torch.float64) + + def encode_i64_vec(self, vec: torch.Tensor): + return FV(vec.detach().flatten()) + + def decode_i64_vec(self, vec): + return vec.data.type(torch.int64) + + def encode_f32_vec(self, vec: torch.Tensor): + return FV(vec.detach().flatten()) + + def decode_f32_vec(self, vec): + return vec.data.type(torch.float32) + + def encode_i32_vec(self, vec: torch.Tensor): + return FV(vec.detach().flatten()) + + def decode_i32_vec(self, vec): + return vec.data.type(torch.int32) + + +def keygen(key_size): + return SK(), PK(), Coder() + + +class evaluator(TensorEvaluator[EV, V, PK, Coder]): + @staticmethod + def add(a: EV, b: EV, pk: PK): + return EV(torch.add(a.data, b.data)) + + @staticmethod + def add_plain(a: EV, b: V, pk: PK, coder: Coder, output_dtype=None): + return EV(torch.add(a.data, pk.encrypt_encoded(coder.encode_tensor(b), obfuscate=False).data)) + + @staticmethod + def add_plain_scalar(a: EV, b, pk: PK, coder: Coder, output_dtype): + return EV(torch.add(a.data, pk.encrypt_encoded_scalar(coder.encode(b), obfuscate=False).data)) + + @staticmethod + def sub(a: EV, b: EV, pk: PK): + return EV(torch.sub(a.data, b.data)) + + @staticmethod + def sub_plain(a: EV, b: V, pk: PK, coder: Coder, output_dtype=None): + data = torch.sub(a.data, pk.encrypt_encoded(coder.encode_tensor(b), obfuscate=False).data) + if output_dtype is not None: + data = data.to(dtype=output_dtype) + return EV(data) + + @staticmethod + def sub_plain_scalar(a: EV, b, pk: PK, coder: Coder, output_dtype): + data = torch.sub(a.data, pk.encrypt_encoded_scalar(coder.encode(b), obfuscate=False).data) + if output_dtype is not None: + data = data.to(dtype=output_dtype) + return EV(data) + + @staticmethod + def rsub(a: EV, b: EV, pk: PK): + return EV(torch.rsub(a.data, b.data)) + + @staticmethod + def rsub_plain(a: EV, b: V, pk: PK, coder: Coder, output_dtype=None): + data = torch.rsub(a.data, pk.encrypt_encoded(coder.encode_tensor(b), obfuscate=False).data) + if output_dtype is not None: + data = data.to(dtype=output_dtype) + return EV(data) + + @staticmethod + def rsub_plain_scalar(a: EV, b, pk: PK, coder: Coder, output_dtype): + data = torch.rsub(a.data, pk.encrypt_encoded_scalar(coder.encode(b), obfuscate=False).data) + if output_dtype is not None: + data = data.to(dtype=output_dtype) + return EV(data) + + @staticmethod + def mul_plain(a: EV, b: V, pk: PK, coder: Coder, output_dtype=None): + data = torch.mul(a.data, coder.encode_tensor(b).data) + if output_dtype is not None: + data = data.to(dtype=output_dtype) + return EV(data) + + @staticmethod + def mul_plain_scalar(a: EV, b, pk: PK, coder: Coder, output_dtype): + data = torch.mul(a.data, coder.encode(b).data) + if output_dtype is not None: + data = data.to(dtype=output_dtype) + return EV(data) + + @staticmethod + def matmul(a: EV, b: V, a_shape, b_shape, pk: PK, coder: Coder, output_dtype): + left = a.data.reshape(a_shape) + right = b.data.reshape(b_shape) + target_type = torch.promote_types(a.data.dtype, b.data.dtype) + if left.dtype != target_type: + left = left.to(dtype=target_type) + if right.dtype != target_type: + right = right.to(dtype=target_type) + data = torch.matmul(left, right).flatten() + if output_dtype is not None: + data = data.to(dtype=output_dtype) + return EV(data) + + @staticmethod + def rmatmul(a: EV, b: V, a_shape, b_shape, pk: PK, coder: Coder, output_dtype): + right = a.data.reshape(a_shape) + left = b.data.reshape(b_shape) + target_type = torch.promote_types(a.data.dtype, b.data.dtype) + if left.dtype != target_type: + left = left.to(dtype=target_type) + if right.dtype != target_type: + right = right.to(dtype=target_type) + data = torch.matmul(left, right).flatten() + if output_dtype is not None: + data = data.to(dtype=output_dtype) + return EV(data) + + @staticmethod + def zeros(size, dtype) -> EV: + return EV(torch.zeros(size, dtype=dtype)) + + @staticmethod + def i_add(pk: PK, a: EV, b: EV, sa=0, sb=0, size: Optional[int] = None) -> None: + """ + inplace add, a[sa:sa+size] += b[sb:sb+size], if size is None, then size = min(a.size - sa, b.size - sb) + Args: + pk: the public key + a: the vector to add to + b: the vector to add + sa: the start index of a + sb: the start index of b + size: the size to add + """ + if size is None: + size = min(a.data.numel() - sa, b.data.numel() - sb) + a.data[sa : sa + size] += b.data[sb : sb + size] + + @staticmethod + def i_sub(pk: PK, a: EV, b: EV, sa=0, sb=0, size: Optional[int] = None) -> None: + """ + inplace add, a[sa:sa+size] += b[sb:sb+size], if size is None, then size = min(a.size - sa, b.size - sb) + Args: + pk: the public key + a: the vector to add to + b: the vector to add + sa: the start index of a + sb: the start index of b + size: the size to add + """ + if size is None: + size = min(a.data.numel() - sa, b.data.numel() - sb) + a.data[sa : sa + size] -= b.data[sb : sb + size] + + @staticmethod + def slice(a: EV, start: int, size: int) -> EV: + """ + slice a[start:start+size] + Args: + a: the vector to slice + start: the start index + size: the size to slice + + Returns: + the sliced vector + """ + return EV(a.data[start : start + size]) + + @staticmethod + def i_shuffle(pk: PK, a: EV, indices: torch.LongTensor) -> None: + """ + inplace shuffle, a = a[indices] + Args: + pk: public key, not used + a: the vector to shuffle + indices: the indices to shuffle + """ + shuffled = a.data[indices] + a.data.copy_(shuffled) + + @staticmethod + def shuffle(pk: PK, a: EV, indices: torch.LongTensor) -> EV: + """ + inplace shuffle, a = a[indices] + Args: + pk: public key, not used + a: the vector to shuffle + indices: the indices to shuffle + """ + shuffled = a.data[indices] + return EV(shuffled) + + @staticmethod + def i_update(pk: PK, a: EV, b: EV, positions, stride: int) -> None: + """ + inplace update, a[positions] += b[::stride] + Args: + pk: public key, not used + a: the vector to update + b: the vector to update with + positions: the positions to update + stride: the stride to update + """ + if stride == 1: + index = torch.LongTensor(positions) + value = b.data.view(-1, 1).expand(-1, index.shape[1]).flatten() + index = index.flatten() + data = a.data + else: + index = torch.LongTensor(positions) + data = a.data.view(-1, stride) + value = b.data.view(-1, stride).unsqueeze(1).expand(-1, index.shape[1], stride).reshape(-1, stride) + index = index.flatten().unsqueeze(1).expand(-1, stride) + try: + data.scatter_add_(0, index, value) + except Exception as e: + raise ValueError(f"data: {data.dtype}, value: {value.dtype}") from e + + @staticmethod + def i_update_with_masks(pk: PK, a: EV, b: EV, positions, masks, stride: int) -> None: + """ + inplace update, a[positions] += b[::stride] + Args: + pk: public key, not used + a: the vector to update + b: the vector to update with + positions: the positions to update + stride: the stride to update + """ + if stride == 1: + b = b.data[masks] + index = torch.LongTensor(positions) + value = b.data.view(-1, 1).expand(-1, index.shape[1]).flatten() + index = index.flatten() + data = a.data + else: + index = torch.LongTensor(positions) + data = a.data.view(-1, stride) + value = b.data.view(-1, stride)[masks] + value = value.unsqueeze(1).expand(-1, index.shape[1], stride).reshape(-1, stride) + index = index.flatten().unsqueeze(1).expand(-1, stride) + data.scatter_add_(0, index, value) + + @staticmethod + def intervals_slice(a: EV, intervals: List[Tuple[int, int]]) -> EV: + """ + slice in the given intervals + + for example: + intervals=[(0, 4), (6, 12)], a = [a0, a1, a2, a3, a4, a5, a6, a7,...] + then the result is [a0, a1, a2, a3, a6, a7, a8, a9, a10, a11] + """ + slices = [] + for start, end in intervals: + slices.append(a.data[start:end]) + return EV(torch.cat(slices)) + + @staticmethod + def cat(list: List[EV]) -> EV: + """ + concatenate the list of vectors + Args: + list: the list of vectors + + Returns: the concatenated vector + """ + + if list[0].data.dim() == 0: + return EV(torch.tensor([x.data for x in list])) + return EV(torch.cat([x.data for x in list])) + + @staticmethod + def chunking_cumsum_with_step(pk: PK, a: EV, chunk_sizes: List[int], step: int): + """ + chunking cumsum with step size + + for example: + if step=2, chunk_sizes=[4, 2, 6], a = [a0, a1, a2, a3, a4, a5, a6, a7,...a11] + then the result is [a0, a1, a0+a2, a1+a3, a4, a5, a6, a7, a6+a8, a7+a9, a6+a8+a10, a7+a9+a11] + Args: + pk: the public key + a: the vector to cumsum + chunk_sizes: the chunk sizes, must sum to a.size + step: the step size, cumsum with skip step-1 elements + Returns: + the cumsum result + """ + data_view = a.data.view(-1, step) + start = 0 + for num in chunk_sizes: + num = num // step + data_view[start : start + num, :] = data_view[start : start + num, :].cumsum(dim=0) + start += num + + @staticmethod + def pack_squeeze(a: EV, pack_num: int, shift_bit: int, pk: PK) -> EV: + return a.pack_squeeze(pack_num, shift_bit, pk.pk) diff --git a/python/fate/arch/protocol/phe/ou.py b/python/fate/arch/protocol/phe/ou.py new file mode 100644 index 0000000000..5223c8985a --- /dev/null +++ b/python/fate/arch/protocol/phe/ou.py @@ -0,0 +1,415 @@ +from typing import List, Optional, Tuple + +import torch +from fate_utils.ou import PK as _PK +from fate_utils.ou import SK as _SK +from fate_utils.ou import Coder as _Coder +from fate_utils.ou import Evaluator as _Evaluator +from fate_utils.ou import CiphertextVector, PlaintextVector +from fate_utils.ou import keygen as _keygen + +from .type import TensorEvaluator + +V = torch.Tensor +EV = CiphertextVector +FV = PlaintextVector + + +class SK: + def __init__(self, sk: _SK): + self.sk = sk + + def decrypt_to_encoded(self, vec: EV) -> FV: + return self.sk.decrypt_to_encoded(vec) + + +class PK: + def __init__(self, pk: _PK): + self.pk = pk + + def encrypt_encoded(self, vec: FV, obfuscate: bool) -> EV: + return self.pk.encrypt_encoded(vec, obfuscate) + + def encrypt_encoded_scalar(self, val, obfuscate) -> EV: + return self.pk.encrypt_encoded_scalar(val, obfuscate) + + +class Coder: + def __init__(self, coder: _Coder): + self.coder = coder + + def pack_floats(self, float_tensor: V, offset_bit: int, pack_num: int, precision: int) -> FV: + return self.coder.pack_floats(float_tensor.detach().tolist(), offset_bit, pack_num, precision) + + def unpack_floats(self, packed: FV, offset_bit: int, pack_num: int, precision: int, total_num: int) -> V: + return torch.tensor(self.coder.unpack_floats(packed, offset_bit, pack_num, precision, total_num)) + + def pack_vec(self, vec: torch.LongTensor, num_shift_bit, num_elem_each_pack) -> FV: + return self.coder.pack_u64_vec(vec.detach().tolist(), num_shift_bit, num_elem_each_pack) + + def unpack_vec(self, vec: FV, num_shift_bit, num_elem_each_pack, total_num) -> torch.LongTensor: + return torch.LongTensor(self.coder.unpack_u64_vec(vec, num_shift_bit, num_elem_each_pack, total_num)) + + def encode_tensor(self, tensor: V, dtype: torch.dtype = None) -> FV: + return self.encode_vec(tensor.flatten(), dtype=tensor.dtype) + + def decode_tensor(self, tensor: FV, dtype: torch.dtype, shape: torch.Size = None, device=None) -> V: + data = self.decode_vec(tensor, dtype) + if shape is not None: + data = data.reshape(shape) + if device is not None: + data = data.to(device.to_torch_device()) + return data + + def encode_vec(self, vec: V, dtype: torch.dtype = None) -> FV: + if dtype is None: + dtype = vec.dtype + else: + if dtype != vec.dtype: + vec = vec.to(dtype=dtype) + # if dtype == torch.float64: + # return self.encode_f64_vec(vec) + # if dtype == torch.float32: + # return self.encode_f32_vec(vec) + if dtype == torch.int64: + return self.encode_i64_vec(vec) + if dtype == torch.int32: + return self.encode_i32_vec(vec) + raise NotImplementedError(f"{vec.dtype} not supported") + + def decode_vec(self, vec: FV, dtype: torch.dtype) -> V: + # if dtype == torch.float64: + # return self.decode_f64_vec(vec) + # if dtype == torch.float32: + # return self.decode_f32_vec(vec) + if dtype == torch.int64: + return self.decode_i64_vec(vec) + if dtype == torch.int32: + return self.decode_i32_vec(vec) + raise NotImplementedError(f"{dtype} not supported") + + def encode(self, val, dtype=None) -> FV: + if isinstance(val, torch.Tensor): + assert val.ndim == 0, "only scalar supported" + dtype = val.dtype + val = val.item() + # if dtype == torch.float64: + # return self.encode_f64(val) + # if dtype == torch.float32: + # return self.encode_f32(val) + if dtype == torch.int64: + return self.encode_i64(val) + if dtype == torch.int32: + return self.encode_i32(val) + raise NotImplementedError(f"{dtype} not supported") + + # def encode_f64(self, val: float): + # return self.coder.encode_f64(val) + # + # def decode_f64(self, val): + # return self.coder.decode_f64(val) + + def encode_i64(self, val: int): + return self.coder.encode_u64(val) + + def decode_i64(self, val): + return self.coder.decode_u64(val) + + # def encode_f32(self, val: float): + # return self.coder.encode_f32(val) + # + # def decode_f32(self, val): + # return self.coder.decode_f32(val) + + def encode_i32(self, val: int): + return self.coder.encode_u32(val) + + def decode_i32(self, val): + return self.coder.decode_u32(val) + + # def encode_f64_vec(self, vec: torch.Tensor): + # vec = vec.detach().flatten() + # return self.coder.encode_f64_vec(vec.detach().numpy()) + # + # def decode_f64_vec(self, vec): + # return torch.tensor(self.coder.decode_f64_vec(vec)) + + def encode_i64_vec(self, vec: torch.Tensor): + vec = vec.detach().flatten() + return self.coder.encode_u64_vec(vec.detach().numpy().astype("uint64")) + + def decode_i64_vec(self, vec): + return torch.tensor(self.coder.decode_u64_vec(vec)) + + # def encode_f32_vec(self, vec: torch.Tensor): + # vec = vec.detach().flatten() + # return self.coder.encode_f32_vec(vec.detach().numpy()) + # + # def decode_f32_vec(self, vec): + # return torch.tensor(self.coder.decode_f32_vec(vec)) + + def encode_i32_vec(self, vec: torch.Tensor): + vec = vec.detach().flatten() + return self.coder.encode_u32_vec(vec.detach().numpy().astype("uint32")) + + def decode_i32_vec(self, vec): + return torch.tensor(self.coder.decode_u32_vec(vec)) + + +def keygen(key_size): + sk, pk, coder = _keygen(key_size) + return SK(sk), PK(pk), Coder(coder) + + +class evaluator(TensorEvaluator[EV, V, PK, Coder]): + @staticmethod + def add(a: EV, b: EV, pk: PK): + return a.add(pk.pk, b) + + @staticmethod + def add_plain(a: EV, b: V, pk: PK, coder: Coder, output_dtype=None): + if output_dtype is None: + output_dtype = b.dtype + encoded = coder.encode_tensor(b, dtype=output_dtype) + encrypted = pk.encrypt_encoded(encoded, obfuscate=False) + return a.add(pk.pk, encrypted) + + @staticmethod + def add_plain_scalar(a: EV, b, pk: PK, coder: Coder, output_dtype): + encoded = coder.encode(b, dtype=output_dtype) + encrypted = pk.encrypt_encoded_scalar(encoded, obfuscate=False) + return a.add_scalar(pk.pk, encrypted) + + @staticmethod + def sub(a: EV, b: EV, pk: PK): + return a.sub(pk.pk, b) + + @staticmethod + def sub_plain(a: EV, b: V, pk: PK, coder: Coder, output_dtype=None): + if output_dtype is None: + output_dtype = b.dtype + encoded = coder.encode_tensor(b, dtype=output_dtype) + encrypted = pk.encrypt_encoded(encoded, obfuscate=False) + return a.sub(pk.pk, encrypted) + + @staticmethod + def sub_plain_scalar(a: EV, b, pk: PK, coder: Coder, output_dtype): + encoded = coder.encode(b, dtype=output_dtype) + encrypted = pk.encrypt_encoded_scalar(encoded, obfuscate=False) + return a.sub_scalar(pk.pk, encrypted) + + @staticmethod + def rsub(a: EV, b: EV, pk: PK): + return a.rsub(pk.pk, b) + + @staticmethod + def rsub_plain(a: EV, b: V, pk: PK, coder: Coder, output_dtype=None): + if output_dtype is None: + output_dtype = b.dtype + encoded = coder.encode_tensor(b, dtype=output_dtype) + encrypted = pk.encrypt_encoded(encoded, obfuscate=False) + return a.rsub(pk.pk, encrypted) + + @staticmethod + def rsub_plain_scalar(a: EV, b, pk: PK, coder: Coder, output_dtype): + encoded = coder.encode(b, dtype=output_dtype) + encrypted = pk.encrypt_encoded_scalar(encoded, obfuscate=False) + return a.rsub_scalar(pk.pk, encrypted) + + @staticmethod + def mul_plain(a: EV, b: V, pk: PK, coder: Coder, output_dtype=None): + if output_dtype is None: + output_dtype = b.dtype + encoded = coder.encode_tensor(b, dtype=output_dtype) + return a.mul(pk.pk, encoded) + + @staticmethod + def mul_plain_scalar(a: EV, b, pk: PK, coder: Coder, output_dtype): + encoded = coder.encode(b, dtype=output_dtype) + return a.mul_scalar(pk.pk, encoded) + + @staticmethod + def matmul(a: EV, b: V, a_shape, b_shape, pk: PK, coder: Coder, output_dtype): + encoded = coder.encode_tensor(b, dtype=output_dtype) + # TODO: move this to python side so other protocols can use it without matmul support? + return a.matmul(pk.pk, encoded, a_shape, b_shape) + + @staticmethod + def rmatmul(a: EV, b: V, a_shape, b_shape, pk: PK, coder: Coder, output_dtype): + encoded = coder.encode_tensor(b, dtype=output_dtype) + return a.rmatmul(pk.pk, encoded, a_shape, b_shape) + + @staticmethod + def zeros(size, dtype) -> EV: + return CiphertextVector.zeros(size) + + @staticmethod + def i_add(pk: PK, a: EV, b: EV, sa=0, sb=0, size: Optional[int] = None) -> None: + """ + inplace add, a[sa:sa+size] += b[sb:sb+size], if size is None, then size = min(a.size - sa, b.size - sb) + Args: + pk: the public key + a: the vector to add to + b: the vector to add + sa: the start index of a + sb: the start index of b + size: the size to add + """ + if a is b: + a.iadd_vec_self(sa, sb, size, pk.pk) + else: + a.iadd_vec(b, sa, sb, size, pk.pk) + + @staticmethod + def i_sub(pk: PK, a: EV, b: EV, sa=0, sb=0, size: Optional[int] = None) -> None: + """ + inplace sub, a[sa:sa+size] += b[sb:sb+size], if size is None, then size = min(a.size - sa, b.size - sb) + Args: + pk: the public key + a: the vector to add to + b: the vector to add + sa: the start index of a + sb: the start index of b + size: the size to add + """ + if a is b: + a.isub_vec_self(sa, sb, size, pk.pk) + else: + a.isub_vec(b, sa, sb, size, pk.pk) + + @staticmethod + def slice(a: EV, start: int, size: int) -> EV: + """ + slice a[start:start+size] + Args: + a: the vector to slice + start: the start index + size: the size to slice + + Returns: + the sliced vector + """ + return a.slice(start, size) + + @staticmethod + def i_shuffle(pk: PK, a: EV, indices: torch.LongTensor) -> None: + """ + inplace shuffle, a = a[indices] + Args: + pk: public key, not used + a: the vector to shuffle + indices: the indices to shuffle + """ + a.i_shuffle(indices) + + @staticmethod + def shuffle(pk: PK, a: EV, indices: torch.LongTensor) -> EV: + """ + shuffle, out = a[indices] + Args: + pk: public key, not used + a: the vector to shuffle + indices: the indices to shuffle + """ + return a.shuffle(indices) + + @staticmethod + def i_update(pk: PK, a: EV, b: EV, positions, stride: int) -> None: + """ + inplace update, a[positions] += b[::stride] + Args: + pk: public key, not used + a: the vector to update + b: the vector to update with + positions: the positions to update + stride: the stride to update + """ + a.iupdate(b, positions, stride, pk.pk) + + @staticmethod + def i_update_with_masks(pk: PK, a: EV, b: EV, positions, masks, stride: int) -> None: + """ + inplace update, a[positions] += b[::stride] + Args: + pk: public key, not used + a: the vector to update + b: the vector to update with + positions: the positions to update + stride: the stride to update + """ + a.iupdate_with_masks(b, positions, masks, stride, pk.pk) + + @staticmethod + def intervals_slice(a: EV, intervals: List[Tuple[int, int]]) -> EV: + """ + slice in the given intervals + + for example: + intervals=[(0, 4), (6, 12)], a = [a0, a1, a2, a3, a4, a5, a6, a7,...] + then the result is [a0, a1, a2, a3, a6, a7, a8, a9, a10, a11] + """ + return a.intervals_slice(intervals) + + @staticmethod + def cat(list: List[EV]) -> EV: + """ + concatenate the list of vectors + Args: + list: the list of vectors + + Returns: the concatenated vector + """ + return _Evaluator.cat(list) + + @staticmethod + def chunking_cumsum_with_step(pk: PK, a: EV, chunk_sizes: List[int], step: int): + """ + chunking cumsum with step size + + for example: + if step=2, chunk_sizes=[4, 2, 6], a = [a0, a1, a2, a3, a4, a5, a6, a7,...a11] + then the result is [a0, a1, a0+a2, a1+a3, a4, a5, a6, a7, a6+a8, a7+a9, a6+a8+a10, a7+a9+a11] + Args: + pk: the public key + a: the vector to cumsum + chunk_sizes: the chunk sizes, must sum to a.size + step: the step size, cumsum with skip step-1 elements + Returns: + the cumsum result + """ + return a.chunking_cumsum_with_step(pk.pk, chunk_sizes, step) + + @staticmethod + def pack_squeeze(a: EV, pack_num: int, shift_bit: int, pk: PK) -> EV: + return a.pack_squeeze(pack_num, shift_bit, pk.pk) + + +def test_pack_float(): + offset_bit = 32 + precision = 16 + coder = Coder(_Coder()) + vec = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]) + packed = coder.pack_floats(vec, offset_bit, 2, precision) + unpacked = coder.unpack_floats(packed, offset_bit, 2, precision, 5) + assert torch.allclose(vec, unpacked, rtol=1e-3, atol=1e-3) + + +def test_pack_squeeze(): + offset_bit = 32 + precision = 16 + pack_num = 2 + pack_packed_num = 2 + vec1 = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]) + vec2 = torch.tensor([0.6, 0.7, 0.8, 0.9, 1.0]) + sk, pk, coder = keygen(1024) + a = coder.pack_floats(vec1, offset_bit, pack_num, precision) + ea = pk.encrypt_encoded(a, obfuscate=False) + b = coder.pack_floats(vec2, offset_bit, pack_num, precision) + eb = pk.encrypt_encoded(b, obfuscate=False) + ec = evaluator.add(ea, eb, pk) + + # pack packed encrypted + ec_pack = evaluator.pack_squeeze(ec, pack_packed_num, offset_bit * 2, pk) + c_pack = sk.decrypt_to_encoded(ec_pack) + c = coder.unpack_floats(c_pack, offset_bit, pack_num * pack_packed_num, precision, 5) + assert torch.allclose(vec1 + vec2, c, rtol=1e-3, atol=1e-3) diff --git a/python/fate/arch/protocol/phe/paillier.py b/python/fate/arch/protocol/phe/paillier.py new file mode 100644 index 0000000000..18c29453b5 --- /dev/null +++ b/python/fate/arch/protocol/phe/paillier.py @@ -0,0 +1,415 @@ +from typing import List, Optional, Tuple + +import torch +from fate_utils.paillier import PK as _PK +from fate_utils.paillier import SK as _SK +from fate_utils.paillier import Coder as _Coder +from fate_utils.paillier import Evaluator as _Evaluator +from fate_utils.paillier import CiphertextVector, PlaintextVector +from fate_utils.paillier import keygen as _keygen + +from .type import TensorEvaluator + +V = torch.Tensor +EV = CiphertextVector +FV = PlaintextVector + + +class SK: + def __init__(self, sk: _SK): + self.sk = sk + + def decrypt_to_encoded(self, vec: EV) -> FV: + return self.sk.decrypt_to_encoded(vec) + + +class PK: + def __init__(self, pk: _PK): + self.pk = pk + + def encrypt_encoded(self, vec: FV, obfuscate: bool) -> EV: + return self.pk.encrypt_encoded(vec, obfuscate) + + def encrypt_encoded_scalar(self, val, obfuscate) -> EV: + return self.pk.encrypt_encoded_scalar(val, obfuscate) + + +class Coder: + def __init__(self, coder: _Coder): + self.coder = coder + + def pack_floats(self, float_tensor: V, offset_bit: int, pack_num: int, precision: int) -> FV: + return self.coder.pack_floats(float_tensor.detach().tolist(), offset_bit, pack_num, precision) + + def unpack_floats(self, packed: FV, offset_bit: int, pack_num: int, precision: int, total_num: int) -> V: + return torch.tensor(self.coder.unpack_floats(packed, offset_bit, pack_num, precision, total_num)) + + def pack_vec(self, vec: torch.LongTensor, num_shift_bit, num_elem_each_pack) -> FV: + return self.coder.pack_u64_vec(vec.detach().tolist(), num_shift_bit, num_elem_each_pack) + + def unpack_vec(self, vec: FV, num_shift_bit, num_elem_each_pack, total_num) -> torch.LongTensor: + return torch.LongTensor(self.coder.unpack_u64_vec(vec, num_shift_bit, num_elem_each_pack, total_num)) + + def encode_tensor(self, tensor: V, dtype: torch.dtype = None) -> FV: + return self.encode_vec(tensor.flatten(), dtype=tensor.dtype) + + def decode_tensor(self, tensor: FV, dtype: torch.dtype, shape: torch.Size = None, device=None) -> V: + data = self.decode_vec(tensor, dtype) + if shape is not None: + data = data.reshape(shape) + if device is not None: + data = data.to(device.to_torch_device()) + return data + + def encode_vec(self, vec: V, dtype: torch.dtype = None) -> FV: + if dtype is None: + dtype = vec.dtype + else: + if dtype != vec.dtype: + vec = vec.to(dtype=dtype) + if dtype == torch.float64: + return self.encode_f64_vec(vec) + if dtype == torch.float32: + return self.encode_f32_vec(vec) + if dtype == torch.int64: + return self.encode_i64_vec(vec) + if dtype == torch.int32: + return self.encode_i32_vec(vec) + raise NotImplementedError(f"{vec.dtype} not supported") + + def decode_vec(self, vec: FV, dtype: torch.dtype) -> V: + if dtype == torch.float64: + return self.decode_f64_vec(vec) + if dtype == torch.float32: + return self.decode_f32_vec(vec) + if dtype == torch.int64: + return self.decode_i64_vec(vec) + if dtype == torch.int32: + return self.decode_i32_vec(vec) + raise NotImplementedError(f"{dtype} not supported") + + def encode(self, val, dtype=None) -> FV: + if isinstance(val, torch.Tensor): + assert val.ndim == 0, "only scalar supported" + dtype = val.dtype + val = val.item() + if dtype == torch.float64: + return self.encode_f64(val) + if dtype == torch.float32: + return self.encode_f32(val) + if dtype == torch.int64: + return self.encode_i64(val) + if dtype == torch.int32: + return self.encode_i32(val) + raise NotImplementedError(f"{dtype} not supported") + + def encode_f64(self, val: float): + return self.coder.encode_f64(val) + + def decode_f64(self, val): + return self.coder.decode_f64(val) + + def encode_i64(self, val: int): + return self.coder.encode_i64(val) + + def decode_i64(self, val): + return self.coder.decode_i64(val) + + def encode_f32(self, val: float): + return self.coder.encode_f32(val) + + def decode_f32(self, val): + return self.coder.decode_f32(val) + + def encode_i32(self, val: int): + return self.coder.encode_i32(val) + + def decode_i32(self, val): + return self.coder.decode_i32(val) + + def encode_f64_vec(self, vec: torch.Tensor): + vec = vec.detach().flatten() + return self.coder.encode_f64_vec(vec.detach().numpy()) + + def decode_f64_vec(self, vec): + return torch.tensor(self.coder.decode_f64_vec(vec)) + + def encode_i64_vec(self, vec: torch.Tensor): + vec = vec.detach().flatten() + return self.coder.encode_i64_vec(vec.detach().numpy()) + + def decode_i64_vec(self, vec): + return torch.tensor(self.coder.decode_i64_vec(vec)) + + def encode_f32_vec(self, vec: torch.Tensor): + vec = vec.detach().flatten() + return self.coder.encode_f32_vec(vec.detach().numpy()) + + def decode_f32_vec(self, vec): + return torch.tensor(self.coder.decode_f32_vec(vec)) + + def encode_i32_vec(self, vec: torch.Tensor): + vec = vec.detach().flatten() + return self.coder.encode_i32_vec(vec.detach().numpy()) + + def decode_i32_vec(self, vec): + return torch.tensor(self.coder.decode_i32_vec(vec)) + + +def keygen(key_size): + sk, pk, coder = _keygen(key_size) + return SK(sk), PK(pk), Coder(coder) + + +class evaluator(TensorEvaluator[EV, V, PK, Coder]): + @staticmethod + def add(a: EV, b: EV, pk: PK): + return a.add(pk.pk, b) + + @staticmethod + def add_plain(a: EV, b: V, pk: PK, coder: Coder, output_dtype=None): + if output_dtype is None: + output_dtype = b.dtype + encoded = coder.encode_tensor(b, dtype=output_dtype) + encrypted = pk.encrypt_encoded(encoded, obfuscate=False) + return a.add(pk.pk, encrypted) + + @staticmethod + def add_plain_scalar(a: EV, b, pk: PK, coder: Coder, output_dtype): + encoded = coder.encode(b, dtype=output_dtype) + encrypted = pk.encrypt_encoded_scalar(encoded, obfuscate=False) + return a.add_scalar(pk.pk, encrypted) + + @staticmethod + def sub(a: EV, b: EV, pk: PK): + return a.sub(pk.pk, b) + + @staticmethod + def sub_plain(a: EV, b: V, pk: PK, coder: Coder, output_dtype=None): + if output_dtype is None: + output_dtype = b.dtype + encoded = coder.encode_tensor(b, dtype=output_dtype) + encrypted = pk.encrypt_encoded(encoded, obfuscate=False) + return a.sub(pk.pk, encrypted) + + @staticmethod + def sub_plain_scalar(a: EV, b, pk: PK, coder: Coder, output_dtype): + encoded = coder.encode(b, dtype=output_dtype) + encrypted = pk.encrypt_encoded_scalar(encoded, obfuscate=False) + return a.sub_scalar(pk.pk, encrypted) + + @staticmethod + def rsub(a: EV, b: EV, pk: PK): + return a.rsub(pk.pk, b) + + @staticmethod + def rsub_plain(a: EV, b: V, pk: PK, coder: Coder, output_dtype=None): + if output_dtype is None: + output_dtype = b.dtype + encoded = coder.encode_tensor(b, dtype=output_dtype) + encrypted = pk.encrypt_encoded(encoded, obfuscate=False) + return a.rsub(pk.pk, encrypted) + + @staticmethod + def rsub_plain_scalar(a: EV, b, pk: PK, coder: Coder, output_dtype): + encoded = coder.encode(b, dtype=output_dtype) + encrypted = pk.encrypt_encoded_scalar(encoded, obfuscate=False) + return a.rsub_scalar(pk.pk, encrypted) + + @staticmethod + def mul_plain(a: EV, b: V, pk: PK, coder: Coder, output_dtype=None): + if output_dtype is None: + output_dtype = b.dtype + encoded = coder.encode_tensor(b, dtype=output_dtype) + return a.mul(pk.pk, encoded) + + @staticmethod + def mul_plain_scalar(a: EV, b, pk: PK, coder: Coder, output_dtype): + encoded = coder.encode(b, dtype=output_dtype) + return a.mul_scalar(pk.pk, encoded) + + @staticmethod + def matmul(a: EV, b: V, a_shape, b_shape, pk: PK, coder: Coder, output_dtype): + encoded = coder.encode_tensor(b, dtype=output_dtype) + # TODO: move this to python side so other protocols can use it without matmul support? + return a.matmul(pk.pk, encoded, a_shape, b_shape) + + @staticmethod + def rmatmul(a: EV, b: V, a_shape, b_shape, pk: PK, coder: Coder, output_dtype): + encoded = coder.encode_tensor(b, dtype=output_dtype) + return a.rmatmul(pk.pk, encoded, a_shape, b_shape) + + @staticmethod + def zeros(size, dtype) -> EV: + return CiphertextVector.zeros(size) + + @staticmethod + def i_add(pk: PK, a: EV, b: EV, sa=0, sb=0, size: Optional[int] = None) -> None: + """ + inplace add, a[sa:sa+size] += b[sb:sb+size], if size is None, then size = min(a.size - sa, b.size - sb) + Args: + pk: the public key + a: the vector to add to + b: the vector to add + sa: the start index of a + sb: the start index of b + size: the size to add + """ + if a is b: + a.iadd_vec_self(sa, sb, size, pk.pk) + else: + a.iadd_vec(b, sa, sb, size, pk.pk) + + @staticmethod + def i_sub(pk: PK, a: EV, b: EV, sa=0, sb=0, size: Optional[int] = None) -> None: + """ + inplace sub, a[sa:sa+size] += b[sb:sb+size], if size is None, then size = min(a.size - sa, b.size - sb) + Args: + pk: the public key + a: the vector to add to + b: the vector to add + sa: the start index of a + sb: the start index of b + size: the size to add + """ + if a is b: + a.isub_vec_self(sa, sb, size, pk.pk) + else: + a.isub_vec(b, sa, sb, size, pk.pk) + + @staticmethod + def slice(a: EV, start: int, size: int) -> EV: + """ + slice a[start:start+size] + Args: + a: the vector to slice + start: the start index + size: the size to slice + + Returns: + the sliced vector + """ + return a.slice(start, size) + + @staticmethod + def i_shuffle(pk: PK, a: EV, indices: torch.LongTensor) -> None: + """ + inplace shuffle, a = a[indices] + Args: + pk: public key, not used + a: the vector to shuffle + indices: the indices to shuffle + """ + a.i_shuffle(indices) + + @staticmethod + def shuffle(pk: PK, a: EV, indices: torch.LongTensor) -> EV: + """ + shuffle, out = a[indices] + Args: + pk: public key, not used + a: the vector to shuffle + indices: the indices to shuffle + """ + return a.shuffle(indices) + + @staticmethod + def i_update(pk: PK, a: EV, b: EV, positions, stride: int) -> None: + """ + inplace update, a[positions] += b[::stride] + Args: + pk: public key, not used + a: the vector to update + b: the vector to update with + positions: the positions to update + stride: the stride to update + """ + a.iupdate(b, positions, stride, pk.pk) + + @staticmethod + def i_update_with_masks(pk: PK, a: EV, b: EV, positions, masks, stride: int) -> None: + """ + inplace update, a[positions] += b[::stride] + Args: + pk: public key, not used + a: the vector to update + b: the vector to update with + positions: the positions to update + stride: the stride to update + """ + a.iupdate_with_masks(b, positions, masks, stride, pk.pk) + + @staticmethod + def intervals_slice(a: EV, intervals: List[Tuple[int, int]]) -> EV: + """ + slice in the given intervals + + for example: + intervals=[(0, 4), (6, 12)], a = [a0, a1, a2, a3, a4, a5, a6, a7,...] + then the result is [a0, a1, a2, a3, a6, a7, a8, a9, a10, a11] + """ + return a.intervals_slice(intervals) + + @staticmethod + def cat(list: List[EV]) -> EV: + """ + concatenate the list of vectors + Args: + list: the list of vectors + + Returns: the concatenated vector + """ + return _Evaluator.cat(list) + + @staticmethod + def chunking_cumsum_with_step(pk: PK, a: EV, chunk_sizes: List[int], step: int): + """ + chunking cumsum with step size + + for example: + if step=2, chunk_sizes=[4, 2, 6], a = [a0, a1, a2, a3, a4, a5, a6, a7,...a11] + then the result is [a0, a1, a0+a2, a1+a3, a4, a5, a6, a7, a6+a8, a7+a9, a6+a8+a10, a7+a9+a11] + Args: + pk: the public key + a: the vector to cumsum + chunk_sizes: the chunk sizes, must sum to a.size + step: the step size, cumsum with skip step-1 elements + Returns: + the cumsum result + """ + return a.chunking_cumsum_with_step(pk.pk, chunk_sizes, step) + + @staticmethod + def pack_squeeze(a: EV, pack_num: int, shift_bit: int, pk: PK) -> EV: + return a.pack_squeeze(pack_num, shift_bit, pk.pk) + + +def test_pack_float(): + offset_bit = 32 + precision = 16 + coder = Coder(_Coder()) + vec = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]) + packed = coder.pack_floats(vec, offset_bit, 2, precision) + unpacked = coder.unpack_floats(packed, offset_bit, 2, precision, 5) + assert torch.allclose(vec, unpacked, rtol=1e-3, atol=1e-3) + + +def test_pack_squeeze(): + offset_bit = 32 + precision = 16 + pack_num = 2 + pack_packed_num = 2 + vec1 = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]) + vec2 = torch.tensor([0.6, 0.7, 0.8, 0.9, 1.0]) + sk, pk, coder = keygen(1024) + a = coder.pack_floats(vec1, offset_bit, pack_num, precision) + ea = pk.encrypt_encoded(a, obfuscate=False) + b = coder.pack_floats(vec2, offset_bit, pack_num, precision) + eb = pk.encrypt_encoded(b, obfuscate=False) + ec = evaluator.add(ea, eb, pk) + + # pack packed encrypted + ec_pack = evaluator.pack_squeeze(ec, pack_packed_num, offset_bit * 2, pk) + c_pack = sk.decrypt_to_encoded(ec_pack) + c = coder.unpack_floats(c_pack, offset_bit, pack_num * pack_packed_num, precision, 5) + assert torch.allclose(vec1 + vec2, c, rtol=1e-3, atol=1e-3) diff --git a/python/fate/arch/protocol/phe/type.py b/python/fate/arch/protocol/phe/type.py new file mode 100644 index 0000000000..e2440a4c26 --- /dev/null +++ b/python/fate/arch/protocol/phe/type.py @@ -0,0 +1,22 @@ +from typing import Generic, TypeVar + +EV = TypeVar("EV") +V = TypeVar("V") +PK = TypeVar("PK") +Coder = TypeVar("Coder") + + +class TensorEvaluator(Generic[EV, V, PK, Coder]): + @staticmethod + def add(a: EV, b: EV, pk: PK) -> EV: + ... + + @staticmethod + def add_plain(a: EV, b: V, pk: PK, coder: Coder, output_dtype=None) -> EV: + ... + + @staticmethod + def add_plain_scalar(a: EV, b, pk: PK, coder: Coder, output_dtype) -> EV: + encoded = coder.encode(b, dtype=output_dtype) + encrypted = pk.encrypt_encoded_scalar(encoded, obfuscate=False) + return a.add_scalar(pk.pk, encrypted) diff --git a/python/fate/interface/_data_io.py b/python/fate/arch/protocol/psi/__init__.py similarity index 90% rename from python/fate/interface/_data_io.py rename to python/fate/arch/protocol/psi/__init__.py index 8f31d64549..25e42de9c5 100644 --- a/python/fate/interface/_data_io.py +++ b/python/fate/arch/protocol/psi/__init__.py @@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Protocol - +# -class Dataframe(Protocol): - ... +from ._psi_run import psi_run diff --git a/python/fate/arch/tensor/types/__init__.py b/python/fate/arch/protocol/psi/_psi_run.py similarity index 68% rename from python/fate/arch/tensor/types/__init__.py rename to python/fate/arch/protocol/psi/_psi_run.py index 4753548d9e..a53fa27726 100644 --- a/python/fate/arch/tensor/types/__init__.py +++ b/python/fate/arch/protocol/psi/_psi_run.py @@ -12,12 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +# +from .ecdh._run import psi_ecdh -from ._dstorage import DStorage -from ._dtype import dtype -from ._lstorage import LStorage -from ._shape import DAxis, Shape -Storage = Union[LStorage, DStorage] -__all__ = ["dtype", "Shape", "DAxis", "LStorage", "DStorage", "Storage"] +def psi_run(ctx, df, protocol="ecdh_psi", curve_type="curve25519"): + if protocol == "ecdh_psi": + return psi_ecdh(ctx, df, curve_type=curve_type) + else: + raise ValueError(f"PSI protocol={protocol} does not implemented yet.") diff --git a/python/fate/arch/tensor/storage/local/device/cpu/_ops.py b/python/fate/arch/protocol/psi/ecdh/__init__.py similarity index 95% rename from python/fate/arch/tensor/storage/local/device/cpu/_ops.py rename to python/fate/arch/protocol/psi/ecdh/__init__.py index 8f56bb7e8a..878d3a9c5d 100644 --- a/python/fate/arch/tensor/storage/local/device/cpu/_ops.py +++ b/python/fate/arch/protocol/psi/ecdh/__init__.py @@ -12,4 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -custom_ops = dict(slice=slice) +# diff --git a/python/fate/arch/protocol/psi/ecdh/_run.py b/python/fate/arch/protocol/psi/ecdh/_run.py new file mode 100644 index 0000000000..f7b1f84cd8 --- /dev/null +++ b/python/fate/arch/protocol/psi/ecdh/_run.py @@ -0,0 +1,185 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import functools +import logging + + +from fate.arch.dataframe import DataFrame +from fate_utils.psi import Curve25519 + + +logger = logging.getLogger(__name__) + +GUEST_FIRST_SIGN = "guest_first_sign" +HOST_FIRST_SIGN = "host_first_sign" +GUEST_SECOND_SIGN = "guest_second_sign" +HOST_INDEXER = "host_indexer" + + +def _encrypt_bytes(values, curve: Curve25519 = None): + id_list = [bytes(_id, "utf8") for _id in values[0]] + return curve.encrypt_vec(id_list) + + +def _diffie_hellman(values, curve: Curve25519 = None): + return curve.diffie_hellman_vec(values) + + +def _flat_block_with_possible_duplicate_keys(block_table, duplicate_allow=False): + """ + row_value: (encrypt_id, [(block_id, _offset)]) + """ + def _mapper(kvs): + for block_id, id_list in kvs: + for _i, _id in enumerate(id_list): + yield _id, [(block_id, _i)] + + def _reducer(v1, v2): + if not v1: + return v2 + if not v2: + return v1 + + if not duplicate_allow: + raise ValueError("duplicate match_id detect") + + return v1 + v2 + + return block_table.mapReducePartitions(_mapper, _reducer) + + +def _flat_block_key(intersect_id): + """ + key=eid, value = ((guest_block_id, guest_block_offset), [(host0_block_id, host0_block_offset)...]) + """ + def _mapper(kvs): + for _, value in kvs: + guest_loc = value[0] + host_loc = value[1] + for guest_block_id, guest_offset in guest_loc: + yield (guest_block_id, guest_offset), host_loc + + return intersect_id.mapPartitions(_mapper, use_previous_behavior=False) + + +def psi_ecdh(ctx, df: DataFrame, curve_type="curve25519", **kwargs): + if curve_type != "curve25519": + raise ValueError(f"Only support curve25519, curve_type={curve_type} is not implemented yet") + + if ctx.is_on_guest: + return guest_run(ctx, df, curve_type, **kwargs) + else: + return host_run(ctx, df, curve_type, **kwargs) + + +def guest_run(ctx, df: DataFrame, curve_type="curve25519", **kwargs): + curve = Curve25519() + match_id = df.match_id.block_table + + encrypt_func = functools.partial(_encrypt_bytes, curve=curve) + guest_first_sign_match_id = match_id.mapValues(encrypt_func) + ctx.hosts.put(GUEST_FIRST_SIGN, guest_first_sign_match_id) + + host_first_sign_match_ids = ctx.hosts.get(HOST_FIRST_SIGN) + host_second_sign_match_ids = [] + + dh_func = functools.partial(_diffie_hellman, curve=curve) + + for i, host_first_sign_match_id in enumerate(host_first_sign_match_ids): + host_second_sign_match_ids.append( + _flat_block_with_possible_duplicate_keys( + host_first_sign_match_ids[i].mapValues(dh_func), + duplicate_allow=False + ) + ) + + guest_second_sign_match_ids = ctx.hosts.get(GUEST_SECOND_SIGN) + + flat_intersect_id = None + for guest_second_sign_id, host_second_sign_match_id in zip(guest_second_sign_match_ids, host_second_sign_match_ids): + intersect_eid = guest_second_sign_id.join( + host_second_sign_match_id, lambda id_list_l, id_list_r: (id_list_l, id_list_r) + ) + intersect_single = _flat_block_key(intersect_eid) + if not flat_intersect_id: + flat_intersect_id = intersect_single + else: + flat_intersect_id = flat_intersect_id.join( + intersect_single, lambda id_list_l, id_list_r: id_list_l + id_list_r + ) + + """ + a. flatmap=> + key=(bid, offset), value=[(host0_bid, host0_offset)...] + b. df => flatMap + key=(bid, offset), value=(sample_id, data) + c. (bid, offset), (host + """ + flatten_df = df.flatten(key_type="block_id", with_sample_id=True) + intersect_with_offset_ids = flatten_df.join(flat_intersect_id, lambda vl, vr: (vl, vr)) + """ + host_indexer: key=(block_id, offset), value=(sample_id, (bid, offset)) + """ + for host_id in range(len(guest_second_sign_match_ids)): + host_indexer = intersect_with_offset_ids.mapValues(lambda v: (v[0][0], v[1][host_id])) + ctx.hosts[host_id].put(HOST_INDEXER, host_indexer) + + intersect_guest_data = intersect_with_offset_ids.mapValues(lambda v: v[0]) + + guest_df = DataFrame.from_flatten_data(ctx, intersect_guest_data, df.data_manager, key_type="block_id") + ctx.metrics.log_metrics({"intersect_count": guest_df.shape[0]}, name="intersect_id_count", type="custom") + + """ + the following just for debug, need to be delete + """ + # ids = [v[0] for k, v in sorted(guest_df.block_table.collect())] + # logger.debug(f"intersect ids is: {ids}") + + return guest_df + + +def host_run(ctx, df: DataFrame, curve_type, **kwargs): + curve = Curve25519() + match_id = df.match_id.block_table + + encrypt_func = functools.partial(_encrypt_bytes, curve=curve) + host_first_sign_match_id = match_id.mapValues(encrypt_func) + ctx.guest.put(HOST_FIRST_SIGN, host_first_sign_match_id) + + dh_func = functools.partial(_diffie_hellman, curve=curve) + guest_first_sign_match_id = ctx.guest.get(GUEST_FIRST_SIGN) + guest_second_sign_match_id = guest_first_sign_match_id.mapValues(dh_func) + + guest_second_sign_match_id = _flat_block_with_possible_duplicate_keys(guest_second_sign_match_id, + duplicate_allow=True) + ctx.guest.put(GUEST_SECOND_SIGN, guest_second_sign_match_id) + + """ + host_indexer: key=(block_id, offset), value=(sample_id, (bid, offset)) + """ + host_indexer = ctx.guest.get(HOST_INDEXER) + + host_df = df.loc_with_sample_id_replacement(host_indexer) + + ctx.metrics.log_metrics({"intersect_count": host_df.shape[0]}, name="intersect_id_count", type="custom") + + """ + the following just for debug, need to be delete + """ + # ids = [v[0] for k, v in sorted(host_df.block_table.collect())] + # logger.debug(f"intersect ids is: {ids}") + + return host_df diff --git a/python/fate/arch/protocol/secure_aggregation/__init__.py b/python/fate/arch/protocol/secure_aggregation/__init__.py new file mode 100644 index 0000000000..eadce6cf7f --- /dev/null +++ b/python/fate/arch/protocol/secure_aggregation/__init__.py @@ -0,0 +1,3 @@ +from ._secure_aggregation import SecureAggregatorClient, SecureAggregatorServer + +__all__ = ["SecureAggregatorClient", "SecureAggregatorServer"] diff --git a/python/fate/arch/protocol/secure_aggregation/_secure_aggregation.py b/python/fate/arch/protocol/secure_aggregation/_secure_aggregation.py new file mode 100644 index 0000000000..0bf77444bf --- /dev/null +++ b/python/fate/arch/protocol/secure_aggregation/_secure_aggregation.py @@ -0,0 +1,117 @@ +import typing + +import numpy +from fate.arch import Context +from fate.arch.protocol.diffie_hellman import DiffieHellman +from fate_utils.secure_aggregation_helper import MixAggregate, RandomMix + + +class _SecureAggregatorMeta: + _send_name = "mixed_client_values" + _recv_name = "aggregated_values" + prefix: str + + def _get_name(self, name): + if self.prefix: + return f"{self.prefix}_{name}" + return name + + +class SecureAggregatorClient(_SecureAggregatorMeta): + def __init__(self, prefix: typing.Optional[str] = None, is_mock: bool = False): + """ + secure aggregation client + Args: + prefix: unique prefix for this aggregator + is_mock: mock the aggregator, do not perform secure aggregation, for test only + """ + self.prefix = prefix + self._mixer = None + self._is_mock = is_mock + + def _get_mixer(self): + if self._mixer is None: + raise RuntimeError("mixer not initialized, run dh_exchange first") + return self._mixer + + def dh_exchange(self, ctx: Context, ranks: typing.List[int]): + if self._is_mock: + return + local_rank = ctx.local.rank + dh = {} + seeds = {} + for rank in ranks: + if rank == local_rank: + continue + dh[rank] = DiffieHellman() + ctx.parties[rank].put(self._get_name(f"dh_pubkey"), dh[rank].get_public_key()) + for rank in ranks: + if rank == local_rank: + continue + public_key = ctx.parties[rank].get(self._get_name(f"dh_pubkey")) + seeds[rank] = dh[rank].diffie_hellman(public_key) + self._mixer = RandomMix(seeds, local_rank) + + def secure_aggregate(self, ctx: Context, array: typing.List[numpy.ndarray], weight: typing.Optional[int] = None): + if self._is_mock: + ctx.arbiter.put(self._get_name(self._send_name), (array, weight)) + return ctx.arbiter.get(self._get_name(self._recv_name)) + else: + mixed = self._get_mixer().mix(array, weight) + ctx.arbiter.put(self._get_name(self._send_name), (mixed, weight)) + return ctx.arbiter.get(self._get_name(self._recv_name)) + + +class SecureAggregatorServer(_SecureAggregatorMeta): + def __init__(self, ranks, prefix: typing.Optional[str] = None, is_mock: bool = False): + """ + secure aggregation server + Args: + ranks: all ranks + prefix: unique prefix for this aggregator + is_mock: mock the aggregator, do not perform secure aggregation, for test only + """ + self.prefix = prefix + self.ranks = ranks + self._is_mock = is_mock + + def secure_aggregate(self, ctx: Context, ranks: typing.Optional[int] = None): + """ + perform secure aggregate once + Args: + ctx: Context to use + ranks: ranks to aggregate, if None, use all ranks + """ + if ranks is None: + ranks = self.ranks + aggregated_weight = 0.0 + has_weight = False + + if self._is_mock: + aggregated = [] + for rank in ranks: + arrays, weight = ctx.parties[rank].get(self._get_name(self._send_name)) + for i in range(len(arrays)): + if len(aggregated) <= i: + aggregated.append(arrays[i]) + else: + aggregated[i] += arrays[i] + if weight is not None: + has_weight = True + aggregated_weight += weight + if has_weight: + aggregated = [x / aggregated_weight for x in aggregated] + else: + mix_aggregator = MixAggregate() + for rank in ranks: + mix_arrays, weight = ctx.parties[rank].get(self._get_name(self._send_name)) + mix_aggregator.aggregate(mix_arrays) + if weight is not None: + has_weight = True + aggregated_weight += weight + if not has_weight: + aggregated_weight = None + aggregated = mix_aggregator.finalize(aggregated_weight) + + for rank in ranks: + ctx.parties[rank].put(self._get_name(self._recv_name), aggregated) diff --git a/python/fate/arch/tensor/__init__.py b/python/fate/arch/tensor/__init__.py index 843b821e7f..145dab97ea 100644 --- a/python/fate/arch/tensor/__init__.py +++ b/python/fate/arch/tensor/__init__.py @@ -12,8 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ._tensor import distributed_tensor, randn, tensor -from .ops import * -from .types import * +from ._custom_ops import * +from .distributed import DTensor -__all__ = ["tensor", "randn", "distributed_tensor"] +__all__ = [ + "DTensor", + "encrypt_f", + "decrypt_f", +] diff --git a/python/fate/arch/tensor/_custom_ops.py b/python/fate/arch/tensor/_custom_ops.py new file mode 100644 index 0000000000..123d6d9a30 --- /dev/null +++ b/python/fate/arch/tensor/_custom_ops.py @@ -0,0 +1,114 @@ +import torch + + +def encrypt_f(tensor, encryptor): + if isinstance(tensor, torch.Tensor): + return encryptor.encrypt_tensor(tensor.detach()) + else: + # torch tensor-like + if hasattr(tensor, "__torch_function__"): + return tensor.__torch_function__(encrypt_f, (type(tensor),), (tensor, encryptor), None) + raise NotImplementedError("") + + +def encrypt_encoded_f(tensor, encryptor): + if isinstance(tensor, torch.Tensor): + return encryptor.encrypt_encoded(tensor.detach()) + else: + # torch tensor-like + if hasattr(tensor, "__torch_function__"): + return tensor.__torch_function__(encrypt_encoded_f, (type(tensor),), (tensor, encryptor), None) + raise NotImplementedError("") + + +def decrypt_encoded_f(tensor, decryptor): + # torch tensor-like + if hasattr(tensor, "__torch_function__"): + return tensor.__torch_function__(decrypt_encoded_f, (type(tensor),), (tensor, decryptor), None) + raise NotImplementedError("") + + +def encode_f(tensor, coder): + if isinstance(tensor, torch.Tensor): + return coder.encode(tensor.detach()) + else: + # torch tensor-like + if hasattr(tensor, "__torch_function__"): + return tensor.__torch_function__(encode_f, (type(tensor),), (tensor, coder), None) + raise NotImplementedError("") + + +def decrypt_f(tensor, decryptor): + # torch tensor-like + if hasattr(tensor, "__torch_function__"): + return tensor.__torch_function__(decrypt_f, (type(tensor),), (tensor, decryptor), None) + raise NotImplementedError(f"{type(tensor)}") + + +def decode_f(tensor, coder): + if hasattr(tensor, "__torch_function__"): + return tensor.__torch_function__(decode_f, (type(tensor),), (tensor, coder), None) + raise NotImplementedError(f"{type(tensor)}") + + +def rmatmul_f(input, other): + if isinstance(input, torch.Tensor) and isinstance(other, torch.Tensor): + return torch.matmul(other, input) + else: + # torch tensor-like + if isinstance(input, torch.Tensor): + return torch.matmul(other, input) + + else: + if hasattr(input, "__torch_function__"): + return input.__torch_function__(rmatmul_f, (type(input), type(other)), (input, other), None) + raise NotImplementedError("") + + +def to_local_f(input): + if isinstance(input, torch.Tensor): + return input + + else: + # torch tensor-like + if hasattr(input, "__torch_function__"): + return input.__torch_function__(to_local_f, (type(input),), (input,), None) + raise NotImplementedError("") + + +def slice_f(input, arg): + if isinstance(input, torch.Tensor): + return input[arg] + + else: + # torch tensor-like + if hasattr(input, "__torch_function__"): + out = input.__torch_function__(slice_f, (type(input),), (input, arg), None) + if out == NotImplemented: + raise NotImplementedError(f"slice_f: {input}") + return out + + raise NotImplementedError(f"slice_f: {input}") + + +def encode_as_int_f(input, precision: int): + if isinstance(input, torch.Tensor): + return (input * 2 ** precision).astype(torch.int64) + else: + # torch tensor-like + if hasattr(input, "__torch_function__"): + return input.__torch_function__(encode_as_int_f, (type(input),), (input, precision), None) + raise NotImplementedError("") + + +# hook custom ops to torch +torch.encrypt_f = encrypt_f +torch.encrypt_encoded_f = encrypt_encoded_f +torch.decrypt_encoded_f = decrypt_encoded_f +torch.decrypt_f = decrypt_f +torch.encode_f = encode_f +torch.decode_f = decode_f +torch.rmatmul_f = rmatmul_f +torch.to_local_f = to_local_f +torch.slice_f = slice_f +torch.encode_as_int_f = encode_as_int_f diff --git a/python/fate/arch/tensor/_generate.py b/python/fate/arch/tensor/_generate.py deleted file mode 100644 index 140ac85043..0000000000 --- a/python/fate/arch/tensor/_generate.py +++ /dev/null @@ -1,77 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -if __name__ == "__main__": - import pathlib - - _unary_path = pathlib.Path(__file__).parent.joinpath("_unary_ops.py") - _binary_path = pathlib.Path(__file__).parent.joinpath("_binary_ops.py") - _unary_funcs = [ - ("abs", "arc cosine"), - ("asin", "arc sin"), - ("atan", ""), - ("atan2", ""), - ("ceil", "ceiling"), - ("cos", ""), - ("cosh", ""), - ("erf", "Gaussian error functiom"), - ("erfinv", "Gaussian error functiom"), - ("exp", ""), - ("expm1", "exponential of each element minus 1"), - ("floor", ""), - ("frac", "fraction part 3.4 -> 0.4"), - ("log", "natural log"), - ("log1p", "y = log(1 + x)"), - ("neg", ""), - ("reciprocal", "1/x"), - ("sigmoid", "sigmode(x)"), - ("sign", ""), - ("sin", ""), - ("sinh", ""), - ("sqrt", ""), - ("square", ""), - ("tan", ""), - ("tanh", ""), - ("trunc", "truncated integer"), - ("rsqrt", "the reciprocal of the square-root"), - ("round", ""), - ] - _binary_funcs = [ - ("add", ""), - ("sub", ""), - ("mul", ""), - ("div", ""), - ("pow", ""), - ("remainder", ""), - ("fmod", "element wise remainder of division"), - ] - with open(_unary_path, "w") as fw: - fw.write("from ._ops import auto_unary_op\n") - for name, comment in _unary_funcs: - fw.write("\n") - fw.write("\n") - fw.write("@auto_unary_op\n") - fw.write(f"def {name}(x, *args, **kwargs):\n") - fw.write(f' "{comment}"\n') - fw.write(f" ...\n") - - with open(_binary_path, "w") as fw: - fw.write("from ._ops import auto_binary_op\n") - for name, comment in _binary_funcs: - fw.write("\n") - fw.write("\n") - fw.write("@auto_binary_op\n") - fw.write(f"def {name}(x, y, *args, **kwargs):\n") - fw.write(f' "{comment}"\n') - fw.write(f" ...\n") diff --git a/python/fate/arch/tensor/_phe.py b/python/fate/arch/tensor/_phe.py deleted file mode 100644 index cfb913315e..0000000000 --- a/python/fate/arch/tensor/_phe.py +++ /dev/null @@ -1,93 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from enum import Enum -from typing import Tuple - -from fate.interface import PHECipher as PHECipherInterface - -from ..unify import device - - -class PHEKind(Enum): - AUTO = "auto" - PAILLIER = "Paillier" - RUST_PAILLIER = "rust_paillier" - INTEL_PAILLIER = "intel_paillier" - - -class PHECipher(PHECipherInterface): - def __init__(self, device: device) -> None: - self.device = device - - def keygen(self, kind: PHEKind = PHEKind.AUTO, options={}) -> Tuple["PHEEncryptor", "PHEDecryptor"]: - - if kind == PHEKind.AUTO or PHEKind.PAILLIER: - if self.device == device.CPU: - from .storage.local.device.cpu.multithread_cpu_paillier_block import ( - BlockPaillierCipher, - ) - - key_length = options.get("key_length", 1024) - encryptor, decryptor = BlockPaillierCipher().keygen(key_length=key_length) - return PHEEncryptor(encryptor), PHEDecryptor(decryptor) - - raise NotImplementedError(f"keygen for kind<{kind}>-device<{self.device}> is not implemented") - - -class PHEEncryptor: - def __init__(self, storage_encryptor) -> None: - self._encryptor = storage_encryptor - - def encrypt(self, tensor): - from ..tensor import Tensor - from .storage.local.device.cpu.paillier import _RustPaillierStorage - from .types import DStorage, dtype - - if tensor.device == device.CPU: - storage = tensor.storage - if tensor.is_distributed: - encrypted_storage = DStorage.elemwise_unary_op( - storage, - lambda s: _RustPaillierStorage(dtype.paillier, storage.shape, self._encryptor.encrypt(s.data)), - dtype.paillier, - ) - else: - encrypted_storage = _RustPaillierStorage( - dtype.paillier, storage.shape, self._encryptor.encrypt(storage.data) - ) - else: - raise NotImplementedError() - return Tensor(encrypted_storage) - - -class PHEDecryptor: - def __init__(self, storage_decryptor) -> None: - self._decryptor = storage_decryptor - - def decrypt(self, tensor): - from ..tensor import Tensor - from .storage.local.device.cpu.plain import _TorchStorage - from .types import DStorage, dtype - - storage = tensor.storage - if isinstance(storage, DStorage): - encrypted_storage = DStorage.elemwise_unary_op( - storage, - lambda s: _TorchStorage(dtype.paillier, storage.shape, self._decryptor.decrypt(s.data)), - dtype.paillier, - ) - else: - encrypted_storage = _TorchStorage(dtype.float32, storage.shape, self._decryptor.decrypt(storage.data)) - return Tensor(encrypted_storage) diff --git a/python/fate/arch/tensor/_tensor.py b/python/fate/arch/tensor/_tensor.py deleted file mode 100644 index 55cfdcf562..0000000000 --- a/python/fate/arch/tensor/_tensor.py +++ /dev/null @@ -1,187 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import List, Union - -import torch -from fate.interface import Context - -from .types import DStorage, LStorage, Shape, dtype - - -def tensor(t: torch.Tensor): - from .storage.local.device.cpu.plain import _TorchStorage - - storage = _TorchStorage(dtype.from_torch_dtype(t.dtype), Shape(t.shape), t) - return Tensor(storage) - - -def randn(shape, dtype): - torch_tensor = torch.randn(shape, dtype=dtype.to_torch_dtype()) - return tensor(torch_tensor) - - -def distributed_tensor(ctx: Context, tensors: List[torch.Tensor], d_axis=0, partitions=3): - from .storage.local.device.cpu.plain import _TorchStorage - - storages = [_TorchStorage(dtype.from_torch_dtype(t.dtype), Shape(t.shape), t) for t in tensors] - storage = DStorage.from_storages(ctx, storages, d_axis, partitions) - return Tensor(storage) - - -def _inject_op_sinature1(func): - method = func.__name__ - - def _wrap(input): - from .ops._ops import dispatch_signature1 - - return dispatch_signature1(method, input, [], {}) - - return _wrap - - -def _inject_op_sinature2(func): - method = func.__name__ - - def _wrap(input, other): - from .ops._ops import dispatch_signature2 - - return dispatch_signature2(method, input, other, [], {}) - - return _wrap - - -class Tensor: - def __init__(self, storage: Union[DStorage, LStorage]) -> None: - self._storage = storage - - @property - def is_distributed(self): - from .types import DStorage - - return isinstance(self._storage, DStorage) - - def to(self, party, name: str): - return party.put(name, self) - - @property - def dtype(self): - return self._storage.dtype - - @property - def storage(self): - return self._storage - - @property - def device(self): - return self._storage.device - - @property - def T(self): - return Tensor(self._storage.transpose()) - - @property - def shape(self): - return self._storage.shape - - def to_local(self): - if isinstance(self._storage, DStorage): - return Tensor(self._storage.to_local()) - return self - - def tolist(self): - if isinstance(self._storage, DStorage): - return self._storage.to_local().tolist() - return self._storage.tolist() - - def __eq__(self, __o: object) -> bool: - return isinstance(__o, Tensor) and self._storage == __o._storage - - def __str__(self) -> str: - return f"Tensor(storage={self.storage})" - - def __repr__(self) -> str: - return self.__str__() - - def __add__(self, other): - return self.add(other) - - def __radd__(self, other): - return self.add(other) - - def __sub__(self, other): - return self.sub(other) - - def __rsub__(self, other): - return self.rsub(other) - - def __mul__(self, other): - return self.mul(other) - - def __rmul__(self, other): - return self.mul(other) - - def __div__(self, other): - return self.div(other) - - def __truediv__(self, other): - return self.truediv(other) - - def __matmul__(self, other): - from .ops import matmul - - return matmul(self, other) - - def __getitem__(self, key): - from .ops import slice - - return slice(self, key) - - """and others""" - - @_inject_op_sinature2 - def add(self, other) -> "Tensor": - ... - - @_inject_op_sinature2 - def sub(self, other) -> "Tensor": - ... - - @_inject_op_sinature2 - def rsub(self, other) -> "Tensor": - ... - - @_inject_op_sinature2 - def mul(self, other) -> "Tensor": - ... - - @_inject_op_sinature2 - def div(self, other) -> "Tensor": - ... - - @_inject_op_sinature2 - def truediv(self, other) -> "Tensor": - ... - - def mean(self, *args, **kwargs) -> "Tensor": - return Tensor(self.storage.mean(*args, **kwargs)) - - def std(self, *args, **kwargs) -> "Tensor": - return Tensor(self.storage.std(*args, **kwargs)) - - def max(self, *args, **kwargs) -> "Tensor": - return Tensor(self.storage.max(*args, **kwargs)) - - def min(self, *args, **kwargs) -> "Tensor": - return Tensor(self.storage.min(*args, **kwargs)) diff --git a/python/fate/arch/tensor/distributed/__init__.py b/python/fate/arch/tensor/distributed/__init__.py new file mode 100644 index 0000000000..ff2d442642 --- /dev/null +++ b/python/fate/arch/tensor/distributed/__init__.py @@ -0,0 +1,11 @@ +from ._op_matmul import * +from ._op_slice import * +from ._op_transpose import * +from ._ops_agg import * +from ._ops_binary import * +from ._ops_cipher import * +from ._ops_others import * +from ._ops_unary import * +from ._tensor import DTensor + +__all__ = ["DTensor"] diff --git a/python/fate/arch/tensor/distributed/_op_matmul.py b/python/fate/arch/tensor/distributed/_op_matmul.py new file mode 100644 index 0000000000..027b99d746 --- /dev/null +++ b/python/fate/arch/tensor/distributed/_op_matmul.py @@ -0,0 +1,211 @@ +import logging + +import torch +from fate.arch.tensor import _custom_ops + +from ._tensor import DTensor, implements + +logger = logging.getLogger(__name__) + + +@implements(_custom_ops.rmatmul_f) +def rmatmul_f(a: DTensor, b: DTensor): + assert isinstance(a, DTensor) or isinstance(b, DTensor), "atleast one dtensor" + if not isinstance(a, DTensor): + return matmul(b, a) + + if len(a.shape) == 1 and len(b.shape) == 1: + if isinstance(b, DTensor): + return a.shardings.join_reduce_shard(b.shardings, _custom_ops.rmatmul_f, torch.add) + else: + assert a.shape[0] == b.shape[0], f"shapes mismatch: {a.shape} and {b.shape}" + logger.warning("matmul shape 1 distributed tensor with local shape 1 tensor maybe slow") + return a.shardings.map_reduce_shard_with_stride( + lambda stride, size, s: _custom_ops.rmatmul_f(s, b[stride : stride + size]), torch.add + ) + + if len(a.shape) == 1 and len(b.shape) > 1: + if isinstance(b, DTensor): + assert b.shardings.shapes.axis == len(b.shardings.shape) - 1, "distributed axis mismatch" + return a.shardings.join_reduce_shard(b.shardings, _custom_ops.rmatmul_f, torch.add) + else: + assert a.shape[0] == b.shape[-2:][-1], f"shapes mismatch: {a.shape} and {b.shape}" + logger.warning("rmatmul shape 1 distributed tensor with local tensor maybe slow") + axis = len(b.shape) - 1 + + def _mapper(stride, size, s): + slices = tuple( + slice(stride, stride + size) if i == axis else slice(None, None, None) for i in range(len(b.shape)) + ) + return _custom_ops.rmatmul_f(s, b[slices]) + + return a.shardings.map_reduce_shard_with_stride(_mapper, torch.add) + + if len(a.shape) > 1 and len(b.shape) == 1: + if isinstance(b, DTensor): + assert a.shardings.shapes.axis == len(a.shardings.shape) - 2, "distributed axis mismatch" + return a.shardings.join_reduce_shard(b.shardings, _custom_ops.rmatmul_f, torch.add) + else: + assert a.shape[-2] == b.shape[0], f"shapes mismatch: {a.shape} and {b.shape}" + logger.warning("matmul shape 1 distributed tensor with local tensor maybe slow") + + def _mapper(stride, size, s): + slices = slice(stride, stride + size, 1) + return _custom_ops.rmatmul_f(s, b[slices]) + + return a.shardings.map_reduce_shard_with_stride(_mapper, torch.add) + + else: + if isinstance(b, DTensor): + na_axis = a.shardings.shapes.axis - len(a.shape) + nb_axis = b.shardings.shapes.axis - len(b.shape) + # distributed axis in broadcast part + if na_axis < -2 and nb_axis < -2 and na_axis == nb_axis: + shapes = [ + torch.Size([*torch.broadcast_shapes(sa[:-2], sb[:-2]), sb[-2], sa[-1]]) + for sa, sb in zip(a.shardings.shapes.shapes, b.shardings.shapes.shapes) + ] + axis = len(shapes[0]) + na_axis + return DTensor( + a.shardings.join_shard(b.shardings, func=_custom_ops.rmatmul_f, out_shapes=shapes, out_axis=axis) + ) + + # distributed axis in matmul part + elif na_axis == -2 and nb_axis == -1: + return a.shardings.join_reduce_shard( + b.shardings, mapper_func=_custom_ops.rmatmul_f, reduce_func=torch.add + ) + else: + raise RuntimeError(f"invalid shape {a.shape} and {b.shape}") + + else: + assert a.shape[-1] == b.shape[-2], f"shapes mismatch: {a.shape} and {b.shape}" + na_axis = a.shardings.shapes.axis - len(a.shape) + if na_axis != -2: + shapes = [ + torch.Size([*torch.broadcast_shapes(sa[:-2], b.shape[:-2]), b.shape[-2], sa[-1]]) + for sa in a.shardings.shapes.shapes + ] + axis = len(shapes[0]) + na_axis + return DTensor(a.shardings.map_shard(lambda x: _custom_ops.rmatmul_f(x, b), shapes=shapes, axis=axis)) + else: + logger.warning("matmul shape 1 distributed tensor with local tensor maybe slow") + axis = len(b.shape) - 1 + + def _mapper(stride, size, s): + slices = tuple( + slice(stride, stride + size, 1) if i == axis else slice(None, None, None) + for i in range(len(b.shape)) + ) + return _custom_ops.rmatmul_f(s, b[slices]) + + return a.shardings.map_reduce_shard_with_stride(_mapper, torch.add) + + +def promote_torch_matmul(a: torch.Tensor, b: torch.Tensor): + if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): + target_dtype = torch.promote_types(a.dtype, b.dtype) + a = a.type(target_dtype) + b = b.type(target_dtype) + return torch.matmul(_maybe_detach(a), _maybe_detach(b)) + + +def _maybe_detach(a): + if isinstance(a, torch.Tensor): + return a.detach() + return a + + +@implements(torch.matmul) +def matmul(a: DTensor, b: DTensor): + assert isinstance(a, DTensor) or isinstance(b, DTensor), "atleast one dtensor" + if not isinstance(a, DTensor): + return rmatmul_f(b, a) + + if len(a.shape) == 1 and len(b.shape) == 1: + if isinstance(b, DTensor): + return a.shardings.join_reduce_shard(b.shardings, promote_torch_matmul, torch.add) + else: + assert a.shape[0] == b.shape[0], f"shapes mismatch: {a.shape} and {b.shape}" + logger.warning("matmul shape 1 distributed tensor with local shape 1 tensor maybe slow") + return a.shardings.map_reduce_shard_with_stride( + lambda stride, size, s: promote_torch_matmul(s, b[stride : stride + size]), torch.add + ) + + elif len(a.shape) == 1 and len(b.shape) > 1: + if isinstance(b, DTensor): + assert b.shardings.shapes.axis == len(b.shardings.shape) - 2, "distributed axis mismatch" + return a.shardings.join_reduce_shard(b.shardings, promote_torch_matmul, torch.add) + else: + assert a.shape[0] == b.shape[-2:][0], f"shapes mismatch: {a.shape} and {b.shape}" + logger.warning("matmul shape 1 distributed tensor with local tensor maybe slow") + axis = len(b.shape) - 2 + + def _mapper(stride, size, s): + slices = tuple( + slice(stride, stride + size, 1) if i == axis else slice(None, None, None) + for i in range(len(b.shape)) + ) + return promote_torch_matmul(s, b[slices]) + + return a.shardings.map_reduce_shard_with_stride(_mapper, torch.add) + + elif len(a.shape) > 1 and len(b.shape) == 1: + if isinstance(b, DTensor): + assert a.shardings.shapes.axis == len(a.shardings.shape) - 1, "distributed axis mismatch" + return a.shardings.join_reduce_shard(b.shardings, promote_torch_matmul, torch.add) + else: + assert a.shape[-1] == b.shape[0], f"shapes mismatch: {a.shape} and {b.shape}" + logger.warning("matmul shape 1 distributed tensor with local tensor maybe slow") + + def _mapper(stride, size, s): + slices = slice(stride, stride + size, 1) + return promote_torch_matmul(s, b[slices]) + + return a.shardings.map_reduce_shard_with_stride(_mapper, torch.add) + + else: + if isinstance(b, DTensor): + na_axis = a.shardings.shapes.axis - len(a.shape) + nb_axis = b.shardings.shapes.axis - len(b.shape) + # distributed axis in broadcast part + if na_axis < -2 and nb_axis < -2 and na_axis == nb_axis: + shapes = [ + torch.Size([*torch.broadcast_shapes(sa[:-2], sb[:-2]), sa[-2], sb[-1]]) + for sa, sb in zip(a.shardings.shapes.shapes, b.shardings.shapes.shapes) + ] + axis = len(shapes[0]) + na_axis + return DTensor( + a.shardings.join_shard(b.shardings, func=promote_torch_matmul, out_shapes=shapes, out_axis=axis) + ) + + # distributed axis in matmul part + elif na_axis == -1 and nb_axis == -2: + return a.shardings.join_reduce_shard( + b.shardings, mapper_func=promote_torch_matmul, reduce_func=torch.add + ) + else: + raise RuntimeError(f"invalid shape {a.shardings.shapes} and {b.shardings.shapes}") + + else: + assert a.shape[-1] == b.shape[-2], f"shapes mismatch: {a.shape} and {b.shape}" + na_axis = a.shardings.shapes.axis - len(a.shape) + if na_axis != -1: + shapes = [ + torch.Size([*torch.broadcast_shapes(sa[:-2], b.shape[:-2]), sa[-2], b.shape[-1]]) + for sa in a.shardings.shapes.shapes + ] + axis = len(shapes[0]) + na_axis + return DTensor(a.shardings.map_shard(lambda x: promote_torch_matmul(x, b), shapes=shapes, axis=axis)) + else: + logger.warning("matmul shape 1 distributed tensor with local tensor maybe slow") + axis = len(b.shape) - 2 + + def _mapper(stride, size, s): + slices = tuple( + slice(stride, stride + size, 1) if i == axis else slice(None, None, None) + for i in range(len(b.shape)) + ) + return promote_torch_matmul(s, b[slices]) + + return a.shardings.map_reduce_shard_with_stride(_mapper, torch.add) diff --git a/python/fate/arch/tensor/distributed/_op_slice.py b/python/fate/arch/tensor/distributed/_op_slice.py new file mode 100644 index 0000000000..c98004c9a9 --- /dev/null +++ b/python/fate/arch/tensor/distributed/_op_slice.py @@ -0,0 +1,63 @@ +import torch +from fate.arch.tensor import _custom_ops + +from ._tensor import DTensor, implements + + +@implements(_custom_ops.slice_f) +def slice_f(input: DTensor, key): + # 1: int slice key means slice 0 dimention + if isinstance(key, int): + if 0 <= key < input.shape[0]: + # 1.1: slice output in one of shardings + if input.shardings.shapes.axis == 0: + return input.shardings.map_reduce_shard_with_stride( + stride_mapper_func=lambda stride, _, s: [s[key - stride]] + if stride <= key < stride + s.shape[0] + else [], + reducer_func=lambda x, y: [*x, *y], + )[0] + # 1.2: slice output is distributed + else: + return DTensor( + input.shardings.map_shard(lambda s: s[key], shapes=input.shardings.shapes.squeeze((0,))) + ) + + else: + raise IndexError(f"index {key} is out of bounds for dimension 0 with size {input.shape[0]}") + + # 2: list slice key + if isinstance(key, list): + for k in key: + if k < 0 or k >= input.shape[0]: + raise IndexError(f"index {k} is out of bounds for dimension 0 with size {input.shape[0]}") + + if input.shardings.shapes.axis == 0: + outputs = input.shardings.map_reduce_shard_with_stride( + stride_mapper_func=lambda stride, _, s: [ + (i, s[k - stride]) for i, k in enumerate(key) if stride <= k < stride + s.shape[0] + ], + reducer_func=lambda x, y: [*x, *y], + ) + return torch.cat([v for _, v in sorted(outputs)]) + else: + return DTensor(input.shardings.map_shard(lambda s: s[key], shapes=input.shardings.shapes.squeeze((0,)))) + + # 3: slice key + if isinstance(key, slice): + start, stop, step = key.indices(input.shape[0]) + indices = list(range(start, stop, step)) + return slice_f(input, indices) + + # 4: tuple key for multi-dimensional slicing + if isinstance(key, tuple): + raise NotImplementedError("tuple key {key}") + # result = input + # for dim, k in enumerate(key): + # if isinstance(k, (int, list, slice)): + # ... + # else: + # raise NotImplementedError(f"slice_f on {key}") + # return result + + raise NotImplementedError(f"slice_f on {key}") diff --git a/python/fate/arch/tensor/distributed/_op_transpose.py b/python/fate/arch/tensor/distributed/_op_transpose.py new file mode 100644 index 0000000000..593603e39f --- /dev/null +++ b/python/fate/arch/tensor/distributed/_op_transpose.py @@ -0,0 +1,50 @@ +import torch + +from ._tensor import DTensor, Shardings, _ShardingShapes, implements + + +@implements(torch.transpose) +def transpose(input: DTensor, dim0, dim1): + shapes = transpose_shape(input.shardings.shapes, dim0, dim1) + return DTensor( + input.shardings.map_shard(lambda x: x.transpose(dim0, dim1).detach(), shapes=shapes.shapes, axis=shapes.axis) + ) + + # TODO: lazy transpose + + # if dim0 and dim1 are not in distributed axis: + # 1. just transpose local tensor in each partition + # 2. shapes should be modified. + # if input.shardings.shapes.axis not in (dim0, dim1): + # return DTensor( + # input.shardings.map_shard(lambda x: x.transpose(dim0, dim1), shapes=shapes.shapes, axis=shapes.axis) + # ) + # # if dim0 and dim1 are in distributed axis: + # # 1. local tensor in each partition should not be transposed. + # # 2. only shapes and distributed axis should be modified. + # else: + # return DTensor( + # Shardings( + # data=input.shardings._data, + # shapes=shapes.shapes, + # axis=shapes.axis, + # dtype=input.shardings._dtype, + # device=input.shardings._device, + # ) + # ) + + +def transpose_shape(shape: _ShardingShapes, dim0, dim1): + # transpose shapes + shapes = [] + for s in shape.shapes: + s = list(s) + s[dim0], s[dim1] = s[dim1], s[dim0] + shapes.append(torch.Size(s)) + # transpose axis + axis = shape.axis + if dim0 == axis: + axis = dim1 + elif dim1 == axis: + axis = dim0 + return _ShardingShapes(shapes=shapes, axis=axis) diff --git a/python/fate/arch/tensor/distributed/_ops_agg.py b/python/fate/arch/tensor/distributed/_ops_agg.py new file mode 100644 index 0000000000..32d09093e4 --- /dev/null +++ b/python/fate/arch/tensor/distributed/_ops_agg.py @@ -0,0 +1,313 @@ +from typing import Tuple + +import torch + +from ._tensor import DTensor, implements + + +@implements(torch.sum) +def sum(input: DTensor, *args, **kwargs): + dim = None + if len(args) > 0: + dim = args[0] + if "dim" in kwargs: + dim = kwargs["dim"] + if isinstance(dim, int): + dim = (dim,) + dtype = kwargs.get("dtype", None) + if dim is None: + if "keepdim" in kwargs: + raise TypeError( + f"sum() received an invalid combination of arguments - got (Tensor, keepdim=bool), but expected one of\n" + "* (Tensor input)\n" + "* (Tensor input, tuple of ints dim, bool keepdim)\n" + "* (Tensor input, tuple of names dim, bool keepdim)" + ) + out = input.shardings.map_reduce_shard(lambda x: torch.sum(x), torch.add) + if dtype: + out = out.type(dtype) + return out + + keepdim = kwargs.get("keepdim", False) + if input.shardings.shapes.axis not in dim: + return DTensor( + input.shardings.map_shard( + lambda x: torch.sum(x, dim=dim, keepdim=keepdim, dtype=dtype), + shapes=input.shardings.shapes.squeeze(dim, keepdim), + ) + ) + + out = input.shardings.map_reduce_shard(lambda x: torch.sum(x, dim=dim, keepdim=keepdim), torch.add) + if dtype: + out = out.type(dtype) + return out + + +@implements(torch.mean) +def mean(input: DTensor, *args, **kwargs): + dim = None + if len(args) > 0: + dim = args[0] + if "dim" in kwargs: + dim = kwargs["dim"] + if isinstance(dim, int): + dim = (dim,) + dtype = kwargs.get("dtype", None) + if dtype is None: + if not input.dtype.is_floating_point: + raise RuntimeError( + f"mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: {input.dtype}" + ) + dtype = input.dtype + if dim is None: + if "keepdim" in kwargs: + raise TypeError( + f"mean() received an invalid combination of arguments - got (Tensor, keepdim=bool), but expected one of\n" + "* (Tensor input)\n" + "* (Tensor input, tuple of ints dim, bool keepdim)\n" + "* (Tensor input, tuple of names dim, bool keepdim)" + ) + return torch.div( + input.shardings.map_reduce_shard(lambda x: torch.sum(x, dtype=torch.float64), torch.add), + input.shape.numel(), + ).type(dtype) + + keepdim = kwargs.get("keepdim", False) + count = 1 + for d in dim: + count *= input.shape[d] + if input.shardings.shapes.axis not in dim: + return DTensor( + input.shardings.map_shard( + lambda x: torch.div(torch.sum(x, dim=dim, keepdim=keepdim, dtype=torch.float64), count).type(dtype), + shapes=input.shardings.shapes.squeeze(dim, keepdim), + ) + ) + + return torch.div( + input.shardings.map_reduce_shard( + lambda x: torch.sum(x, dim=dim, keepdim=keepdim, dtype=torch.float64), torch.add + ), + count, + ).type(dtype) + + +@implements(torch.std) +def std(input: DTensor, *args, **kwargs): + dim = None + if len(args) > 0: + dim = args[0] + if "dim" in kwargs: + dim = kwargs["dim"] + if isinstance(dim, int): + dim = (dim,) + dtype = kwargs.get("dtype", None) + if dtype is None: + if not input.dtype.is_floating_point: + raise RuntimeError( + f"std(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: {input.dtype}" + ) + dtype = input.dtype + unbiased = kwargs.get("unbiased", True) + keepdim = kwargs.get("keepdim", False) + if dim is None: + if "keepdim" in kwargs: + raise TypeError( + f"std() received an invalid combination of arguments - got (Tensor, keepdim=bool), but expected one of\n" + "* (Tensor input)\n" + "* (Tensor input, tuple of ints dim, bool keepdim)\n" + "* (Tensor input, tuple of names dim, bool keepdim)" + ) + + if dim is None or input.shardings.shapes.axis in dim: + if dim is None: + n = input.shape.numel() + sq, s = input.shardings.map_reduce_shard( + mapper_func=lambda x: (torch.sum(torch.square(x)), torch.sum(x)), + reducer_func=lambda a, b: (torch.add(a[0], b[0]), torch.add(a[1], b[1])), + ) + else: + n = 1 + for d in dim: + n *= input.shape[d] + sq, s = input.shardings.map_reduce_shard( + mapper_func=lambda x: ( + torch.sum(torch.square(x), dim=dim, keepdim=keepdim), + torch.sum(x, dim=dim, keepdim=keepdim), + ), + reducer_func=lambda a, b: (torch.add(a[0], b[0]), torch.add(a[1], b[1])), + ) + output = torch.sub(torch.div(sq, n), torch.square(torch.div(s, n))) + if unbiased: + output = torch.mul(output, n / (n - 1)) + output = torch.sqrt(output) + return output + + return DTensor( + input.shardings.map_shard( + lambda x: torch.std(x, dim=dim, unbiased=unbiased), shapes=input.shardings.shapes.squeeze(dim, keepdim) + ) + ) + + +@implements(torch.var) +def var(input: DTensor, *args, **kwargs): + dim = None + if len(args) > 0: + dim = args[0] + if "dim" in kwargs: + dim = kwargs["dim"] + if isinstance(dim, int): + dim = (dim,) + dtype = kwargs.get("dtype", None) + if dtype is None: + if not input.dtype.is_floating_point: + raise RuntimeError( + f"var(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: {input.dtype}" + ) + dtype = input.dtype + unbiased = kwargs.get("unbiased", True) + keepdim = kwargs.get("keepdim", False) + if dim is None: + if "keepdim" in kwargs: + raise TypeError( + f"var() received an invalid combination of arguments - got (Tensor, keepdim=bool), but expected one of\n" + "* (Tensor input)\n" + "* (Tensor input, tuple of ints dim, bool keepdim)\n" + "* (Tensor input, tuple of names dim, bool keepdim)" + ) + + if dim is None or input.shardings.shapes.axis in dim: + if dim is None: + n = input.shape.numel() + sq, s = input.shardings.map_reduce_shard( + mapper_func=lambda x: (torch.sum(torch.square(x)), torch.sum(x)), + reducer_func=lambda a, b: (torch.add(a[0], b[0]), torch.add(a[1], b[1])), + ) + else: + n = 1 + for d in dim: + n *= input.shape[d] + sq, s = input.shardings.map_reduce_shard( + mapper_func=lambda x: ( + torch.sum(torch.square(x), dim=dim, keepdim=keepdim), + torch.sum(x, dim=dim, keepdim=keepdim), + ), + reducer_func=lambda a, b: (torch.add(a[0], b[0]), torch.add(a[1], b[1])), + ) + output = torch.sub(torch.div(sq, n), torch.square(torch.div(s, n))) + if unbiased: + output = torch.mul(output, n / (n - 1)) + return output + + return DTensor( + input.shardings.map_shard( + lambda x: torch.var(x, dim=dim, unbiased=unbiased), shapes=input.shardings.shapes.squeeze(dim, keepdim) + ) + ) + + +@implements(torch.min) +def min(input: DTensor, *args, **kwargs): + dim = None + if len(args) > 0: + dim = args[0] + if "dim" in kwargs: + dim = kwargs["dim"] + keepdim = kwargs.get("keepdim", False) + if dim is None: + if "keepdim" in kwargs: + raise TypeError( + f"min() received an invalid combination of arguments - got (Tensor, keepdim=bool), but expected one of\n" + "* (Tensor input)\n" + "* (Tensor input, tuple of ints dim, bool keepdim)\n" + "* (Tensor input, tuple of names dim, bool keepdim)" + ) + else: + return input.shardings.map_reduce_shard(lambda x: torch.min(x), lambda x, y: torch.minimum(x, y)) + + if input.shardings.shapes.axis == dim: + + def _mapper(stride: int, _: int, x: torch.Tensor): + r = torch.min(x, dim=dim, keepdim=keepdim) + return (stride, torch.return_types.min((r.values, r.indices + stride))) + + def _reducer(kv1: Tuple[int, torch.return_types.min], kv2: Tuple[int, torch.return_types.min]): + s1, r1 = kv1 + s2, r2 = kv2 + s1, r1, s2, r2 = (s1, r1, s2, r2) if s1 < s2 else (s2, r2, s1, r1) + values = torch.minimum(r1.values, r2.values) + indices = torch.add( + torch.mul(r1.indices, torch.le(r1.values, r2.values)), + torch.mul(r2.indices, torch.gt(r1.values, r2.values)), + ) + return (s1, torch.return_types.min((values, indices))) + + return input.shardings.map_reduce_shard_with_stride(_mapper, _reducer)[1] + + values = DTensor( + input.shardings.map_shard( + lambda x: torch.min(x, dim=dim, keepdim=keepdim).values, + shapes=input.shardings.shapes.squeeze((dim,), keepdim=keepdim), + ) + ) + indices = DTensor( + input.shardings.map_shard( + lambda x: torch.min(x, dim=dim, keepdim=keepdim).indices, + shapes=input.shardings.shapes.squeeze((dim,), keepdim=keepdim), + ) + ) + return torch.return_types.min((values, indices)) + + +@implements(torch.max) +def max(input: DTensor, *args, **kwargs): + dim = None + if len(args) > 0: + dim = args[0] + if "dim" in kwargs: + dim = kwargs["dim"] + keepdim = kwargs.get("keepdim", False) + if dim is None: + if "keepdim" in kwargs: + raise TypeError( + f"max() received an invalid combination of arguments - got (Tensor, keepdim=bool), but expected one of\n" + "* (Tensor input)\n" + "* (Tensor input, tuple of ints dim, bool keepdim)\n" + "* (Tensor input, tuple of names dim, bool keepdim)" + ) + else: + return input.shardings.map_reduce_shard(lambda x: torch.max(x), lambda x, y: torch.minimum(x, y)) + + if input.shardings.shapes.axis == dim: + + def _mapper(stride: int, _: int, x: torch.Tensor): + r = torch.max(x, dim=dim, keepdim=keepdim) + return (stride, torch.return_types.max((r.values, r.indices + stride))) + + def _reducer(kv1: Tuple[int, torch.return_types.max], kv2: Tuple[int, torch.return_types.max]): + s1, r1 = kv1 + s2, r2 = kv2 + s1, r1, s2, r2 = (s1, r1, s2, r2) if s1 < s2 else (s2, r2, s1, r1) + values = torch.minimum(r1.values, r2.values) + indices = torch.add( + torch.mul(r1.indices, torch.ge(r1.values, r2.values)), + torch.mul(r2.indices, torch.lt(r1.values, r2.values)), + ) + return (s1, torch.return_types.max((values, indices))) + + return input.shardings.map_reduce_shard_with_stride(_mapper, _reducer)[1] + + values = DTensor( + input.shardings.map_shard( + lambda x: torch.max(x, dim=dim, keepdim=keepdim).values, + shapes=input.shardings.shapes.squeeze((dim,), keepdim=keepdim), + ) + ) + indices = DTensor( + input.shardings.map_shard( + lambda x: torch.max(x, dim=dim, keepdim=keepdim).indices, + shapes=input.shardings.shapes.squeeze((dim,), keepdim=keepdim), + ) + ) + return torch.return_types.max((values, indices)) diff --git a/python/fate/arch/tensor/distributed/_ops_binary.py b/python/fate/arch/tensor/distributed/_ops_binary.py new file mode 100644 index 0000000000..fe7b143d53 --- /dev/null +++ b/python/fate/arch/tensor/distributed/_ops_binary.py @@ -0,0 +1,62 @@ +import torch + +from ._tensor import DTensor, implements + + +@implements(torch.add) +def add(input, other): + return _binary(input, other, torch.add) + + +@implements(torch.sub) +def sub(input, other): + return _binary(input, other, torch.sub) + + +@implements(torch.rsub) +def rsub(input, other): + return _binary(input, other, torch.rsub) + + +@implements(torch.mul) +def mul(input, other): + return _binary(input, other, torch.mul) + + +@implements(torch.div) +def div(input, other): + return _binary(input, other, torch.div, dtype_promote_to=torch.float32) + + +def _binary(input, other, op, swap_operad=False, dtype_promote_to=None): + # swap input and output if input is not DTensor + if not isinstance(input, DTensor): + return _binary(other, input, op, swap_operad=not swap_operad, dtype_promote_to=dtype_promote_to) + + if isinstance(other, DTensor): + if swap_operad: + return DTensor(other.shardings.join_shard(input.shardings, op, dtype_promote_to=dtype_promote_to)) + else: + return DTensor(input.shardings.join_shard(other.shardings, op, dtype_promote_to=dtype_promote_to)) + + # other is local tensor, broadcast to partitions + else: + if isinstance(other, torch.Tensor): + shapes = input.shardings.shapes.bc_shapes(other.shape) + else: + # other is scalar + shapes = input.shardings.shapes.bc_shapes(torch.Size([])) + + if swap_operad: + return DTensor( + input.shardings.map_shard( + lambda x: op(other, x), dtype_promote_to=dtype_promote_to, shapes=shapes.shapes, axis=shapes.axis + ) + ) + + else: + return DTensor( + input.shardings.map_shard( + lambda x: op(x, other), dtype_promote_to=dtype_promote_to, shapes=shapes.shapes, axis=shapes.axis + ) + ) diff --git a/python/fate/arch/tensor/distributed/_ops_cipher.py b/python/fate/arch/tensor/distributed/_ops_cipher.py new file mode 100644 index 0000000000..64d505dea2 --- /dev/null +++ b/python/fate/arch/tensor/distributed/_ops_cipher.py @@ -0,0 +1,33 @@ +from fate.arch.tensor import _custom_ops + +from ._tensor import DTensor, implements + + +@implements(_custom_ops.encrypt_encoded_f) +def encrypt_encoded_f(input: DTensor, encryptor): + return DTensor(input.shardings.map_shard(lambda x: _custom_ops.encrypt_encoded_f(x, encryptor), type="encrypted")) + + +@implements(_custom_ops.decrypt_encoded_f) +def decrypt_encoded_f(input: DTensor, decryptor): + return DTensor(input.shardings.map_shard(lambda x: _custom_ops.decrypt_encoded_f(x, decryptor), type="encoded")) + + +@implements(_custom_ops.encrypt_f) +def encrypt_f(input: DTensor, encryptor): + return DTensor(input.shardings.map_shard(lambda x: _custom_ops.encrypt_f(x, encryptor), type="encrypted")) + + +@implements(_custom_ops.decrypt_f) +def decrypt_f(input: DTensor, decryptor): + return DTensor(input.shardings.map_shard(lambda x: _custom_ops.decrypt_f(x, decryptor), type="plain")) + + +@implements(_custom_ops.decode_f) +def decode_f(input: DTensor, coder): + return DTensor(input.shardings.map_shard(lambda x: _custom_ops.decode_f(x, coder), type="plain")) + + +@implements(_custom_ops.encode_f) +def encode_f(input: DTensor, coder): + return DTensor(input.shardings.map_shard(lambda x: _custom_ops.encode_f(x, coder), type="encoded")) diff --git a/python/fate/arch/tensor/distributed/_ops_others.py b/python/fate/arch/tensor/distributed/_ops_others.py new file mode 100644 index 0000000000..b437e730fc --- /dev/null +++ b/python/fate/arch/tensor/distributed/_ops_others.py @@ -0,0 +1,14 @@ +import torch + +from fate.arch.tensor import _custom_ops +from ._tensor import DTensor, implements + + +@implements(_custom_ops.to_local_f) +def to_local_f(input: DTensor): + return input.shardings.merge() + + +@implements(_custom_ops.encode_as_int_f) +def encode_as_int_f(input: DTensor, precision): + return DTensor(input.shardings.map_shard(lambda x: (x * 2 ** precision).type(torch.int64), dtype=torch.int64)) diff --git a/python/fate/arch/tensor/distributed/_ops_unary.py b/python/fate/arch/tensor/distributed/_ops_unary.py new file mode 100644 index 0000000000..fa71c4b68d --- /dev/null +++ b/python/fate/arch/tensor/distributed/_ops_unary.py @@ -0,0 +1,23 @@ +import torch + +from ._tensor import DTensor, implements + + +@implements(torch.exp) +def exp(input: DTensor): + return DTensor(input.shardings.map_shard(torch.exp, dtype_promote_to=torch.float32)) + + +@implements(torch.log) +def log(input: DTensor): + return DTensor(input.shardings.map_shard(torch.log, dtype_promote_to=torch.float32)) + + +@implements(torch.square) +def square(input: DTensor): + return DTensor(input.shardings.map_shard(torch.square)) + + +@implements(torch.sigmoid) +def sigmoid(input: DTensor): + return DTensor(input.shardings.map_shard(torch.sigmoid)) diff --git a/python/fate/arch/tensor/distributed/_tensor.py b/python/fate/arch/tensor/distributed/_tensor.py new file mode 100644 index 0000000000..88d558cb9c --- /dev/null +++ b/python/fate/arch/tensor/distributed/_tensor.py @@ -0,0 +1,400 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import typing +from typing import List, Optional, Tuple, TypeVar, cast + +import torch +from fate.arch.abc import CTableABC +from fate.arch.context import Context + +_HANDLED_FUNCTIONS = {} + + +def implements(torch_function): + """Register a torch function override for DStorage""" + + @functools.wraps(torch_function) + def decorator(func): + _HANDLED_FUNCTIONS[torch_function] = func + return func + + return decorator + + +class DTensor: + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func not in _HANDLED_FUNCTIONS or not all(issubclass(t, (torch.Tensor, DTensor)) for t in types): + return NotImplemented + return _HANDLED_FUNCTIONS[func](*args, **kwargs) + + @property + def T(self): + return torch.transpose(self, 0, 1) + + def elem_type(self) -> Optional[str]: + return self.shardings._type + + def __init__(self, shardings: "Shardings") -> None: + self.shardings = shardings + + def __add__(self, other): + try: + return torch.add(self, other) + except Exception as e: + raise RuntimeError(f"Failed to add {self} and {other}") from e + + def __radd__(self, other): + return torch.add(other, self) + + def __sub__(self, other): + return torch.sub(self, other) + + def __rsub__(self, other): + return torch.rsub(self, other) + + def __mul__(self, other): + return torch.mul(self, other) + + def __rmul__(self, other): + return torch.mul(other, self) + + def __truediv__(self, other): + return torch.div(self, other) + + def __rtruediv__(self, other): + return torch.div(other, self) + + def __matmul__(self, other): + return torch.matmul(self, other) + + def __rmatmul__(self, other): + return torch.matmul(other, self) + + def encrypt(self, encryptor): + return torch.encrypt_f(self, encryptor) + + def encrypt_encoded(self, encryptor): + return torch.encrypt_encoded_f(self, encryptor) + + def decrypt_encoded(self, decryptor): + return torch.decrypt_encoded_f(self, decryptor) + + def encode(self, encoder): + return torch.encode_f(self, encoder) + + def decode(self, decoder): + return torch.decode_f(self, decoder) + + def decrypt(self, decryptor): + return torch.decrypt_f(self, decryptor) + + def exp(self): + return torch.exp(self) + + def log(self): + return torch.log(self) + + def square(self): + return torch.square(self) + + def sigmoid(self): + return torch.sigmoid(self) + + @property + def shape(self): + return self.shardings.shape + + @property + def dtype(self): + return self.shardings.dtype + + @property + def device(self): + return self.shardings.device + + def __eq__(self, __o: object) -> bool: + return isinstance(__o, DTensor) and self.shardings == __o.shardings + + def __str__(self) -> str: + return f"" + + @classmethod + def from_sharding_table( + cls, + data: CTableABC, + shapes: Optional[List[torch.Size]], + axis=0, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + return DTensor(Shardings(data, shapes, axis, dtype, device)) + + @classmethod + def from_sharding_list(cls, ctx: Context, data: List[torch.Tensor], num_partitions=16, axis=0): + dtype = data[0].dtype + device = data[0].device + shapes = [] + for t in data: + shapes.append(t.shape) + assert dtype == t.dtype + assert device == t.device + + for shape in shapes[1:]: + for i, (s1, s2) in enumerate(zip(shapes[0], shape)): + if i == axis: + continue + else: + assert s1 == s2 + return cls.from_sharding_table( + ctx.computing.parallelize(data, partition=num_partitions, include_key=False), shapes, axis, dtype, device + ) + + +T1 = TypeVar("T1") +T2 = TypeVar("T2") + + +class Shardings: + def __init__( + self, + data: CTableABC[int, torch.Tensor], + shapes: Optional[List[torch.Size]] = None, + axis: int = 0, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + type: Optional[str] = None, + ): + self._data = data + self._type = type + + if shapes is None: + shards_shape = sorted(self._data.map(lambda k, s: (k, s.shape)).collect()) + _shapes = [] + for i, (k, s) in enumerate(shards_shape): + assert i == k + _shapes.append(s) + else: + _shapes = shapes + self._shapes = _ShardingShapes(_shapes, axis) + + if dtype is None or device is None: + first_shard = self._data.first()[1] + shard_dtype = cast(torch.dtype, first_shard.dtype) + shard_device = cast(torch.device, first_shard.device) + if dtype is not None: + assert dtype == shard_dtype + if device is not None: + assert device == shard_device + self._dtype = shard_dtype + self._device = shard_device + else: + self._dtype = dtype + self._device = device + + @property + def shapes(self): + return self._shapes + + @property + def dtype(self): + return self._dtype + + @property + def shape(self): + return self.shapes.merge_shapes() + + def with_dtype(self, dtype: torch.dtype): + self._dtype = dtype + return self + + @property + def device(self): + return self._device + + def __eq__(self, __o: object) -> bool: + return ( + isinstance(__o, Shardings) + and self.device == __o.device + and self.dtype == __o.dtype + and self.shapes == __o.shapes + and all(self._data.join(__o._data, lambda s1, s2: torch.allclose(s1, s2)).collect()) + ) + + def __str__(self) -> str: + return f"Sharding" + + def merge(self): + shardings = [pair[1] for pair in sorted(self._data.collect())] + return torch.cat(shardings, self.shapes.axis) + + def map_shard( + self, + func: typing.Callable[[torch.Tensor], torch.Tensor], + shapes: Optional[List[torch.Size]] = None, + axis: Optional[int] = None, + dtype_promote_to: Optional[torch.dtype] = None, + type: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + ): + if dtype is None: + if dtype_promote_to is not None: + dtype = torch.promote_types(self.dtype, dtype_promote_to) + else: + dtype = self._dtype + if shapes is None: + shapes = self.shapes.shapes + if axis is None: + axis = self.shapes.axis + if type is None: + type = self._type + return Shardings(self._data.mapValues(func), shapes, axis, dtype, self._device, type) + + def map_reduce_shard( + self, + mapper_func: typing.Callable[[torch.Tensor], T1], + reducer_func: typing.Callable[[T1, T1], T1], + ) -> T1: + """ + expect local output + """ + return self._data.mapValues(mapper_func).reduce(reducer_func) + + def map_reduce_shard_with_stride( + self, + stride_mapper_func: typing.Callable[[int, int, torch.Tensor], T1], + reducer_func: typing.Callable[[T1, T1], T1], + ) -> T1: + """ + map with stride + """ + strides = self.shapes.strides() + axis = self.shapes.axis + + def _stride_mapper(func: typing.Callable[[int, int, torch.Tensor], T1]): + def _wrap(i: int, t: torch.Tensor) -> Tuple[int, T1]: + stride = strides[i] + size = t.shape[axis] + return (i, func(stride, size, t)) + + return _wrap + + return self._data.map(_stride_mapper(stride_mapper_func)).reduce(reducer_func) + + def join_shard( + self, + other: "Shardings", + func: typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + out_dtype: typing.Optional[torch.dtype] = None, + out_shapes: typing.Optional[List[torch.Size]] = None, + out_axis: typing.Optional[int] = None, + dtype_promote_to: Optional[torch.dtype] = None, + ): + if out_dtype is None: + out_dtype = torch.promote_types(self._dtype, other._dtype) + if dtype_promote_to is not None: + out_dtype = torch.promote_types(out_dtype, dtype_promote_to) + if out_shapes is None or out_axis is None: + shapes = self.shapes.bc_shapes(other.shapes) + out_shapes = shapes.shapes + out_axis = shapes.axis + return Shardings(self._data.join(other._data, func), out_shapes, out_axis, out_dtype, self._device) + + def join_reduce_shard( + self, + other: "Shardings", + mapper_func: typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + reduce_func: typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + ): + return self._data.join(other._data, mapper_func).reduce(reduce_func) + + +class _ShardingShapes: + def __init__(self, shapes: List[torch.Size], axis: int) -> None: + self.shapes = shapes + self.axis = axis + + def __eq__(self, __o: object) -> bool: + if isinstance(__o, _ShardingShapes) and self.axis == __o.axis and len(self.shapes) == len(__o.shapes): + for s1, s2 in zip(self.shapes, __o.shapes): + if s1 != s2: + return False + return True + + def __str__(self) -> str: + return f"" + + def __repr__(self): + return self.__str__() + + def bc_shapes(self, other: "_ShardingShapes") -> "_ShardingShapes": + if isinstance(other, _ShardingShapes): + assert len(self.shapes) == len(other.shapes), f"sharding num mismatch: {self.shapes} vs {other.shapes}" + _bc_shapes = [] + for s1, s2 in zip(self.shapes, other.shapes): + _bc_shapes.append(torch.broadcast_shapes(s1, s2)) + + self_axis = len(_bc_shapes[0]) - len(self.shapes[0]) + self.axis + other_axis = len(_bc_shapes[0]) - len(other.shapes[0]) + other.axis + assert self_axis == other_axis, f"sharding axis mismatch: {self_axis} vs {other_axis}" + return _ShardingShapes(_bc_shapes, self_axis) + elif isinstance(other, torch.Size): + _bc_shapes = [] + for s in self.shapes: + _bc_shapes.append(torch.broadcast_shapes(s, other)) + # assert other shape in distributed axis is 1 + other_align_axis = len(other) - len(s) + self.axis + if other_align_axis >= 0: + assert other[other_align_axis] == 1, f"shape in distributed axis should be 1: {self} vs {other}" + self_axis = len(_bc_shapes[0]) - len(self.shapes[0]) + self.axis + + return _ShardingShapes(_bc_shapes, self_axis) + else: + raise NotImplementedError(f"type `{other}`") + + def merge_shapes(self): + _shape = list(self.shapes[0]) + for s in self.shapes[1:]: + for i in range(len(_shape)): + if i == self.axis: + _shape[i] += s[i] + else: + assert _shape[i] == s[i] + return torch.Size(_shape) + + def strides(self): + _stride = [0] + agg = 0 + for s in self.shapes[:-1]: + agg += s[self.axis] + _stride.append(agg) + return _stride + + def squeeze(self, dims: Tuple[int], keepdim=False): + _shapes = [] + for s in self.shapes: + _s = [] + for i in range(len(s)): + if i in dims: + if keepdim: + _s.append(1) + else: + _s.append(s[i]) + _shapes.append(torch.Size(_s)) + return _shapes diff --git a/python/fate/arch/tensor/inside/__init__.py b/python/fate/arch/tensor/inside/__init__.py new file mode 100644 index 0000000000..fdbbfc3879 --- /dev/null +++ b/python/fate/arch/tensor/inside/__init__.py @@ -0,0 +1 @@ +from ._op_quantile import GKSummary diff --git a/python/fate/arch/tensor/inside/_op_quantile.py b/python/fate/arch/tensor/inside/_op_quantile.py new file mode 100644 index 0000000000..9c975c04c0 --- /dev/null +++ b/python/fate/arch/tensor/inside/_op_quantile.py @@ -0,0 +1,66 @@ +from typing import List, Union + +import numpy +import torch +from fate_utils import quantile + + +def quantile_fi(input: torch.Tensor, q, epsilon): + if input.dtype == torch.float64: + if len(input.shape) == 1: + return quantile.quantile_f64_ix1(input.numpy(), q, epsilon) + elif len(input.shape) == 2: + return quantile.quantile_f64_ix2(input.numpy(), q, epsilon) + raise NotImplementedError() + + +class GKSummary: + """ + GKSummary is a summary of a stream of numbers, which can be used to estimate quantiles. + + Examples: + >>> summary = GKSummary(0.001) + >>> summary += torch.tensor([1.0, 2.0, 3.0]) + >>> summary += torch.tensor([4.0, 5.0, 6.0]) + >>> summary2 = GKSummary(0.001) + >>> summary2 += torch.tensor([7.0, 8.0, 9.0, 10.0]) + >>> summary = summary + summary2 + >>> summary.queries([0.1, 0.2, 0.7, 0.8]) + [1.0, 2.0, 7.0, 8.0] + """ + + def __init__(self, epsilon: float) -> None: + self._epsilon = epsilon + self._summary = None + + def _get_summary(self): + if self._summary is None: + self._summary = quantile.QuantileSummaryStream(self._epsilon) + return self._summary + + def merge(self, other: "GKSummary"): + """merge other summary into self.""" + gk = GKSummary(self._epsilon) + gk._summary = self._get_summary().merge(other._get_summary()) + return gk + + def push(self, array: Union[torch.Tensor, numpy.ndarray]): + """push elements in array into summary.""" + if isinstance(array, torch.Tensor): + array = array.numpy() + self._get_summary().insert_array(array.astype(numpy.float64)) + return self + + def __add__(self, other: "GKSummary"): + if isinstance(other, GKSummary): + return self.merge(other) + return NotImplemented + + def __iadd__(self, other: Union[torch.Tensor, numpy.ndarray]): + if isinstance(other, torch.Tensor) or isinstance(other, numpy.ndarray): + return self.push(other) + return NotImplemented + + def queries(self, q: List[float]) -> List[float]: + """return quantile values of q.""" + return self._get_summary().queries(q) diff --git a/python/fate/arch/tensor/ops/__init__.py b/python/fate/arch/tensor/ops/__init__.py deleted file mode 100644 index aa8e04f264..0000000000 --- a/python/fate/arch/tensor/ops/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from ._agg_ops import * -from ._binary_ops import * -from ._matmul_ops import * -from ._ops import * -from ._slice_ops import * -from ._unary_ops import * diff --git a/python/fate/arch/tensor/ops/_agg_ops.py b/python/fate/arch/tensor/ops/_agg_ops.py deleted file mode 100644 index 1a32c29a60..0000000000 --- a/python/fate/arch/tensor/ops/_agg_ops.py +++ /dev/null @@ -1,77 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import overload - -from .._tensor import Tensor - -# TODO: parameter `keepdim` maybe a bit complex in distributed version, fix me later - - -@overload -def sum(a: Tensor, *, dtype=None) -> Tensor: - ... - - -@overload -def sum(a: Tensor, dim, keepdim=False, *, dtype=None) -> Tensor: - ... - - -def sum(a: Tensor, *args, **kwargs): - storage = a.storage - if func := getattr(storage, "sum"): - return Tensor(func(*args, **kwargs)) - else: - raise NotImplementedError(f"sum not impl for tensor `{a}` with storage `{storage}`") - - -def mean(a: Tensor, *args, **kwargs): - storage = a.storage - if func := getattr(storage, "mean"): - return Tensor(func(*args, **kwargs)) - else: - raise NotImplementedError(f"mean not impl for tensor `{a}` with storage `{storage}`") - - -def std(a: Tensor, *args, **kwargs): - storage = a.storage - if func := getattr(storage, "std"): - return Tensor(func(*args, **kwargs)) - else: - raise NotImplementedError(f"std not impl for tensor `{a}` with storage `{storage}`") - - -def var(a: Tensor, *args, **kwargs): - storage = a.storage - if func := getattr(storage, "var"): - return Tensor(func(*args, **kwargs)) - else: - raise NotImplementedError(f"var not impl for tensor `{a}` with storage `{storage}`") - - -def max(a: Tensor, *args, **kwargs): - storage = a.storage - if func := getattr(storage, "max"): - return Tensor(func(*args, **kwargs)) - else: - raise NotImplementedError(f"max not impl for tensor `{a}` with storage `{storage}`") - - -def min(a: Tensor, *args, **kwargs): - storage = a.storage - if func := getattr(storage, "min"): - return Tensor(func(*args, **kwargs)) - else: - raise NotImplementedError(f"min not impl for tensor `{a}` with storage `{storage}`") diff --git a/python/fate/arch/tensor/ops/_matmul_ops.py b/python/fate/arch/tensor/ops/_matmul_ops.py deleted file mode 100644 index 0fa8fa3a9b..0000000000 --- a/python/fate/arch/tensor/ops/_matmul_ops.py +++ /dev/null @@ -1,95 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from .._tensor import DStorage, Tensor -from ..types import DAxis, Shape -from ._ops import _get_dispatch_info, dispatch_signature2 - - -def matmul(a: Tensor, b: Tensor) -> Tensor: - """ - If both arguments are 2-D they are multiplied like conventional matrices. - If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indexes and broadcast accordingly. - If the first argument is 1-D, it is promoted to a matrix by prepending a 1 to its dimensions. After matrix multiplication the prepended 1 is removed. - If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its dimensions. After matrix multiplication the appended 1 is removed. - """ - _is_distributed, _device, _dtype = _get_dispatch_info([a, b]) - - # both local - from ..storage._helper import local_ops_helper - - local_ops = local_ops_helper(_device, _dtype) - - if not _is_distributed: - storage = local_ops.matmul(a.storage, b.storage) - return Tensor(storage) - - bc_shape_a = a.shape[:-2] - bc_shape_b = b.shape[:-2] - bs_shape = Shape.broadcast_shape([bc_shape_a, bc_shape_b], raise_exception=False) - if bs_shape is None: - raise ValueError("matmul: shape broadcast failed") - - if bc_shape_a.d_axis is not None: - # distributed along bc part: (...,d,...,m, k) x (...,d,...,k, n) -> (...,d,..., m, n) - # join and matmul - return dispatch_signature2("matmul", a, b, [], {}, bc_shape_validate=False) - - mul_shape_a = a.shape[-2:] - mul_shape_b = b.shape[-2:] - if mul_shape_a.size[-1] != mul_shape_b.size[0]: - raise ValueError("matmul: dimension mismatch: should be (..., n) x (...,n,?)") - - if mul_shape_a.is_d_axis(-2) and mul_shape_b.is_d_axis(-1): - raise ValueError( - f"not supported distributed axis position (...,d,?) for left tensor {a} and distributed axis position (...,?,d) for right tensor {b}" - ) - - if mul_shape_a.is_d_axis(-2) and mul_shape_b.d_axis is None: - shape = Shape( - size=[*bs_shape.size, mul_shape_a.size[0], mul_shape_b.size[-1]], - d_axis=DAxis(len(bs_shape.size) + mul_shape_a.d_axis.axis, mul_shape_a.d_axis.partitions), - ) - out_storage = DStorage.elemwise_bc_op(a.storage, b.storage, lambda l, r: local_ops.matmul(l, r), shape=shape) - elif mul_shape_b.is_d_axis(-1) and mul_shape_a.d_axis is None: - shape = ( - Shape( - size=[*bs_shape.size, mul_shape_a.size[0], mul_shape_b.size[-1]], - d_axis=DAxis(len(bs_shape.size) + mul_shape_b.d_axis.axis, mul_shape_b.d_axis.partitions), - ), - ) - out_storage = DStorage.elemwise_bc_op( - a.storage, b.storage, lambda l, r: local_ops.matmul(l, r), shape=bs_shape - ) - else: - out_storage = a.storage.blocks.join( - b.storage.blocks, - apply_transpose( - local_ops.matmul, - a.storage.transposed, - b.storage.transposed, - ), - ).reduce(local_ops.add) - return Tensor(out_storage) - - -def apply_transpose(func, lf, rf): - def _wrap(a, b): - if lf: - a = a.transpose() - if rf: - b = b.transpose() - return func(a, b) - - return _wrap diff --git a/python/fate/arch/tensor/ops/_ops.py b/python/fate/arch/tensor/ops/_ops.py deleted file mode 100644 index e70a4ee72e..0000000000 --- a/python/fate/arch/tensor/ops/_ops.py +++ /dev/null @@ -1,103 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from functools import wraps - -from .._exception import OpDispatchInvalidDevice, OpsDispatchBadSignatureError -from .._tensor import Tensor -from ..types import Shape - - -def auto_unary_op(func): - method = func.__name__ - - @wraps(func) - def wrapper(x, *args, **kwargs): - return dispatch_signature1(method, x, args, kwargs) - - return wrapper - - -def auto_binary_op(func): - method = func.__name__ - - @wraps(func) - def wrapper(x, y, *args, **kwargs): - return dispatch_signature2(method, x, y, args, kwargs) - - return wrapper - - -def _maybe_get_storage(tensor): - if isinstance(tensor, Tensor): - return tensor.storage - else: - return tensor - - -def _get_dispatch_info(tensors): - _is_distributed = False - _device = None - _dtype = None - for tensor in tensors: - if isinstance(tensor, Tensor): - # set distributed or local - _is_distributed = _is_distributed or tensor.is_distributed - - # set device - if _device is None: - _device = tensor.device - elif _device != tensor.device: - raise OpDispatchInvalidDevice(f"device mismatch: {_device} and {tensor.device}") - - # set dtypes - if _dtype is None: - _dtype = tensor.dtype - else: - _dtype = _dtype.type_promoted(tensor.dtype) - return _is_distributed, _device, _dtype - - -def dispatch_signature1(method, tensor, args, kwargs): - if not isinstance(tensor, Tensor): - raise OpsDispatchBadSignatureError(f"required exactly one tensor input, got {tensor}") - from ..storage._ops import _ops_dispatch_signature1_unknown_unknown_unknown - - storage_op = _ops_dispatch_signature1_unknown_unknown_unknown( - method=method, - distributed=tensor.is_distributed, - device=tensor.device, - dtype=tensor.dtype, - args=args, - kwargs=kwargs, - ) - storage = storage_op(_maybe_get_storage(tensor)) - return Tensor(storage) - - -def dispatch_signature2(method, tensor, other, args, kwargs, bc_shape_validate=True): - if not isinstance(tensor, Tensor) and not isinstance(other, Tensor): - raise OpsDispatchBadSignatureError(f"atleast one tensor input, got {tensor} and {other}") - from ..storage._ops import _ops_dispatch_signature2_unknown_unknown_unknown - - if bc_shape_validate: - if isinstance(tensor, Tensor) and isinstance(other, Tensor): - if Shape.broadcast_shape([tensor.shape, other.shape], raise_exception=False) is None: - raise RuntimeError(f"shape broadcast failed: {tensor.shape} and {other.shape}") - _is_distributed, _device, _dtype = _get_dispatch_info([tensor, other]) - storage_op = _ops_dispatch_signature2_unknown_unknown_unknown( - method, _is_distributed, _device, _dtype, args, kwargs - ) - storage = storage_op(_maybe_get_storage(tensor), _maybe_get_storage(other)) - return Tensor(storage) diff --git a/python/fate/arch/tensor/ops/_slice_ops.py b/python/fate/arch/tensor/ops/_slice_ops.py deleted file mode 100644 index 03f5b4ba39..0000000000 --- a/python/fate/arch/tensor/ops/_slice_ops.py +++ /dev/null @@ -1,63 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from .._tensor import Tensor -from ..types import DAxis, DStorage, Shape -from ._ops import _get_dispatch_info - - -def slice(a: Tensor, key) -> Tensor: - _is_distributed, _device, _dtype = _get_dispatch_info([a]) - from ..storage._helper import local_ops_helper - - local_ops = local_ops_helper(_device, _dtype) - if not _is_distributed: - output_storage = local_ops.slice(a.storage, key) - else: - storage = a.storage - assert isinstance(storage, DStorage), "" - - if isinstance(key, list): - partition_keys = [[] for _ in storage.d_axis.partitions] - agg = 0 - i = 0 - j = 0 - while j < len(key) and i < len(storage.d_axis.partitions): - if key[j] >= agg and key[j] < agg + storage.d_axis.partitions[i]: - partition_keys[i].append(key[j] - agg) - j += 1 - else: - agg += storage.d_axis.partitions[i] - i += 1 - if j != len(key): - raise ValueError(f"out of bound: {key}") - - def mapper(ind, s): - return (ind, local_ops.slice(s, partition_keys[ind])) - - blocks = storage.blocks.map(mapper) - size = (len(key), *storage.shape.size[1:]) - d_axis = DAxis(axis=storage.d_axis.axis, partitions=[len(p) for p in partition_keys]) - - output_storage = DStorage( - blocks, - shape=Shape(size, d_axis), - dtype=storage.dtype, - device=storage.device, - transposed=storage.transposed, - ) - else: - raise NotImplementedError(f"key {key}") - - return Tensor(output_storage) diff --git a/python/fate/arch/tensor/ops/_unary_ops.py b/python/fate/arch/tensor/ops/_unary_ops.py deleted file mode 100644 index 76b40d5eb0..0000000000 --- a/python/fate/arch/tensor/ops/_unary_ops.py +++ /dev/null @@ -1,183 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from ._ops import auto_unary_op - - -@auto_unary_op -def abs(x, *args, **kwargs): - "arc cosine" - ... - - -@auto_unary_op -def asin(x, *args, **kwargs): - "arc sin" - ... - - -@auto_unary_op -def atan(x, *args, **kwargs): - """""" - ... - - -@auto_unary_op -def atan2(x, *args, **kwargs): - """""" - ... - - -@auto_unary_op -def ceil(x, *args, **kwargs): - "ceiling" - ... - - -@auto_unary_op -def cos(x, *args, **kwargs): - """""" - ... - - -@auto_unary_op -def cosh(x, *args, **kwargs): - """""" - ... - - -@auto_unary_op -def erf(x, *args, **kwargs): - "Gaussian error functiom" - ... - - -@auto_unary_op -def erfinv(x, *args, **kwargs): - "Gaussian error functiom" - ... - - -@auto_unary_op -def exp(x, *args, **kwargs): - """""" - ... - - -@auto_unary_op -def expm1(x, *args, **kwargs): - "exponential of each element minus 1" - ... - - -@auto_unary_op -def floor(x, *args, **kwargs): - """""" - ... - - -@auto_unary_op -def frac(x, *args, **kwargs): - "fraction part 3.4 -> 0.4" - ... - - -@auto_unary_op -def log(x, *args, **kwargs): - "natural log" - ... - - -@auto_unary_op -def log1p(x, *args, **kwargs): - "y = log(1 + x)" - ... - - -@auto_unary_op -def neg(x, *args, **kwargs): - """""" - ... - - -@auto_unary_op -def reciprocal(x, *args, **kwargs): - "1/x" - ... - - -@auto_unary_op -def sigmoid(x, *args, **kwargs): - "sigmode(x)" - ... - - -@auto_unary_op -def sign(x, *args, **kwargs): - """""" - ... - - -@auto_unary_op -def sin(x, *args, **kwargs): - """""" - ... - - -@auto_unary_op -def sinh(x, *args, **kwargs): - """""" - ... - - -@auto_unary_op -def sqrt(x, *args, **kwargs): - """""" - ... - - -@auto_unary_op -def square(x, *args, **kwargs): - """""" - ... - - -@auto_unary_op -def tan(x, *args, **kwargs): - """""" - ... - - -@auto_unary_op -def tanh(x, *args, **kwargs): - """""" - ... - - -@auto_unary_op -def trunc(x, *args, **kwargs): - "truncated integer" - ... - - -@auto_unary_op -def rsqrt(x, *args, **kwargs): - "the reciprocal of the square-root" - ... - - -@auto_unary_op -def round(x, *args, **kwargs): - """""" - ... diff --git a/python/fate/arch/tensor/phe/__init__.py b/python/fate/arch/tensor/phe/__init__.py new file mode 100644 index 0000000000..db2ae2f56f --- /dev/null +++ b/python/fate/arch/tensor/phe/__init__.py @@ -0,0 +1,5 @@ +from ._keypair import PHETensorCipher +from ._ops import * +from ._tensor import PHETensor + +__all__ = ["PHETensor", "PHETensorCipher"] diff --git a/python/fate/arch/tensor/phe/_keypair.py b/python/fate/arch/tensor/phe/_keypair.py new file mode 100644 index 0000000000..339f1ff2b8 --- /dev/null +++ b/python/fate/arch/tensor/phe/_keypair.py @@ -0,0 +1,135 @@ +import typing + +import torch + +if typing.TYPE_CHECKING: + from fate.arch.protocol.phe.paillier import PK, SK, Coder + + from ._tensor import PHETensor, PHETensorEncoded + + +class PHETensorCipher: + def __init__(self, pk: "PHETensorEncryptor", coder: "PHETensorCoder", sk: "PHETensorDecryptor", evaluator) -> None: + self._pk = pk + self._coder = coder + self._sk = sk + self._evaluator = evaluator + + @classmethod + def from_raw_cipher(cls, pk: "PK", coder: "Coder", sk: "SK", evaluator): + coder = PHETensorCoder(coder) + encryptor = PHETensorEncryptor(pk, coder, evaluator) + decryptor = PHETensorDecryptor(sk, coder) + return cls(encryptor, coder, decryptor, evaluator) + + @property + def pk(self): + return self._pk + + @property + def coder(self): + return self._coder + + @property + def sk(self): + return self._sk + + +class PHETensorCoder: + def __init__(self, coder: "Coder") -> None: + self._coder = coder + + def encode(self, tensor: torch.Tensor): + if isinstance(tensor, torch.Tensor): + from fate.arch.unify import device + + from ._tensor import PHETensorEncoded + + shape = tensor.shape + dtype = tensor.dtype + data = self._coder.encode_tensor(tensor, dtype) + return PHETensorEncoded(self._coder, shape, data, tensor.dtype, device.from_torch_device(tensor.device)) + elif hasattr(tensor, "encode"): + return tensor.encode(self) + else: + raise NotImplementedError(f"`{tensor}` not supported") + + def pack_encode_float_tensor(self, tensor: torch.DoubleTensor, offset_bit: int, pack_num: int, precision: int): + if isinstance(tensor, torch.Tensor): + from fate.arch.unify import device + + from ._tensor import PHETensorEncoded + + shape = tensor.shape + dtype = tensor.dtype + data = self._coder.pack_floats(tensor, offset_bit, pack_num, precision) + return PHETensorEncoded(self._coder, shape, data, tensor.dtype, device.from_torch_device(tensor.device)) + elif hasattr(tensor, "pack_encode_float_tensor"): + return tensor.pack_encode_float_tensor(self, offset_bit, pack_num, precision) + else: + raise NotImplementedError(f"`{tensor}` not supported") + + def decode(self, tensor: "PHETensorEncoded"): + from ._tensor import PHETensorEncoded + + if isinstance(tensor, PHETensorEncoded): + return self._coder.decode_tensor(tensor.data, tensor.dtype, tensor.shape, tensor.device) + elif hasattr(tensor, "decode"): + return tensor.decode(self) + else: + raise NotImplementedError(f"`{tensor}` not supported") + + def pack_decode_float_tensor(self, tensor: "PHETensorEncoded", offset_bit: int, pack_num: int, precision: int): + from ._tensor import PHETensorEncoded + + if isinstance(tensor, PHETensorEncoded): + return self._coder.unpack_floats(tensor.data, offset_bit, pack_num, precision, tensor.shape.numel()) + elif hasattr(tensor, "pack_decode_float_tensor"): + return tensor.pack_decode_float_tensor(self, offset_bit, pack_num, precision) + else: + raise NotImplementedError(f"`{tensor}` not supported") + + +class PHETensorEncryptor: + def __init__(self, pk: "PK", coder: "PHETensorCoder", evaluator) -> None: + self._pk = pk + self._coder = coder + self._evaluator = evaluator + + def encrypt_encoded(self, tensor: "PHETensorEncoded", obfuscate=False): + from ._tensor import PHETensor, PHETensorEncoded + + if isinstance(tensor, PHETensorEncoded): + data = self._pk.encrypt_encoded(tensor.data, obfuscate) + return PHETensor(self._pk, self._evaluator, tensor.coder, tensor.shape, data, tensor.dtype, tensor.device) + elif hasattr(tensor, "encrypt_encoded"): + return tensor.encrypt_encoded(self) + raise NotImplementedError(f"`{tensor}` not supported") + + def encrypt_tensor(self, tensor: torch.Tensor, obfuscate=False): + coded = self._coder.encode(tensor) + return self.encrypt_encoded(coded, obfuscate) + + def lift(self, data, shape, dtype, device): + from ._tensor import PHETensor + return PHETensor(self._pk, self._evaluator, self._coder, shape, data, dtype, device) + + +class PHETensorDecryptor: + def __init__(self, sk: "SK", coder: "PHETensorCoder") -> None: + self._sk = sk + self._coder = coder + + def decrypt_encoded(self, tensor: "PHETensor"): + from ._tensor import PHETensor, PHETensorEncoded + + if isinstance(tensor, PHETensor): + data = self._sk.decrypt_to_encoded(tensor.data) + return PHETensorEncoded(tensor.coder, tensor.shape, data, tensor.dtype, tensor.device) + + elif hasattr(tensor, "decrypt_encoded"): + return tensor.decrypt_encoded(self) + raise NotImplementedError(f"`{tensor}` not supported") + + def decrypt_tensor(self, tensor: "PHETensor"): + return self._coder.decode(self.decrypt_encoded(tensor)) diff --git a/python/fate/arch/tensor/phe/_ops.py b/python/fate/arch/tensor/phe/_ops.py new file mode 100644 index 0000000000..e38749aed0 --- /dev/null +++ b/python/fate/arch/tensor/phe/_ops.py @@ -0,0 +1,250 @@ +import torch +from fate.arch.tensor import _custom_ops + +from ._tensor import PHETensor, implements, implements_encoded + + +@implements(_custom_ops.encrypt_f) +def encrypt(input, encryptor): + return encryptor.encrypt_tensor(input) + + +@implements_encoded(_custom_ops.encrypt_encoded_f) +def encrypt_encoded(input, encryptor): + return encryptor.encrypt_encoded(input) + + +@implements(_custom_ops.decrypt_encoded_f) +def decrypt_encoded(input, decryptor): + return decryptor.decrypt_encoded(input) + + +@implements(_custom_ops.decrypt_f) +def decrypt(input, decryptor): + return decryptor.decrypt_tensor(input) + + +@implements(_custom_ops.encode_f) +def encode(input, coder): + return coder.encode(input) + + +@implements_encoded(_custom_ops.decode_f) +def decode(input, coder): + return coder.decode(input) + + +@implements(torch.add) +def add(input: PHETensor, other): + if not isinstance(input, PHETensor) and isinstance(other, PHETensor): + return add(other, input) + + evaluator = input.evaluator + pk = input.pk + coder = input.coder + shape = input.shape + dtype = input.dtype + if isinstance(other, PHETensor): + assert shape == other.shape, f"shape mismatch {shape} != {other.shape}" + output_dtype = torch.promote_types(dtype, other.dtype) + data = evaluator.add(input.data, other.data, pk) + return input.with_template(data, dtype=output_dtype) + + elif isinstance(other, torch.Tensor): + # TODO: support broadcast + if shape == other.shape: + output_dtype = torch.promote_types(dtype, other.dtype) + data = evaluator.add_plain(input.data, other, pk, coder, output_dtype) + return input.with_template(data, dtype=output_dtype) + elif other.ndim == 0: + output_dtype = torch.promote_types(dtype, other.dtype) + data = evaluator.add_plain_scalar(input.data, other.detach().item(), pk, coder, output_dtype) + return input.with_template(data, dtype=output_dtype) + else: + raise NotImplementedError(f"broadcast not supported") + + elif isinstance(other, (float, int)): + output_dtype = torch.promote_types(dtype, torch.get_default_dtype()) + data = evaluator.add_plain_scalar(input.data, other, pk, coder, output_dtype) + return input.with_template(data, dtype=output_dtype) + else: + return NotImplemented + + +@implements(torch.rsub) +def rsub(input, other): + if not isinstance(input, PHETensor) and isinstance(other, PHETensor): + return sub(other, input) + + evaluator = input.evaluator + pk = input.pk + coder = input.coder + shape = input.shape + dtype = input.dtype + if isinstance(other, PHETensor): + assert shape == other.shape, f"shape mismatch {shape} != {other.shape}" + output_dtype = torch.promote_types(dtype, other.dtype) + data = evaluator.rsub(input.data, other.data, pk) + return input.with_template(data, dtype=output_dtype) + + elif isinstance(other, torch.Tensor): + if shape == other.shape: + output_dtype = torch.promote_types(dtype, other.dtype) + data = evaluator.rsub_plain(input.data, other, pk, coder, output_dtype) + return input.with_template(data, dtype=output_dtype) + elif other.ndim == 0: + output_dtype = torch.promote_types(dtype, other.dtype) + data = evaluator.rsub_plain_scalar(input.data, other.detach().item(), pk, coder, output_dtype) + return input.with_template(data, dtype=output_dtype) + else: + raise NotImplementedError(f"broadcast not supported") + + elif isinstance(other, (float, int)): + output_dtype = torch.promote_types(dtype, torch.get_default_dtype()) + data = evaluator.rsub_plain_scalar(input.data, other, pk, coder, output_dtype) + return input.with_template(data, dtype=output_dtype) + + else: + return NotImplemented + + +@implements(torch.sub) +def sub(input, other): + if not isinstance(input, PHETensor) and isinstance(other, PHETensor): + return rsub(other, input) + + evaluator = input.evaluator + pk = input.pk + coder = input.coder + shape = input.shape + dtype = input.dtype + if isinstance(other, PHETensor): + assert shape == other.shape, f"shape mismatch {shape} != {other.shape}" + output_dtype = torch.promote_types(dtype, other.dtype) + data = evaluator.sub(input.data, other.data, pk) + return input.with_template(data, dtype=output_dtype) + + elif isinstance(other, torch.Tensor): + if shape == other.shape: + output_dtype = torch.promote_types(dtype, other.dtype) + data = evaluator.sub_plain(input.data, other, pk, coder, output_dtype) + return input.with_template(data, dtype=output_dtype) + elif other.ndim == 0: + output_dtype = torch.promote_types(dtype, other.dtype) + data = evaluator.sub_plain_scalar(input.data, other.detach().item(), pk, coder, output_dtype) + return input.with_template(data, dtype=output_dtype) + else: + raise NotImplementedError(f"broadcast not supported") + + elif isinstance(other, (float, int)): + output_dtype = torch.promote_types(dtype, torch.get_default_dtype()) + data = evaluator.sub_plain_scalar(input.data, other, pk, coder, output_dtype) + return input.with_template(data, dtype=output_dtype) + + else: + return NotImplemented + + +@implements(torch.mul) +def mul(input, other): + if not isinstance(input, PHETensor) and isinstance(other, PHETensor): + return mul(other, input) + + evaluator = input.evaluator + pk = input.pk + coder = input.coder + shape = input.shape + dtype = input.dtype + if isinstance(other, PHETensor): + raise NotImplementedError( + f"mul {input} with {other} not supported, paillier is not multiplicative homomorphic" + ) + + elif isinstance(other, torch.Tensor): + if shape == other.shape: + output_dtype = torch.promote_types(dtype, other.dtype) + data = evaluator.mul_plain(input.data, other, pk, coder, output_dtype) + return input.with_template(data, dtype=output_dtype) + elif other.ndim == 0: + output_dtype = torch.promote_types(dtype, other.dtype) + data = evaluator.mul_plain_scalar(input.data, other.detach().item(), pk, coder, output_dtype) + return input.with_template(data, dtype=output_dtype) + else: + raise NotImplementedError(f"broadcast not supported") + + elif isinstance(other, (float, int)): + output_dtype = torch.promote_types(dtype, torch.get_default_dtype()) + data = evaluator.mul_plain_scalar(input.data, other, pk, coder, output_dtype) + return input.with_template(data, dtype=output_dtype) + + else: + return NotImplemented + + +@implements(_custom_ops.rmatmul_f) +def rmatmul_f(input, other): + if not isinstance(input, PHETensor) and isinstance(other, PHETensor): + return matmul(other, input) + + if input.ndim > 2 or input.ndim < 1: + raise ValueError(f"can't rmatmul `PHETensor` with `torch.Tensor` with dim `{input.ndim}`") + + if isinstance(other, PHETensor): + raise NotImplementedError( + f"rmatmul {input} with {other} not supported, phe is not multiplicative homomorphic" + ) + + if not isinstance(other, torch.Tensor): + return NotImplemented + + evaluator = input.evaluator + pk = input.pk + coder = input.coder + shape = input.shape + device = input.device + other_shape = other.shape + output_dtype = torch.promote_types(input.dtype, other.dtype) + output_shape = torch.matmul(torch.rand(*other_shape, device="meta"), torch.rand(*shape, device="meta")).shape + data = evaluator.rmatmul(input.data, other, shape, other_shape, pk, coder, output_dtype) + return PHETensor(pk, evaluator, coder, output_shape, data, output_dtype, device) + + +@implements(torch.matmul) +def matmul(input, other): + if not isinstance(input, PHETensor) and isinstance(other, PHETensor): + return rmatmul_f(other, input) + + if input.ndim > 2 or input.ndim < 1: + raise ValueError(f"can't matmul `PHETensor` with `torch.Tensor` with dim `{input.ndim}`") + + if isinstance(other, PHETensor): + raise ValueError("can't matmul `PHETensor` with `PHETensor`") + + if not isinstance(other, torch.Tensor): + return NotImplemented + + evaluator = input.evaluator + pk = input.pk + coder = input.coder + shape = input.shape + device = input.device + other_shape = other.shape + output_dtype = torch.promote_types(input.dtype, other.dtype) + output_shape = torch.matmul(torch.rand(*shape, device="meta"), torch.rand(*other_shape, device="meta")).shape + data = evaluator.matmul(input.data, other, shape, other_shape, pk, coder, output_dtype) + return PHETensor(pk, evaluator, coder, output_shape, data, output_dtype, device) + + +@implements(_custom_ops.slice_f) +def slice_f(input, item): + evaluator = input.evaluator + stride = input.shape[1] + start = stride * item + data = evaluator.slice(input._data, start, stride) + device = input.device + return PHETensor(input.pk, evaluator, input.coder, torch.Size([*input.shape[1:]]), data, input.dtype, device) + + +@implements(_custom_ops.to_local_f) +def to_local_f(input): + return input diff --git a/python/fate/arch/tensor/phe/_tensor.py b/python/fate/arch/tensor/phe/_tensor.py new file mode 100644 index 0000000000..839b524e70 --- /dev/null +++ b/python/fate/arch/tensor/phe/_tensor.py @@ -0,0 +1,169 @@ +import functools + +import torch + +_HANDLED_FUNCTIONS = {} +_PHE_TENSOR_ENCODED_HANDLED_FUNCTIONS = {} + + +class PHETensorEncoded: + def __init__(self, coder, shape: torch.Size, data, dtype, device) -> None: + self.coder = coder + self.data = data + self.shape = shape + self.dtype = dtype + self.device = device + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func not in _PHE_TENSOR_ENCODED_HANDLED_FUNCTIONS or not all( + issubclass(t, (torch.Tensor, PHETensorEncoded)) for t in types + ): + return NotImplemented + return _PHE_TENSOR_ENCODED_HANDLED_FUNCTIONS[func](*args, **kwargs) + + +class PHETensor: + def __init__(self, pk, evaluator, coder, shape: torch.Size, data, dtype, device) -> None: + self._pk = pk + self._evaluator = evaluator + self._coder = coder + self._data = data + self._shape = shape + self._dtype = dtype + if isinstance(device, torch.device): + from fate.arch import unify + + self._device = unify.device.from_torch_device(device) + else: + self._device = device + + def type(self, dtype): + return self.with_template(self._data, dtype) + + def __repr__(self) -> str: + return f"" + + def __str__(self) -> str: + return self.__repr__() + + def __getitem__(self, item): + from ._ops import slice_f + + if isinstance(item, int): + return slice_f(self, item) + else: + raise NotImplementedError(f"item {item} not supported") + + def with_template(self, data, dtype=None, shape=None): + if dtype is None: + dtype = self._dtype + if shape is None: + shape = self._shape + return PHETensor(self._pk, self._evaluator, self._coder, shape, data, dtype, self._device) + + @property + def pk(self): + return self._pk + + @property + def evaluator(self): + return self._evaluator + + @property + def coder(self): + return self._coder + + @property + def data(self): + return self._data + + @property + def shape(self): + return self._shape + + @property + def dtype(self): + return self._dtype + + @property + def ndim(self): + return len(self.shape) + + @property + def device(self): + return self._device + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func not in _HANDLED_FUNCTIONS or not all(issubclass(t, (torch.Tensor, PHETensor)) for t in types): + return NotImplemented + return _HANDLED_FUNCTIONS[func](*args, **kwargs) + + """implement arth magic""" + + def __add__(self, other): + from ._ops import add + + return add(self, other) + + def __radd__(self, other): + from ._ops import add + + return add(other, self) + + def __sub__(self, other): + from ._ops import sub + + return sub(self, other) + + def __rsub__(self, other): + from ._ops import rsub + + return rsub(self, other) + + def __mul__(self, other): + from ._ops import mul + + return mul(self, other) + + def __rmul__(self, other): + from ._ops import mul + + return mul(other, self) + + def __matmul__(self, other): + from ._ops import matmul + + return matmul(self, other) + + def __rmatmul__(self, other): + from ._ops import rmatmul_f + + return rmatmul_f(self, other) + + +def implements(torch_function): + """Register a torch function override for PHETensor""" + + @functools.wraps(torch_function) + def decorator(func): + _HANDLED_FUNCTIONS[torch_function] = func + return func + + return decorator + + +def implements_encoded(torch_function): + """Register a torch function override for PHEEncodedTensor""" + + @functools.wraps(torch_function) + def decorator(func): + _PHE_TENSOR_ENCODED_HANDLED_FUNCTIONS[torch_function] = func + return func + + return decorator diff --git a/python/fate/arch/tensor/storage/_helper.py b/python/fate/arch/tensor/storage/_helper.py deleted file mode 100644 index a5ecb231da..0000000000 --- a/python/fate/arch/tensor/storage/_helper.py +++ /dev/null @@ -1,76 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -class local_ops_helper: - def __init__(self, device, dtype) -> None: - self.device = device - self.dtype = dtype - - def square(self, x, *args, **kwargs): - return self.apply_signature1("square", args, kwargs)(x) - - def var(self, x, *args, **kwargs): - return self.apply_signature1("var", args, kwargs)(x) - - def std(self, x, *args, **kwargs): - return self.apply_signature1("std", args, kwargs)(x) - - def max(self, x, *args, **kwargs): - return self.apply_signature1("max", args, kwargs)(x) - - def min(self, x, *args, **kwargs): - return self.apply_signature1("min", args, kwargs)(x) - - def sum(self, x, *args, **kwargs): - return self.apply_signature1("sum", args, kwargs)(x) - - def sqrt(self, x, *args, **kwargs): - return self.apply_signature1("sqrt", args, kwargs)(x) - - def add(self, x, y, *args, **kwargs): - return self.apply_signature2("add", args, kwargs)(x, y) - - def maximum(self, x, y, *args, **kwargs): - return self.apply_signature2("maximum", args, kwargs)(x, y) - - def minimum(self, x, y, *args, **kwargs): - return self.apply_signature2("minimum", args, kwargs)(x, y) - - def div(self, x, y, *args, **kwargs): - return self.apply_signature2("div", args, kwargs)(x, y) - - def sub(self, x, y, *args, **kwargs): - return self.apply_signature2("sub", args, kwargs)(x, y) - - def mul(self, x, y, *args, **kwargs): - return self.apply_signature2("mul", args, kwargs)(x, y) - - def truediv(self, x, y, *args, **kwargs): - return self.apply_signature2("true_divide", args, kwargs)(x, y) - - def matmul(self, x, y, *args, **kwargs): - return self.apply_signature2("matmul", args, kwargs)(x, y) - - def slice(self, x, *args, **kwargs): - return self.apply_signature1("slice", args, kwargs)(x) - - def apply_signature1(self, method, args, kwargs): - from .local.device import _ops_dispatch_signature1_local_unknown_unknown - - return _ops_dispatch_signature1_local_unknown_unknown(method, self.device, self.dtype, args, kwargs) - - def apply_signature2(self, method, args, kwargs): - from .local.device import _ops_dispatch_signature2_local_unknown_unknown - - return _ops_dispatch_signature2_local_unknown_unknown(method, self.device, self.dtype, args, kwargs) diff --git a/python/fate/arch/tensor/storage/_ops.py b/python/fate/arch/tensor/storage/_ops.py deleted file mode 100644 index 60e389fa42..0000000000 --- a/python/fate/arch/tensor/storage/_ops.py +++ /dev/null @@ -1,72 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Callable - -from ..types import DStorage, Storage -from .local.device import ( - _ops_dispatch_signature1_local_unknown_unknown, - _ops_dispatch_signature2_local_unknown_unknown, -) - - -# signature1: elemwise (tensor) -> tensor -def _ops_dispatch_signature1_unknown_unknown_unknown( - method, distributed, device, dtype, args, kwargs -) -> Callable[[Storage], Storage]: - if distributed: - return _ops_dispatch_signature1_distributed_unknown_unknown(method, device, dtype, args, kwargs) - else: - - return _ops_dispatch_signature1_local_unknown_unknown(method, device, dtype, args, kwargs) - - -def _ops_dispatch_signature1_distributed_unknown_unknown( - method, device, dtype, args, kwargs -) -> Callable[[Storage], Storage]: - - local_ops = _ops_dispatch_signature1_local_unknown_unknown(method, device, dtype, args, kwargs) - - def _wrap(storage: Storage) -> DStorage: - return DStorage.elemwise_unary_op( - storage, local_ops, dtype - ) # FIXME: infer output dtype is hard without additional table call - - return _wrap - - -# signature2: elemwise (tensor, tensor) -> tensor -def _ops_dispatch_signature2_unknown_unknown_unknown( - method, distributed, device, dtype, args, kwargs -) -> Callable[[Any, Any], Storage]: - if distributed: - return _ops_dispatch_signature2_distributed_unknown_unknown(method, device, dtype, args, kwargs) - else: - - return _ops_dispatch_signature2_local_unknown_unknown(method, device, dtype, args, kwargs) - - -def _ops_dispatch_signature2_distributed_unknown_unknown( - method, device, dtype, args, kwargs -) -> Callable[[Any, Any], Storage]: - local_ops = _ops_dispatch_signature2_local_unknown_unknown(method, device, dtype, args, kwargs) - - def _wrap(storage1, storage2, **kwargs) -> DStorage: - if isinstance(storage1, DStorage) and isinstance(storage2, DStorage): - return DStorage.elemwise_binary_op(storage1, storage2, local_ops, dtype, **kwargs) - else: - # then storage2 should be broadcast - return DStorage.elemwise_bc_op(storage1, storage2, local_ops, dtype, **kwargs) - - return _wrap diff --git a/python/fate/arch/tensor/storage/distributed/agg.py b/python/fate/arch/tensor/storage/distributed/agg.py deleted file mode 100644 index 0cb0f97b1e..0000000000 --- a/python/fate/arch/tensor/storage/distributed/agg.py +++ /dev/null @@ -1,165 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from fate.arch.tensor.types import DStorage - - -def sum(storage: DStorage, *args, **kwargs): - dim = None - if len(args) > 0: - dim = args[0] - if "dim" in kwargs: - dim = kwargs["dim"] - if dim is not None and not kwargs.get("keepdim", False): - kwargs["keepdim"] = True - local_ops = storage.local_ops_helper() - output = DStorage.unary_op(storage, lambda x: local_ops.sum(x, *args, **kwargs)) - if dim is None or dim == storage.d_axis.axis: - output = output.blocks.reduce(lambda x, y: local_ops.add(x, y)) - return output - - -def mean(storage: DStorage, *args, **kwargs): - dim = None - if len(args) > 0: - dim = args[0] - if "dim" in kwargs: - dim = kwargs["dim"] - local_ops = storage.local_ops_helper() - if dim is not None and dim != storage.d_axis.axis: - count = storage.shape[dim] - return DStorage.unary_op(storage, lambda x: local_ops.truediv(local_ops.sum(x, *args, **kwargs), count)) - else: - output = DStorage.unary_op(storage, lambda x: local_ops.sum(x, *args, **kwargs)) - if dim is None: - count = storage.shape.prod() - else: - count = storage.shape[dim] - output = output.blocks.reduce(lambda x, y: local_ops.add(x, y)) - return local_ops.truediv(output, count) - - -def var(storage: DStorage, *args, **kwargs): - dim = None - if len(args) > 0: - dim = args[0] - if "dim" in kwargs: - dim = kwargs["dim"] - unbiased = kwargs.get("unbiased", True) - - local_ops = storage.local_ops_helper() - if dim is not None and dim != storage.d_axis.axis: - return DStorage.unary_op(storage, lambda x: local_ops.var(x, dim=dim, unbiased=unbiased)) - - else: - if dim is None: - n = storage.shape.prod() - - def _mapper(x): - return (local_ops.sum(local_ops.square(x)), local_ops.sum(x)) - - else: - n = storage.shape[dim] - - def _mapper(x): - return (local_ops.sum(local_ops.square(x), dim=dim), local_ops.sum(x, dim=dim)) - - def _reducer(x, y): - return (local_ops.add(x[0], y[0]), local_ops.add(x[1], y[1])) - - sq, s = storage.blocks.mapValues(_mapper).reduce(_reducer) - output = local_ops.sub(local_ops.div(sq, n), local_ops.square(local_ops.div(s, n))) - if unbiased: - output = local_ops.mul(output, n / (n - 1)) - return output - - -def std(storage: DStorage, *args, **kwargs): - dim = None - if len(args) > 0: - dim = args[0] - if "dim" in kwargs: - dim = kwargs["dim"] - unbiased = kwargs.get("unbiased", True) - - local_ops = storage.local_ops_helper() - if dim is not None and dim != storage.d_axis.axis: - return DStorage.unary_op(storage, lambda x: local_ops.std(x, dim=dim, unbiased=unbiased)) - - else: - if dim is None: - n = storage.shape.prod() - - def _mapper(x): - return (local_ops.sum(local_ops.square(x)), local_ops.sum(x)) - - else: - n = storage.shape[dim] - - def _mapper(x): - return (local_ops.sum(local_ops.square(x), dim=dim), local_ops.sum(x, dim=dim)) - - def _reducer(x, y): - return (local_ops.add(x[0], y[0]), local_ops.add(x[1], y[1])) - - sq, s = storage.blocks.mapValues(_mapper).reduce(_reducer) - output = local_ops.sub(local_ops.div(sq, n), local_ops.square(local_ops.div(s, n))) - if unbiased: - output = local_ops.mul(output, n / (n - 1)) - output = local_ops.sqrt(output) - return output - - -def max(storage: DStorage, *args, **kwargs): - dim = None - if len(args) > 0: - dim = args[0] - if "dim" in kwargs: - dim = kwargs["dim"] - local_ops = storage.local_ops_helper() - if dim is None: - - def _mapper(x): - return local_ops.max(x) - - return storage.blocks.mapValues(_mapper).reduce(lambda x, y: local_ops.maximum(x, y)) - else: - if dim == storage.d_axis.axis: - return storage.blocks.mapValues(lambda x: local_ops.max(x, dim=dim)).reduce( - lambda x, y: local_ops.maximum(x, y) - ) - else: - return DStorage.unary_op(storage, lambda s: local_ops.max(s, *args, **kwargs)) - - -def min(storage: DStorage, *args, **kwargs): - dim = None - if len(args) > 0: - dim = args[0] - if "dim" in kwargs: - dim = kwargs["dim"] - local_ops = storage.local_ops_helper() - if dim is None: - - def _mapper(x): - return local_ops.min(x) - - return storage.blocks.mapValues(_mapper).reduce(lambda x, y: local_ops.minimum(x, y)) - else: - if dim == storage.d_axis.axis: - return storage.blocks.mapValues(lambda x: local_ops.min(x, dim=dim)).reduce( - lambda x, y: local_ops.minimum(x, y) - ) - else: - return DStorage.unary_op(storage, lambda s: local_ops.min(s, *args, **kwargs)) diff --git a/python/fate/arch/tensor/storage/local/device/__init__.py b/python/fate/arch/tensor/storage/local/device/__init__.py deleted file mode 100644 index 5d2b3a4991..0000000000 --- a/python/fate/arch/tensor/storage/local/device/__init__.py +++ /dev/null @@ -1,43 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Callable - -from fate.arch.unify import device - -from ....types import LStorage - - -def _ops_dispatch_signature1_local_unknown_unknown( - method, - _device, - dtype, - args, - kwargs, -) -> Callable[[LStorage], LStorage]: - if _device == device.CPU: - from .cpu._base import _ops_dispatch_signature_1_local_cpu_unknown - - return _ops_dispatch_signature_1_local_cpu_unknown(method, dtype, args, kwargs) - raise ValueError() - - -def _ops_dispatch_signature2_local_unknown_unknown( - method, _device, dtype, args, kwargs -) -> Callable[[Any, Any], LStorage]: - if _device == device.CPU: - from .cpu._base import _ops_dispatch_signature_2_local_cpu_unknown - - return _ops_dispatch_signature_2_local_cpu_unknown(method, dtype, args, kwargs) - raise ValueError() diff --git a/python/fate/arch/tensor/storage/local/device/cpu/_base.py b/python/fate/arch/tensor/storage/local/device/cpu/_base.py deleted file mode 100644 index 8a12448e84..0000000000 --- a/python/fate/arch/tensor/storage/local/device/cpu/_base.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Callable - -from fate.arch.tensor.types import LStorage, dtype - - -def _ops_dispatch_signature_1_local_cpu_unknown( - method, - dtype: dtype, - args, - kwargs, -) -> Callable[[LStorage], LStorage]: - if dtype.is_basic(): - from .plain import _TorchStorage - - return _TorchStorage.unary(method, args, kwargs) - elif dtype.is_paillier(): - from .paillier import _ops_dispatch_signature_1_local_cpu_paillier - - return _ops_dispatch_signature_1_local_cpu_paillier(method, args, kwargs) - - -def _ops_dispatch_signature_2_local_cpu_unknown(method, dtype: dtype, args, kwargs) -> Callable[[Any, Any], LStorage]: - if dtype.is_basic(): - from .plain import _TorchStorage - - return _TorchStorage.binary(method, args, kwargs) - elif dtype.is_paillier(): - from .paillier import _ops_dispatch_signature_2_local_cpu_paillier - - return _ops_dispatch_signature_2_local_cpu_paillier(method, args, kwargs) diff --git a/python/fate/arch/tensor/storage/local/device/cpu/_metaclass.py b/python/fate/arch/tensor/storage/local/device/cpu/_metaclass.py deleted file mode 100644 index 61d74a0438..0000000000 --- a/python/fate/arch/tensor/storage/local/device/cpu/_metaclass.py +++ /dev/null @@ -1,364 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import pickle - -import numpy as np -import torch - - -def _impl_ops(class_obj, method_name, ops): - def func(self, other): - cb = ops(self._cb, other, class_obj) - if cb is NotImplemented: - return NotImplemented - else: - return class_obj(cb) - - func.__name__ = method_name - return func - - -def _impl_init(): - def __init__(self, cb): - self._cb = cb - - return __init__ - - -def _impl_encryptor_init(): - def __init__(self, pk): - self._pk = pk - - return __init__ - - -def _impl_decryptor_init(): - def __init__(self, sk): - self._sk = sk - - return __init__ - - -def _impl_encrypt(pheblock_cls, fpbloke_cls, encrypt_op): - def encrypt(self, other) -> pheblock_cls: - if isinstance(other, fpbloke_cls): - return pheblock_cls(encrypt_op(self._pk, other.numpy())) - - raise NotImplementedError(f"type {other} not supported") - - return encrypt - - -def _impl_decrypt(pheblock_cls, fpbloke_cls, decrypt_op): - def decrypt(self, other, dtype=np.float32) -> fpbloke_cls: - if isinstance(other, pheblock_cls): - return torch.from_numpy(decrypt_op(self._sk, other._cb, dtype)) - raise NotImplementedError(f"type {other} not supported") - - return decrypt - - -def _impl_serialize(): - def serialize(self) -> bytes: - return pickle.dumps(self._cb) - - return serialize - - -def _impl_keygen(encrypt_cls, decrypt_cls, keygen_op): - @classmethod - def keygen(cls, key_length=1024): - pk, sk = keygen_op(bit_size=key_length) - return (encrypt_cls(pk), decrypt_cls(sk)) - - return keygen - - -def _maybe_setattr(obj, name, value): - if not hasattr(obj, name): - setattr(obj, name, value) - - -def phe_keygen_metaclass(encrypt_cls, decrypt_cls, keygen_op): - class PHEKeygenMetaclass(type): - def __new__(cls, name, bases, dict): - keygen_cls = super().__new__(cls, name, bases, dict) - - setattr(keygen_cls, "keygen", _impl_keygen(encrypt_cls, decrypt_cls, keygen_op)) - return keygen_cls - - return PHEKeygenMetaclass - - -def phe_decryptor_metaclass(pheblock_cls, fpblock_cls): - class PHEDecryptorMetaclass(type): - def __new__(cls, name, bases, dict): - decryptor_cls = super().__new__(cls, name, bases, dict) - - setattr(decryptor_cls, "__init__", _impl_decryptor_init()) - setattr( - decryptor_cls, - "decrypt", - _impl_decrypt(pheblock_cls, fpblock_cls, PHEDecryptorMetaclass._decrypt_numpy), - ) - return decryptor_cls - - @staticmethod - def _decrypt_numpy(sk, cb, dtype): - if dtype == np.float64: - return sk.decrypt_f64(cb) - if dtype == np.float32: - return sk.decrypt_f32(cb) - if dtype == np.int64: - return sk.decrypt_i64(cb) - if dtype == np.int32: - return sk.decrypt_i32(cb) - raise NotImplementedError("dtype = {dtype}") - - return PHEDecryptorMetaclass - - -def phe_encryptor_metaclass(pheblock_cls, fpblock_cls): - class PHEEncryptorMetaclass(type): - def __new__(cls, name, bases, dict): - encryptor_cls = super().__new__(cls, name, bases, dict) - - setattr(encryptor_cls, "__init__", _impl_encryptor_init()) - setattr( - encryptor_cls, - "encrypt", - _impl_encrypt(pheblock_cls, fpblock_cls, PHEEncryptorMetaclass._encrypt_numpy), - ) - return encryptor_cls - - @staticmethod - def _encrypt_numpy(pk, other): - if is_ndarray(other): - if is_nd_float64(other): - return pk.encrypt_f64(other) - if is_nd_float32(other): - return pk.encrypt_f32(other) - if is_nd_int64(other): - return pk.encrypt_i64(other) - if is_nd_int32(other): - return pk.encrypt_i32(other) - raise NotImplementedError(f"type {other} {other.dtype} not supported") - - return PHEEncryptorMetaclass - - -class PHEBlockMetaclass(type): - def __new__(cls, name, bases, dict): - class_obj = super().__new__(cls, name, bases, dict) - - setattr(class_obj, "__init__", _impl_init()) - - @property - def shape(self): - return self._cb.shape - - setattr(class_obj, "shape", shape) - _maybe_setattr(class_obj, "serialize", _impl_serialize()) - for impl_name, ops in { - "__add__": PHEBlockMetaclass._add, - "__radd__": PHEBlockMetaclass._radd, - "__sub__": PHEBlockMetaclass._sub, - "__rsub__": PHEBlockMetaclass._rsub, - "__mul__": PHEBlockMetaclass._mul, - "__rmul__": PHEBlockMetaclass._rmul, - "__matmul__": PHEBlockMetaclass._matmul, - "__rmatmul__": PHEBlockMetaclass._rmatmul, - }.items(): - _maybe_setattr(class_obj, impl_name, _impl_ops(class_obj, impl_name, ops)) - - return class_obj - - @staticmethod - def _rmatmul(cb, other, class_obj): - if isinstance(other, torch.Tensor): - other = other.numpy() - if isinstance(other, np.ndarray): - if len(other.shape) == 2: - if is_nd_float64(other): - return cb.rmatmul_plaintext_ix2_f64(other) - if is_nd_float32(other): - return cb.rmatmul_plaintext_ix2_f32(other) - if is_nd_int64(other): - return cb.rmatmul_plaintext_ix2_i64(other) - if is_nd_int32(other): - return cb.rmatmul_plaintext_ix2_i32(other) - if len(other.shape) == 1: - if is_nd_float64(other): - return cb.rmatmul_plaintext_ix1_f64(other) - if is_nd_float32(other): - return cb.rmatmul_plaintext_ix1_f32(other) - if is_nd_int64(other): - return cb.rmatmul_plaintext_ix1_i64(other) - if is_nd_int32(other): - return cb.rmatmul_plaintext_ix1_i32(other) - return NotImplemented - - @staticmethod - def _matmul(cb, other, class_obj): - if isinstance(other, torch.Tensor): - other = other.numpy() - if is_ndarray(other): - if len(other.shape) == 2: - if is_nd_float64(other): - return cb.matmul_plaintext_ix2_f64(other) - if is_nd_float32(other): - return cb.matmul_plaintext_ix2_f32(other) - if is_nd_int64(other): - return cb.matmul_plaintext_ix2_i64(other) - if is_nd_int32(other): - return cb.matmul_plaintext_ix2_i32(other) - if len(other.shape) == 1: - if is_nd_float64(other): - return cb.matmul_plaintext_ix1_f64(other) - if is_nd_float32(other): - return cb.matmul_plaintext_ix1_f32(other) - if is_nd_int64(other): - return cb.matmul_plaintext_ix1_i64(other) - if is_nd_int32(other): - return cb.matmul_plaintext_ix1_i32(other) - return NotImplemented - - @staticmethod - def _mul(cb, other, class_obj): - if isinstance(other, torch.Tensor): - other = other.numpy() - if is_ndarray(other): - if is_nd_float64(other): - return cb.mul_plaintext_f64(other) - if is_nd_float32(other): - return cb.mul_plaintext_f32(other) - if is_nd_int64(other): - return cb.mul_plaintext_i64(other) - if is_nd_int32(other): - return cb.mul_plaintext_i32(other) - raise NotImplemented - if is_float(other): - return cb.mul_plaintext_scalar_f64(other) - if is_float32(other): - return cb.mul_plaintext_scalar_f32(other) - if is_int(other): - return cb.mul_plaintext_scalar_i64(other) - if is_int32(other): - return cb.mul_plaintext_scalar_i32(other) - return NotImplemented - - @staticmethod - def _sub(cb, other, class_obj): - if isinstance(other, torch.Tensor): - other = other.numpy() - if is_ndarray(other): - if is_nd_float64(other): - return cb.sub_plaintext_f64(other) - if is_nd_float32(other): - return cb.sub_plaintext_f32(other) - if is_nd_int64(other): - return cb.sub_plaintext_i64(other) - if is_nd_int32(other): - return cb.sub_plaintext_i32(other) - return NotImplemented - - if isinstance(other, class_obj): - return cb.sub_cipherblock(other._cb) - if is_float(other): - return cb.sub_plaintext_scalar_f64(other) - if is_float32(other): - return cb.sub_plaintext_scalar_f32(other) - if is_int(other): - return cb.sub_plaintext_scalar_i64(other) - if is_int32(other): - return cb.sub_plaintext_scalar_i32(other) - - return NotImplemented - - @staticmethod - def _add(cb, other, class_obj): - if isinstance(other, torch.Tensor): - other = other.numpy() - if is_ndarray(other): - if is_nd_float64(other): - return cb.add_plaintext_f64(other) - if is_nd_float32(other): - return cb.add_plaintext_f32(other) - if is_nd_int64(other): - return cb.add_plaintext_i64(other) - if is_nd_int32(other): - return cb.add_plaintext_i32(other) - return NotImplemented - - if isinstance(other, class_obj): - return cb.add_cipherblock(other._cb) - if is_float(other): - return cb.add_plaintext_scalar_f64(other) - if is_float32(other): - return cb.add_plaintext_scalar_f32(other) - if is_int(other): - return cb.add_plaintext_scalar_i64(other) - if is_int32(other): - return cb.add_plaintext_scalar_i32(other) - - return NotImplemented - - @staticmethod - def _radd(cb, other, class_obj): - return PHEBlockMetaclass._add(cb, other, class_obj) - - @staticmethod - def _rsub(cb, other, class_obj): - return PHEBlockMetaclass._add(PHEBlockMetaclass._mul(cb, -1, class_obj), other, class_obj) - - @staticmethod - def _rmul(cb, other, class_obj): - return PHEBlockMetaclass._mul(cb, other, class_obj) - - -def is_ndarray(v): - return isinstance(v, np.ndarray) - - -def is_float(v): - return isinstance(v, (float, np.float64)) - - -def is_float32(v): - return isinstance(v, np.float32) - - -def is_int(v): - return isinstance(v, (int, np.int64)) - - -def is_int32(v): - return isinstance(v, np.int32) - - -def is_nd_float64(v): - return v.dtype == np.float64 - - -def is_nd_float32(v): - return v.dtype == np.float32 - - -def is_nd_int64(v): - return v.dtype == np.int64 - - -def is_nd_int32(v): - return v.dtype == np.int32 diff --git a/python/fate/arch/tensor/storage/local/device/cpu/cpu_paillier_block.py b/python/fate/arch/tensor/storage/local/device/cpu/cpu_paillier_block.py deleted file mode 100644 index 8b532e4fea..0000000000 --- a/python/fate/arch/tensor/storage/local/device/cpu/cpu_paillier_block.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -import rust_paillier -import torch - -from ._metaclass import ( - PHEBlockMetaclass, - phe_decryptor_metaclass, - phe_encryptor_metaclass, - phe_keygen_metaclass, -) - - -class PaillierBlock(metaclass=PHEBlockMetaclass): - pass - - -class BlockPaillierEncryptor(metaclass=phe_encryptor_metaclass(PaillierBlock, torch.Tensor)): - pass - - -class BlockPaillierDecryptor(metaclass=phe_decryptor_metaclass(PaillierBlock, torch.Tensor)): - pass - - -class BlockPaillierCipher( - metaclass=phe_keygen_metaclass(BlockPaillierEncryptor, BlockPaillierDecryptor, rust_paillier.keygen) -): - pass diff --git a/python/fate/arch/tensor/storage/local/device/cpu/multithread_cpu_paillier_block.py b/python/fate/arch/tensor/storage/local/device/cpu/multithread_cpu_paillier_block.py deleted file mode 100644 index 1716a9090e..0000000000 --- a/python/fate/arch/tensor/storage/local/device/cpu/multithread_cpu_paillier_block.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -import rust_paillier.par -import torch - -from ._metaclass import ( - PHEBlockMetaclass, - phe_decryptor_metaclass, - phe_encryptor_metaclass, - phe_keygen_metaclass, -) - - -class PaillierBlock(metaclass=PHEBlockMetaclass): - pass - - -class BlockPaillierEncryptor(metaclass=phe_encryptor_metaclass(PaillierBlock, torch.Tensor)): - pass - - -class BlockPaillierDecryptor(metaclass=phe_decryptor_metaclass(PaillierBlock, torch.Tensor)): - pass - - -class BlockPaillierCipher( - metaclass=phe_keygen_metaclass(BlockPaillierEncryptor, BlockPaillierDecryptor, rust_paillier.par.keygen) -): - pass diff --git a/python/fate/arch/tensor/storage/local/device/cpu/paillier.py b/python/fate/arch/tensor/storage/local/device/cpu/paillier.py deleted file mode 100644 index a49d84ce12..0000000000 --- a/python/fate/arch/tensor/storage/local/device/cpu/paillier.py +++ /dev/null @@ -1,107 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Callable, List - -import torch -from fate.arch.tensor._exception import OpsDispatchUnsupportedError -from fate.arch.tensor.types import LStorage, Shape, dtype -from fate.arch.unify import device - - -class _RustPaillierStorage(LStorage): - device = device.CPU - - def __init__(self, dtype: dtype, shape: Shape, data) -> None: - self.dtype = dtype - self.shape = shape - self.data = data - - def tolist(self): - return self.data.tolist() - - def to_local(self) -> "_RustPaillierStorage": - return self - - def transpose(self): - return _RustPaillierStorage(self.dtype, self.shape.transpose(), self.data.T) - - def __eq__(self, __o: object) -> bool: - return ( - isinstance(__o, _RustPaillierStorage) - and (self.dtype == __o.dtype) - and (isinstance(self.data, torch.Tensor)) - and (isinstance(__o.data, torch.Tensor)) - and torch.equal(self.data, __o.data) - ) - - def __str__(self) -> str: - if isinstance(self.data, torch.Tensor): - return f"_RustPaillierStorage({self.device}, {self.dtype}, {self.shape},\n)" - return f"_RustPaillierStorage({self.device}, {self.dtype}, {self.shape},\n{self.data})" - - def __repr__(self) -> str: - return self.__str__() - - def cat(self, others: List["_RustPaillierStorage"], axis): - device = self.device - d_type = self.dtype - tensors = [self.data] - for storage in others: - if not isinstance(storage, _RustPaillierStorage) or storage.dtype != d_type or storage.device != device: - raise RuntimeError(f"not supported type: {storage}") - tensors.extend([storage.data for storage in others]) - cat_tensor = torch.cat(tensors, axis) - return _RustPaillierStorage(d_type, Shape(cat_tensor.shape), cat_tensor) - - -def _ops_dispatch_signature_1_local_cpu_paillier( - method, args, kwargs -) -> Callable[[_RustPaillierStorage], _RustPaillierStorage]: - raise OpsDispatchUnsupportedError(method, False, device.CPU, dtype.paillier) - - -def _ops_dispatch_signature_2_local_cpu_paillier( - method, - args, - kwargs, -) -> Callable[[Any, Any], _RustPaillierStorage]: - - # TODO: implement ops directly in C/Rust side - def _wrap(a, b, **kwargs) -> _RustPaillierStorage: - import operator - - a, b = _maybe_unwrap_storage(a), _maybe_unwrap_storage(b) - func = getattr(operator, method) - output = func(a, b) - return _RustPaillierStorage(dtype.paillier, Shape(output.shape), output) - - return _wrap - - -def _ops_dispatch_signature_3_local_cpu_paillier( - method, - args, - kwargs, -) -> Callable[[_RustPaillierStorage], _RustPaillierStorage]: - raise OpsDispatchUnsupportedError(method, False, device.CPU, dtype.paillier) - - -def _maybe_unwrap_storage(s): - from .plain import _TorchStorage - - if isinstance(s, (_RustPaillierStorage, _TorchStorage)): - return s.data - else: - return s diff --git a/python/fate/arch/tensor/storage/local/device/cpu/plain.py b/python/fate/arch/tensor/storage/local/device/cpu/plain.py deleted file mode 100644 index bafe4a47ab..0000000000 --- a/python/fate/arch/tensor/storage/local/device/cpu/plain.py +++ /dev/null @@ -1,236 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Callable, List - -import torch -from fate.arch.tensor.types import LStorage, Shape, dtype -from fate.arch.unify import device - - -class _TorchStorage(LStorage): - device = device.CPU - - def __init__(self, dtype: dtype, shape: Shape, data) -> None: - self.dtype = dtype - self.shape = shape - self.data = data - - def tolist(self): - return self.data.tolist() - - def to_local(self) -> "_TorchStorage": - return self - - def transpose(self): - return _TorchStorage(self.dtype, self.shape.transpose(), self.data.T) - - def __eq__(self, __o: object) -> bool: - return ( - isinstance(__o, _TorchStorage) - and (self.dtype == __o.dtype) - and (isinstance(self.data, torch.Tensor)) - and (isinstance(__o.data, torch.Tensor)) - and torch.equal(self.data, __o.data) - ) - - def __str__(self) -> str: - if isinstance(self.data, torch.Tensor): - return f"_CPUStorage({self.device}, {self.dtype}, {self.shape},\n)" - return f"_CPUStorage({self.device}, {self.dtype}, {self.shape},\n{self.data})" - - def __repr__(self) -> str: - return self.__str__() - - def cat(self, others: List["_TorchStorage"], axis): - device = self.device - d_type = self.dtype - tensors = [self.data] - for storage in others: - if not isinstance(storage, _TorchStorage) or storage.dtype != d_type or storage.device != device: - raise RuntimeError(f"not supported type: {storage}") - tensors.extend([storage.data for storage in others]) - cat_tensor = torch.cat(tensors, axis) - return _TorchStorage(d_type, Shape(cat_tensor.shape), cat_tensor) - - ### ops dispatch, use staticmethod here - @staticmethod - def unary(method, args, kwargs): - if _has_custom_unary(method): - return _ops_cpu_plain_unary_custom(method, args, kwargs) - else: - return _ops_cpu_plain_unary_buildin(method, args, kwargs) - - @staticmethod - def binary(method, args, kwargs): - if _has_custom_binary(method): - return _ops_cpu_plain_binary_custom(method, args, kwargs) - else: - return _ops_cpu_plain_binary_buildin(method, args, kwargs) - - def mean(self, *args, **kwargs): - return _ops_cpu_plain_unary_buildin("mean", args, kwargs)(self) - - def sum(self, *args, **kwargs): - return _ops_cpu_plain_unary_buildin("sum", args, kwargs)(self) - - def var(self, *args, **kwargs): - return _ops_cpu_plain_unary_buildin("var", args, kwargs)(self) - - def std(self, *args, **kwargs): - return _ops_cpu_plain_unary_buildin("std", args, kwargs)(self) - - def max(self, *args, **kwargs): - return _ops_cpu_plain_unary_custom("max", args, kwargs)(self) - - def min(self, *args, **kwargs): - return _ops_cpu_plain_unary_custom("min", args, kwargs)(self) - - -def _ops_cpu_plain_unary_buildin(method, args, kwargs) -> Callable[[_TorchStorage], _TorchStorage]: - if ( - func := { - "exp": torch.exp, - "log": torch.log, - "neg": torch.neg, - "reciprocal": torch.reciprocal, - "square": torch.square, - "abs": torch.abs, - "sum": torch.sum, - "sqrt": torch.sqrt, - "var": torch.var, - "std": torch.std, - "mean": torch.mean, - }.get(method) - ) is not None: - - def _wrap(storage: _TorchStorage) -> _TorchStorage: - output = func(storage.data, *args, **kwargs) - output_dtype = dtype.from_torch_dtype(output.dtype) - output_shape = Shape(output.shape) - return _TorchStorage(output_dtype, output_shape, output) - - return _wrap - raise NotImplementedError(f"method `{method}` not found in torch unary buildin, consider to add custom extending") - - -def _has_custom_unary(method): - return method in {"slice", "max", "min"} - - -def _ops_cpu_plain_unary_custom(method, args, kwargs) -> Callable[[_TorchStorage], _TorchStorage]: - if method == "slice": - - def _slice(storage: _TorchStorage): - output = storage.data[args[0]] - output_dtype = dtype.from_torch_dtype(output.dtype) - output_shape = Shape(output.shape) - return _TorchStorage(output_dtype, output_shape, output) - - return _slice - - if method == "max": - - def _max(storage: _TorchStorage): - dim = None - if len(args) > 0: - dim = args[0] - if "dim" in kwargs: - dim = kwargs["dim"] - if dim is None: - output = torch.as_tensor(storage.data.max(*args, **kwargs)) - else: - output = storage.data.max(*args, **kwargs).values - output_dtype = dtype.from_torch_dtype(output.dtype) - output_shape = Shape(output.shape) - return _TorchStorage(output_dtype, output_shape, output) - - return _max - - if method == "min": - - def _min(storage: _TorchStorage): - dim = None - if len(args) > 0: - dim = args[0] - if "dim" in kwargs: - dim = kwargs["dim"] - if dim is None: - output = torch.as_tensor(storage.data.min(*args, **kwargs)) - else: - output = storage.data.min(*args, **kwargs).values - output_dtype = dtype.from_torch_dtype(output.dtype) - output_shape = Shape(output.shape) - return _TorchStorage(output_dtype, output_shape, output) - - return _min - - raise NotImplementedError(f"method `{method}` not found in torch unary custom, consider to add custom extending") - - -def _ops_cpu_plain_binary_buildin(method, args, kwargs) -> Callable[[Any, Any], _TorchStorage]: - if ( - func := { - "add": torch.add, - "sub": torch.sub, - "mul": torch.mul, - "div": torch.div, - "pow": torch.pow, - "remainder": torch.remainder, - "matmul": torch.matmul, - "true_divide": torch.true_divide, - "truediv": torch.true_divide, - "maximum": torch.maximum, - "minimum": torch.minimum, - }.get(method) - ) is not None: - - def _wrap(a, b) -> _TorchStorage: - output = func(_maybe_unwrap_storage(a), _maybe_unwrap_storage(b), *args, **kwargs) - output_dtype = dtype.from_torch_dtype(output.dtype) - output_shape = Shape(output.shape) - return _TorchStorage(output_dtype, output_shape, output) - - return _wrap - raise NotImplementedError(f"method `{method}` not found in torch binary buildin, consider to add custom extending") - - -def _has_custom_binary(method): - return False - - -def _ops_cpu_plain_binary_custom(method, args, kwargs) -> Callable[[_TorchStorage], _TorchStorage]: - raise NotImplementedError(f"method `{method}` not found in torch buildin, consider to add custom extending") - - -def _maybe_unwrap_storage(s): - if isinstance(s, _TorchStorage): - return s.data - else: - return s - - -# def _ops_dispatch_signature_3_local_cpu_plain( -# method, -# args, -# kwargs, -# ) -> Callable[[_CPUStorage], _CPUStorage]: -# def _wrap(storage: _CPUStorage) -> _CPUStorage: -# func = getattr(torch, method) -# output = func(storage.data, *args, **kwargs) -# output_dtype = dtype.from_torch_dtype(output.dtype) -# output_shape = Shape(output.shape) -# return _CPUStorage(output_dtype, output_shape, output) - -# return _wrap diff --git a/python/fate/arch/tensor/types/_dstorage.py b/python/fate/arch/tensor/types/_dstorage.py deleted file mode 100644 index 953bd7c164..0000000000 --- a/python/fate/arch/tensor/types/_dstorage.py +++ /dev/null @@ -1,264 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Callable, List, Optional - -from fate.arch.unify import device - -from ._dtype import dtype -from ._lstorage import LStorage -from ._shape import DAxis, Shape - - -class DStorage: - def __init__(self, blocks, shape: Shape, dtype: dtype, device: device, transposed=False) -> None: - self.blocks = blocks - self._shape = shape - self._dtype = dtype - self._device = device - self.transposed = transposed - - @property - def shape(self): - return self._shape - - @property - def d_axis(self) -> DAxis: - if self._shape.d_axis is None: - raise ValueError(f"DStorage should not have none daxis") - return self._shape.d_axis - - @property - def dtype(self): - return self._dtype - - @property - def device(self): - return self._device - - def transpose(self) -> "DStorage": - return DStorage(self.blocks, self.shape.transpose(), self.dtype, self.device, not self.transposed) - - def sum(self, *args, **kwargs): - from ..storage.distributed.agg import sum - - return sum(self, *args, **kwargs) - - def max(self, *args, **kwargs): - from ..storage.distributed.agg import max - - return max(self, *args, **kwargs) - - def min(self, *args, **kwargs): - from ..storage.distributed.agg import min - - return min(self, *args, **kwargs) - - def mean(self, *args, **kwargs): - from ..storage.distributed.agg import mean - - return mean(self, *args, **kwargs) - - def std(self, *args, **kwargs): - from ..storage.distributed.agg import std - - return std(self, *args, **kwargs) - - def var(self, *args, **kwargs): - from ..storage.distributed.agg import var - - return var(self, *args, **kwargs) - - def __eq__(self, __o: object) -> bool: - if isinstance(__o, DStorage) and self._dtype == __o.dtype and self._device == __o.device: - return self.to_local() == __o.to_local() - else: - return False - - def __str__(self) -> str: - return f"DStorage({self.device}, {self.dtype}, {self.shape})" - - def num_blocks(self): - return self.blocks.count() - - def collect(self) -> List[LStorage]: - return [pair[1] for pair in sorted(self.blocks.collect())] - - def to_local(self): - storages = self.collect() - return storages[0].cat(storages[1:], self.shape.d_axis.axis) - - @classmethod - def from_storages(cls, ctx, storages: List[LStorage], d_axis=0, partitions=4): - d_type = storages[0].dtype - device = storages[0].device - shape_size = storages[0].shape.size - if storages[0].shape.d_axis is not None: - raise RuntimeError(f"can't create DStorage from list of DStorage") - if isinstance(shape_size, int): - shape_size = (shape_size,) - shape_len = len(shape_size) - if d_axis > shape_len or d_axis < 0: - raise RuntimeError(f"d_axis out of bound") - for storage in storages[1:]: - if storage.dtype != d_type: - raise RuntimeError(f"requires same dtype") - if storage.device != device: - raise RuntimeError(f"requires same device") - if storage.shape.d_axis is not None: - raise RuntimeError(f"can't create DStorage from list of DStorage") - if len(storage.shape.size) != shape_len: - raise RuntimeError(f"requires same shape len") - for i in range(shape_len): - if i == d_axis: - shape_size = ( - *shape_size[:d_axis], - shape_size[d_axis] + storage.shape.size[d_axis], - *shape_size[(d_axis + 1) :], - ) - else: - if shape_size[i] != storage.shape.size[i]: - raise RuntimeError(f"requires same shape except d_axis") - blocks = ctx.computing.parallelize(enumerate(storages), partition=partitions, include_key=True) - d_axis_cls = DAxis(d_axis, [s.shape.size[d_axis] for s in storages]) - return DStorage(blocks, Shape(shape_size, d_axis_cls), d_type, device) - - @classmethod - def unary_op( - cls, - a: "DStorage", - mapper: Callable[[LStorage], LStorage], - output_shape: Optional[Shape] = None, - output_dtype=None, - ): - def _apply_transpose(func, flag): - def _wrap(blk): - if flag: - blk = blk.transpose() - return func(blk) - - return _wrap - - mapper = _apply_transpose(mapper, a.transposed) - output_block = a.blocks.mapValues(mapper) - if output_dtype is None: - output_dtype = a._dtype - if output_shape is None: - output_shape = a.shape - return DStorage(output_block, output_shape, output_dtype, a._device) - - @classmethod - def elemwise_unary_op( - cls, - a, - mapper: Callable[[LStorage], LStorage], - output_dtype=None, - ): - def _apply_transpose(func, flag): - def _wrap(blk): - if flag: - blk = blk.transpose() - return func(blk) - - return _wrap - - mapper = _apply_transpose(mapper, a.transposed) - output_block = a.blocks.mapValues(mapper) - if output_dtype is None: - output_dtype = a._dtype - return DStorage(output_block, a.shape, output_dtype, a._device) - - @classmethod - def agg_unary_op( - cls, - a: "DStorage", - mapper: Callable[[LStorage], LStorage], - reducer, - post_func, - output_dtype=None, - ): - if output_dtype is None: - output_dtype = a._dtype - output_block = a.blocks.mapValues(mapper) - if reducer is not None: - output_block = output_block.reduce(reducer) - - if post_func is not None: - output_block = post_func(output_block) - return output_block - else: - return DStorage(output_block, a.shape, output_dtype, a._device) - - @classmethod - def elemwise_binary_op( - cls, - a: "DStorage", - b: "DStorage", - binary_mapper: Callable[[LStorage, LStorage], LStorage], - output_dtype=None, - ): - def _apply_transpose(func, lf, rf): - def _wrap(lblk, rblk): - if lf: - lblk = lblk.transpose() - if rf: - rblk = rblk.transpose() - return func(lblk, rblk) - - return _wrap - - binary_mapper = _apply_transpose(binary_mapper, a.transposed, b.transposed) - output_blocks = a.blocks.join(b.blocks, binary_mapper) - if output_dtype is None: - output_dtype = a._dtype - return DStorage(output_blocks, a.shape, output_dtype, a._device) - - @classmethod - def elemwise_bc_op( - cls, - a: "DStorage", - b: "DStorage", - func: Callable[[LStorage, LStorage], LStorage], - output_dtype=None, - shape=None, - **kwargs, - ): - def _apply_transpose(func, lf, rf): - def _wrap(lblk, rblk): - if lf: - lblk = lblk.transpose() - if rf: - rblk = rblk.transpose() - return func(lblk, rblk) - - return _wrap - - if isinstance(a, DStorage) and not isinstance(b, DStorage): - func = _apply_transpose(func, a.transposed, False) - output_blocks = a.blocks.mapValues(lambda x: func(x, b, **kwargs)) - elif isinstance(b, DStorage) and not isinstance(a, DStorage): - func = _apply_transpose(func, False, b.transposed) - output_blocks = b.blocks.mapValues(lambda x: func(a, x, **kwargs)) - else: - raise RuntimeError("exactly one DStorage required") - if output_dtype is None: - output_dtype = a._dtype - if shape is None: - shape = a.shape - return DStorage(output_blocks, shape, output_dtype, a._device) - - def local_ops_helper(self): - from ..storage._helper import local_ops_helper - - return local_ops_helper(self.device, self.dtype) diff --git a/python/fate/arch/tensor/types/_dtype.py b/python/fate/arch/tensor/types/_dtype.py deleted file mode 100644 index 0e3fa441fe..0000000000 --- a/python/fate/arch/tensor/types/_dtype.py +++ /dev/null @@ -1,65 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from enum import Enum - -import torch - - -class dtype(Enum): - def __init__(self, is_floating_point, is_signed, rank) -> None: - self.is_floating_point = is_floating_point - self.is_signed = is_signed - self.rank = rank - - int32 = (False, True, 1) - int64 = (False, True, 2) - float32 = (True, True, 3) - float64 = (True, True, 4) - paillier = (True, True, 5) # partially homomorphic encryption - # - def is_basic(self): - return self == dtype.float32 or self == dtype.float64 or self == dtype.int32 or self == dtype.int64 - - def is_paillier(self): - return self == dtype.paillier - - def type_promoted(self, other: "dtype") -> "dtype": - if self.rank < other.rank: - return other - else: - return self - - def to_torch_dtype(self): - if self == dtype.int32: - return torch.int32 - if self == dtype.int64: - return torch.int64 - if self == dtype.float64: - return torch.float64 - if self == dtype.float32: - return torch.float32 - raise TypeError(f"unsupported type: {self}") - - @classmethod - def from_torch_dtype(cls, t_type): - if t_type == torch.int32: - return dtype.int32 - if t_type == torch.int64: - return dtype.int64 - if t_type == torch.float64: - return dtype.float64 - if t_type == torch.float32: - return dtype.float32 - raise TypeError(f"unsupported type: {t_type}") diff --git a/python/fate/arch/tensor/types/_shape.py b/python/fate/arch/tensor/types/_shape.py deleted file mode 100644 index 8abe0ef673..0000000000 --- a/python/fate/arch/tensor/types/_shape.py +++ /dev/null @@ -1,159 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from functools import reduce -from typing import List, Optional, overload - - -class DAxis: - def __init__(self, axis: int, partitions) -> None: - self.axis = axis - self.partitions = partitions - - def __str__(self) -> str: - return f"DAxis" - - -class Shape: - def __init__(self, size, d_axis: Optional[DAxis] = None) -> None: - if isinstance(size, int): - size = (size,) - self.size = size - self.d_axis = d_axis - - def transpose(self) -> "Shape": - if len(self.size) != 2: - raise RuntimeError(f"transpose of size {self.size} no supported") - size = self.size[::-1] - - if self.d_axis is not None: - d_axis = DAxis(len(self.size) - 1 - self.d_axis.axis, self.d_axis.partitions) - else: - d_axis = None - return Shape(size, d_axis) - - def is_d_axis(self, axis: int): - if self.d_axis is None: - return False - gap = abs(self.d_axis.axis - axis) - return gap == 0 or gap == len(self.size) - - def __len__(self): - return len(self.size) - - def prod(self): - return reduce(lambda x, y: x * y, self.size) - - def __str__(self) -> str: - return f"Shape" - - def __repr__(self) -> str: - return self.__str__() - - def slice(self, key): - if isinstance(key, int): - raise NotImplementedError(f"key {key}") - if isinstance(key, list): - if self.d_axis is None: - raise NotImplementedError(f"key {key}") - - @overload - def __getitem__(self, key: int) -> int: - ... - - @overload - def __getitem__(self, key: slice) -> "Shape": - ... - - def __getitem__(self, key): - if isinstance(key, int): - if -len(self.size) + 1 < key < len(self.size): - return self.size[key] - else: - raise ValueError("out of bound") - elif isinstance(key, slice): - out = self.size[key] - out_d_axis = None - if self.d_axis is not None: - d_axis_mask = [False] * len(self.size) - d_axis_mask[self.d_axis.axis] = True - out_d_axis = None - for i, v in enumerate(d_axis_mask[key]): - if v: - out_d_axis = DAxis(i, self.d_axis.partitions) - return Shape(out, out_d_axis) - else: - raise NotImplementedError(f"key type {type(key)}") - - @classmethod - def broadcast_shape(cls, shapes: List["Shape"], raise_exception=True): - max_len = 0 - for shape in shapes: - if isinstance(shape.size, int): - if max_len < 1: - max_len = 1 - elif isinstance(shape.size, tuple) or isinstance(shape.size, list): - s = len(shape.size) - if max_len < s: - max_len = s - result = [1] * max_len - d_axis = None - shapes = [Shape((s.size,), s.d_axis) if isinstance(s.size, int) else s for s in shapes] - for shape in shapes: - if isinstance(shape.size, tuple) or isinstance(shape.size, list): - if shape.d_axis is not None: - aligned_d_axis = DAxis(max_len - len(shape.size) + shape.d_axis.axis, shape.d_axis.partitions) - if d_axis is None: - d_axis = aligned_d_axis - elif d_axis.axis != aligned_d_axis.axis: - if raise_exception: - raise RuntimeError("d_axis mismatch: d_axis should be equal after shape broadcast") - else: - return None - for i in range(-1, -1 - len(shape.size), -1): - if shape.size[i] < 0: - if raise_exception: - raise RuntimeError( - "Trying to create tensor with negative dimension ({}): ({})".format( - shape.size[i], shape.size[i] - ) - ) - else: - return None - if shape.size[i] == 1 or shape.size[i] == result[i]: - continue - if result[i] != 1: - if raise_exception: - raise RuntimeError("Shape mismatch: objects cannot be broadcast to a single shape") - else: - return None - result[i] = shape.size[i] - else: - if raise_exception: - raise RuntimeError( - "Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", - shape, - ) - else: - return None - # check d_axis - # TODO: we may split local tensor into parts and distributed in future - if d_axis is not None: - for shape in shapes: - if shape.d_axis is not None: - continue - p = d_axis.axis - (len(result) - len(shape.size)) - if p >= 0 and shape.size[p] != 1: - raise RuntimeError("Can't broadcast along distributed axis for Local Storage ") - return Shape(result, d_axis) diff --git a/python/fate/arch/unify/__init__.py b/python/fate/arch/unify/__init__.py index ff9b292c11..5d41ad0872 100644 --- a/python/fate/arch/unify/__init__.py +++ b/python/fate/arch/unify/__init__.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from ._infra_def import Backend, device -from ._io import URI, EggrollURI, FileURI, HdfsURI, HttpsURI, HttpURI +from ._io import URI from ._uuid import generate_computing_uuid, uuid __all__ = [ @@ -22,9 +22,4 @@ "device", "uuid", "URI", - "EggrollURI", - "FileURI", - "HdfsURI", - "HttpURI", - "HttpsURI", ] diff --git a/python/fate/arch/unify/_infra_def.py b/python/fate/arch/unify/_infra_def.py index adbd69b987..d95fd3f602 100644 --- a/python/fate/arch/unify/_infra_def.py +++ b/python/fate/arch/unify/_infra_def.py @@ -23,6 +23,25 @@ def __init__(self, type: str, index) -> None: CPU = ("CPU", 1) CUDA = ("CUDA", 2) + @classmethod + def from_torch_device(cls, tensor_device): + if tensor_device.type == "cpu": + return device.CPU + elif tensor_device.type == "cuda": + return device.CUDA + else: + raise ValueError(f"device type {tensor_device.type} not supported") + + def to_torch_device(self): + import torch + + if self.type == "CPU": + return torch.device("cpu") + elif self.type == "CUDA": + return torch.device("cuda", self.index) + else: + raise ValueError(f"device type {self.type} not supported") + class Backend(Enum): STANDALONE = "STANDALONE" diff --git a/python/fate/arch/unify/_io.py b/python/fate/arch/unify/_io.py index 155aaa3bb9..811a7ada49 100644 --- a/python/fate/arch/unify/_io.py +++ b/python/fate/arch/unify/_io.py @@ -12,11 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import hashlib import re -from abc import ABCMeta -from dataclasses import dataclass -from typing import Optional +from typing import List, Optional # see https://www.rfc-editor.org/rfc/rfc3986#appendix-B # scheme = $2 @@ -27,158 +24,55 @@ _uri_regex = re.compile(r"^(([^:/?#]+):)?(//([^/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?") -@dataclass class URI: - schema: str - path: str - query: Optional[str] = None - fragment: Optional[str] = None - authority: Optional[str] = None + def __init__( + self, + scheme: str, + path: str, + query: Optional[str] = None, + fragment: Optional[str] = None, + authority: Optional[str] = None, + original_uri: Optional[str] = None, + ): + self.scheme = scheme + self.path = path + self.query = query + self.fragment = fragment + self.authority = authority + + self.original_uri = original_uri + if self.original_uri is None: + self.original_uri = self.to_string() @classmethod def from_string(cls, uri: str) -> "URI": match = _uri_regex.fullmatch(uri) if match is None: raise ValueError(f"`{uri}` is not valid uri") - _, schema, _, authority, path, _, query, _, fragment = match.groups() - return URI(schema, path, query, fragment, authority) + _, scheme, _, authority, path, _, query, _, fragment = match.groups() + return URI(scheme=scheme, path=path, query=query, fragment=fragment, authority=authority, original_uri=uri) - def to_schema(self): - for cls in ConcrateURI.__subclasses__(): - if cls.schema() == self.schema: - return cls.from_uri(self) - raise NotImplementedError(f"uri schema `{self.schema}` not found") - - -class ConcrateURI(metaclass=ABCMeta): - @classmethod - def schema(cls) -> str: - ... - - @classmethod - def from_uri(cls, uri: URI) -> "ConcrateURI": - ... - - def create_file(self, name): - ... - - def to_string(self): - ... - - -_EGGROLL_NAME_MAX_SIZE = 128 - - -@dataclass -class FileURI(ConcrateURI): - path: str - - @classmethod - def schema(cls): - return "file" - - @classmethod - def from_uri(cls, uri: URI): - return FileURI(uri.path) - - def create_file(self, name): - return FileURI(f"{self.path}/{name}") - - def to_string(self): - return f"file://{self.path}" - - -@dataclass -class EggrollURI(ConcrateURI): - namespace: str - name: str - - @classmethod - def schema(cls): - return "eggroll" - - @classmethod - def from_uri(cls, uri: URI): - _, namespace, *names = uri.path.split("/") - name = "_".join(names) - if len(name) > _EGGROLL_NAME_MAX_SIZE: - name = hashlib.md5(name.encode(encoding="utf8")).hexdigest()[:_EGGROLL_NAME_MAX_SIZE] - return EggrollURI(namespace, name) - - def create_file(self, name): - name = f"{self.name}_{name}" - if len(name) > _EGGROLL_NAME_MAX_SIZE: - name = hashlib.md5(name.encode(encoding="utf8")).hexdigest()[:_EGGROLL_NAME_MAX_SIZE] - return EggrollURI(namespace=self.namespace, name=name) - - def to_string(self): - return f"eggroll:///{self.namespace}/{self.name}" - - -@dataclass -class HdfsURI(ConcrateURI): - path: str - authority: Optional[str] = None - - @classmethod - def schema(cls): - return "hdfs" - - @classmethod - def from_uri(cls, uri: URI): - return HdfsURI(uri.path, uri.authority) - - def create_file(self, name): - return HdfsURI(path=f"{self.path}/{name}", authority=self.authority) - - def to_string(self): - if self.authority: - return f"hdfs://{self.authority}{self.path}" - else: - return f"hdfs://{self.path}" - - -@dataclass -class HttpURI(ConcrateURI): - path: str - authority: Optional[str] = None - - @classmethod - def schema(cls): - return "http" - - @classmethod - def from_uri(cls, uri: URI): - return HttpURI(uri.path, uri.authority) - - def create_file(self, name): - return HttpURI(path=f"{self.path}/{name}", authority=self.authority) - - def to_string(self): - if self.authority: - return f"http://{self.authority}{self.path}" - else: - return f"http://{self.path}" - - -@dataclass -class HttpsURI(ConcrateURI): - path: str - authority: Optional[str] = None - - @classmethod - def schema(cls): - return "https" - - @classmethod - def from_uri(cls, uri: URI): - return HttpsURI(uri.path, uri.authority) - - def create_file(self, name): - return HttpURI(path=f"{self.path}/{name}", authority=self.authority) - - def to_string(self): + def to_string(self) -> str: + uri = "" + if self.scheme: + uri += f"{self.scheme}:" if self.authority: - return f"https://{self.authority}{self.path}" - else: - return f"https://{self.path}" + uri += f"//{self.authority}" + elif self.scheme: + uri += f"//" + uri += self.path + if self.query: + uri += f"?{self.query}" + if self.fragment: + uri += f"#{self.fragment}" + return uri + + def __str__(self): + return self.to_string() + + def __repr__(self): + return self.to_string() + + def path_splits(self) -> List[str]: + parts = self.path.split("/") + return parts diff --git a/python/fate/components/__init__.py b/python/fate/components/__init__.py index c8bfccd1cd..e69de29bb2 100644 --- a/python/fate/components/__init__.py +++ b/python/fate/components/__init__.py @@ -1,224 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Dict, List, Literal, Optional, Type, TypeVar - -from typing_extensions import Annotated - -T_ROLE = Literal["guest", "host", "arbiter"] -T_STAGE = Literal["train", "predict", "default"] -T_LABEL = Literal["trainable"] - - -class Role: - def __init__(self, name: T_ROLE) -> None: - self.name: T_ROLE = name - - @property - def is_guest(self) -> bool: - return self.name == "guest" - - @property - def is_host(self) -> bool: - return self.name == "host" - - @property - def is_arbiter(self) -> bool: - return self.name == "arbiter" - - -GUEST = Role("guest") -HOST = Role("host") -ARBITER = Role("arbiter") - -T_ROLE = Literal["guest", "host", "arbiter"] -T_STAGE = Literal["train", "predict", "default"] -T_LABEL = Literal["trainable"] - - -class Stage: - def __init__(self, name: str) -> None: - self.name = name - - @property - def is_train(self): - return self.name == "train" - - @property - def is_predict(self): - return self.name == "predict" - - @property - def is_default(self): - return self.name == "default" - - -TRAIN = Stage("train") -PREDICT = Stage("predict") -DEFAULT = Stage("default") - - -class LABELS: - TRAINABLE = "trainable" - - -class OutputAnnotated: - ... - - -class InputAnnotated: - ... - - -T = TypeVar("T") -Output = Annotated[T, OutputAnnotated] -Input = Annotated[T, InputAnnotated] - - -class Artifact: - type: str = "artifact" - """Represents a generic machine learning artifact. - - This class and all artifact classes - store the name, uri, and metadata for a machine learning artifact. - Use this artifact type when an artifact - does not fit into another more specific artifact type (e.g., ``Model``, ``Dataset``). - - Args: - name: Name of the artifact. - uri: The artifact's location on disk or cloud storage. - metadata: Arbitrary key-value pairs about the artifact. - """ - - def __init__( - self, - name: Optional[str] = None, - uri: Optional[str] = None, - metadata: Optional[Dict] = None, - ) -> None: - """Initializes the Artifact with the given name, URI and metadata.""" - self.uri = uri or "" - self.name = name or "" - self.metadata = metadata or {} - - def __str__(self) -> str: - return f"<{type(self).__name__} {dict(name=self.name, uri=self.uri, metadata=self.metadata)}>" - - def __repr__(self) -> str: - return self.__str__() - - -class Artifacts: - type: str - artifact_type: Type[Artifact] - - def __init__(self, artifacts: List[Artifact]) -> None: - self.artifacts = artifacts - - def __str__(self) -> str: - return f"<{type(self).__name__} {self.artifacts}>" - - def __repr__(self) -> str: - return self.__str__() - - -class DatasetArtifact(Artifact): - type = "dataset" - """An artifact representing a machine learning dataset. - - Args: - name: Name of the dataset. - uri: The dataset's location on disk or cloud storage. - metadata: Arbitrary key-value pairs about the dataset. - """ - - def __init__( - self, - name: Optional[str] = None, - uri: Optional[str] = None, - metadata: Optional[Dict] = None, - ) -> None: - super().__init__(uri=uri, name=name, metadata=metadata) - - -class DatasetArtifacts(Artifacts): - type = "datasets" - artifact_type: Type[Artifact] = DatasetArtifact - - -class ModelArtifact(Artifact): - type = "model" - """An artifact representing a machine learning model. - - Args: - name: Name of the model. - uri: The model's location on disk or cloud storage. - metadata: Arbitrary key-value pairs about the model. - """ - - def __init__( - self, - name: Optional[str] = None, - uri: Optional[str] = None, - metadata: Optional[Dict] = None, - ) -> None: - super().__init__(uri=uri, name=name, metadata=metadata) - - -class ModelArtifacts(Artifacts): - type = "models" - - -class MetricArtifact(Artifact): - type = "metric" - - def __init__( - self, - name: Optional[str] = None, - uri: Optional[str] = None, - metadata: Optional[Dict] = None, - ) -> None: - super().__init__(uri=uri, name=name, metadata=metadata) - - -class LossMetrics(MetricArtifact): - type = "loss" - - def __init__( - self, - name: Optional[str] = None, - uri: Optional[str] = None, - metadata: Optional[Dict] = None, - ) -> None: - super().__init__(uri=uri, name=name, metadata=metadata) - - -class ClassificationMetrics(MetricArtifact): - """An artifact for storing classification metrics. - - Args: - name: Name of the metrics artifact. - uri: The metrics artifact's location on disk or cloud storage. - metadata: The key-value scalar metrics. - """ - - type = "classification_metrics" - - def __init__( - self, - name: Optional[str] = None, - uri: Optional[str] = None, - metadata: Optional[Dict] = None, - ): - super().__init__(uri=uri, name=name, metadata=metadata) diff --git a/python/fate/components/__main__.py b/python/fate/components/__main__.py index bb0c6df837..f093478b68 100644 --- a/python/fate/components/__main__.py +++ b/python/fate/components/__main__.py @@ -15,16 +15,15 @@ # """ -execute with python -m fate.components --execution_id xxx --config xxx +execute with python -m fate.components --config xxx """ - if __name__ == "__main__": import click - from fate.components.entrypoint.clean_cli import clean - from fate.components.entrypoint.component_cli import component + from fate.components.entrypoint.cli.component.__main__ import component + from fate.components.entrypoint.cli.test.__main__ import test cli = click.Group() cli.add_command(component) - cli.add_command(clean) + cli.add_command(test) cli(prog_name="python -m fate.components") diff --git a/python/fate/components/components/__init__.py b/python/fate/components/components/__init__.py index d8f7b64466..a33a031132 100644 --- a/python/fate/components/components/__init__.py +++ b/python/fate/components/components/__init__.py @@ -12,10 +12,179 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .evaluation import evaluation -from .feature_scale import feature_scale -from .hetero_lr import hetero_lr -from .intersection import intersection -from .reader import reader -BUILDIN_COMPONENTS = [hetero_lr, reader, feature_scale, intersection, evaluation] +import typing +from typing import List + +if typing.TYPE_CHECKING: + from fate.components.core import Component + + +class _ComponentDecorator: + def __init__(self): + self._component_map = {} + + def __call__(self, func): + self._component_map[func.__name__] = func + return func + + def __getitem__(self, item): + return self._component_map[item] + + def __contains__(self, item): + return item in self._component_map + + def __iter__(self): + return iter(self._component_map) + + +_lazy_cpn = _ComponentDecorator() + + +class LazyBuildInComponentsLoader: + @_lazy_cpn + def feature_scale(self): + from .feature_scale import feature_scale + + return feature_scale + + @_lazy_cpn + def reader(self): + from .reader import reader + + return reader + + @_lazy_cpn + def coordinated_lr(self): + from .coordinated_lr import coordinated_lr + + return coordinated_lr + + @_lazy_cpn + def coordinated_linr(self): + from .coordinated_linr import coordinated_linr + + return coordinated_linr + + @_lazy_cpn + def homo_nn(self): + from .homo_nn import homo_nn + + return homo_nn + + @_lazy_cpn + def homo_lr(self): + from .homo_lr import homo_lr + + return homo_lr + + @_lazy_cpn + def hetero_sbt(self): + from .hetero_sbt import hetero_sbt + + return hetero_sbt + + @_lazy_cpn + def dataframe_transformer(self): + from .dataframe_transformer import dataframe_transformer + + return dataframe_transformer + + @_lazy_cpn + def psi(self): + from .psi import psi + + return psi + + @_lazy_cpn + def evaluation(self): + from .evaluation import evaluation + + return evaluation + + @_lazy_cpn + def artifact_test(self): + from .artifact_test import artifact_test + + return artifact_test + + @_lazy_cpn + def statistics(self): + from .statistics import statistics + + return statistics + + @_lazy_cpn + def hetero_feature_binning(self): + from .hetero_feature_binning import hetero_feature_binning + + return hetero_feature_binning + + @_lazy_cpn + def hetero_feature_selection(self): + from .hetero_feature_selection import hetero_feature_selection + + return hetero_feature_selection + + @_lazy_cpn + def union(self): + from .union import union + + return union + + @_lazy_cpn + def sample(self): + from .sample import sample + + return sample + + @_lazy_cpn + def data_split(self): + from .data_split import data_split + + return data_split + + @_lazy_cpn + def toy_example(self): + from .toy_example import toy_example + + return toy_example + + @_lazy_cpn + def dataframe_io_test(self): + from .dataframe_io_test import dataframe_io_test + + return dataframe_io_test + + @_lazy_cpn + def multi_model_test(self): + from .multi_model_test import multi_model_test + + return multi_model_test + + @_lazy_cpn + def cv_test2(self): + from .cross_validation_test import cv_test + + return cv_test + + @classmethod + def contains(cls, cpn_name: str): + return cpn_name in _lazy_cpn + + @classmethod + def list(cls) -> List[str]: + return list(_lazy_cpn) + + def load_cpn(self, cpn_name: str) -> "Component": + if self.contains(cpn_name): + cpn = _lazy_cpn[cpn_name](self) + if cpn.name != cpn_name: + # TODO: add warning + # logger.warning(f"Component name {cpn_name} is not consistent with the name of the component class.") + # the cpn name updated by the lazy decorator, treat it as a reexport component + cpn.name = cpn_name + return cpn + + else: + raise ValueError(f"Component {cpn_name} does not exist.") diff --git a/python/fate/components/components/artifact_test.py b/python/fate/components/components/artifact_test.py new file mode 100644 index 0000000000..b7954ed056 --- /dev/null +++ b/python/fate/components/components/artifact_test.py @@ -0,0 +1,86 @@ +import typing + +from fate.components.core import ARBITER, GUEST, HOST, Role, cpn, params + +if typing.TYPE_CHECKING: + from fate.arch import Context + + +@cpn.component(roles=[GUEST, HOST, ARBITER]) +def artifact_test( + ctx: "Context", + role: Role, + parameter: cpn.parameter(type=params.string_choice(["a", "b"]), desc="parameter", optional=False), + # mix_input: cpn.dataframe_input(roles=[GUEST, HOST]) | cpn.data_directory_input(), + mix_input: cpn.union(cpn.dataframe_input, cpn.data_directory_input)(roles=[GUEST, HOST]), + dataframe_inputs: cpn.dataframe_inputs(roles=[GUEST, HOST]), + # dataframe_input: cpn.dataframe_input(roles=[GUEST, HOST]), + dataset_inputs: cpn.data_directory_inputs(roles=[GUEST, HOST]), + dataset_input: cpn.data_directory_input(roles=[GUEST, HOST]), + table_input: cpn.table_input(roles=[GUEST, HOST]), + table_inputs: cpn.table_inputs(roles=[GUEST, HOST]), + json_model_input: cpn.json_model_input(roles=[HOST]), + # dataframe_outputs: cpn.dataframe_outputs(roles=[GUEST, HOST]), + # dataframe_output: cpn.dataframe_output(roles=[GUEST, HOST]), + dataset_outputs: cpn.data_directory_outputs(roles=[GUEST, HOST]), + dataset_output: cpn.data_directory_output(roles=[GUEST, HOST]), + json_model_output: cpn.json_model_output(roles=[GUEST, HOST]), + json_model_outputs: cpn.json_model_outputs(roles=[GUEST, HOST]), + model_directory_output: cpn.model_directory_output(roles=[GUEST, HOST]), +): + # print("dataframe_input", dataframe_input) + # print("dataset_inputs", dataset_inputs) + # print("dataset_input", dataset_input) + # print("table_input", table_input) + # print("table_inputs", table_inputs) + # print("json_model_input", json_model_input) + # + # print("dataframe_outputs", dataframe_outputs) + # dataframe_outputs_0 = next(dataframe_outputs) + # dataframe_outputs_1 = next(dataframe_outputs) + # print(" dataframe_outputs_0", dataframe_outputs_0) + # print(" dataframe_outputs_1", dataframe_outputs_1) + # + print("dataset_outputs", dataset_outputs) + dataset_outputs_0 = next(dataset_outputs) + dataset_outputs_1 = next(dataset_outputs) + print(" dataset_outputs_0", dataset_outputs_0) + print(" dataset_outputs_1", dataset_outputs_1) + # + # print("dataframe_output", dataframe_output) + # dataframe_output.write(dataframe_input.read(), name="myname", namespace="mynamespace") + # print("dataset_output", dataset_output) + # + next(json_model_outputs).write({"aaa": 1}, metadata={"bbb": 2}) + next(json_model_outputs).write({"aaa": 2}, metadata={"bbb": 2}) + + json_model_output.write({"aaa": 1}, metadata={"bbb": 2}) + + model_directory_output.get_directory() + + # output_path = model_directory_output.get_directory() + # with open(output_path + "/a.txt", "w") as fw: + # fw.write("a") + # model_directory_output.write_metadata({"model_directory_output_metadata": 1}) + + # ctx.metrics.log_accuracy("s", 1.0, 0) + # for i, sub_ctx in ctx.ctxs_range(10): + # sub_ctx.metrics.log_accuracy("sub", 1.0) + # print(ctx.metrics.handler._metrics) + # ctx.sub_ctx("ssss").metrics.log_loss("loss", 1.0, 0) + # + # ctx.metrics.log_metrics( + # [1, 2, 3, 4], + # "summary", + # "summary", + # ) + # # print("dataframe_inputs", dataframe_inputs) + # + # ctx.metrics.log_accuracy("aaa", 1, 0) + ctx.metrics.log_loss("loss_aa", [1.0, 2.0], 0) + ctx.metrics.log_metrics(values=[1, 2, 3], name="metric_single", type="custom", metadata={"bbb": 2}) + ctx.sub_ctx("sub_metric").metrics.log_loss("loss", 1.0, 0) + ctx.sub_ctx("sub_metric").metrics.log_loss("loss", 0.9, 1) + + for i, auto_step_sub_ctx in ctx.ctxs_range(10): + auto_step_sub_ctx.metrics.log_accuracy("sub", 1.0) diff --git a/python/fate/components/components/coordinated_linr.py b/python/fate/components/components/coordinated_linr.py new file mode 100644 index 0000000000..81805cb35d --- /dev/null +++ b/python/fate/components/components/coordinated_linr.py @@ -0,0 +1,352 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from fate.arch import Context +from fate.arch.dataframe import DataFrame +from fate.components.components.utils import consts, tools +from fate.components.core import ARBITER, GUEST, HOST, Role, cpn, params +from fate.ml.glm import CoordinatedLinRModuleArbiter, CoordinatedLinRModuleGuest, CoordinatedLinRModuleHost + +logger = logging.getLogger(__name__) + + +@cpn.component(roles=[GUEST, HOST, ARBITER], provider="fate") +def coordinated_linr(ctx, role): + ... + + +@coordinated_linr.train() +def train( + ctx: Context, + role: Role, + train_data: cpn.dataframe_input(roles=[GUEST, HOST]), + validate_data: cpn.dataframe_input(roles=[GUEST, HOST], optional=True), + learning_rate_scheduler: cpn.parameter(type=params.lr_scheduler_param(), + default=params.LRSchedulerParam(method="linear", + scheduler_params={"start_factor": 1.0}), + desc="learning rate scheduler, " + "select method from {'step', 'linear', 'constant'}" + "for list of configurable arguments, " + "refer to torch.optim.lr_scheduler"), + epochs: cpn.parameter(type=params.conint(gt=0), default=20, + desc="max iteration num"), + batch_size: cpn.parameter( + type=params.conint(ge=10), + default=None, desc="batch size, None means full batch, otherwise should be no less than 10, default None" + ), + optimizer: cpn.parameter(type=params.optimizer_param(), + default=params.OptimizerParam(method="sgd", penalty='l2', alpha=1.0, + optimizer_params={"lr": 1e-2, "weight_decay": 0})), + tol: cpn.parameter(type=params.confloat(ge=0), default=1e-4), + early_stop: cpn.parameter(type=params.string_choice(["weight_diff", "diff", "abs"]), default="diff", + desc="early stopping criterion, choose from {weight_diff, diff, abs, val_metrics}"), + init_param: cpn.parameter( + type=params.init_param(), + default=params.InitParam(method="random_uniform", fit_intercept=True, random_state=None), + desc="Model param init setting.", + ), + he_param: cpn.parameter(type=params.he_param(), default=params.HEParam(kind="paillier", key_length=1024), + desc="homomorphic encryption param"), + floating_point_precision: cpn.parameter( + type=params.conint(ge=0), + default=23, + desc="floating point precision, "), + train_output_data: cpn.dataframe_output(roles=[GUEST, HOST]), + output_model: cpn.json_model_output(roles=[GUEST, HOST, ARBITER]), + warm_start_model: cpn.json_model_input(roles=[GUEST, HOST, ARBITER], optional=True), + +): + logger.info(f"enter coordinated linr train") + # temp code start + optimizer = optimizer.dict() + learning_rate_scheduler = learning_rate_scheduler.dict() + init_param = init_param.dict() + ctx.cipher.set_phe(ctx.device, he_param.dict()) + # temp code end + if role.is_guest: + train_guest( + ctx, train_data, validate_data, train_output_data, output_model, epochs, + batch_size, optimizer, learning_rate_scheduler, init_param, floating_point_precision, + warm_start_model + ) + elif role.is_host: + train_host( + ctx, train_data, validate_data, train_output_data, output_model, epochs, + batch_size, optimizer, learning_rate_scheduler, init_param, floating_point_precision, + warm_start_model + ) + elif role.is_arbiter: + train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer, learning_rate_scheduler, + output_model, warm_start_model) + + +@coordinated_linr.predict() +def predict( + ctx, + role: Role, + test_data: cpn.dataframe_input(roles=[GUEST, HOST]), + input_model: cpn.json_model_input(roles=[GUEST, HOST]), + test_output_data: cpn.dataframe_output(roles=[GUEST, HOST]) +): + if role.is_guest: + predict_guest(ctx, input_model, test_data, test_output_data) + if role.is_host: + predict_host(ctx, input_model, test_data, test_output_data) + + +@coordinated_linr.cross_validation() +def cross_validation( + ctx: Context, + role: Role, + cv_data: cpn.dataframe_input(roles=[GUEST, HOST]), + learning_rate_scheduler: cpn.parameter( + type=params.lr_scheduler_param(), + default=params.LRSchedulerParam(method="linear", scheduler_params={"start_factor": 1.0}), + desc="learning rate scheduler, " + "select method from {'step', 'linear', 'constant'}" + "for list of configurable arguments, " + "refer to torch.optim.lr_scheduler", + ), + epochs: cpn.parameter(type=params.conint(gt=0), default=20, desc="max iteration num"), + batch_size: cpn.parameter( + type=params.conint(ge=10), + default=None, desc="batch size, None means full batch, otherwise should be no less than 10, default None" + ), + optimizer: cpn.parameter( + type=params.optimizer_param(), + default=params.OptimizerParam( + method="sgd", penalty="l2", alpha=1.0, optimizer_params={"lr": 1e-2, "weight_decay": 0} + ), + ), + tol: cpn.parameter(type=params.confloat(ge=0), default=1e-4), + early_stop: cpn.parameter( + type=params.string_choice(["weight_diff", "diff", "abs"]), + default="diff", + desc="early stopping criterion, choose from {weight_diff, diff, abs, val_metrics}", + ), + init_param: cpn.parameter( + type=params.init_param(), + default=params.InitParam(method="random_uniform", fit_intercept=True, random_state=None), + desc="Model param init setting.", + ), + cv_param: cpn.parameter(type=params.cv_param(), + default=params.CVParam(n_splits=5, shuffle=False, random_state=None), + desc="cross validation param"), + floating_point_precision: cpn.parameter( + type=params.conint(ge=0), + default=23, + desc="floating point precision, "), + he_param: cpn.parameter(type=params.he_param(), default=params.HEParam(kind="paillier", key_length=1024), + desc="homomorphic encryption param"), + metrics: cpn.parameter(type=params.metrics_param(), default=["mse"]), + output_cv_data: cpn.parameter(type=bool, default=True, desc="whether output prediction result per cv fold"), + cv_output_datas: cpn.dataframe_outputs(roles=[GUEST, HOST], optional=True), +): + # temp code start + optimizer = optimizer.dict() + learning_rate_scheduler = learning_rate_scheduler.dict() + init_param = init_param.dict() + ctx.cipher.set_phe(ctx.device, he_param.dict()) + # temp code end + if role.is_arbiter: + i = 0 + for fold_ctx, _ in ctx.on_cross_validations.ctxs_zip(zip(range(cv_param.n_splits))): + logger.info(f"enter fold {i}") + module = CoordinatedLinRModuleArbiter( + epochs=epochs, + early_stop=early_stop, + tol=tol, + batch_size=batch_size, + optimizer_param=optimizer, + learning_rate_param=learning_rate_scheduler, + ) + module.fit(fold_ctx) + i += 1 + return + + from fate.arch.dataframe import KFold + kf = KFold(ctx, role=role, n_splits=cv_param.n_splits, shuffle=cv_param.shuffle, random_state=cv_param.random_state) + i = 0 + for fold_ctx, (train_data, validate_data) in ctx.on_cross_validations.ctxs_zip(kf.split(cv_data.read())): + logger.info(f"enter fold {i}") + if role.is_guest: + module = CoordinatedLinRModuleGuest( + epochs=epochs, + batch_size=batch_size, + optimizer_param=optimizer, + learning_rate_param=learning_rate_scheduler, + init_param=init_param, + floating_point_precision=floating_point_precision + ) + module.fit(fold_ctx, train_data, validate_data) + if output_cv_data: + sub_ctx = fold_ctx.sub_ctx("predict_train") + train_predict_df = module.predict(sub_ctx, train_data) + """train_predict_result = transform_to_predict_result( + train_data, predict_score, data_type="train" + )""" + train_predict_result = tools.add_dataset_type(train_predict_df, consts.TRAIN_SET) + sub_ctx = fold_ctx.sub_ctx("predict_validate") + validate_predict_df = module.predict(sub_ctx, validate_data) + validate_predict_result = tools.add_dataset_type(validate_predict_df, consts.VALIDATE_SET) + """validate_predict_result = transform_to_predict_result( + validate_data, predict_score, data_type="predict" + ) + """ + predict_result = DataFrame.vstack([train_predict_result, validate_predict_result]) + next(cv_output_datas).write(df=predict_result) + + # evaluation = evaluate(predicted) + elif role.is_host: + module = CoordinatedLinRModuleHost( + epochs=epochs, + batch_size=batch_size, + optimizer_param=optimizer, + learning_rate_param=learning_rate_scheduler, + init_param=init_param, + floating_point_precision=floating_point_precision + ) + module.fit(fold_ctx, train_data, validate_data) + if output_cv_data: + sub_ctx = fold_ctx.sub_ctx("predict_train") + module.predict(sub_ctx, train_data) + sub_ctx = fold_ctx.sub_ctx("predict_validate") + module.predict(sub_ctx, validate_data) + i += 1 + + +def train_guest(ctx, train_data, validate_data, train_output_data, output_model, epochs, + batch_size, optimizer_param, learning_rate_param, init_param, floating_point_precision, input_model): + if input_model is not None: + logger.info(f"warm start model provided") + model = input_model.read() + module = CoordinatedLinRModuleGuest.from_model(model) + module.set_epochs(epochs) + module.set_batch_size(batch_size) + else: + module = CoordinatedLinRModuleGuest(epochs=epochs, batch_size=batch_size, + optimizer_param=optimizer_param, learning_rate_param=learning_rate_param, + init_param=init_param, floating_point_precision=floating_point_precision) + logger.info(f"coordinated linr guest start train") + sub_ctx = ctx.sub_ctx("train") + train_data = train_data.read() + if validate_data is not None: + validate_data = validate_data.read() + module.fit(sub_ctx, train_data, validate_data) + model = module.get_model() + output_model.write(model, metadata={"optimizer_param": optimizer_param, + "learning_rate_param": learning_rate_param}) + + sub_ctx = ctx.sub_ctx("predict") + + predict_df = module.predict(sub_ctx, train_data) + """predict_result = transform_to_predict_result(train_data, predict_score, + data_type="train")""" + predict_result = tools.add_dataset_type(predict_df, consts.TRAIN_SET) + if validate_data is not None: + sub_ctx = ctx.sub_ctx("validate_predict") + predict_df = module.predict(sub_ctx, validate_data) + validate_predict_result = tools.add_dataset_type(predict_df, consts.VALIDATE_SET) + + """validate_predict_result = transform_to_predict_result(validate_data, predict_score, + data_type="validate") + """ + predict_result = DataFrame.vstack([predict_result, validate_predict_result]) + train_output_data.write(predict_result) + + +def train_host(ctx, train_data, validate_data, train_output_data, output_model, epochs, batch_size, + optimizer_param, learning_rate_param, init_param, floating_point_precision, input_model): + if input_model is not None: + logger.info(f"warm start model provided") + model = input_model.read() + module = CoordinatedLinRModuleHost.from_model(model) + module.set_epochs(epochs) + module.set_batch_size(batch_size) + else: + module = CoordinatedLinRModuleHost(epochs=epochs, batch_size=batch_size, + optimizer_param=optimizer_param, learning_rate_param=learning_rate_param, + init_param=init_param, floating_point_precision=floating_point_precision) + logger.info(f"coordinated linr host start train") + sub_ctx = ctx.sub_ctx("train") + + train_data = train_data.read() + if validate_data is not None: + validate_data = validate_data.read() + module.fit(sub_ctx, train_data, validate_data) + model = module.get_model() + output_model.write(model, metadata={"optimizer_param": optimizer_param, + "learning_rate_param": learning_rate_param}) + + sub_ctx = ctx.sub_ctx("predict") + module.predict(sub_ctx, train_data) + if validate_data is not None: + sub_ctx = ctx.sub_ctx("validate_predict") + module.predict(sub_ctx, validate_data) + + +def train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer_param, + learning_rate_param, output_model, input_model): + if input_model is not None: + logger.info(f"warm start model provided") + model = input_model.read() + module = CoordinatedLinRModuleArbiter.from_model(model) + module.set_epochs(epochs) + module.set_batch_size(batch_size) + else: + module = CoordinatedLinRModuleArbiter(epochs=epochs, early_stop=early_stop, tol=tol, batch_size=batch_size, + optimizer_param=optimizer_param, learning_rate_param=learning_rate_param) + logger.info(f"coordinated linr arbiter start train") + + sub_ctx = ctx.sub_ctx("train") + module.fit(sub_ctx) + + model = module.get_model() + output_model.write(model, metadata={"optimizer_param": optimizer_param, + "learning_rate_param": learning_rate_param}) + + +def predict_guest(ctx, input_model, test_data, test_output_data): + sub_ctx = ctx.sub_ctx("predict") + model = input_model.read() + + module = CoordinatedLinRModuleGuest.from_model(model) + test_data = test_data.read() + predict_result = module.predict(sub_ctx, test_data) + predict_result = tools.add_dataset_type(predict_result, consts.TEST_SET) + # predict_result = transform_to_predict_result(test_data, predict_score, data_type="predict") + test_output_data.write(predict_result) + + +def predict_host(ctx, input_model, test_data, test_output_data): + sub_ctx = ctx.sub_ctx("predict") + model = input_model.read() + module = CoordinatedLinRModuleHost.from_model(model) + test_data = test_data.read() + module.predict(sub_ctx, test_data) + + +"""def transform_to_predict_result(test_data, predict_score, data_type="test"): + df = test_data.create_frame(with_label=True, with_weight=False) + pred_res = test_data.create_frame(with_label=False, with_weight=False) + pred_res["predict_result"] = predict_score + df[["predict_result", "predict_score", "predict_detail", "type"]] = pred_res.apply_row(lambda v: [ + v[0], + v[0], + json.dumps({"label": v[0]}), + data_type], enable_type_align_checking=False) + return df""" diff --git a/python/fate/components/components/coordinated_lr.py b/python/fate/components/components/coordinated_lr.py new file mode 100644 index 0000000000..bb685ff26a --- /dev/null +++ b/python/fate/components/components/coordinated_lr.py @@ -0,0 +1,431 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from fate.arch import Context +from fate.arch.dataframe import DataFrame +from fate.components.components.utils import consts, tools +from fate.components.core import ARBITER, GUEST, HOST, Role, cpn, params +from fate.ml.glm import CoordinatedLRModuleGuest, CoordinatedLRModuleHost, CoordinatedLRModuleArbiter + +logger = logging.getLogger(__name__) + + +@cpn.component(roles=[GUEST, HOST, ARBITER], provider="fate") +def coordinated_lr(ctx, role): + ... + + +@coordinated_lr.train() +def train( + ctx: Context, + role: Role, + train_data: cpn.dataframe_input(roles=[GUEST, HOST]), + validate_data: cpn.dataframe_input(roles=[GUEST, HOST], optional=True), + learning_rate_scheduler: cpn.parameter( + type=params.lr_scheduler_param(), + default=params.LRSchedulerParam(method="linear", scheduler_params={"start_factor": 1.0}), + desc="learning rate scheduler, " + "select method from {'step', 'linear', 'constant'}" + "for list of configurable arguments, " + "refer to torch.optim.lr_scheduler", + ), + epochs: cpn.parameter(type=params.conint(gt=0), default=20, desc="max iteration num"), + batch_size: cpn.parameter( + type=params.conint(ge=10), + default=None, desc="batch size, None means full batch, otherwise should be no less than 10, default None" + ), + optimizer: cpn.parameter( + type=params.optimizer_param(), + default=params.OptimizerParam( + method="sgd", penalty="l2", alpha=1.0, optimizer_params={"lr": 1e-2, "weight_decay": 0} + ), + ), + floating_point_precision: cpn.parameter( + type=params.conint(ge=0), + default=23, + desc="floating point precision, "), + tol: cpn.parameter(type=params.confloat(ge=0), default=1e-4), + early_stop: cpn.parameter( + type=params.string_choice(["weight_diff", "diff", "abs"]), + default="diff", + desc="early stopping criterion, choose from {weight_diff, diff, abs, val_metrics}", + ), + he_param: cpn.parameter(type=params.he_param(), default=params.HEParam(kind="paillier", key_length=1024), + desc="homomorphic encryption param"), + init_param: cpn.parameter( + type=params.init_param(), + default=params.InitParam(method="random_uniform", fit_intercept=True, random_state=None), + desc="Model param init setting.", + ), + threshold: cpn.parameter( + type=params.confloat(ge=0.0, le=1.0), default=0.5, desc="predict threshold for binary data" + ), + train_output_data: cpn.dataframe_output(roles=[GUEST, HOST]), + output_model: cpn.json_model_output(roles=[GUEST, HOST, ARBITER]), + warm_start_model: cpn.json_model_input(roles=[GUEST, HOST, ARBITER], optional=True), +): + logger.info(f"enter coordinated lr train") + # temp code start + optimizer = optimizer.dict() + learning_rate_scheduler = learning_rate_scheduler.dict() + init_param = init_param.dict() + ctx.cipher.set_phe(ctx.device, he_param.dict()) + + if role.is_guest: + train_guest( + ctx, + train_data, + validate_data, + train_output_data, + output_model, + epochs, + batch_size, + optimizer, + learning_rate_scheduler, + init_param, + threshold, + floating_point_precision, + warm_start_model + ) + elif role.is_host: + train_host( + ctx, + train_data, + validate_data, + train_output_data, + output_model, + epochs, + batch_size, + optimizer, + learning_rate_scheduler, + init_param, + floating_point_precision, + warm_start_model + ) + elif role.is_arbiter: + train_arbiter(ctx, + epochs, + early_stop, + tol, batch_size, + optimizer, + learning_rate_scheduler, + output_model, + warm_start_model) + + +@coordinated_lr.predict() +def predict( + ctx, + role: Role, + # threshold: cpn.parameter(type=params.confloat(ge=0.0, le=1.0), default=0.5), + test_data: cpn.dataframe_input(roles=[GUEST, HOST]), + input_model: cpn.json_model_input(roles=[GUEST, HOST]), + test_output_data: cpn.dataframe_output(roles=[GUEST, HOST]), +): + if role.is_guest: + predict_guest(ctx, input_model, test_data, test_output_data) + if role.is_host: + predict_host(ctx, input_model, test_data, test_output_data) + + +@coordinated_lr.cross_validation() +def cross_validation( + ctx: Context, + role: Role, + cv_data: cpn.dataframe_input(roles=[GUEST, HOST]), + learning_rate_scheduler: cpn.parameter( + type=params.lr_scheduler_param(), + default=params.LRSchedulerParam(method="linear", scheduler_params={"start_factor": 1.0}), + desc="learning rate scheduler, " + "select method from {'step', 'linear', 'constant'}" + "for list of configurable arguments, " + "refer to torch.optim.lr_scheduler", + ), + epochs: cpn.parameter(type=params.conint(gt=0), default=20, desc="max iteration num"), + batch_size: cpn.parameter( + type=params.conint(ge=10), + default=None, desc="batch size, None means full batch, otherwise should be no less than 10, default None" + ), + optimizer: cpn.parameter( + type=params.optimizer_param(), + default=params.OptimizerParam( + method="sgd", penalty="l2", alpha=1.0, optimizer_params={"lr": 1e-2, "weight_decay": 0} + ), + ), + tol: cpn.parameter(type=params.confloat(ge=0), default=1e-4), + early_stop: cpn.parameter( + type=params.string_choice(["weight_diff", "diff", "abs"]), + default="diff", + desc="early stopping criterion, choose from {weight_diff, diff, abs, val_metrics}", + ), + init_param: cpn.parameter( + type=params.init_param(), + default=params.InitParam(method="random_uniform", fit_intercept=True, random_state=None), + desc="Model param init setting.", + ), + threshold: cpn.parameter( + type=params.confloat(ge=0.0, le=1.0), default=0.5, desc="predict threshold for binary data" + ), + he_param: cpn.parameter(type=params.he_param(), default=params.HEParam(kind="paillier", key_length=1024), + desc="homomorphic encryption param"), + floating_point_precision: cpn.parameter( + type=params.conint(ge=0), + default=23, + desc="floating point precision, "), + cv_param: cpn.parameter(type=params.cv_param(), + default=params.CVParam(n_splits=5, shuffle=False, random_state=None), + desc="cross validation param"), + metrics: cpn.parameter(type=params.metrics_param(), default=["auc"]), + output_cv_data: cpn.parameter(type=bool, default=True, desc="whether output prediction result per cv fold"), + cv_output_datas: cpn.dataframe_outputs(roles=[GUEST, HOST], optional=True), +): + optimizer = optimizer.dict() + learning_rate_scheduler = learning_rate_scheduler.dict() + init_param = init_param.dict() + ctx.cipher.set_phe(ctx.device, he_param.dict()) + + if role.is_arbiter: + i = 0 + for fold_ctx, _ in ctx.on_cross_validations.ctxs_zip(zip(range(cv_param.n_splits))): + logger.info(f"enter fold {i}") + module = CoordinatedLRModuleArbiter( + epochs=epochs, + early_stop=early_stop, + tol=tol, + batch_size=batch_size, + optimizer_param=optimizer, + learning_rate_param=learning_rate_scheduler, + ) + module.fit(fold_ctx) + i += 1 + return + + from fate.arch.dataframe import KFold + kf = KFold(ctx, role=role, n_splits=cv_param.n_splits, shuffle=cv_param.shuffle, random_state=cv_param.random_state) + i = 0 + for fold_ctx, (train_data, validate_data) in ctx.on_cross_validations.ctxs_zip(kf.split(cv_data.read())): + logger.info(f"enter fold {i}") + if role.is_guest: + module = CoordinatedLRModuleGuest( + epochs=epochs, + batch_size=batch_size, + optimizer_param=optimizer, + learning_rate_param=learning_rate_scheduler, + init_param=init_param, + threshold=threshold, + floating_point_precision=floating_point_precision, + ) + module.fit(fold_ctx, train_data, validate_data) + if output_cv_data: + sub_ctx = fold_ctx.sub_ctx("predict_train") + predict_df = module.predict(sub_ctx, train_data) + """train_predict_result = transform_to_predict_result( + train_data, predict_score, module.labels, threshold=module.threshold, is_ovr=module.ovr, + data_type="train" + )""" + train_predict_result = tools.add_dataset_type(predict_df, consts.TRAIN_SET) + sub_ctx = fold_ctx.sub_ctx("predict_validate") + predict_df = module.predict(sub_ctx, validate_data) + """validate_predict_result = transform_to_predict_result( + validate_data, predict_score, module.labels, threshold=module.threshold, is_ovr=module.ovr, + data_type="predict" + )""" + validate_predict_result = tools.add_dataset_type(predict_df, consts.VALIDATE_SET) + predict_result = DataFrame.vstack([train_predict_result, validate_predict_result]) + next(cv_output_datas).write(df=predict_result) + + # evaluation = evaluate(predicted) + elif role.is_host: + module = CoordinatedLRModuleHost( + epochs=epochs, + batch_size=batch_size, + optimizer_param=optimizer, + learning_rate_param=learning_rate_scheduler, + init_param=init_param, + floating_point_precision=floating_point_precision + ) + module.fit(fold_ctx, train_data, validate_data) + if output_cv_data: + sub_ctx = fold_ctx.sub_ctx("predict_train") + module.predict(sub_ctx, train_data) + sub_ctx = fold_ctx.sub_ctx("predict_validate") + module.predict(sub_ctx, validate_data) + i += 1 + + +def train_guest( + ctx, + train_data, + validate_data, + train_output_data, + output_model, + epochs, + batch_size, + optimizer_param, + learning_rate_param, + init_param, + threshold, + floating_point_precision, + input_model +): + if input_model is not None: + logger.info(f"warm start model provided") + model = input_model.read() + module = CoordinatedLRModuleGuest.from_model(model) + module.set_epochs(epochs) + module.set_batch_size(batch_size) + + else: + module = CoordinatedLRModuleGuest( + epochs=epochs, + batch_size=batch_size, + optimizer_param=optimizer_param, + learning_rate_param=learning_rate_param, + init_param=init_param, + threshold=threshold, + floating_point_precision=floating_point_precision + ) + # optimizer = optimizer_factory(optimizer_param) + logger.info(f"coordinated lr guest start train") + sub_ctx = ctx.sub_ctx("train") + train_data = train_data.read() + + if validate_data is not None: + logger.info(f"validate data provided") + validate_data = validate_data.read() + + module.fit(sub_ctx, train_data, validate_data) + model = module.get_model() + output_model.write(model, metadata={}) + + sub_ctx = ctx.sub_ctx("predict") + + predict_df = module.predict(sub_ctx, train_data) + """predict_result = transform_to_predict_result( + train_data, predict_score, module.labels, threshold=module.threshold, is_ovr=module.ovr, data_type="train" + )""" + predict_result = tools.add_dataset_type(predict_df, consts.TRAIN_SET) + if validate_data is not None: + sub_ctx = ctx.sub_ctx("validate_predict") + predict_df = module.predict(sub_ctx, validate_data) + """validate_predict_result = transform_to_predict_result( + validate_data, + predict_score, + module.labels, + threshold=module.threshold, + is_ovr=module.ovr, + data_type="validate", + )""" + validate_predict_result = tools.add_dataset_type(predict_df, consts.VALIDATE_SET) + predict_result = DataFrame.vstack([predict_result, validate_predict_result]) + train_output_data.write(predict_result) + + +def train_host( + ctx, + train_data, + validate_data, + train_output_data, + output_model, + epochs, + batch_size, + optimizer_param, + learning_rate_param, + init_param, + floating_point_precision, + input_model +): + if input_model is not None: + logger.info(f"warm start model provided") + model = input_model.read() + module = CoordinatedLRModuleHost.from_model(model) + module.set_epochs(epochs) + module.set_batch_size(batch_size) + else: + module = CoordinatedLRModuleHost( + epochs=epochs, + batch_size=batch_size, + optimizer_param=optimizer_param, + learning_rate_param=learning_rate_param, + init_param=init_param, + floating_point_precision=floating_point_precision + ) + logger.info(f"coordinated lr host start train") + sub_ctx = ctx.sub_ctx("train") + train_data = train_data.read() + + if validate_data is not None: + logger.info(f"validate data provided") + validate_data = validate_data.read() + + module.fit(sub_ctx, train_data, validate_data) + model = module.get_model() + output_model.write(model, metadata={}) + sub_ctx = ctx.sub_ctx("predict") + module.predict(sub_ctx, train_data) + if validate_data is not None: + sub_ctx = ctx.sub_ctx("validate_predict") + module.predict(sub_ctx, validate_data) + + +def train_arbiter(ctx, epochs, early_stop, tol, batch_size, optimizer_param, learning_rate_scheduler, + output_model, input_model): + if input_model is not None: + logger.info(f"warm start model provided") + model = input_model.read() + module = CoordinatedLRModuleArbiter.from_model(model) + module.set_epochs(epochs) + module.set_batch_size(batch_size) + else: + module = CoordinatedLRModuleArbiter( + epochs=epochs, + early_stop=early_stop, + tol=tol, + batch_size=batch_size, + optimizer_param=optimizer_param, + learning_rate_param=learning_rate_scheduler, + ) + logger.info(f"coordinated lr arbiter start train") + sub_ctx = ctx.sub_ctx("train") + module.fit(sub_ctx) + model = module.get_model() + output_model.write(model, metadata={}) + + +def predict_guest(ctx, input_model, test_data, test_output_data): + logger.info(f"coordinated lr guest start predict") + sub_ctx = ctx.sub_ctx("predict") + model = input_model.read() + module = CoordinatedLRModuleGuest.from_model(model) + # if module.threshold != 0.5: + # module.threshold = threshold + test_data = test_data.read() + predict_df = module.predict(sub_ctx, test_data) + """predict_result = transform_to_predict_result( + test_data, predict_score, module.labels, threshold=module.threshold, is_ovr=module.ovr, data_type="test" + )""" + predict_result = tools.add_dataset_type(predict_df, consts.TEST_SET) + test_output_data.write(predict_result) + + +def predict_host(ctx, input_model, test_data, test_output_data): + logger.info(f"coordinated lr host start predict") + sub_ctx = ctx.sub_ctx("predict") + model = input_model.read() + module = CoordinatedLRModuleHost.from_model(model) + test_data = test_data.read() + module.predict(sub_ctx, test_data) diff --git a/python/fate/components/components/cross_validation_test.py b/python/fate/components/components/cross_validation_test.py new file mode 100644 index 0000000000..032d9001e2 --- /dev/null +++ b/python/fate/components/components/cross_validation_test.py @@ -0,0 +1,67 @@ +import logging +import typing + +from fate.components.core import ARBITER, GUEST, HOST, Role, cpn, params + +if typing.TYPE_CHECKING: + from fate.arch import Context + +logger = logging.getLogger(__name__) + + +@cpn.component(roles=[GUEST, HOST, ARBITER]) +def cv_test(ctx, role): + ... + + +@cv_test.cross_validation() +def cv( + ctx: "Context", + role: Role, + num_fold: cpn.parameter(type=params.conint(gt=1), desc="parameter", optional=False), + dataframe_input: cpn.dataframe_input(roles=[GUEST, HOST]), + json_model_outputs: cpn.json_model_outputs(roles=[GUEST, HOST]), + dataframe_outputs: cpn.dataframe_outputs(roles=[GUEST, HOST]), +): + # split data + for cv_ctx, (train_data, validata_data) in ctx.on_cross_validations.ctxs_zip( + split_data(dataframe_input.read(), num_fold) + ): + # train model + model = FakeTrainer() + model.fit(cv_ctx, train_data=train_data) + # predict model + predict_result = model.predict(cv_ctx, validata_data=validata_data) + # evaluate model + evaluation_result = fake_evaluation(cv_ctx, predict_result=predict_result) + next(json_model_outputs).write(data=model.get_model(), metadata=evaluation_result) + cv_ctx.metrics.log_auc("fold_auc", evaluation_result["auc"]) + cv_ctx.metrics.log_roc("fold_roc", [0.1, 0.2, 0.3, 0.4, 0.5]) + next(dataframe_outputs).write(df=predict_result) + + +def split_data(dataframe_input, num_fold): + """fake split data""" + for i in range(num_fold): + yield dataframe_input, dataframe_input + + +class FakeTrainer: + def __init__(self, **kwargs): + self.model = { + "data": "fake_model", + } + + def fit(self, ctx, train_data, **kwargs): + for i, sub_ctx in ctx.on_iterations.ctxs_range(5): + sub_ctx.metrics.log_auc("auc", i / 5) + + def predict(self, ctx, validata_data, **kwargs): + return validata_data + + def get_model(self): + return self.model + + +def fake_evaluation(ctx, **kwargs): + return {"auc": 0.9} diff --git a/python/fate/components/components/data_split.py b/python/fate/components/components/data_split.py new file mode 100644 index 0000000000..37ab025990 --- /dev/null +++ b/python/fate/components/components/data_split.py @@ -0,0 +1,79 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Union + +from fate.arch import Context +from fate.components.core import GUEST, HOST, Role, cpn, params +from fate.ml.model_selection.data_split import DataSplitModuleGuest, DataSplitModuleHost + +logger = logging.getLogger(__name__) + + +@cpn.component(roles=[GUEST, HOST], provider="fate") +def data_split( + ctx: Context, + role: Role, + input_data: cpn.dataframe_input(roles=[GUEST, HOST]), + train_size: cpn.parameter(type=Union[params.confloat(ge=0.0, le=1.0), params.conint(ge=0)], default=None, + desc="size of output training data, " + "should be either int for exact sample size or float for fraction"), + validate_size: cpn.parameter(type=Union[params.confloat(ge=0.0, le=1.0), params.conint(ge=0)], default=None, + desc="size of output validation data, " + "should be either int for exact sample size or float for fraction"), + test_size: cpn.parameter(type=Union[params.confloat(ge=0.0, le=1.0), params.conint(ge=0)], default=None, + desc="size of output test data, " + "should be either int for exact sample size or float for fraction"), + stratified: cpn.parameter(type=bool, default=False, + desc="whether sample with stratification, " + "should not use this for data with continuous label values"), + random_state: cpn.parameter(type=params.conint(ge=0), default=None, desc="random state"), + hetero_sync: cpn.parameter(type=bool, default=True, + desc="whether guest sync data set sids with host, " + "default True for hetero scenario, " + "should set to False for local and homo scenario"), + train_output_data: cpn.dataframe_output(roles=[GUEST, HOST], optional=True), + validate_output_data: cpn.dataframe_output(roles=[GUEST, HOST], optional=True), + test_output_data: cpn.dataframe_output(roles=[GUEST, HOST], optional=True), +): + if train_size is None and validate_size is None and test_size is None: + train_size = 0.8 + validate_size = 0.2 + test_size = 0.0 + + # logger.info(f"in cpn received train_size: {train_size}, validate_size: {validate_size}, test_size: {test_size}") + # check if local but federated sample + if hetero_sync and len(ctx.parties.ranks) < 2: + raise ValueError(f"federated sample can only be called when both 'guest' and 'host' present. Please check") + + sub_ctx = ctx.sub_ctx("train") + if role.is_guest: + module = DataSplitModuleGuest(train_size, validate_size, test_size, stratified, random_state, hetero_sync) + elif role.is_host: + module = DataSplitModuleHost(train_size, validate_size, test_size, stratified, random_state, hetero_sync) + input_data = input_data.read() + + train_data_set, validate_data_set, test_data_set = module.fit(sub_ctx, input_data) + # train_data_set, validate_data_set, test_data_set = module.split_data(train_data) + logger.info(f"output train size: {train_data_set.shape if train_data_set else None}, " + f"validate size: {validate_data_set.shape if validate_data_set else None}," + f"test size: {test_data_set.shape if test_data_set else None}") + if train_data_set: + train_output_data.write(train_data_set) + if validate_data_set: + validate_output_data.write(validate_data_set) + if test_data_set: + test_output_data.write(test_data_set) diff --git a/python/fate/components/components/dataframe_io_test.py b/python/fate/components/components/dataframe_io_test.py new file mode 100644 index 0000000000..59e69bbc14 --- /dev/null +++ b/python/fate/components/components/dataframe_io_test.py @@ -0,0 +1,59 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from fate.components.core import LOCAL, Role, cpn + + +@cpn.component(roles=[LOCAL]) +def dataframe_io_test( + ctx, + role: Role, + dataframe_input: cpn.dataframe_input(roles=[LOCAL]), + dataframe_output: cpn.dataframe_output(roles=[LOCAL]), + dataframe_inputs: cpn.dataframe_inputs(roles=[LOCAL]), + dataframe_outputs: cpn.dataframe_outputs(roles=[LOCAL]), + json_model_output: cpn.json_model_output(roles=[LOCAL]), + json_model_outputs: cpn.json_model_outputs(roles=[LOCAL]), + model_directory_output: cpn.model_directory_output(roles=[LOCAL]), + model_directory_outputs: cpn.model_directory_outputs(roles=[LOCAL]), +): + df = dataframe_input.read() + df = df + 1 + df_list = [_input.read() for _input in dataframe_inputs] + dataframe_output.write(df) + + assert len(df_list) == 4 + for i in range(10): + output = next(dataframe_outputs) + output.write(df_list[i % 4]) + + json_model_output.write({"aaa": 1}, metadata={"bbb": 2}) + for i in range(8): + model_output = next(json_model_outputs) + model_output.write({"io_model_out": i}, metadata={"i-th output": i}) + + path = model_directory_output.get_directory() + with open(path + "/output.txt", "w") as fw: + fw.write("xxx\n") + + model_directory_output.write_metadata({"model_directory_output_metadata": "test_directory"}) + + for i in range(5): + directory_output = next(model_directory_outputs) + path = directory_output.get_directory() + with open(path + f"/output_{i}.txt", "w") as fw: + fw.write("test for model directory output\n") + + directory_output.write_metadata({f"model_directory_output_{i}_metadata": f"test_directory_{i}"}) diff --git a/python/fate/components/components/dataframe_transformer.py b/python/fate/components/components/dataframe_transformer.py new file mode 100644 index 0000000000..8bafd30fb3 --- /dev/null +++ b/python/fate/components/components/dataframe_transformer.py @@ -0,0 +1,53 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from fate.components.core import LOCAL, Role, cpn + + +@cpn.component(roles=[LOCAL]) +def dataframe_transformer( + ctx, + role: Role, + table: cpn.table_input(roles=[LOCAL]), + dataframe_output: cpn.dataframe_output(roles=[LOCAL]), + namespace: cpn.parameter(type=str, default=None, optional=True), + name: cpn.parameter(type=str, default=None, optional=True), + site_name: cpn.parameter(type=str, default=None, optional=True), +): + from fate.arch.dataframe import TableReader + + table = table.read() + metadata = table.schema + table_reader = TableReader( + sample_id_name=metadata.get("sample_id_name", None), + match_id_name=metadata.get("match_id_name", None), + match_id_list=metadata.get("match_id_list", None), + match_id_range=metadata.get("match_id_range", 1), + label_name=metadata.get("label_name", None), + label_type=metadata.get("label_type", "int32"), + weight_name=metadata.get("weight_name", None), + weight_type=metadata.get("weight_type", "float32"), + header=metadata.get("header", None), + na_values=metadata.get("na_values", None), + dtype=metadata.get("dtype", "float32"), + anonymous_site_name=site_name, + delimiter=metadata.get("delimiter", ","), + input_format=metadata.get("input_format", "dense"), + tag_with_value=metadata.get("tag_with_value", False), + tag_value_delimiter=metadata.get("tag_value_delimiter", ":"), + ) + + df = table_reader.to_frame(ctx, table) + dataframe_output.write(df, name=name, namespace=namespace) diff --git a/python/fate/components/components/evaluation.py b/python/fate/components/components/evaluation.py index 8a2437a861..252e41beb7 100644 --- a/python/fate/components/components/evaluation.py +++ b/python/fate/components/components/evaluation.py @@ -12,34 +12,99 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from fate.components import ( - ARBITER, - GUEST, - HOST, - ClassificationMetrics, - DatasetArtifact, - Input, - Output, - Role, - cpn, +import logging +from typing import Dict +from fate.arch import Context +import numpy as np +import pandas as pd +from fate.arch import Context +from fate.components.core import ARBITER, GUEST, HOST, Role, cpn +from fate.components.core.params import string_choice +from fate.ml.evaluation.tool import ( + get_binary_metrics, + get_multi_metrics, + get_regression_metrics, + get_specified_metrics, ) -from fate.ml.evaluation import BinaryEvaluator +from fate.ml.utils.predict_tools import PREDICT_SCORE, PREDICT_RESULT, LABEL +from fate.components.components.utils.consts import BINARY, REGRESSION, MULTI +logger = logging.getLogger(__name__) -@cpn.component(roles=[GUEST, HOST, ARBITER]) -@cpn.artifact("input_data", type=Input[DatasetArtifact], roles=[GUEST, HOST, ARBITER]) -@cpn.parameter("eval_type", type=str, default="binary", optional=True) -@cpn.artifact("output_metric", type=Output[ClassificationMetrics], roles=[GUEST, HOST, ARBITER]) -def evaluation(ctx, role: Role, input_data, eval_type, output_metric): - evaluate(ctx, input_data, eval_type, output_metric) +def split_dataframe_by_type(input_df: pd.DataFrame) -> Dict[str, pd.DataFrame]: -def evaluate(ctx, input_data, eval_type, output_metric): - data = ctx.reader(input_data).read_dataframe().data - y_true = data.label.tolist() - y_pred = data.predict_score.values.tolist() + if "type" in input_df.columns: + return {dataset_type: input_df[input_df["type"] == dataset_type] for dataset_type in input_df["type"].unique()} + else: + return {"origin": input_df} - if eval_type == "binary": - ctx.metrics.handler.register_metrics(auc=ctx.writer(output_metric)) - evaluator = BinaryEvaluator() - evaluator.fit(ctx, y_true, y_pred) + +@cpn.component(roles=[GUEST, HOST]) +def evaluation( + ctx: Context, + role: Role, + input_data: cpn.dataframe_inputs(roles=[GUEST, HOST]), + default_eval_setting: cpn.parameter( + type=string_choice(choice=["binary", "multi", "regression"]), default="binary", optional=True + ), + metrics: cpn.parameter(type=list, default=None, optional=True), + predict_column_name: cpn.parameter(type=str, default=None, optional=True, + desc="predict data column name, if None(default), will use \ + 'predict_score' in the input dataframe when the default setting are binary and regression, \ + and use 'predict_result' if default setting is multi"), + label_column_name: cpn.parameter(type=str, default=None, optional=True, desc="label data column namem if None(default), \ + will use 'label' in the input dataframe") +): + + if role.is_arbiter: + return + else: + + if metrics is not None: + metrics_ensemble = get_specified_metrics(metrics) + predict_col = predict_column_name if predict_column_name is not None else PREDICT_SCORE + label_col = label_column_name if label_column_name is not None else LABEL + else: + if default_eval_setting == MULTI: + metrics_ensemble = get_multi_metrics() + predict_col = predict_column_name if predict_column_name is not None else PREDICT_RESULT + label_col = label_column_name if label_column_name is not None else LABEL + else: + if default_eval_setting == BINARY: + metrics_ensemble = get_binary_metrics() + elif default_eval_setting == REGRESSION: + metrics_ensemble = get_regression_metrics() + else: + raise ValueError("default_eval_setting should be one of binary, multi, regression, got {}") + predict_col = predict_column_name if predict_column_name is not None else PREDICT_SCORE + label_col = label_column_name if label_column_name is not None else LABEL + + df_list = [_input.read() for _input in input_data] + task_names = [_input.artifact.metadata.source.task_name for _input in input_data] + eval_rs = {} + logger.info('components names are {}'.format(task_names)) + for name, df in zip(task_names, df_list): + rs_ = evaluate(df, metrics_ensemble, predict_col, label_col) + eval_rs[name] = rs_ + + ctx.metrics.log_metrics(eval_rs, name='evaluation', type='evaluation') + logger.info("eval result: {}".format(eval_rs)) + + +def evaluate(input_data, metrics, predict_col, label_col): + + data = input_data.as_pd_df() + split_dict = split_dataframe_by_type(data) + rs_dict = {} + logger.info('eval dataframe is {}'.format(data)) + for name, df in split_dict.items(): + + logger.info('eval dataframe is \n\n{}'.format(df)) + y_true = df[label_col] + # in case is multi result, use tolist + y_pred = df[predict_col] + rs = metrics(predict=y_pred, label=y_true) + rs_dict[name] = rs + + return rs_dict diff --git a/python/fate/components/components/feature_scale.py b/python/fate/components/components/feature_scale.py index 87542a9e29..0dd4fa34c5 100644 --- a/python/fate/components/components/feature_scale.py +++ b/python/fate/components/components/feature_scale.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The FATE Authors. All Rights Reserved. +# Copyright 2023 The FATE Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,77 +12,160 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from fate.components import ( - GUEST, - HOST, - DatasetArtifact, - Input, - ModelArtifact, - Output, - Role, - cpn, -) - - -@cpn.component(roles=[GUEST, HOST]) + +import logging +from typing import List, Union + +from fate.arch import Context +from fate.components.core import GUEST, HOST, Role, cpn, params + +logger = logging.getLogger(__name__) + + +@cpn.component(roles=[GUEST, HOST], provider="fate") def feature_scale(ctx, role): ... @feature_scale.train() -@cpn.artifact("train_data", type=Input[DatasetArtifact], roles=[GUEST, HOST]) -@cpn.parameter("method", type=str, default="standard", optional=False) -@cpn.artifact("train_output_data", type=Output[DatasetArtifact], roles=[GUEST, HOST]) -@cpn.artifact("output_model", type=Output[ModelArtifact], roles=[GUEST, HOST]) def feature_scale_train( - ctx, - role: Role, - train_data, - method, - train_output_data, - output_model, + ctx: Context, + role: Role, + train_data: cpn.dataframe_input(roles=[GUEST, HOST]), + method: cpn.parameter(type=params.string_choice(["standard", "min_max"]), default="standard", optional=False), + feature_range: cpn.parameter( + type=Union[list, dict], + default=[0, 1], + optional=True, + desc="Result feature value range for `min_max` method, " + "take either dict in format: {col_name: [min, max]} for specific columns " + "or [min, max] for all columns. Columns unspecified will be scaled to default range [0,1]", + ), + scale_col: cpn.parameter( + type=List[str], + default=None, + optional=True, + desc="list of column names to be scaled, if None, all columns will be scaled; " + "only one of {scale_col, scale_idx} should be specified", + ), + scale_idx: cpn.parameter( + type=List[params.conint(ge=0)], + default=None, + optional=True, + desc="list of column index to be scaled, if None, all columns will be scaled; " + "only one of {scale_col, scale_idx} should be specified", + ), + strict_range: cpn.parameter( + type=bool, + default=True, + desc="whether transformed value to be strictly restricted within given range; " + "effective for 'min_max' scale method only", + ), + use_anonymous: cpn.parameter( + type=bool, default=False, desc="bool, whether interpret `scale_col` as anonymous column names" + ), + train_output_data: cpn.dataframe_output(roles=[GUEST, HOST]), + output_model: cpn.json_model_output(roles=[GUEST, HOST]), ): - train(ctx, train_data, train_output_data, output_model, method) + train( + ctx, + train_data, + train_output_data, + output_model, + method, + feature_range, + scale_col, + scale_idx, + strict_range, + use_anonymous, + ) @feature_scale.predict() -@cpn.artifact("input_model", type=Input[ModelArtifact], roles=[GUEST, HOST]) -@cpn.artifact("test_data", type=Input[DatasetArtifact], optional=False, roles=[GUEST, HOST]) -@cpn.artifact("test_output_data", type=Output[DatasetArtifact], roles=[GUEST, HOST]) def feature_scale_predict( - ctx, - role: Role, - test_data, - input_model, - test_output_data, + ctx: Context, + role: Role, + test_data: cpn.dataframe_input(roles=[GUEST, HOST]), + input_model: cpn.json_model_input(roles=[GUEST, HOST]), + test_output_data: cpn.dataframe_output(roles=[GUEST, HOST]), ): predict(ctx, input_model, test_data, test_output_data) -def train(ctx, train_data, train_output_data, output_model, method): - from fate.ml.feature_scale import FeatureScale +def train( + ctx, + train_data, + train_output_data, + output_model, + method, + feature_range, + scale_col, + scale_idx, + strict_range, + use_anonymous, +): + logger.info(f"start scale train") + from fate.ml.preprocessing import FeatureScale + + train_data = train_data.read() + + sub_ctx = ctx.sub_ctx("train") + columns = train_data.schema.columns.to_list() + anonymous_columns = None + if use_anonymous: + anonymous_columns = train_data.schema.anonymous_columns.to_list() + if method != "min_max": + feature_range = None + scale_col, feature_range = get_to_scale_cols(columns, anonymous_columns, scale_col, scale_idx, feature_range) - scaler = FeatureScale(method) - with ctx.sub_ctx("train") as sub_ctx: - train_data = sub_ctx.reader(train_data).read_dataframe().data - scaler.fit(sub_ctx, train_data) + scaler = FeatureScale(method, scale_col, feature_range, strict_range) + scaler.fit(sub_ctx, train_data) - model = scaler.to_model() - with output_model as model_writer: - model_writer.write_model("feature_scale", model, metadata={}) + model = scaler.get_model() + output_model.write(model, metadata={}) - with ctx.sub_ctx("predict") as sub_ctx: - output_data = scaler.transform(sub_ctx, train_data) - sub_ctx.writer(train_output_data).write_dataframe(output_data) + sub_ctx = ctx.sub_ctx("predict") + output_data = scaler.transform(sub_ctx, train_data) + train_output_data.write(output_data) def predict(ctx, input_model, test_data, test_output_data): - from fate.ml.feature_scale import FeatureScale - - with ctx.sub_ctx("predict") as sub_ctx: - with input_model as model_reader: - model = model_reader.read_model() - scaler = FeatureScale.from_model(model) - test_data = sub_ctx.reader(test_data).read_dataframe().data - output_data = scaler.transform(sub_ctx, test_data) - sub_ctx.writer(test_output_data).write_dataframe(output_data) + logger.info(f"start scale transform") + + from fate.ml.preprocessing import FeatureScale + + sub_ctx = ctx.sub_ctx("predict") + model = input_model.read() + scaler = FeatureScale.from_model(model) + test_data = test_data.read() + output_data = scaler.transform(sub_ctx, test_data) + test_output_data.write(output_data) + + +def get_to_scale_cols(columns, anonymous_columns, scale_col, scale_idx, feature_range): + if anonymous_columns is not None: + scale_col = [columns[anonymous_columns.index(col)] for col in scale_col] + + if scale_col is not None: + if scale_idx is not None: + raise ValueError(f"`scale_col` and `scale_idx` cannot be specified simultaneously, please check.") + select_col = scale_col + elif scale_idx is not None: + select_col = [columns[i] for i in scale_idx] + else: + select_col = columns + col_set = set(columns) + if not all(col in col_set for col in select_col): + raise ValueError(f"Given scale columns not found in data schema, please check.") + + if feature_range is not None: + if isinstance(feature_range, dict): + for col in select_col: + if col not in feature_range: + feature_range[col] = [0, 1] + else: + feature_range = {col: feature_range for col in select_col} + if len(select_col) == 0: + logger.warning(f"No cols provided. " + f"To scale all columns, please set `scale_col` to None.") + return select_col, feature_range diff --git a/python/fate/components/components/hetero_feature_binning.py b/python/fate/components/components/hetero_feature_binning.py new file mode 100644 index 0000000000..a29f437205 --- /dev/null +++ b/python/fate/components/components/hetero_feature_binning.py @@ -0,0 +1,210 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import List + +from fate.arch import Context +from fate.components.core import GUEST, HOST, Role, cpn, params +from fate.ml.feature_binning import HeteroBinningModuleHost, HeteroBinningModuleGuest + +logger = logging.getLogger(__name__) + + +@cpn.component(roles=[GUEST, HOST], provider="fate") +def hetero_feature_binning(ctx, role): + ... + + +""" +@cpn.parameter("bins", type=dict, default=[], + desc="dict of format {col_name: [bins]} which specifies bin edges for each feature, " + "including right edge of last bin") +""" + + +@hetero_feature_binning.train() +def feature_binning_train( + ctx: Context, + role: Role, + train_data: cpn.dataframe_input(roles=[GUEST, HOST]), + method: cpn.parameter(type=params.string_choice(["quantile", "bucket", "manual"]), + default="quantile", optional=False, + desc="binning method, options: {quantile, bucket, manual}"), + n_bins: cpn.parameter(type=params.conint(gt=1), default=10, + desc="max number of bins, should be no less than 2"), + split_pt_dict: cpn.parameter(type=dict, default=None, optional=True, + desc="dict, manually provided split points, " + "only effective when `method`='manual'"), + bin_col: cpn.parameter(type=List[str], default=None, + desc="list of column names to be binned, if None, all columns will be binned; " + "only one of {bin_col, bin_idx} should be specified"), + bin_idx: cpn.parameter(type=List[params.conint(ge=0)], default=None, + desc="list of column index to be binned, if None, all columns will be binned; " + "only one of {bin_col, bin_idx} should be specified"), + category_col: cpn.parameter(type=List[str], default=None, + desc="list of column names to be treated as categorical " + "features and will not be binned; " + "only one of {category_col, category_idx} should be specified" + "note that metrics will be computed over categorical features " + "if this param is specified"), + category_idx: cpn.parameter(type=List[params.conint(ge=0)], default=None, + desc="list of column index to be treated as categorical features " + "and will not be binned; " + "only one of {category_col, category_idx} should be specified" + "note that metrics will be computed over categorical features " + "if this param is specified"), + use_anonymous: cpn.parameter(type=bool, default=False, + desc="bool, whether interpret `bin_col` & `category_col` " + "as anonymous column names"), + transform_method: cpn.parameter(type=params.string_choice(['woe', 'bin_idx']), + default=None, # may support user-provided dict in future release + desc="str, values to which binned data will be transformed, " + "select from {'woe', 'bin_idx'}; " + "note that host will not transform features " + "to woe values regardless of setting"), + skip_metrics: cpn.parameter(type=bool, default=False, + desc="bool, whether compute host's metrics or not"), + local_only: cpn.parameter(type=bool, default=False, desc="bool, whether compute host's metrics or not"), + relative_error: cpn.parameter(type=params.confloat(gt=0, le=1), default=1e-6, + desc="float, error rate for quantile"), + adjustment_factor: cpn.parameter(type=params.confloat(gt=0), default=0.5, + desc="float, useful when here is no event or non-event in a bin"), + he_param: cpn.parameter(type=params.he_param(), default=params.HEParam(kind="paillier", key_length=1024), + desc="homomorphic encryption param"), + train_output_data: cpn.dataframe_output(roles=[GUEST, HOST]), + output_model: cpn.json_model_output(roles=[GUEST, HOST]), +): + ctx.cipher.set_phe(ctx.device, he_param.dict()) + train(ctx, train_data, train_output_data, output_model, role, method, n_bins, split_pt_dict, + bin_col, bin_idx, category_col, category_idx, use_anonymous, transform_method, + skip_metrics, local_only, relative_error, adjustment_factor) + + +@hetero_feature_binning.predict() +def feature_binning_predict( + ctx: Context, + role: Role, + test_data: cpn.dataframe_input(roles=[GUEST, HOST]), + input_model: cpn.json_model_input(roles=[GUEST, HOST]), + transform_method: cpn.parameter(type=params.string_choice(['woe', 'bin_idx']), + default=None, # may support user-provided dict in future release + desc="str, values to which binned data will be transformed, " + "select from {'woe', 'bin_idx'}; " + "note that host will not transform features " + "to woe values regardless of setting"), + skip_metrics: cpn.parameter(type=bool, default=False, + desc="bool, whether compute host's metrics or not"), + test_output_data: cpn.dataframe_output(roles=[GUEST, HOST]), +): + predict(ctx, input_model, test_data, test_output_data, role, transform_method, skip_metrics) + + +def train(ctx, train_data, train_output_data, output_model, role, method, n_bins, split_pt_dict, + bin_col, bin_idx, category_col, category_idx, use_anonymous, transform_method, + skip_metrics, local_only, relative_error, adjustment_factor): + logger.info(f"start binning train") + sub_ctx = ctx.sub_ctx("train") + train_data = train_data.read() + columns = train_data.schema.columns.to_list() + anonymous_columns = None + if use_anonymous: + anonymous_columns = train_data.schema.anonymous_columns.to_list() + split_pt_dict = {columns[anonymous_columns.index(col)]: split_pt_dict[col] for col in split_pt_dict.keys()} + to_bin_cols, merged_category_col = get_to_bin_cols(columns, anonymous_columns, + bin_col, bin_idx, category_col, category_idx) + if split_pt_dict: + to_bin_cols = list(set(to_bin_cols).intersection(split_pt_dict.keys())) + + if role.is_guest: + binning = HeteroBinningModuleGuest(method, n_bins, split_pt_dict, to_bin_cols, transform_method, + merged_category_col, local_only, relative_error, adjustment_factor) + elif role.is_host: + binning = HeteroBinningModuleHost(method, n_bins, split_pt_dict, to_bin_cols, transform_method, + merged_category_col, local_only, relative_error, adjustment_factor) + else: + raise ValueError(f"unknown role: {role}") + binning.fit(sub_ctx, train_data) + binned_data = None + if not skip_metrics: + binned_data = binning._bin_obj.bucketize_data(train_data) + binning.compute_metrics(sub_ctx, binned_data) + model = binning.get_model() + output_model.write(model) + + sub_ctx = ctx.sub_ctx("predict") + output_data = train_data + if transform_method is not None: + if binned_data is None: + binned_data = binning._bin_obj.bucketize_data(train_data) + output_data = binning.transform(sub_ctx, binned_data) + train_output_data.write(output_data) + + +def predict(ctx, input_model, test_data, test_output_data, role, transform_method, skip_metrics): + sub_ctx = ctx.sub_ctx("predict") + model = input_model.read() + if role.is_guest: + binning = HeteroBinningModuleGuest.from_model(model) + elif role.is_host: + binning = HeteroBinningModuleHost.from_model(model) + # model_meta = model["meta_data"] + else: + raise ValueError(f"unknown role: {role}") + + binning.set_transform_method(transform_method) + test_data = test_data.read() + if skip_metrics and transform_method is None: + return test_data + binned_data = binning._bin_obj.bucketize_data(test_data) + if not skip_metrics: + binning.compute_metrics(sub_ctx, binned_data) + output_data = test_data + if transform_method is not None: + output_data = binning.transform(sub_ctx, binned_data) + test_output_data.write(output_data) + + +def get_to_bin_cols(columns, anonymous_columns, bin_col, bin_idx, category_col, category_idx): + if anonymous_columns is not None: + if bin_col is not None: + bin_col = [columns[anonymous_columns.index(col)] for col in bin_col] + if category_col is not None: + category_col = [columns[anonymous_columns.index(col)] for col in category_col] + + if bin_col is not None: + if bin_idx is not None: + raise ValueError(f"`bin_col` and `bin_idx` cannot be specified simultaneously, please check.") + select_col = bin_col + elif bin_idx is not None: + select_col = [columns[i] for i in bin_idx] + else: + select_col = columns + col_set = set(columns) + if not all(col in col_set for col in select_col): + raise ValueError(f"Given bin columns not found in data schema, please check.") + + if category_col is not None: + if category_idx is not None: + raise ValueError(f"`category_col` and `category_idx` cannot be specified simultaneously, please check.") + elif category_idx is not None: + category_col = [columns[i] for i in category_idx] + else: + return select_col, [] + if not all(col in col_set for col in category_col): + raise ValueError(f"Given category columns not found in data schema, please check.") + category_col_set = set(category_col) + select_col = [col for col in select_col if col not in category_col_set] + return select_col, category_col diff --git a/python/fate/components/components/hetero_feature_selection.py b/python/fate/components/components/hetero_feature_selection.py new file mode 100644 index 0000000000..16cb430558 --- /dev/null +++ b/python/fate/components/components/hetero_feature_selection.py @@ -0,0 +1,172 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import List + +from fate.arch import Context +from fate.components.core import GUEST, HOST, Role, cpn, params + +logger = logging.getLogger(__name__) + + +@cpn.component(roles=[GUEST, HOST], provider="fate") +def hetero_feature_selection(ctx, role): + ... + + +@hetero_feature_selection.train() +def train( + ctx: Context, + role: Role, + train_data: cpn.dataframe_input(roles=[GUEST, HOST]), + input_models: cpn.json_model_inputs(roles=[GUEST, HOST], optional=True), + method: cpn.parameter( + type=List[params.string_choice(["manual", "iv", "statistics"])], + default=["manual"], + optional=False, + desc="selection method, options: {manual, binning, statistics}", + ), + select_col: cpn.parameter( + type=List[str], + default=None, + desc="list of column names to be selected, if None, all columns will be considered", + ), + iv_param: cpn.parameter( + type=params.iv_filter_param(), + default=params.IVFilterParam( + metrics="iv", + take_high=True, + threshold=1, + filter_type="threshold", + host_thresholds=1, + host_take_high=True, + select_federated=True, + ), + desc="iv filter param", + ), + statistic_param: cpn.parameter( + type=params.statistic_filter_param(), + default=params.StatisticFilterParam(metrics="mean", threshold=1, filter_type="threshold", take_high=True), + desc="statistic filter param", + ), + manual_param: cpn.parameter( + type=params.manual_filter_param(), + default=params.ManualFilterParam(filter_out_col=[], keep_col=[]), + desc="manual filter param", + ), + keep_one: cpn.parameter(type=bool, default=True, desc="whether to keep at least one feature among `select_col`"), + use_anonymous: cpn.parameter( + type=bool, + default=False, + desc="bool, whether interpret `select_col` & `filter_out_col` & `keep_col` " "as anonymous column names", + ), + train_output_data: cpn.dataframe_output(roles=[GUEST, HOST]), + train_output_model: cpn.json_model_output(roles=[GUEST, HOST]), +): + from fate.ml.feature_selection import ( + HeteroSelectionModuleGuest, + HeteroSelectionModuleHost, + ) + + logger.info(f"start selection train") + + sub_ctx = ctx.sub_ctx("train") + + train_data = train_data.read() + columns = train_data.schema.columns.to_list() + if use_anonymous: + logger.debug(f"use anonymous columns") + anonymous_columns = train_data.schema.anonymous_columns.to_list() + if select_col is not None: + select_col = [columns[anonymous_columns.index(col)] for col in select_col] + if manual_param.filter_out_col is not None: + filter_out_col = [columns[anonymous_columns.index(col)] for col in manual_param.filter_out_col] + manual_param.filter_out_col = filter_out_col + if manual_param.keep_col is not None: + keep_col = [columns[anonymous_columns.index(col)] for col in manual_param.keep_col] + manual_param.keep_col = keep_col + iv_param = iv_param.dict() + statistic_param = statistic_param.dict() + manual_param = manual_param.dict() + # logger.info(f"input_models: {input_models}, len: {len(input_models)}") + + input_iso_models = [model.read() for model in input_models] if input_models is not None else None + # logger.info(f"read in input_models len: {len(input_iso_models)}; \n read in input models: {input_iso_models}") + if role.is_guest: + selection = HeteroSelectionModuleGuest( + method=method, + select_col=select_col, + input_models=input_iso_models, + iv_param=iv_param, + statistic_param=statistic_param, + manual_param=manual_param, + keep_one=keep_one, + ) + + elif role.is_host: + selection = HeteroSelectionModuleHost( + method=method, + select_col=select_col, + input_models=input_iso_models, + iv_param=iv_param, + statistic_param=statistic_param, + manual_param=manual_param, + keep_one=keep_one, + ) + else: + raise ValueError(f"role: {role} is not valid") + selection.fit(sub_ctx, train_data) + model = selection.get_model() + train_output_model.write(model, metadata={}) + + sub_ctx = ctx.sub_ctx("predict") + output_data = train_data + if method is not None: + output_data = selection.transform(sub_ctx, train_data) + # logger.info(f"output_data schema columns: {output_data.schema.columns}; " + # f"anonymous columns: {output_data.schema.anonymous_columns}") + train_output_data.write(output_data) + + +@hetero_feature_selection.predict() +def predict( + ctx: Context, + role: Role, + test_data: cpn.dataframe_input(roles=[GUEST, HOST]), + input_model: cpn.json_model_input(roles=[GUEST, HOST]), + test_output_data: cpn.dataframe_output(roles=[GUEST, HOST]), +): + from fate.ml.feature_selection import ( + HeteroSelectionModuleGuest, + HeteroSelectionModuleHost, + ) + + logger.info(f"start selection predict") + sub_ctx = ctx.sub_ctx("predict") + model = input_model.read() + if role.is_guest: + selection = HeteroSelectionModuleGuest.from_model(model) + elif role.is_host: + selection = HeteroSelectionModuleHost.from_model(model) + else: + raise ValueError(f"role: {role} is not valid") + + test_data = test_data.read() + + output_data = test_data + if selection.method is not None: + output_data = selection.transform(sub_ctx, test_data) + test_output_data.write(output_data) diff --git a/python/fate/components/components/hetero_lr.py b/python/fate/components/components/hetero_lr.py deleted file mode 100644 index e3f233f6b8..0000000000 --- a/python/fate/components/components/hetero_lr.py +++ /dev/null @@ -1,154 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from fate.components import ( - ARBITER, - GUEST, - HOST, - DatasetArtifact, - Input, - LossMetrics, - ModelArtifact, - Output, - Role, - cpn, - params, -) - - -@cpn.component(roles=[GUEST, HOST, ARBITER]) -def hetero_lr(ctx, role): - ... - - -@hetero_lr.train() -@cpn.artifact("train_data", type=Input[DatasetArtifact], roles=[GUEST, HOST], desc="training data") -@cpn.artifact("validate_data", type=Input[DatasetArtifact], optional=True, roles=[GUEST, HOST], desc="validation data") -@cpn.parameter("learning_rate", type=params.ConFloat(gt=0.0), default=0.1, desc="learning rate") -@cpn.parameter("max_iter", type=params.ConInt(gt=0), default=100, desc="max iteration num") -@cpn.parameter( - "batch_size", type=params.ConInt(), default=100, desc="batch size, value less or equals to 0 means full batch" -) -@cpn.artifact("train_output_data", type=Output[DatasetArtifact], roles=[GUEST, HOST]) -@cpn.artifact("train_output_metric", type=Output[LossMetrics], roles=[ARBITER]) -@cpn.artifact("output_model", type=Output[ModelArtifact], roles=[GUEST, HOST]) -def train( - ctx, - role: Role, - train_data, - validate_data, - learning_rate, - max_iter, - batch_size, - train_output_data, - train_output_metric, - output_model, -): - if role.is_guest: - train_guest( - ctx, train_data, validate_data, train_output_data, output_model, max_iter, learning_rate, batch_size - ) - elif role.is_host: - train_host( - ctx, train_data, validate_data, train_output_data, output_model, max_iter, learning_rate, batch_size - ) - elif role.is_arbiter: - train_arbiter(ctx, max_iter, batch_size, train_output_metric) - - -@hetero_lr.predict() -@cpn.artifact("input_model", type=Input[ModelArtifact], roles=[GUEST, HOST]) -@cpn.artifact("test_data", type=Input[DatasetArtifact], optional=False, roles=[GUEST, HOST]) -@cpn.artifact("test_output_data", type=Output[DatasetArtifact], roles=[GUEST, HOST]) -def predict( - ctx, - role: Role, - test_data, - input_model, - test_output_data, -): - if role.is_guest: - predict_guest(ctx, input_model, test_data, test_output_data) - if role.is_host: - predict_host(ctx, input_model, test_data, test_output_data) - - -def train_guest(ctx, train_data, validate_data, train_output_data, output_model, max_iter, learning_rate, batch_size): - - from fate.ml.lr.guest import LrModuleGuest - - with ctx.sub_ctx("train") as sub_ctx: - module = LrModuleGuest(max_iter=max_iter, learning_rate=learning_rate, batch_size=batch_size) - train_data = sub_ctx.reader(train_data).read_dataframe() - if validate_data is not None: - validate_data = sub_ctx.reader(validate_data).read_dataframe() - module.fit(sub_ctx, train_data, validate_data) - model = module.get_model() - with output_model as model_writer: - model_writer.write_model("hetero_lr_guest", model, metadata={}) - - with ctx.sub_ctx("predict") as sub_ctx: - predict_score = module.predict(sub_ctx, validate_data) - predict_result = validate_data.data.transform_to_predict_result(predict_score) - sub_ctx.writer(train_output_data).write_dataframe(predict_result) - - -def train_host(ctx, train_data, validate_data, train_output_data, output_model, max_iter, learning_rate, batch_size): - from fate.ml.lr.host import LrModuleHost - - with ctx.sub_ctx("train") as sub_ctx: - module = LrModuleHost(max_iter=max_iter, learning_rate=learning_rate, batch_size=batch_size) - train_data = sub_ctx.reader(train_data).read_dataframe() - if validate_data is not None: - validate_data = sub_ctx.reader(validate_data).read_dataframe() - module.fit(sub_ctx, train_data, validate_data) - model = module.get_model() - with output_model as model_writer: - model_writer.write_model("hetero_lr_host", model, metadata={}) - with ctx.sub_ctx("predict") as sub_ctx: - module.predict(sub_ctx, validate_data) - - -def train_arbiter(ctx, max_iter, batch_size, train_output_metric): - from fate.ml.lr.arbiter import LrModuleArbiter - - ctx.metrics.handler.register_metrics(lr_loss=ctx.writer(train_output_metric)) - - with ctx.sub_ctx("train") as sub_ctx: - module = LrModuleArbiter(max_iter=max_iter, batch_size=batch_size) - module.fit(sub_ctx) - - -def predict_guest(ctx, input_model, test_data, test_output_data): - from fate.ml.lr.guest import LrModuleGuest - - with ctx.sub_ctx("predict") as sub_ctx: - with input_model as model_reader: - model = model_reader.read_model() - module = LrModuleGuest.from_model(model) - test_data = sub_ctx.reader(test_data).read_dataframe() - predict_score = module.predict(sub_ctx, test_data) - predict_result = test_data.data.transform_to_predict_result(predict_score, data_type="predict") - sub_ctx.writer(test_output_data).write_dataframe(predict_result) - - -def predict_host(ctx, input_model, test_data, test_output_data): - from fate.ml.lr.host import LrModuleHost - - with ctx.sub_ctx("predict") as sub_ctx: - with input_model as model_reader: - model = model_reader.read_model() - module = LrModuleHost.from_model(model) - test_data = sub_ctx.reader(test_data).read_dataframe() - module.predict(sub_ctx, test_data) diff --git a/python/fate/components/components/hetero_sbt.py b/python/fate/components/components/hetero_sbt.py new file mode 100644 index 0000000000..95361fc631 --- /dev/null +++ b/python/fate/components/components/hetero_sbt.py @@ -0,0 +1,132 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from fate.arch import Context +from fate.components.components.utils import consts +from fate.components.core import GUEST, HOST, Role, cpn, params +from fate.ml.ensemble import HeteroSecureBoostGuest, HeteroSecureBoostHost, BINARY_BCE, MULTI_CE, REGRESSION_L2 +from fate.components.components.utils.tools import add_dataset_type +from fate.components.components.utils import consts + + +logger = logging.getLogger(__name__) + + +@cpn.component(roles=[GUEST, HOST], provider="fate") +def hetero_sbt(ctx, role): + ... + + +@hetero_sbt.train() +def train( + ctx: Context, + role: Role, + train_data: cpn.dataframe_input(roles=[GUEST, HOST]), + validate_data: cpn.dataframe_input(roles=[GUEST, HOST], optional=True), + num_trees: cpn.parameter(type=params.conint(gt=0), default=3, + desc="max tree num"), + learning_rate: cpn.parameter(type=params.confloat(gt=0), default=0.3, desc='decay factor of each tree'), + max_depth: cpn.parameter(type=params.conint(gt=0), default=3, desc='max depth of a tree'), + max_bin: cpn.parameter(type=params.conint(gt=0), default=32, desc='max bin number of feature binning'), + objective: cpn.parameter(type=params.string_choice(choice=[BINARY_BCE, MULTI_CE, REGRESSION_L2]), default=BINARY_BCE, \ + desc='objective function, available: {}'.format([BINARY_BCE, MULTI_CE, REGRESSION_L2])), + num_class: cpn.parameter(type=params.conint(gt=0), default=2, desc='class number of multi classification, active when objective is {}'.format(MULTI_CE)), + l2: cpn.parameter(type=params.confloat(gt=0), default=0.1, desc='L2 regularization'), + min_impurity_split: cpn.parameter(type=params.confloat(gt=0), default=1e-2, desc='min impurity when splitting a tree node'), + min_sample_split: cpn.parameter(type=params.conint(gt=0), default=2, desc='min sample to split a tree node'), + min_leaf_node: cpn.parameter(type=params.conint(gt=0), default=1, desc='mininum sample contained in a leaf node'), + min_child_weight: cpn.parameter(type=params.confloat(gt=0), default=1, desc='minumum hessian contained in a leaf node'), + gh_pack: cpn.parameter(type=bool, default=True, desc='whether to pack gradient and hessian together'), + split_info_pack: cpn.parameter(type=bool, default=True, desc='for host side, whether to pack split info together'), + hist_sub: cpn.parameter(type=bool, default=True, desc='whether to use histogram subtraction'), + he_param: cpn.parameter(type=params.he_param(), default=params.HEParam(kind='paillier', key_length=1024), desc='homomorphic encryption param, support paillier, ou and mock in current version'), + train_data_output: cpn.dataframe_output(roles=[GUEST, HOST], optional=True), + train_model_output: cpn.json_model_output(roles=[GUEST, HOST], optional=True), + train_model_input: cpn.json_model_input(roles=[GUEST, HOST], optional=True) +): + + train_data = train_data.read() + if validate_data is not None: + validate_data = validate_data.read() + + if train_model_input is not None: + train_model_input = train_model_input.read() + + if role.is_guest: + + # initialize encrypt kit + ctx.cipher.set_phe(ctx.device, he_param.dict()) + + booster = HeteroSecureBoostGuest(num_trees=num_trees, max_depth=max_depth, learning_rate=learning_rate, max_bin=max_bin, + l2=l2, min_impurity_split=min_impurity_split, min_sample_split=min_sample_split, + min_leaf_node=min_leaf_node, min_child_weight=min_child_weight, objective=objective, num_class=num_class, + gh_pack=gh_pack, split_info_pack=split_info_pack, hist_sub=hist_sub + ) + if train_model_input is not None: + booster.from_model(train_model_input) + logger.info('sbt input model loaded, will start warmstarting') + booster.fit(ctx, train_data, validate_data) + # get cached train data score + train_scores = booster.get_train_predict() + train_scores = add_dataset_type(train_scores, consts.TRAIN_SET) + train_data_output.write(train_scores) + # get tree param + tree_dict = booster.get_model() + train_model_output.write(tree_dict, metadata={}) + + elif role.is_host: + + booster = HeteroSecureBoostHost(num_trees=num_trees, max_depth=max_depth, learning_rate=learning_rate, max_bin=max_bin, hist_sub=hist_sub) + if train_model_input is not None: + booster.from_model(train_model_input) + logger.info('sbt input model loaded, will start warmstarting') + booster.fit(ctx, train_data, validate_data) + tree_dict = booster.get_model() + train_model_output.write(tree_dict, metadata={}) + + else: + raise RuntimeError(f"Unknown role: {role}") + + + +@hetero_sbt.predict() +def predict( + ctx, + role: Role, + test_data: cpn.dataframe_input(roles=[GUEST, HOST]), + predict_model_input: cpn.json_model_input(roles=[GUEST, HOST]), + test_output_data: cpn.dataframe_output(roles=[GUEST, HOST]), +): + + model_input = predict_model_input.read() + test_data = test_data.read() + if role.is_guest: + booster = HeteroSecureBoostGuest() + booster.from_model(model_input) + pred_table_rs = booster.predict(ctx, test_data) + pred_table_rs = add_dataset_type(pred_table_rs, consts.TEST_SET) + test_output_data.write(pred_table_rs) + + elif role.is_host: + booster = HeteroSecureBoostHost() + booster.from_model(model_input) + booster.predict(ctx, test_data) + + else: + raise RuntimeError(f"Unknown role: {role}") + + diff --git a/python/fate/components/components/homo_lr.py b/python/fate/components/components/homo_lr.py new file mode 100644 index 0000000000..1ac6331911 --- /dev/null +++ b/python/fate/components/components/homo_lr.py @@ -0,0 +1,136 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from fate.arch import Context +from fate.ml.glm.homo.lr.client import HomoLRClient +from fate.ml.glm.homo.lr.server import HomoLRServer +from fate.components.core import ARBITER, GUEST, HOST, Role, cpn, params +from fate.components.components.utils import consts +from fate.ml.utils.model_io import ModelIO +from fate.components.components.utils.tools import add_dataset_type +from fate.arch.dataframe import DataFrame + +logger = logging.getLogger(__name__) + + +@cpn.component(roles=[GUEST, HOST, ARBITER]) +def homo_lr(ctx, role): + ... + + +@homo_lr.train() +def train( + ctx: Context, + role: Role, + train_data: cpn.dataframe_input(roles=[GUEST, HOST]), + validate_data: cpn.dataframe_input(roles=[GUEST, HOST], optional=True), + learning_rate_scheduler: cpn.parameter(type=params.lr_scheduler_param(), + default=params.LRSchedulerParam(method="constant", + scheduler_params={"factor": 1.0}), + desc="learning rate scheduler, " + "select method from {'step', 'linear', 'constant'}" + "for list of configurable arguments, " + "refer to torch.optim.lr_scheduler"), + epochs: cpn.parameter(type=params.conint(gt=0), default=20, + desc="max iteration num"), + batch_size: cpn.parameter(type=params.conint(ge=0), default=None, + desc="batch size, int > 0, if None means full batch" + "non"), + optimizer: cpn.parameter(type=params.optimizer_param(), + default=params.OptimizerParam(method="sgd", penalty='l2', alpha=1.0, + optimizer_params={"lr": 1e-2, "weight_decay": 0})), + init_param: cpn.parameter(type=params.init_param(), + default=params.InitParam(method='random', fit_intercept=True), + desc="Model param init setting."), + threshold: cpn.parameter(type=params.confloat(ge=0.0, le=1.0), default=0.5, + desc="predict threshold for binary data"), + ovr: cpn.parameter(type=bool, default=False, + desc="enable ovr for multi-classifcation"), + label_num: cpn.parameter(type=params.conint(ge=2), default=None), + train_output_data: cpn.dataframe_output(roles=[GUEST, HOST]), + train_input_model: cpn.json_model_input(roles=[GUEST, HOST], optional=True), + train_output_model: cpn.json_model_output(roles=[GUEST, HOST]) +): + + sub_ctx = ctx.sub_ctx(consts.TRAIN) + + if role.is_guest or role.is_host: # is client + + logger.info('homo lr component: client start training') + logger.info('optim param {} \n init param {} \n learning rate param {}'.format( + optimizer.dict(), init_param.dict(), learning_rate_scheduler.dict())) + + client = HomoLRClient( + epochs=epochs, + batch_size=batch_size, + optimizer_param=optimizer.dict(), + init_param=init_param.dict(), + learning_rate_scheduler=learning_rate_scheduler.dict(), + threshold=threshold, + ovr=ovr, + label_num=label_num) + + if train_input_model is not None: + model_input = train_input_model.read() + client.from_model(model_input) + logger.info('model input loaded') + train_df = train_data.read() + validate_df = validate_data.read() if validate_data else None + client.fit(sub_ctx, train_df, validate_df) + model_dict = client.get_model() + + train_rs = client.predict(sub_ctx, train_df) + train_rs = add_dataset_type(train_rs, consts.TRAIN_SET) + if validate_df: + validate_rs = client.predict(sub_ctx, validate_df) + validate_rs = add_dataset_type(validate_rs, consts.VALIDATE_SET) + ret_df = DataFrame.vstack([train_rs, validate_rs]) + else: + ret_df = train_rs + + train_output_data.write(ret_df) + train_output_model.write(model_dict, metadata=model_dict['meta']) + + elif role.is_arbiter: # is server + logger.info('homo lr component: server start training') + server = HomoLRServer() + server.fit(sub_ctx) + + +@homo_lr.predict() +def predict( + ctx, + role: Role, + test_data: cpn.dataframe_input(roles=[GUEST, HOST]), + batch_size: cpn.parameter(type=params.conint(ge=-1), default=100, + desc="batch size, " + "value less or equals to 0 means full batch"), + threshold: cpn.parameter(type=params.confloat(ge=0.0, le=1.0), default=0.5, + desc="predict threshold for binary data"), + predict_input_model: cpn.json_model_input(roles=[GUEST, HOST]), + test_output_data: cpn.dataframe_output(roles=[GUEST, HOST]) +): + + if role.is_guest or role.is_host: # is client + + client = HomoLRClient(batch_size=batch_size, threshold=threshold) + model_input = predict_input_model.read() + client.from_model(model_input) + pred_rs = client.predict(ctx, test_data.read()) + pred_rs = add_dataset_type(pred_rs, consts.TEST_SET) + test_output_data.write(pred_rs) + + elif role.is_arbiter: # is server + logger.info("arbiter skip predict") diff --git a/python/fate/components/components/homo_nn.py b/python/fate/components/components/homo_nn.py new file mode 100644 index 0000000000..fbe62d6b07 --- /dev/null +++ b/python/fate/components/components/homo_nn.py @@ -0,0 +1,245 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from fate.arch import Context +from fate.components.components.nn.loader import Loader +from fate.components.components.nn.nn_runner import NNRunner +from fate.components.components.nn.runner.default_runner import DefaultRunner +from fate.components.components.utils import consts +from fate.components.core import ARBITER, GUEST, HOST, Role, cpn +from fate.arch.dataframe import DataFrame +from fate.components.components.utils.tools import add_dataset_type + +logger = logging.getLogger(__name__) + + +def is_path(s): + return os.path.exists(s) + + +""" +Input Functions +""" + + +def prepare_runner_class(runner_module, runner_class, runner_conf, source): + logger.info("runner conf is {}".format(runner_conf)) + logger.info("source is {}".format(source)) + if runner_module != "fate_runner": + if source is None: + # load from default folder + runner = Loader( + "fate.components.components.nn.runner." + + runner_module, + runner_class, + **runner_conf)() + else: + runner = Loader( + runner_module, + runner_class, + source=source, + **runner_conf)() + assert isinstance( + runner, NNRunner), "loaded class must be a subclass of NNRunner class, but got {}".format( + type(runner)) + else: + logger.info("using default fate runner") + runner = DefaultRunner(**runner_conf) + + return runner + + +def prepare_context_and_role(runner, ctx, role, sub_ctx_name): + sub_ctx = ctx.sub_ctx(sub_ctx_name) + runner.set_context(sub_ctx) + runner.set_role(role) + + +def get_input_data(stage, cpn_input_data): + + if stage == 'train': + train_data, validate_data = cpn_input_data + train_data = train_data.read() + if validate_data is not None: + validate_data = validate_data.read() + + return train_data, validate_data + + elif stage == 'predict': + test_data = cpn_input_data + test_data = test_data.read() + return test_data + else: + raise ValueError(f"Unknown stage {stage}") + + +"""" +Output functions +""" + + +def get_model_output_conf(runner_module, + runner_class, + runner_conf, + source, + ): + return { + "runner_module": runner_module, + "runner_class": runner_class, + "runner_conf": runner_conf, + "source": source, + } + + +def prepared_saved_conf( + model_conf, + runner_class, + runner_module, + runner_conf, + source): + + logger.info("loaded model_conf is: {}".format(model_conf)) + if "source" in model_conf: + if source is None: + source = model_conf["source"] + + runner_class_, runner_module_ = model_conf['runner_class'], model_conf['runner_module'] + if runner_class_ == runner_class and runner_module_ == runner_module: + if "runner_conf" in model_conf: + saved_conf = model_conf['runner_conf'] + saved_conf.update(runner_conf) + runner_conf = saved_conf + logger.info("runner_conf is updated: {}".format(runner_conf)) + else: + logger.warning( + "runner_class or runner_module is not equal to the saved model, " + "use the new runner_conf, runner_class and runner module to train the model,\ + saved module & class: {} {}, new module & class: {} {}".format( + runner_module_, runner_class_, runner_module, runner_class)) + + return runner_conf, source, runner_class, runner_module + + +@cpn.component(roles=[GUEST, HOST, ARBITER]) +def homo_nn(ctx, role): + ... + + +@homo_nn.train() +def train( + ctx: Context, + role: Role, + train_data: cpn.dataframe_input(roles=[GUEST, HOST]), + validate_data: cpn.dataframe_input(roles=[GUEST, HOST], optional=True), + runner_module: cpn.parameter(type=str, default="default_runner", desc="name of your runner script"), + runner_class: cpn.parameter(type=str, default="DefaultRunner", desc="class name of your runner class"), + runner_conf: cpn.parameter(type=dict, default={}, desc="the parameter dict of the NN runner class"), + source: cpn.parameter(type=str, default=None, desc="path to your runner script folder"), + train_data_output: cpn.dataframe_output(roles=[GUEST, HOST], optional=True), + train_model_output: cpn.model_directory_output(roles=[GUEST, HOST], optional=True), + train_model_input: cpn.model_directory_input(roles=[GUEST, HOST], optional=True) +): + + + if role.is_guest or role.is_host: # is client + + if train_model_input is not None: + model_conf = train_model_input.get_metadata() + runner_conf, source, runner_class, runner_module = prepared_saved_conf( + model_conf, runner_class, runner_module, runner_conf, source) + saved_model_path = str(train_model_input.get_directory()) + else: + saved_model_path = None + + runner: NNRunner = prepare_runner_class( + runner_module, runner_class, runner_conf, source) + prepare_context_and_role(runner, ctx, role, consts.TRAIN) + + output_dir = str(train_model_output.get_directory()) + train_data_, validate_data_ = get_input_data( + consts.TRAIN, [train_data, validate_data]) + runner.train(train_data_, validate_data_, output_dir, saved_model_path) + + logger.info('Predicting Train & Validate Data') + train_pred = runner.predict(train_data_, saved_model_path) + if train_pred is not None: + assert isinstance( + train_pred, DataFrame), "train predict result should be a DataFrame" + add_dataset_type(train_pred, consts.TRAIN_SET) + + if validate_data_ is not None: + validate_pred = runner.predict(validate_data_) + assert isinstance( + validate_pred, DataFrame), "validate predict result should be a DataFrame" + add_dataset_type(validate_pred, consts.VALIDATE_SET) + output_df = DataFrame.vstack([train_pred, validate_pred]) + else: + output_df = train_pred + logger.info('write result dataframe') + train_data_output.write(output_df) + else: + logger.warning( + "train_pred is None, It seems that the runner is not able to predict. Failed to output data") + + output_conf = get_model_output_conf(runner_module, + runner_class, + runner_conf, + source + ) + train_model_output.write_metadata(output_conf) + + elif role.is_arbiter: # is server + runner: NNRunner = prepare_runner_class( + runner_module, runner_class, runner_conf, source) + prepare_context_and_role(runner, ctx, role, consts.TRAIN) + runner.train() + + +@homo_nn.predict() +def predict( + ctx, role: Role, test_data: cpn.dataframe_input( + roles=[ + GUEST, HOST]), predict_model_input: cpn.model_directory_input( + roles=[ + GUEST, HOST]), predict_data_output: cpn.dataframe_output( + roles=[ + GUEST, HOST], optional=True)): + + if role.is_guest or role.is_host: # is client + + model_conf = predict_model_input.get_metadata() + runner_module = model_conf['runner_module'] + runner_class = model_conf['runner_class'] + runner_conf = model_conf['runner_conf'] + source = model_conf['source'] + saved_model_path = str(predict_model_input.get_directory()) + test_data_ = get_input_data(consts.PREDICT, test_data) + runner: NNRunner = prepare_runner_class( + runner_module, runner_class, runner_conf, source) + prepare_context_and_role(runner, ctx, role, consts.PREDICT) + test_pred = runner.predict( + test_data_, saved_model_path=saved_model_path) + if test_pred is not None: + assert isinstance( + test_pred, DataFrame), "test predict result should be a DataFrame" + add_dataset_type(test_pred, consts.TEST_SET) + predict_data_output.write(test_pred) + else: + logger.warning( + "test_pred is None, It seems that the runner is not able to predict. Failed to output data") + + elif role.is_arbiter: # is server + logger.info("arbiter skip predict") diff --git a/python/fate/components/components/intersection.py b/python/fate/components/components/intersection.py deleted file mode 100644 index 847f8ff30d..0000000000 --- a/python/fate/components/components/intersection.py +++ /dev/null @@ -1,52 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from fate.components import GUEST, HOST, DatasetArtifact, Input, Output, Role, cpn - - -@cpn.component(roles=[GUEST, HOST], provider="fate") -@cpn.artifact("input_data", type=Input[DatasetArtifact], roles=[GUEST, HOST]) -@cpn.parameter("method", type=str, default="raw", optional=True) -@cpn.artifact("output_data", type=Output[DatasetArtifact], roles=[GUEST, HOST]) -def intersection( - ctx, - role: Role, - input_data, - method, - output_data, -): - if role.is_guest: - if method == "raw": - raw_intersect_guest(ctx, input_data, output_data) - elif role.is_host: - if method == "raw": - raw_intersect_host(ctx, input_data, output_data) - - -def raw_intersect_guest(ctx, input_data, output_data): - from fate.ml.intersection import RawIntersectionGuest - - data = ctx.reader(input_data).read_dataframe().data - guest_intersect_obj = RawIntersectionGuest() - intersect_data = guest_intersect_obj.fit(ctx, data) - ctx.writer(output_data).write_dataframe(intersect_data) - - -def raw_intersect_host(ctx, input_data, output_data): - from fate.ml.intersection import RawIntersectionHost - - data = ctx.reader(input_data).read_dataframe().data - host_intersect_obj = RawIntersectionHost() - intersect_data = host_intersect_obj.fit(ctx, data) - ctx.writer(output_data).write_dataframe(intersect_data) diff --git a/python/fate/components/components/multi_model_test.py b/python/fate/components/components/multi_model_test.py new file mode 100644 index 0000000000..d5aa459285 --- /dev/null +++ b/python/fate/components/components/multi_model_test.py @@ -0,0 +1,66 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from fate.arch import Context +from fate.components.core import LOCAL, Role, cpn + + +@cpn.component(roles=[LOCAL]) +def multi_model_test( + ctx: Context, + role: Role, + dataframe_input: cpn.dataframe_input(roles=[LOCAL]), + dataframe_output: cpn.dataframe_output(roles=[LOCAL]), + json_model_input: cpn.json_model_input(roles=[LOCAL]), + json_model_inputs: cpn.json_model_inputs(roles=[LOCAL]), + json_model_output: cpn.json_model_output(roles=[LOCAL]), + json_model_outputs: cpn.json_model_outputs(roles=[LOCAL]), + model_directory_output: cpn.model_directory_output(roles=[LOCAL]), + model_directory_outputs: cpn.model_directory_outputs(roles=[LOCAL]), +): + df = dataframe_input.read() + df = df + 1 + + dataframe_output.write(df) + json_model_input = json_model_input.read() + json_model_inputs = [model_input.read() for model_input in json_model_inputs] + + assert json_model_input == {"aaa": 1} + assert len(json_model_inputs) == 3 + json_model_output.write({"ccccccc": 1}, metadata={"ccc": 2}) + for i in range(10): + model_output = next(json_model_outputs) + model_output.write({"abc": i}, metadata={"i-th output": i}) + + output_path = model_directory_output.get_directory() + with open(output_path + "/a.txt", "w") as fw: + fw.write("test for model directory output\n") + + model_directory_output.write_metadata({"model_directory_output_metadata": "test_directory"}) + + for i in range(5): + directory_output = next(model_directory_outputs) + path = directory_output.get_directory() + with open(path + f"/output_{i}.txt", "w") as fw: + fw.write("test for model directory output\n") + + directory_output.write_metadata({f"model_directory_output_{i}_metadata": f"test_directory_{i}"}) + + ctx.metrics.log_metrics(values=[1, 2, 3], name="metric_single", type="custom", metadata={"bbb": 2}) + ctx.sub_ctx("sub_metric").metrics.log_loss("loss", 1.0, 0) + ctx.sub_ctx("sub_metric").metrics.log_loss("loss", 0.9, 1) + + for i, auto_step_sub_ctx in ctx.ctxs_range(10): + auto_step_sub_ctx.metrics.log_accuracy("sub", 1.0) diff --git a/examples/pipeline/__init__.py b/python/fate/components/components/nn/__init__.py similarity index 100% rename from examples/pipeline/__init__.py rename to python/fate/components/components/nn/__init__.py diff --git a/python/fate/components/components/nn/loader.py b/python/fate/components/components/nn/loader.py new file mode 100644 index 0000000000..6dfe1269ac --- /dev/null +++ b/python/fate/components/components/nn/loader.py @@ -0,0 +1,193 @@ +import os +import sys +import importlib.util +from abc import ABC, abstractmethod +import json +import yaml +import difflib + + +class _Source(object): + MODEL_ZOO = 'fate.ml.nn.model_zoo' + DATASET = 'fate.ml.nn.dataset' + CUST_FUNC = 'fate.ml.nn.cust_func' + + +SOURCE_FILE = 'source.yaml' + + +def is_path(s): + return os.path.exists(s) + + +def load_source(): + script_path = os.path.realpath(__file__) + script_dir = os.path.dirname(script_path) + with open(script_dir + '/' + SOURCE_FILE, 'r') as f: + source = yaml.safe_load(f) + return source + + +class AbstractLoader(ABC): + @abstractmethod + def __init__(self, module_name, item_name, source=None): + pass + + @abstractmethod + def call_item(self): + pass + + @abstractmethod + def load_item(self): + pass + + @abstractmethod + def to_json(self): + pass + + @abstractmethod + def to_dict(self): + pass + + +class Loader(AbstractLoader): + + def __init__(self, module_name, item_name, source=None, **kwargs): + + self.item_name = item_name + self.module_name = module_name + self.source = source + self.source_path = None + + if isinstance(source, str): + self.module_name = module_name + source_dict = load_source() + if self.source in source_dict: + self.source_path = source_dict[self.source] + else: + raise ValueError( + 'source name {} is not found in the source.yaml file. Please check the source name.'.format( + self.source)) + elif source is None: + self.module_name = module_name + + self.kwargs = kwargs + + def __call__(self): + return self.call_item() + + def call_item(self): + item = self._load_item() + + if item is not None and callable(item): + item = item(**self.kwargs) + + return item + + def load_item(self): + return self._load_item() + + def _load_item(self): + if self.source_path is not None: + sys.path.append(self.source_path) + + spec = importlib.util.find_spec(self.module_name) + if spec is None: + # Search for similar module names + suggestion = self._find_similar_module_names() + if suggestion: + raise ValueError( + "Module: {} not found in the import path. Do you mean {}?".format( + self.module_name, suggestion)) + else: + raise ValueError( + "Module: {} not found in the import path.".format( + self.module_name)) + + module = importlib.import_module(self.module_name) + + item = getattr(module, self.item_name, None) + if item is None: + raise ValueError( + "Item: {} not found in module: {}.".format( + self.item_name, self.module_name)) + + if self.source_path is not None: + sys.path.remove(self.source_path) + + return item + + def _find_similar_module_names(self): + + if self.source_path is None: + return None + files = os.listdir(self.source_path) + print('source matches are', files) + similar_names = difflib.get_close_matches(self.module_name, files) + return similar_names[0] if similar_names else None + + def to_json(self): + return json.dumps(self.to_dict()) + + def to_dict(self): + return { + 'module_name': self.module_name, + 'item_name': self.item_name, + 'kwargs': self.kwargs, + 'source': self.source + } + + @staticmethod + def from_json(json_str): + data = json.loads(json_str) + return Loader.from_dict(data) + + @staticmethod + def from_dict(data_dict): + return Loader(module_name=data_dict['module_name'], + item_name=data_dict['item_name'], + source=data_dict.get('source', None), + **data_dict.get('kwargs', {}) + ) + + +class ModelLoader(Loader): + def __init__(self, module_name, item_name, source=None, **kwargs): + if source is None: + # add prefix for moduele loader + module_name = f'{_Source.MODEL_ZOO}.{module_name}' + super( + ModelLoader, + self).__init__( + module_name, + item_name, + source, + **kwargs) + + +class DatasetLoader(Loader): + def __init__(self, module_name, item_name, source=None, **kwargs): + if source is None: + # add prefix for moduele loader + module_name = f'{_Source.DATASET}.{module_name}' + super( + DatasetLoader, + self).__init__( + module_name, + item_name, + source, + **kwargs) + + +class CustFuncLoader(Loader): + def __init__(self, module_name, item_name, source=None, **kwargs): + if source is None: + # add prefix for moduele loader + module_name = f'{_Source.CUST_FUNC}.{module_name}' + super( + CustFuncLoader, + self).__init__( + module_name, + item_name, + source, + **kwargs) diff --git a/python/fate/components/components/nn/nn_runner.py b/python/fate/components/components/nn/nn_runner.py new file mode 100644 index 0000000000..0d9f3b69dd --- /dev/null +++ b/python/fate/components/components/nn/nn_runner.py @@ -0,0 +1,211 @@ +import numpy as np +import torch +import pandas as pd +from typing import Union, Optional, Literal +from fate.components.core import Role +from fate.arch import Context +from typing import Optional, Union +from transformers.trainer_utils import PredictionOutput +import numpy as np +from fate.arch.dataframe._dataframe import DataFrame +from fate.components.components.utils import consts +import logging +from fate.ml.utils.predict_tools import to_dist_df, array_to_predict_df +from fate.ml.utils.predict_tools import BINARY, MULTI, REGRESSION, OTHER, LABEL, PREDICT_SCORE + + +logger = logging.getLogger(__name__) + + +def _convert_to_numpy_array( + data: Union[pd.Series, pd.DataFrame, np.ndarray, torch.Tensor]) -> np.ndarray: + if isinstance(data, pd.Series) or isinstance(data, pd.DataFrame): + return data.to_numpy() + elif isinstance(data, torch.Tensor): + return data.cpu().numpy() + else: + return np.array(data) + + +def task_type_infer(predict_result, true_label): + + pred_shape = predict_result.shape + + if true_label.max() == 1.0 and true_label.min() == 0.0: + return consts.BINARY + + if (len(pred_shape) > 1) and (pred_shape[1] > 1): + if np.isclose( + predict_result.sum( + axis=1), np.array( + [1.0])).all(): + return consts.MULTI + else: + return None + elif (len(pred_shape) == 1) or (pred_shape[1] == 1): + return consts.REGRESSION + + return None + + +class NNRunner(object): + + def __init__(self) -> None: + + self._role = None + self._party_id = None + self._ctx: Context = None + + def set_context(self, context: Context): + assert isinstance(context, Context) + self._ctx = context + + def get_context(self) -> Context: + return self._ctx + + def set_role(self, role: Role): + assert isinstance(role, Role) + self._role = role + + def is_client(self) -> bool: + return self._role.is_guest or self._role.is_host + + def is_server(self) -> bool: + return self._role.is_arbiter + + def set_party_id(self, party_id: int): + assert isinstance(self._party_id, int) + self._party_id = party_id + + def get_fateboard_tracker(self): + pass + + def get_nn_output_dataframe( + self, + ctx, + predictions: Union[np.ndarray, torch.Tensor, DataFrame, PredictionOutput], + labels: Union[np.ndarray, torch.Tensor, DataFrame, PredictionOutput] = None, + match_ids: Union[pd.DataFrame, np.ndarray] = None, + sample_ids: Union[pd.DataFrame, np.ndarray] = None, + match_id_name: str = None, + sample_id_name: str = None, + dataframe_format: Literal['default', 'fate_std'] = 'default', + task_type: Literal['binary', 'multi', 'regression', 'others'] = None, + threshold: float = 0.5, + classes: list = None + ) -> DataFrame: + """ + Constructs a FATE DataFrame from predictions and labels. This Dataframe is able to flow through FATE components. + + Parameters: + ctx (Context): The Context Instance. + predictions (Union[np.ndarray, torch.Tensor, DataFrame, PredictionOutput]): The model's predictions, which can be numpy arrays, torch tensors, pandas DataFrames, or PredictionOutputs. + labels (Union[np.ndarray, torch.Tensor, DataFrame, PredictionOutput]): The true labels, which can be numpy arrays, torch tensors, pandas DataFrames, or PredictionOutputs. + match_ids (Union[pd.DataFrame, np.ndarray], optional): Match IDs, if applicable. Defaults to None. If None, will auto generate match_ids. + sample_ids (Union[pd.DataFrame, np.ndarray], optional): Sample IDs, if applicable. Defaults to None. If None, will auto generate sample_ids. + match_id_name (str, optional): Column name for match IDs in the resulting DataFrame. If None, Defaults to 'id'. + sample_id_name (str, optional): Column name for sample IDs in the resulting DataFrame. If None, Defaults to 'sample_id'. + dataframe_format (Literal['default', 'fate_std'], optional): Output format of the resulting DataFrame. If 'default', simply combines labels and predictions into a DataFrame. + If 'fate_std', organizes output according to the FATE framework's format. Defaults to 'default'. + task_type (Literal['binary', 'multi', 'regression', 'others'], optional): This parameter is only needed when dataframe_format is 'fate_std'. Defaults to None. + The type of machine learning task, which can be 'binary', 'multi', 'regression', or 'others'. + threshold (float, optional): This parameter is only needed when dataframe_format is 'fate_std' and task_type is 'binary'. Defaults to 0.5. + classes (list, optional): This parameter is only needed when dataframe_format is 'fate_std'. List of classes. + Returns: + DataFrame: A DataFrame that contains the neural network's predictions and the true labels, possibly along with match IDs and sample IDs, formatted according to the specified format. + """ + # check parameters + assert task_type in [BINARY, MULTI, REGRESSION, OTHER], f"task_type {task_type} is not supported" + assert dataframe_format in [ + 'default', 'fate_std'], f"dataframe_format {dataframe_format} is not supported" + + if match_id_name is None: + match_id_name = 'id' + if sample_id_name is None: + sample_id_name = 'sample_id' + + if isinstance(predictions, PredictionOutput): + predictions = predictions.predictions + + if labels is not None: + if isinstance(labels, PredictionOutput): + labels = labels.label_ids + predictions = _convert_to_numpy_array(predictions) + labels = _convert_to_numpy_array(labels) + assert len(predictions) == len( + labels), f"predictions length {len(predictions)} != labels length {len(labels)}" + + # check match ids + if match_ids is not None: + match_ids = _convert_to_numpy_array(match_ids).flatten() + else: + logger.info( + "match_ids is not provided, will auto generate match_ids") + match_ids = np.array( + [i for i in range(len(predictions))]).flatten() + + # check sample ids + if sample_ids is not None: + sample_ids = _convert_to_numpy_array(sample_ids).flatten() + else: + logger.info( + "sample_ids is not provided, will auto generate sample_ids") + sample_ids = np.array( + [i for i in range(len(predictions))]).flatten() + + assert len(match_ids) == len( + predictions), f"match_ids length {len(match_ids)} != predictions length {len(predictions)}" + assert len(sample_ids) == len( + predictions), f"sample_ids length {len(sample_ids)} != predictions length {len(predictions)}" + + # match id name and sample id name must be str + assert isinstance( + match_id_name, str), f"match_id_name must be str, but got {type(match_id_name)}" + assert isinstance( + sample_id_name, str), f"sample_id_name must be str, but got {type(sample_id_name)}" + + if dataframe_format == 'default' or ( + dataframe_format == 'fate_std' and task_type == OTHER): + df = pd.DataFrame() + if labels is not None: + df[LABEL] = labels.to_list() + df[PREDICT_SCORE] = predictions.to_list() + df[match_id_name] = match_ids.flatten() + df[sample_id_name] = sample_ids.flatten() + df = to_dist_df(ctx, sample_id_name, match_id_name, df) + return df + elif dataframe_format == 'fate_std' and task_type in [BINARY, MULTI, REGRESSION]: + df = array_to_predict_df(ctx, task_type, predictions, match_ids, sample_ids, match_id_name, sample_id_name, labels, threshold, classes) + return df + + def train(self, + train_data: Optional[Union[str, + DataFrame]] = None, + validate_data: Optional[Union[str, + DataFrame]] = None, + output_dir: str = None, + saved_model_path: str = None) -> None: + """ + Train interface. + + Parameters: + train_data (Union[str, DataFrame]): The training data, which can be a FATE DataFrame containing the data, or a string path representing the bound data.Train data is Optional on the server side. + validate_data (Optional[Union[str, DataFrame]]): The validation data, which can be a FATE DataFrame containing the data, or a string path representing the bound data . This argument is optional. + output_dir (str, optional): The path to the directory where the trained model should be saved. If this class is running in the fate pipeline, this path will provided by FATE framework. + saved_model_path (str, optional): The path to the saved model that should be loaded before training starts.If this class is running in the fate pipeline, this path will provided by FATE framework. + """ + pass + + def predict(self, + test_data: Optional[Union[str, + DataFrame]] = None, + output_dir: str = None, + saved_model_path: str = None) -> DataFrame: + """ + Predict interface. + + Parameters: + test_data (Union[str, DataFrame]): The data to predict, which can be a FATE DataFrame containing the data, or a string path representing the bound data.Test data is Optional on the server side. + output_dir (str, optional): The path to the directory where the trained model should be saved. If this class is running in the fate pipeline, this path will provided by FATE framework. + saved_model_path (str, optional): The path to the saved model that should be loaded before training starts.If this class is running in the fate pipeline, this path will provided by FATE framework. + """ diff --git a/python/fate/arch/context/io/__init__.py b/python/fate/components/components/nn/runner/__init__.py similarity index 100% rename from python/fate/arch/context/io/__init__.py rename to python/fate/components/components/nn/runner/__init__.py diff --git a/python/fate/components/components/nn/runner/default_runner.py b/python/fate/components/components/nn/runner/default_runner.py new file mode 100644 index 0000000000..ed83c702b3 --- /dev/null +++ b/python/fate/components/components/nn/runner/default_runner.py @@ -0,0 +1,382 @@ +import torch as t +import os +from fate.components.components.nn.nn_runner import NNRunner +from fate.ml.nn.algo.homo.fedavg import FedAVG, FedAVGArguments, FedAVGCLient, FedAVGServer, TrainingArguments +from typing import Optional, Dict, Union +from fate.components.components.nn.loader import Loader +import torch.nn as nn +import torch.optim as optim +import torch.utils.data as data_utils +from torch.optim.lr_scheduler import _LRScheduler +from fate.ml.nn.trainer.trainer_base import FedArguments, TrainingArguments, FedTrainerClient, FedTrainerServer +from typing import Union, Type, Callable, Optional +from transformers.trainer_utils import get_last_checkpoint +from fate.ml.nn.dataset.table import TableDataset +from typing import Literal +import logging +from fate.components.components.utils import consts +from fate.ml.nn.dataset.table import TableDataset +from fate.arch.dataframe import DataFrame + + +logger = logging.getLogger(__name__) + + +SUPPORTED_ALGO = ['fedavg'] + + +def load_model_dict_from_path(path): + # Ensure that the path is a string + assert isinstance( + path, str), "Path must be a string, but got {}".format( + type(path)) + + # Append the filename to the path + model_path = os.path.join(path, 'pytorch_model.bin') + + # Check if the file exists + if not os.path.isfile(model_path): + raise FileNotFoundError( + f"No 'pytorch_model.bin' file found at {model_path}, no saved model found") + + # Load the state dict from the specified path + model_dict = t.load(model_path) + + return model_dict + + +def dir_warning(train_args): + if 'output_dir' in train_args or 'logging_dir' in train_args or 'resume_from_checkpoint' in train_args: + logger.warning( + "The output_dir, logging_dir, and resume_from_checkpoint arguments are not supported in the " + "DefaultRunner when running the Pipeline. These arguments will be replaced by FATE provided paths.") + + +class SetupReturn: + """ + Class to encapsulate the return objects from the setup. + """ + + def __init__(self, + trainer: Union[Type[FedTrainerClient], + Type[FedTrainerServer]] = None, + model: Type[nn.Module] = None, + optimizer: Type[optim.Optimizer] = None, + loss: Callable = None, + scheduler: Type[_LRScheduler] = None, + train_args: TrainingArguments = None, + fed_args: FedArguments = None, + data_collator: Callable = None) -> None: + + if trainer is not None and not ( + issubclass( + type(trainer), + FedTrainerClient) or issubclass( + type(trainer), + FedTrainerServer)): + raise TypeError( + f"SetupReturn Error: trainer must be a subclass of either FedTrainerClient or FedTrainerServer but got {type(trainer)}") + + if model is not None and not issubclass(type(model), nn.Module): + raise TypeError( + f"SetupReturn Error: model must be a subclass of torch.nn.Module but got {type(model)}") + + if optimizer is not None and not issubclass( + type(optimizer), optim.Optimizer): + raise TypeError( + f"SetupReturn Error: optimizer must be a subclass of torch.optim.Optimizer but got {type(optimizer)}") + + if loss is not None and not callable(loss): + raise TypeError( + f"SetupReturn Error: loss must be callable but got {type(loss)}") + + if scheduler is not None and not issubclass( + type(scheduler), _LRScheduler): + raise TypeError( + f"SetupReturn Error: scheduler must be a subclass of torch.optim.lr_scheduler._LRScheduler but got {type(scheduler)}") + + if train_args is not None and not isinstance( + train_args, TrainingArguments): + raise TypeError( + f"SetupReturn Error: train_args must be an instance of TrainingArguments but got {type(train_args)}") + + if fed_args is not None and not isinstance(fed_args, FedArguments): + raise TypeError( + f"SetupReturn Error: fed_args must be an instance of FedArguments but got {type(fed_args)}") + + if data_collator is not None and not callable(data_collator): + raise TypeError( + f"SetupReturn Error: data_collator must be callable but got {type(data_collator)}") + + self.trainer = trainer + self.model = model + self.optimizer = optimizer + self.loss = loss + self.scheduler = scheduler + self.train_args = train_args + self.fed_args = fed_args + self.data_collator = data_collator + + def __getitem__(self, item): + return getattr(self, item) + + def __repr__(self): + repr_string = "SetupReturn(\n" + for key, value in self.__dict__.items(): + repr_string += f" {key}={type(value)},\n" + repr_string = repr_string.rstrip(',\n') + repr_string += "\n)" + return repr_string + + +class DefaultRunner(NNRunner): + + def __init__(self, + algo: str = 'fedavg', + model_conf: Optional[Dict] = None, + dataset_conf: Optional[Dict] = None, + optimizer_conf: Optional[Dict] = None, + training_args_conf: Optional[Dict] = None, + fed_args_conf: Optional[Dict] = None, + loss_conf: Optional[Dict] = None, + data_collator_conf: Optional[Dict] = None, + tokenizer_conf: Optional[Dict] = None, + task_type: Literal['binary', + 'multi', + 'regression', + 'others'] = 'binary', + threshold: float = 0.5, + local_mode: bool = False) -> None: + + super().__init__() + self.algo = algo + self.model_conf = model_conf + self.dataset_conf = dataset_conf + self.optimizer_conf = optimizer_conf + self.training_args_conf = training_args_conf + self.fed_args_conf = fed_args_conf + self.loss_conf = loss_conf + self.data_collator_conf = data_collator_conf + self.local_mode = local_mode + self.tokenizer_conf = tokenizer_conf + self.task_type = task_type + self.threshold = threshold + + # check param + if self.algo not in SUPPORTED_ALGO: + raise ValueError('algo should be one of [fedavg]') + if self.task_type not in ['binary', 'multi', 'regression', 'others']: + raise ValueError( + 'task_type should be one of [binary, multi, regression, others]') + assert self.threshold >= 0 and self.threshold <= 1, 'threshold should be in [0, 1]' + assert isinstance(self.local_mode, bool), 'local should be bool' + + # setup var + self.trainer = None + + def _loader_load_from_conf(self, conf, return_class=False): + if conf is None: + return None + if return_class: + return Loader.from_dict(conf).load_item() + return Loader.from_dict(conf).call_item() + + def _prepare_data(self, data, data_name) -> SetupReturn: + + if data is None: + return None + if isinstance(data, DataFrame) and self.dataset_conf is None: + logger.info( + 'Input data {} is FATE DataFrame and dataset conf is None, will automatically handle the input data'.format(data_name)) + if self.task_type == consts.MULTI: + dataset = TableDataset( + flatten_label=True, + label_dtype='long', + to_tensor=True) + else: + dataset = TableDataset(to_tensor=True) + dataset.load(data) + else: + dataset = self._loader_load_from_conf(self.dataset_conf) + if hasattr(dataset, 'load'): + dataset.load(data) + else: + raise ValueError( + f"The dataset {dataset} lacks a load() method, which is required for data parsing in the DefaultRunner. \ + Please implement this method in your dataset class. You can refer to the base class 'Dataset' in 'fate.ml.nn.dataset.base' \ + for the necessary interfaces to implement.") + if dataset is not None and not issubclass( + type(dataset), data_utils.Dataset): + raise TypeError( + f"SetupReturn Error: {data_name}_set must be a subclass of torch.utils.data.Dataset but got {type(dataset)}") + + return dataset + + def client_setup( + self, + train_set=None, + validate_set=None, + output_dir=None, + saved_model=None, + stage='train'): + + if stage == 'predict': + self.local_mode = True + + if self.algo == 'fedavg': + client_class: FedAVGCLient = FedAVG.client + else: + raise ValueError(f"algo {self.algo} not supported") + + ctx = self.get_context() + print(self.model_conf) + model = self._loader_load_from_conf(self.model_conf) + if model is None: + raise ValueError( + f"model is None, cannot load model from conf {self.model_conf}") + + if output_dir is None: + output_dir = './' + + resume_path = None + if saved_model is not None: + model_dict = load_model_dict_from_path(saved_model) + model.load_state_dict(model_dict) + logger.info(f"loading model dict from {saved_model} to model done") + if get_last_checkpoint(saved_model) is not None: + resume_path = saved_model + logger.info( + f"checkpoint detected, resume_path set to {resume_path}") + # load optimizer + optimizer_loader = Loader.from_dict(self.optimizer_conf) + optimizer_ = optimizer_loader.load_item() + optimizer_params = optimizer_loader.kwargs + optimizer = optimizer_(model.parameters(), **optimizer_params) + # load loss + loss = self._loader_load_from_conf(self.loss_conf) + # load collator func + data_collator = self._loader_load_from_conf(self.data_collator_conf) + # load tokenizer if import conf provided + tokenizer = self._loader_load_from_conf(self.tokenizer_conf) + # args + dir_warning(self.training_args_conf) + training_args = TrainingArguments(**self.training_args_conf) + # reset to default, saving to arbitrary path is not allowed in + # DefaultRunner + training_args.output_dir = output_dir + training_args.resume_from_checkpoint = resume_path # resume path + fed_args = FedAVGArguments(**self.fed_args_conf) + + # prepare trainer + trainer = client_class( + ctx=ctx, + model=model, + loss_fn=loss, + optimizer=optimizer, + training_args=training_args, + fed_args=fed_args, + data_collator=data_collator, + tokenizer=tokenizer, + train_set=train_set, + val_set=validate_set, + local_mode=self.local_mode) + + return SetupReturn( + trainer=trainer, + model=model, + optimizer=optimizer, + loss=loss, + train_args=training_args, + fed_args=fed_args, + data_collator=data_collator) + + def server_setup(self, stage='train'): + + if stage == 'predict': + self.local_mode = True + if self.algo == 'fedavg': + server_class: FedAVGServer = FedAVG.server + else: + raise ValueError(f"algo {self.algo} not supported") + ctx = self.get_context() + trainer = server_class(ctx=ctx, local_mode=self.local_mode) + return SetupReturn(trainer=trainer) + + def train(self, + train_data: Optional[Union[str, + DataFrame]] = None, + validate_data: Optional[Union[str, + DataFrame]] = None, + output_dir: str = None, + saved_model_path: str = None): + + if self.is_client(): + train_set = self._prepare_data(train_data, 'train_data') + validate_set = self._prepare_data(validate_data, 'val_data') + setup = self.client_setup( + train_set=train_set, + validate_set=validate_set, + output_dir=output_dir, + saved_model=saved_model_path) + trainer = setup['trainer'] + self.trainer = trainer + trainer.train() + if output_dir is not None: + trainer.save_model(output_dir) + elif self.is_server(): + setup = self.server_setup() + trainer = setup['trainer'] + trainer.train() + + def _run_dataset_func(self, dataset, func_name): + + if hasattr(dataset, func_name): + output = getattr(dataset, func_name)() + if output is None: + logger.info( + f'dataset {type(dataset)}: {func_name} returns None, this will influence the output of predict') + return output + else: + logger.info( + f'dataset {type(dataset)} not implemented {func_name}, classes set to None, this will influence the output of predict') + return None + + def predict(self, + test_data: Union[str, + DataFrame], + saved_model_path: str = None) -> Union[DataFrame, + None]: + + if self.is_client(): + test_set = self._prepare_data(test_data, 'test_data') + if self.trainer is not None: + trainer = self.trainer + logger.info('trainer found, skip setting up') + else: + setup = self.client_setup( + saved_model=saved_model_path, stage='predict') + trainer = setup['trainer'] + + classes = self._run_dataset_func(test_set, 'get_classes') + match_ids = self._run_dataset_func(test_set, 'get_match_ids') + sample_ids = self._run_dataset_func(test_set, 'get_sample_ids') + match_id_name = self._run_dataset_func( + test_set, 'get_match_id_name') + sample_id_name = self._run_dataset_func( + test_set, 'get_sample_id_name') + pred_rs = trainer.predict(test_set) + rs_df = self.get_nn_output_dataframe( + self.get_context(), + pred_rs.predictions, + pred_rs.label_ids if hasattr(pred_rs, 'label_ids') else None, + match_ids, + sample_ids, + match_id_name=match_id_name, + sample_id_name=sample_id_name, + dataframe_format='fate_std', + task_type=self.task_type, + classes=classes) + return rs_df + else: + # server not predict + return diff --git a/python/fate/components/components/nn/source.yaml b/python/fate/components/components/nn/source.yaml new file mode 100644 index 0000000000..12a549be55 --- /dev/null +++ b/python/fate/components/components/nn/source.yaml @@ -0,0 +1 @@ +workspace: '' \ No newline at end of file diff --git a/python/fate/components/components/nn/torch/__init__.py b/python/fate/components/components/nn/torch/__init__.py new file mode 100644 index 0000000000..ef471ba686 --- /dev/null +++ b/python/fate/components/components/nn/torch/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# \ No newline at end of file diff --git a/python/fate/components/components/nn/torch/base.py b/python/fate/components/components/nn/torch/base.py new file mode 100644 index 0000000000..0bd5b650d9 --- /dev/null +++ b/python/fate/components/components/nn/torch/base.py @@ -0,0 +1,114 @@ +import torch as t +from fate.components.components.nn.loader import Loader +from torch.nn import Sequential as tSequential +import json + + +def convert_tuples_to_lists(data): + if isinstance(data, tuple): + return list(data) + elif isinstance(data, list): + return [convert_tuples_to_lists(item) for item in data] + elif isinstance(data, dict): + return {key: convert_tuples_to_lists( + value) for key, value in data.items()} + else: + return data + + +class TorchModule(object): + + def __init__(self): + t.nn.Module.__init__(self) + self.param_dict = dict() + self.optimizer = None + + def to_dict(self): + ret_dict = { + 'module_name': 'torch.nn', + 'item_name': str(type(self).__name__), + 'kwargs': convert_tuples_to_lists(self.param_dict) + } + return ret_dict + + +class TorchOptimizer(object): + + def __init__(self): + self.param_dict = dict() + self.torch_class = None + + def to_dict(self): + ret_dict = { + 'module_name': 'torch.optim', + 'item_name': type(self).__name__, + 'kwargs': convert_tuples_to_lists(self.param_dict) + } + return ret_dict + + def check_params(self, params): + + if isinstance( + params, + TorchModule) or isinstance( + params, + Sequential): + params.add_optimizer(self) + params = params.parameters() + else: + params = params + + l_param = list(params) + if len(l_param) == 0: + # fake parameters, for the case that there are only cust model + return [t.nn.Parameter(t.Tensor([0]))] + + return l_param + + def register_optimizer(self, input_): + + if input_ is None: + return + if isinstance( + input_, + TorchModule) or isinstance( + input_, + Sequential): + input_.add_optimizer(self) + + def to_torch_instance(self, parameters): + return self.torch_class(parameters, **self.param_dict) + + +def load_seq(seq_conf: dict) -> None: + + confs = list(dict(sorted(seq_conf.items())).values()) + model_list = [] + for conf in confs: + layer = Loader.from_dict(conf)() + model_list.append(layer) + + return tSequential(*model_list) + + +class Sequential(tSequential): + + def to_dict(self): + """ + get the structure of current sequential + """ + layer_confs = {} + idx = 0 + for k in self._modules: + ordered_name = idx + layer_confs[ordered_name] = self._modules[k].to_dict() + idx += 1 + ret_dict = { + 'module_name': 'fate.components.components.nn.torch.base', + 'item_name': load_seq.__name__, + 'kwargs': {'seq_conf': layer_confs} + } + return ret_dict + + def to_json(self): + return json.dumps(self.to_dict(), indent=4) diff --git a/python/fate/components/components/nn/torch/nn.py b/python/fate/components/components/nn/torch/nn.py new file mode 100644 index 0000000000..add91eed1a --- /dev/null +++ b/python/fate/components/components/nn/torch/nn.py @@ -0,0 +1,2430 @@ +from torch import nn +from fate.components.components.nn.torch.base import TorchModule + + +class Bilinear(nn.modules.linear.Bilinear, TorchModule): + + def __init__( + self, + in1_features, + in2_features, + out_features, + bias=True, + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['bias'] = bias + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['in1_features'] = in1_features + self.param_dict['in2_features'] = in2_features + self.param_dict['out_features'] = out_features + self.param_dict.update(kwargs) + nn.modules.linear.Bilinear.__init__(self, **self.param_dict) + + +class Identity(nn.modules.linear.Identity, TorchModule): + + def __init__(self, **kwargs): + TorchModule.__init__(self) + self.param_dict.update(kwargs) + nn.modules.linear.Identity.__init__(self, **self.param_dict) + + +class LazyLinear(nn.modules.linear.LazyLinear, TorchModule): + + def __init__( + self, + out_features, + bias=True, + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['bias'] = bias + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['out_features'] = out_features + self.param_dict.update(kwargs) + nn.modules.linear.LazyLinear.__init__(self, **self.param_dict) + + +class Linear(nn.modules.linear.Linear, TorchModule): + + def __init__( + self, + in_features, + out_features, + bias=True, + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['bias'] = bias + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['in_features'] = in_features + self.param_dict['out_features'] = out_features + self.param_dict.update(kwargs) + nn.modules.linear.Linear.__init__(self, **self.param_dict) + + +class NonDynamicallyQuantizableLinear( + nn.modules.linear.NonDynamicallyQuantizableLinear, + TorchModule): + + def __init__( + self, + in_features, + out_features, + bias=True, + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['bias'] = bias + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['in_features'] = in_features + self.param_dict['out_features'] = out_features + self.param_dict.update(kwargs) + nn.modules.linear.NonDynamicallyQuantizableLinear.__init__( + self, **self.param_dict) + + +class GRU(nn.modules.rnn.GRU, TorchModule): + + def __init__(self, **kwargs): + TorchModule.__init__(self) + self.param_dict.update(kwargs) + nn.modules.rnn.GRU.__init__(self, **self.param_dict) + + +class GRUCell(nn.modules.rnn.GRUCell, TorchModule): + + def __init__( + self, + input_size, + hidden_size, + bias=True, + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['bias'] = bias + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['input_size'] = input_size + self.param_dict['hidden_size'] = hidden_size + self.param_dict.update(kwargs) + nn.modules.rnn.GRUCell.__init__(self, **self.param_dict) + + +class LSTM(nn.modules.rnn.LSTM, TorchModule): + + def __init__(self, **kwargs): + TorchModule.__init__(self) + self.param_dict.update(kwargs) + nn.modules.rnn.LSTM.__init__(self, **self.param_dict) + + +class LSTMCell(nn.modules.rnn.LSTMCell, TorchModule): + + def __init__( + self, + input_size, + hidden_size, + bias=True, + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['bias'] = bias + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['input_size'] = input_size + self.param_dict['hidden_size'] = hidden_size + self.param_dict.update(kwargs) + nn.modules.rnn.LSTMCell.__init__(self, **self.param_dict) + + +class RNN(nn.modules.rnn.RNN, TorchModule): + + def __init__(self, **kwargs): + TorchModule.__init__(self) + self.param_dict.update(kwargs) + nn.modules.rnn.RNN.__init__(self, **self.param_dict) + + +class RNNBase(nn.modules.rnn.RNNBase, TorchModule): + + def __init__( + self, + mode, + input_size, + hidden_size, + num_layers=1, + bias=True, + batch_first=False, + dropout=0.0, + bidirectional=False, + proj_size=0, + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['num_layers'] = num_layers + self.param_dict['bias'] = bias + self.param_dict['batch_first'] = batch_first + self.param_dict['dropout'] = dropout + self.param_dict['bidirectional'] = bidirectional + self.param_dict['proj_size'] = proj_size + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['mode'] = mode + self.param_dict['input_size'] = input_size + self.param_dict['hidden_size'] = hidden_size + self.param_dict.update(kwargs) + nn.modules.rnn.RNNBase.__init__(self, **self.param_dict) + + +class RNNCell(nn.modules.rnn.RNNCell, TorchModule): + + def __init__( + self, + input_size, + hidden_size, + bias=True, + nonlinearity='tanh', + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['bias'] = bias + self.param_dict['nonlinearity'] = nonlinearity + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['input_size'] = input_size + self.param_dict['hidden_size'] = hidden_size + self.param_dict.update(kwargs) + nn.modules.rnn.RNNCell.__init__(self, **self.param_dict) + + +class RNNCellBase(nn.modules.rnn.RNNCellBase, TorchModule): + + def __init__( + self, + input_size, + hidden_size, + bias, + num_chunks, + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['input_size'] = input_size + self.param_dict['hidden_size'] = hidden_size + self.param_dict['bias'] = bias + self.param_dict['num_chunks'] = num_chunks + self.param_dict.update(kwargs) + nn.modules.rnn.RNNCellBase.__init__(self, **self.param_dict) + + +class Embedding(nn.modules.sparse.Embedding, TorchModule): + + def __init__( + self, + num_embeddings, + embedding_dim, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['padding_idx'] = padding_idx + self.param_dict['max_norm'] = max_norm + self.param_dict['norm_type'] = norm_type + self.param_dict['scale_grad_by_freq'] = scale_grad_by_freq + self.param_dict['sparse'] = sparse + self.param_dict['_weight'] = _weight + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['num_embeddings'] = num_embeddings + self.param_dict['embedding_dim'] = embedding_dim + self.param_dict.update(kwargs) + nn.modules.sparse.Embedding.__init__(self, **self.param_dict) + + +class EmbeddingBag(nn.modules.sparse.EmbeddingBag, TorchModule): + + def __init__( + self, + num_embeddings, + embedding_dim, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + mode='mean', + sparse=False, + _weight=None, + include_last_offset=False, + padding_idx=None, + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['max_norm'] = max_norm + self.param_dict['norm_type'] = norm_type + self.param_dict['scale_grad_by_freq'] = scale_grad_by_freq + self.param_dict['mode'] = mode + self.param_dict['sparse'] = sparse + self.param_dict['_weight'] = _weight + self.param_dict['include_last_offset'] = include_last_offset + self.param_dict['padding_idx'] = padding_idx + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['num_embeddings'] = num_embeddings + self.param_dict['embedding_dim'] = embedding_dim + self.param_dict.update(kwargs) + nn.modules.sparse.EmbeddingBag.__init__(self, **self.param_dict) + + +class AlphaDropout(nn.modules.dropout.AlphaDropout, TorchModule): + + def __init__(self, p=0.5, inplace=False, **kwargs): + TorchModule.__init__(self) + self.param_dict['p'] = p + self.param_dict['inplace'] = inplace + self.param_dict.update(kwargs) + nn.modules.dropout.AlphaDropout.__init__(self, **self.param_dict) + + +class Dropout(nn.modules.dropout.Dropout, TorchModule): + + def __init__(self, p=0.5, inplace=False, **kwargs): + TorchModule.__init__(self) + self.param_dict['p'] = p + self.param_dict['inplace'] = inplace + self.param_dict.update(kwargs) + nn.modules.dropout.Dropout.__init__(self, **self.param_dict) + + +class Dropout1d(nn.modules.dropout.Dropout1d, TorchModule): + + def __init__(self, p=0.5, inplace=False, **kwargs): + TorchModule.__init__(self) + self.param_dict['p'] = p + self.param_dict['inplace'] = inplace + self.param_dict.update(kwargs) + nn.modules.dropout.Dropout1d.__init__(self, **self.param_dict) + + +class Dropout2d(nn.modules.dropout.Dropout2d, TorchModule): + + def __init__(self, p=0.5, inplace=False, **kwargs): + TorchModule.__init__(self) + self.param_dict['p'] = p + self.param_dict['inplace'] = inplace + self.param_dict.update(kwargs) + nn.modules.dropout.Dropout2d.__init__(self, **self.param_dict) + + +class Dropout3d(nn.modules.dropout.Dropout3d, TorchModule): + + def __init__(self, p=0.5, inplace=False, **kwargs): + TorchModule.__init__(self) + self.param_dict['p'] = p + self.param_dict['inplace'] = inplace + self.param_dict.update(kwargs) + nn.modules.dropout.Dropout3d.__init__(self, **self.param_dict) + + +class FeatureAlphaDropout(nn.modules.dropout.FeatureAlphaDropout, TorchModule): + + def __init__(self, p=0.5, inplace=False, **kwargs): + TorchModule.__init__(self) + self.param_dict['p'] = p + self.param_dict['inplace'] = inplace + self.param_dict.update(kwargs) + nn.modules.dropout.FeatureAlphaDropout.__init__( + self, **self.param_dict) + + +class _DropoutNd(nn.modules.dropout._DropoutNd, TorchModule): + + def __init__(self, p=0.5, inplace=False, **kwargs): + TorchModule.__init__(self) + self.param_dict['p'] = p + self.param_dict['inplace'] = inplace + self.param_dict.update(kwargs) + nn.modules.dropout._DropoutNd.__init__(self, **self.param_dict) + + +class CELU(nn.modules.activation.CELU, TorchModule): + + def __init__(self, alpha=1.0, inplace=False, **kwargs): + TorchModule.__init__(self) + self.param_dict['alpha'] = alpha + self.param_dict['inplace'] = inplace + self.param_dict.update(kwargs) + nn.modules.activation.CELU.__init__(self, **self.param_dict) + + +class ELU(nn.modules.activation.ELU, TorchModule): + + def __init__(self, alpha=1.0, inplace=False, **kwargs): + TorchModule.__init__(self) + self.param_dict['alpha'] = alpha + self.param_dict['inplace'] = inplace + self.param_dict.update(kwargs) + nn.modules.activation.ELU.__init__(self, **self.param_dict) + + +class GELU(nn.modules.activation.GELU, TorchModule): + + def __init__(self, approximate='none', **kwargs): + TorchModule.__init__(self) + self.param_dict['approximate'] = approximate + self.param_dict.update(kwargs) + nn.modules.activation.GELU.__init__(self, **self.param_dict) + + +class GLU(nn.modules.activation.GLU, TorchModule): + + def __init__(self, dim=-1, **kwargs): + TorchModule.__init__(self) + self.param_dict['dim'] = dim + self.param_dict.update(kwargs) + nn.modules.activation.GLU.__init__(self, **self.param_dict) + + +class Hardshrink(nn.modules.activation.Hardshrink, TorchModule): + + def __init__(self, lambd=0.5, **kwargs): + TorchModule.__init__(self) + self.param_dict['lambd'] = lambd + self.param_dict.update(kwargs) + nn.modules.activation.Hardshrink.__init__(self, **self.param_dict) + + +class Hardsigmoid(nn.modules.activation.Hardsigmoid, TorchModule): + + def __init__(self, inplace=False, **kwargs): + TorchModule.__init__(self) + self.param_dict['inplace'] = inplace + self.param_dict.update(kwargs) + nn.modules.activation.Hardsigmoid.__init__(self, **self.param_dict) + + +class Hardswish(nn.modules.activation.Hardswish, TorchModule): + + def __init__(self, inplace=False, **kwargs): + TorchModule.__init__(self) + self.param_dict['inplace'] = inplace + self.param_dict.update(kwargs) + nn.modules.activation.Hardswish.__init__(self, **self.param_dict) + + +class Hardtanh(nn.modules.activation.Hardtanh, TorchModule): + + def __init__( + self, + min_val=-1.0, + max_val=1.0, + inplace=False, + min_value=None, + max_value=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['min_val'] = min_val + self.param_dict['max_val'] = max_val + self.param_dict['inplace'] = inplace + self.param_dict['min_value'] = min_value + self.param_dict['max_value'] = max_value + self.param_dict.update(kwargs) + nn.modules.activation.Hardtanh.__init__(self, **self.param_dict) + + +class LeakyReLU(nn.modules.activation.LeakyReLU, TorchModule): + + def __init__(self, negative_slope=0.01, inplace=False, **kwargs): + TorchModule.__init__(self) + self.param_dict['negative_slope'] = negative_slope + self.param_dict['inplace'] = inplace + self.param_dict.update(kwargs) + nn.modules.activation.LeakyReLU.__init__(self, **self.param_dict) + + +class LogSigmoid(nn.modules.activation.LogSigmoid, TorchModule): + + def __init__(self, **kwargs): + TorchModule.__init__(self) + self.param_dict.update(kwargs) + nn.modules.activation.LogSigmoid.__init__(self, **self.param_dict) + + +class LogSoftmax(nn.modules.activation.LogSoftmax, TorchModule): + + def __init__(self, dim=None, **kwargs): + TorchModule.__init__(self) + self.param_dict['dim'] = dim + self.param_dict.update(kwargs) + nn.modules.activation.LogSoftmax.__init__(self, **self.param_dict) + + +class Mish(nn.modules.activation.Mish, TorchModule): + + def __init__(self, inplace=False, **kwargs): + TorchModule.__init__(self) + self.param_dict['inplace'] = inplace + self.param_dict.update(kwargs) + nn.modules.activation.Mish.__init__(self, **self.param_dict) + + +class MultiheadAttention(nn.modules.activation.MultiheadAttention, TorchModule): + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['dropout'] = dropout + self.param_dict['bias'] = bias + self.param_dict['add_bias_kv'] = add_bias_kv + self.param_dict['add_zero_attn'] = add_zero_attn + self.param_dict['kdim'] = kdim + self.param_dict['vdim'] = vdim + self.param_dict['batch_first'] = batch_first + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['embed_dim'] = embed_dim + self.param_dict['num_heads'] = num_heads + self.param_dict.update(kwargs) + nn.modules.activation.MultiheadAttention.__init__( + self, **self.param_dict) + + +class PReLU(nn.modules.activation.PReLU, TorchModule): + + def __init__( + self, + num_parameters=1, + init=0.25, + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['num_parameters'] = num_parameters + self.param_dict['init'] = init + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict.update(kwargs) + nn.modules.activation.PReLU.__init__(self, **self.param_dict) + + +class RReLU(nn.modules.activation.RReLU, TorchModule): + + def __init__( + self, + lower=0.125, + upper=0.3333333333333333, + inplace=False, + **kwargs): + TorchModule.__init__(self) + self.param_dict['lower'] = lower + self.param_dict['upper'] = upper + self.param_dict['inplace'] = inplace + self.param_dict.update(kwargs) + nn.modules.activation.RReLU.__init__(self, **self.param_dict) + + +class ReLU(nn.modules.activation.ReLU, TorchModule): + + def __init__(self, inplace=False, **kwargs): + TorchModule.__init__(self) + self.param_dict['inplace'] = inplace + self.param_dict.update(kwargs) + nn.modules.activation.ReLU.__init__(self, **self.param_dict) + + +class ReLU6(nn.modules.activation.ReLU6, TorchModule): + + def __init__(self, inplace=False, **kwargs): + TorchModule.__init__(self) + self.param_dict['inplace'] = inplace + self.param_dict.update(kwargs) + nn.modules.activation.ReLU6.__init__(self, **self.param_dict) + + +class SELU(nn.modules.activation.SELU, TorchModule): + + def __init__(self, inplace=False, **kwargs): + TorchModule.__init__(self) + self.param_dict['inplace'] = inplace + self.param_dict.update(kwargs) + nn.modules.activation.SELU.__init__(self, **self.param_dict) + + +class SiLU(nn.modules.activation.SiLU, TorchModule): + + def __init__(self, inplace=False, **kwargs): + TorchModule.__init__(self) + self.param_dict['inplace'] = inplace + self.param_dict.update(kwargs) + nn.modules.activation.SiLU.__init__(self, **self.param_dict) + + +class Sigmoid(nn.modules.activation.Sigmoid, TorchModule): + + def __init__(self, **kwargs): + TorchModule.__init__(self) + self.param_dict.update(kwargs) + nn.modules.activation.Sigmoid.__init__(self, **self.param_dict) + + +class Softmax(nn.modules.activation.Softmax, TorchModule): + + def __init__(self, dim=None, **kwargs): + TorchModule.__init__(self) + self.param_dict['dim'] = dim + self.param_dict.update(kwargs) + nn.modules.activation.Softmax.__init__(self, **self.param_dict) + + +class Softmax2d(nn.modules.activation.Softmax2d, TorchModule): + + def __init__(self, **kwargs): + TorchModule.__init__(self) + self.param_dict.update(kwargs) + nn.modules.activation.Softmax2d.__init__(self, **self.param_dict) + + +class Softmin(nn.modules.activation.Softmin, TorchModule): + + def __init__(self, dim=None, **kwargs): + TorchModule.__init__(self) + self.param_dict['dim'] = dim + self.param_dict.update(kwargs) + nn.modules.activation.Softmin.__init__(self, **self.param_dict) + + +class Softplus(nn.modules.activation.Softplus, TorchModule): + + def __init__(self, beta=1, threshold=20, **kwargs): + TorchModule.__init__(self) + self.param_dict['beta'] = beta + self.param_dict['threshold'] = threshold + self.param_dict.update(kwargs) + nn.modules.activation.Softplus.__init__(self, **self.param_dict) + + +class Softshrink(nn.modules.activation.Softshrink, TorchModule): + + def __init__(self, lambd=0.5, **kwargs): + TorchModule.__init__(self) + self.param_dict['lambd'] = lambd + self.param_dict.update(kwargs) + nn.modules.activation.Softshrink.__init__(self, **self.param_dict) + + +class Softsign(nn.modules.activation.Softsign, TorchModule): + + def __init__(self, **kwargs): + TorchModule.__init__(self) + self.param_dict.update(kwargs) + nn.modules.activation.Softsign.__init__(self, **self.param_dict) + + +class Tanh(nn.modules.activation.Tanh, TorchModule): + + def __init__(self, **kwargs): + TorchModule.__init__(self) + self.param_dict.update(kwargs) + nn.modules.activation.Tanh.__init__(self, **self.param_dict) + + +class Tanhshrink(nn.modules.activation.Tanhshrink, TorchModule): + + def __init__(self, **kwargs): + TorchModule.__init__(self) + self.param_dict.update(kwargs) + nn.modules.activation.Tanhshrink.__init__(self, **self.param_dict) + + +class Threshold(nn.modules.activation.Threshold, TorchModule): + + def __init__(self, threshold, value, inplace=False, **kwargs): + TorchModule.__init__(self) + self.param_dict['inplace'] = inplace + self.param_dict['threshold'] = threshold + self.param_dict['value'] = value + self.param_dict.update(kwargs) + nn.modules.activation.Threshold.__init__(self, **self.param_dict) + + +class Conv1d(nn.modules.conv.Conv1d, TorchModule): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode='zeros', + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['stride'] = stride + self.param_dict['padding'] = padding + self.param_dict['dilation'] = dilation + self.param_dict['groups'] = groups + self.param_dict['bias'] = bias + self.param_dict['padding_mode'] = padding_mode + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['in_channels'] = in_channels + self.param_dict['out_channels'] = out_channels + self.param_dict['kernel_size'] = kernel_size + self.param_dict.update(kwargs) + nn.modules.conv.Conv1d.__init__(self, **self.param_dict) + + +class Conv2d(nn.modules.conv.Conv2d, TorchModule): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode='zeros', + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['stride'] = stride + self.param_dict['padding'] = padding + self.param_dict['dilation'] = dilation + self.param_dict['groups'] = groups + self.param_dict['bias'] = bias + self.param_dict['padding_mode'] = padding_mode + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['in_channels'] = in_channels + self.param_dict['out_channels'] = out_channels + self.param_dict['kernel_size'] = kernel_size + self.param_dict.update(kwargs) + nn.modules.conv.Conv2d.__init__(self, **self.param_dict) + + +class Conv3d(nn.modules.conv.Conv3d, TorchModule): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode='zeros', + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['stride'] = stride + self.param_dict['padding'] = padding + self.param_dict['dilation'] = dilation + self.param_dict['groups'] = groups + self.param_dict['bias'] = bias + self.param_dict['padding_mode'] = padding_mode + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['in_channels'] = in_channels + self.param_dict['out_channels'] = out_channels + self.param_dict['kernel_size'] = kernel_size + self.param_dict.update(kwargs) + nn.modules.conv.Conv3d.__init__(self, **self.param_dict) + + +class ConvTranspose1d(nn.modules.conv.ConvTranspose1d, TorchModule): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode='zeros', + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['stride'] = stride + self.param_dict['padding'] = padding + self.param_dict['output_padding'] = output_padding + self.param_dict['groups'] = groups + self.param_dict['bias'] = bias + self.param_dict['dilation'] = dilation + self.param_dict['padding_mode'] = padding_mode + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['in_channels'] = in_channels + self.param_dict['out_channels'] = out_channels + self.param_dict['kernel_size'] = kernel_size + self.param_dict.update(kwargs) + nn.modules.conv.ConvTranspose1d.__init__(self, **self.param_dict) + + +class ConvTranspose2d(nn.modules.conv.ConvTranspose2d, TorchModule): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode='zeros', + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['stride'] = stride + self.param_dict['padding'] = padding + self.param_dict['output_padding'] = output_padding + self.param_dict['groups'] = groups + self.param_dict['bias'] = bias + self.param_dict['dilation'] = dilation + self.param_dict['padding_mode'] = padding_mode + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['in_channels'] = in_channels + self.param_dict['out_channels'] = out_channels + self.param_dict['kernel_size'] = kernel_size + self.param_dict.update(kwargs) + nn.modules.conv.ConvTranspose2d.__init__(self, **self.param_dict) + + +class ConvTranspose3d(nn.modules.conv.ConvTranspose3d, TorchModule): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode='zeros', + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['stride'] = stride + self.param_dict['padding'] = padding + self.param_dict['output_padding'] = output_padding + self.param_dict['groups'] = groups + self.param_dict['bias'] = bias + self.param_dict['dilation'] = dilation + self.param_dict['padding_mode'] = padding_mode + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['in_channels'] = in_channels + self.param_dict['out_channels'] = out_channels + self.param_dict['kernel_size'] = kernel_size + self.param_dict.update(kwargs) + nn.modules.conv.ConvTranspose3d.__init__(self, **self.param_dict) + + +class LazyConv1d(nn.modules.conv.LazyConv1d, TorchModule): + + def __init__( + self, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode='zeros', + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['stride'] = stride + self.param_dict['padding'] = padding + self.param_dict['dilation'] = dilation + self.param_dict['groups'] = groups + self.param_dict['bias'] = bias + self.param_dict['padding_mode'] = padding_mode + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['out_channels'] = out_channels + self.param_dict['kernel_size'] = kernel_size + self.param_dict.update(kwargs) + nn.modules.conv.LazyConv1d.__init__(self, **self.param_dict) + + +class LazyConv2d(nn.modules.conv.LazyConv2d, TorchModule): + + def __init__( + self, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode='zeros', + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['stride'] = stride + self.param_dict['padding'] = padding + self.param_dict['dilation'] = dilation + self.param_dict['groups'] = groups + self.param_dict['bias'] = bias + self.param_dict['padding_mode'] = padding_mode + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['out_channels'] = out_channels + self.param_dict['kernel_size'] = kernel_size + self.param_dict.update(kwargs) + nn.modules.conv.LazyConv2d.__init__(self, **self.param_dict) + + +class LazyConv3d(nn.modules.conv.LazyConv3d, TorchModule): + + def __init__( + self, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode='zeros', + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['stride'] = stride + self.param_dict['padding'] = padding + self.param_dict['dilation'] = dilation + self.param_dict['groups'] = groups + self.param_dict['bias'] = bias + self.param_dict['padding_mode'] = padding_mode + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['out_channels'] = out_channels + self.param_dict['kernel_size'] = kernel_size + self.param_dict.update(kwargs) + nn.modules.conv.LazyConv3d.__init__(self, **self.param_dict) + + +class LazyConvTranspose1d(nn.modules.conv.LazyConvTranspose1d, TorchModule): + + def __init__( + self, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode='zeros', + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['stride'] = stride + self.param_dict['padding'] = padding + self.param_dict['output_padding'] = output_padding + self.param_dict['groups'] = groups + self.param_dict['bias'] = bias + self.param_dict['dilation'] = dilation + self.param_dict['padding_mode'] = padding_mode + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['out_channels'] = out_channels + self.param_dict['kernel_size'] = kernel_size + self.param_dict.update(kwargs) + nn.modules.conv.LazyConvTranspose1d.__init__(self, **self.param_dict) + + +class LazyConvTranspose2d(nn.modules.conv.LazyConvTranspose2d, TorchModule): + + def __init__( + self, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode='zeros', + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['stride'] = stride + self.param_dict['padding'] = padding + self.param_dict['output_padding'] = output_padding + self.param_dict['groups'] = groups + self.param_dict['bias'] = bias + self.param_dict['dilation'] = dilation + self.param_dict['padding_mode'] = padding_mode + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['out_channels'] = out_channels + self.param_dict['kernel_size'] = kernel_size + self.param_dict.update(kwargs) + nn.modules.conv.LazyConvTranspose2d.__init__(self, **self.param_dict) + + +class LazyConvTranspose3d(nn.modules.conv.LazyConvTranspose3d, TorchModule): + + def __init__( + self, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode='zeros', + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['stride'] = stride + self.param_dict['padding'] = padding + self.param_dict['output_padding'] = output_padding + self.param_dict['groups'] = groups + self.param_dict['bias'] = bias + self.param_dict['dilation'] = dilation + self.param_dict['padding_mode'] = padding_mode + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['out_channels'] = out_channels + self.param_dict['kernel_size'] = kernel_size + self.param_dict.update(kwargs) + nn.modules.conv.LazyConvTranspose3d.__init__(self, **self.param_dict) + + +class _ConvNd(nn.modules.conv._ConvNd, TorchModule): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + bias, + padding_mode, + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['in_channels'] = in_channels + self.param_dict['out_channels'] = out_channels + self.param_dict['kernel_size'] = kernel_size + self.param_dict['stride'] = stride + self.param_dict['padding'] = padding + self.param_dict['dilation'] = dilation + self.param_dict['transposed'] = transposed + self.param_dict['output_padding'] = output_padding + self.param_dict['groups'] = groups + self.param_dict['bias'] = bias + self.param_dict['padding_mode'] = padding_mode + self.param_dict.update(kwargs) + nn.modules.conv._ConvNd.__init__(self, **self.param_dict) + + +class _ConvTransposeMixin(nn.modules.conv._ConvTransposeMixin, TorchModule): + + def __init__(self, **kwargs): + TorchModule.__init__(self) + self.param_dict.update(kwargs) + nn.modules.conv._ConvTransposeMixin.__init__(self, **self.param_dict) + + +class _ConvTransposeNd(nn.modules.conv._ConvTransposeNd, TorchModule): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + bias, + padding_mode, + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['in_channels'] = in_channels + self.param_dict['out_channels'] = out_channels + self.param_dict['kernel_size'] = kernel_size + self.param_dict['stride'] = stride + self.param_dict['padding'] = padding + self.param_dict['dilation'] = dilation + self.param_dict['transposed'] = transposed + self.param_dict['output_padding'] = output_padding + self.param_dict['groups'] = groups + self.param_dict['bias'] = bias + self.param_dict['padding_mode'] = padding_mode + self.param_dict.update(kwargs) + nn.modules.conv._ConvTransposeNd.__init__(self, **self.param_dict) + + +class _LazyConvXdMixin(nn.modules.conv._LazyConvXdMixin, TorchModule): + + def __init__(self, **kwargs): + TorchModule.__init__(self) + self.param_dict.update(kwargs) + nn.modules.conv._LazyConvXdMixin.__init__(self, **self.param_dict) + + +class Transformer(nn.modules.transformer.Transformer, TorchModule): + + def __init__( + self, + d_model=512, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + custom_encoder=None, + custom_decoder=None, + layer_norm_eps=1e-05, + batch_first=False, + norm_first=False, + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['d_model'] = d_model + self.param_dict['nhead'] = nhead + self.param_dict['num_encoder_layers'] = num_encoder_layers + self.param_dict['num_decoder_layers'] = num_decoder_layers + self.param_dict['dim_feedforward'] = dim_feedforward + self.param_dict['dropout'] = dropout + self.param_dict['custom_encoder'] = custom_encoder + self.param_dict['custom_decoder'] = custom_decoder + self.param_dict['layer_norm_eps'] = layer_norm_eps + self.param_dict['batch_first'] = batch_first + self.param_dict['norm_first'] = norm_first + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict.update(kwargs) + nn.modules.transformer.Transformer.__init__(self, **self.param_dict) + + +class TransformerDecoder(nn.modules.transformer.TransformerDecoder, TorchModule): + + def __init__(self, decoder_layer, num_layers, norm=None, **kwargs): + TorchModule.__init__(self) + self.param_dict['norm'] = norm + self.param_dict['decoder_layer'] = decoder_layer + self.param_dict['num_layers'] = num_layers + self.param_dict.update(kwargs) + nn.modules.transformer.TransformerDecoder.__init__( + self, **self.param_dict) + + +class TransformerDecoderLayer( + nn.modules.transformer.TransformerDecoderLayer, + TorchModule): + + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + layer_norm_eps=1e-05, + batch_first=False, + norm_first=False, + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['dim_feedforward'] = dim_feedforward + self.param_dict['dropout'] = dropout + self.param_dict['layer_norm_eps'] = layer_norm_eps + self.param_dict['batch_first'] = batch_first + self.param_dict['norm_first'] = norm_first + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['d_model'] = d_model + self.param_dict['nhead'] = nhead + self.param_dict.update(kwargs) + nn.modules.transformer.TransformerDecoderLayer.__init__( + self, **self.param_dict) + + +class TransformerEncoder(nn.modules.transformer.TransformerEncoder, TorchModule): + + def __init__( + self, + encoder_layer, + num_layers, + norm=None, + enable_nested_tensor=True, + mask_check=True, + **kwargs): + TorchModule.__init__(self) + self.param_dict['norm'] = norm + self.param_dict['enable_nested_tensor'] = enable_nested_tensor + self.param_dict['mask_check'] = mask_check + self.param_dict['encoder_layer'] = encoder_layer + self.param_dict['num_layers'] = num_layers + self.param_dict.update(kwargs) + nn.modules.transformer.TransformerEncoder.__init__( + self, **self.param_dict) + + +class TransformerEncoderLayer( + nn.modules.transformer.TransformerEncoderLayer, + TorchModule): + + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + layer_norm_eps=1e-05, + batch_first=False, + norm_first=False, + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['dim_feedforward'] = dim_feedforward + self.param_dict['dropout'] = dropout + self.param_dict['layer_norm_eps'] = layer_norm_eps + self.param_dict['batch_first'] = batch_first + self.param_dict['norm_first'] = norm_first + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['d_model'] = d_model + self.param_dict['nhead'] = nhead + self.param_dict.update(kwargs) + nn.modules.transformer.TransformerEncoderLayer.__init__( + self, **self.param_dict) + + +class AdaptiveAvgPool1d(nn.modules.pooling.AdaptiveAvgPool1d, TorchModule): + + def __init__(self, output_size, **kwargs): + TorchModule.__init__(self) + self.param_dict['output_size'] = output_size + self.param_dict.update(kwargs) + nn.modules.pooling.AdaptiveAvgPool1d.__init__(self, **self.param_dict) + + +class AdaptiveAvgPool2d(nn.modules.pooling.AdaptiveAvgPool2d, TorchModule): + + def __init__(self, output_size, **kwargs): + TorchModule.__init__(self) + self.param_dict['output_size'] = output_size + self.param_dict.update(kwargs) + nn.modules.pooling.AdaptiveAvgPool2d.__init__(self, **self.param_dict) + + +class AdaptiveAvgPool3d(nn.modules.pooling.AdaptiveAvgPool3d, TorchModule): + + def __init__(self, output_size, **kwargs): + TorchModule.__init__(self) + self.param_dict['output_size'] = output_size + self.param_dict.update(kwargs) + nn.modules.pooling.AdaptiveAvgPool3d.__init__(self, **self.param_dict) + + +class AdaptiveMaxPool1d(nn.modules.pooling.AdaptiveMaxPool1d, TorchModule): + + def __init__(self, output_size, return_indices=False, **kwargs): + TorchModule.__init__(self) + self.param_dict['return_indices'] = return_indices + self.param_dict['output_size'] = output_size + self.param_dict.update(kwargs) + nn.modules.pooling.AdaptiveMaxPool1d.__init__(self, **self.param_dict) + + +class AdaptiveMaxPool2d(nn.modules.pooling.AdaptiveMaxPool2d, TorchModule): + + def __init__(self, output_size, return_indices=False, **kwargs): + TorchModule.__init__(self) + self.param_dict['return_indices'] = return_indices + self.param_dict['output_size'] = output_size + self.param_dict.update(kwargs) + nn.modules.pooling.AdaptiveMaxPool2d.__init__(self, **self.param_dict) + + +class AdaptiveMaxPool3d(nn.modules.pooling.AdaptiveMaxPool3d, TorchModule): + + def __init__(self, output_size, return_indices=False, **kwargs): + TorchModule.__init__(self) + self.param_dict['return_indices'] = return_indices + self.param_dict['output_size'] = output_size + self.param_dict.update(kwargs) + nn.modules.pooling.AdaptiveMaxPool3d.__init__(self, **self.param_dict) + + +class AvgPool1d(nn.modules.pooling.AvgPool1d, TorchModule): + + def __init__( + self, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + **kwargs): + TorchModule.__init__(self) + self.param_dict['stride'] = stride + self.param_dict['padding'] = padding + self.param_dict['ceil_mode'] = ceil_mode + self.param_dict['count_include_pad'] = count_include_pad + self.param_dict['kernel_size'] = kernel_size + self.param_dict.update(kwargs) + nn.modules.pooling.AvgPool1d.__init__(self, **self.param_dict) + + +class AvgPool2d(nn.modules.pooling.AvgPool2d, TorchModule): + + def __init__( + self, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['stride'] = stride + self.param_dict['padding'] = padding + self.param_dict['ceil_mode'] = ceil_mode + self.param_dict['count_include_pad'] = count_include_pad + self.param_dict['divisor_override'] = divisor_override + self.param_dict['kernel_size'] = kernel_size + self.param_dict.update(kwargs) + nn.modules.pooling.AvgPool2d.__init__(self, **self.param_dict) + + +class AvgPool3d(nn.modules.pooling.AvgPool3d, TorchModule): + + def __init__( + self, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['stride'] = stride + self.param_dict['padding'] = padding + self.param_dict['ceil_mode'] = ceil_mode + self.param_dict['count_include_pad'] = count_include_pad + self.param_dict['divisor_override'] = divisor_override + self.param_dict['kernel_size'] = kernel_size + self.param_dict.update(kwargs) + nn.modules.pooling.AvgPool3d.__init__(self, **self.param_dict) + + +class FractionalMaxPool2d(nn.modules.pooling.FractionalMaxPool2d, TorchModule): + + def __init__( + self, + kernel_size, + output_size=None, + output_ratio=None, + return_indices=False, + _random_samples=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['output_size'] = output_size + self.param_dict['output_ratio'] = output_ratio + self.param_dict['return_indices'] = return_indices + self.param_dict['_random_samples'] = _random_samples + self.param_dict['kernel_size'] = kernel_size + self.param_dict.update(kwargs) + nn.modules.pooling.FractionalMaxPool2d.__init__( + self, **self.param_dict) + + +class FractionalMaxPool3d(nn.modules.pooling.FractionalMaxPool3d, TorchModule): + + def __init__( + self, + kernel_size, + output_size=None, + output_ratio=None, + return_indices=False, + _random_samples=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['output_size'] = output_size + self.param_dict['output_ratio'] = output_ratio + self.param_dict['return_indices'] = return_indices + self.param_dict['_random_samples'] = _random_samples + self.param_dict['kernel_size'] = kernel_size + self.param_dict.update(kwargs) + nn.modules.pooling.FractionalMaxPool3d.__init__( + self, **self.param_dict) + + +class LPPool1d(nn.modules.pooling.LPPool1d, TorchModule): + + def __init__( + self, + norm_type, + kernel_size, + stride=None, + ceil_mode=False, + **kwargs): + TorchModule.__init__(self) + self.param_dict['stride'] = stride + self.param_dict['ceil_mode'] = ceil_mode + self.param_dict['norm_type'] = norm_type + self.param_dict['kernel_size'] = kernel_size + self.param_dict.update(kwargs) + nn.modules.pooling.LPPool1d.__init__(self, **self.param_dict) + + +class LPPool2d(nn.modules.pooling.LPPool2d, TorchModule): + + def __init__( + self, + norm_type, + kernel_size, + stride=None, + ceil_mode=False, + **kwargs): + TorchModule.__init__(self) + self.param_dict['stride'] = stride + self.param_dict['ceil_mode'] = ceil_mode + self.param_dict['norm_type'] = norm_type + self.param_dict['kernel_size'] = kernel_size + self.param_dict.update(kwargs) + nn.modules.pooling.LPPool2d.__init__(self, **self.param_dict) + + +class MaxPool1d(nn.modules.pooling.MaxPool1d, TorchModule): + + def __init__( + self, + kernel_size, + stride=None, + padding=0, + dilation=1, + return_indices=False, + ceil_mode=False, + **kwargs): + TorchModule.__init__(self) + self.param_dict['stride'] = stride + self.param_dict['padding'] = padding + self.param_dict['dilation'] = dilation + self.param_dict['return_indices'] = return_indices + self.param_dict['ceil_mode'] = ceil_mode + self.param_dict['kernel_size'] = kernel_size + self.param_dict.update(kwargs) + nn.modules.pooling.MaxPool1d.__init__(self, **self.param_dict) + + +class MaxPool2d(nn.modules.pooling.MaxPool2d, TorchModule): + + def __init__( + self, + kernel_size, + stride=None, + padding=0, + dilation=1, + return_indices=False, + ceil_mode=False, + **kwargs): + TorchModule.__init__(self) + self.param_dict['stride'] = stride + self.param_dict['padding'] = padding + self.param_dict['dilation'] = dilation + self.param_dict['return_indices'] = return_indices + self.param_dict['ceil_mode'] = ceil_mode + self.param_dict['kernel_size'] = kernel_size + self.param_dict.update(kwargs) + nn.modules.pooling.MaxPool2d.__init__(self, **self.param_dict) + + +class MaxPool3d(nn.modules.pooling.MaxPool3d, TorchModule): + + def __init__( + self, + kernel_size, + stride=None, + padding=0, + dilation=1, + return_indices=False, + ceil_mode=False, + **kwargs): + TorchModule.__init__(self) + self.param_dict['stride'] = stride + self.param_dict['padding'] = padding + self.param_dict['dilation'] = dilation + self.param_dict['return_indices'] = return_indices + self.param_dict['ceil_mode'] = ceil_mode + self.param_dict['kernel_size'] = kernel_size + self.param_dict.update(kwargs) + nn.modules.pooling.MaxPool3d.__init__(self, **self.param_dict) + + +class MaxUnpool1d(nn.modules.pooling.MaxUnpool1d, TorchModule): + + def __init__(self, kernel_size, stride=None, padding=0, **kwargs): + TorchModule.__init__(self) + self.param_dict['stride'] = stride + self.param_dict['padding'] = padding + self.param_dict['kernel_size'] = kernel_size + self.param_dict.update(kwargs) + nn.modules.pooling.MaxUnpool1d.__init__(self, **self.param_dict) + + +class MaxUnpool2d(nn.modules.pooling.MaxUnpool2d, TorchModule): + + def __init__(self, kernel_size, stride=None, padding=0, **kwargs): + TorchModule.__init__(self) + self.param_dict['stride'] = stride + self.param_dict['padding'] = padding + self.param_dict['kernel_size'] = kernel_size + self.param_dict.update(kwargs) + nn.modules.pooling.MaxUnpool2d.__init__(self, **self.param_dict) + + +class MaxUnpool3d(nn.modules.pooling.MaxUnpool3d, TorchModule): + + def __init__(self, kernel_size, stride=None, padding=0, **kwargs): + TorchModule.__init__(self) + self.param_dict['stride'] = stride + self.param_dict['padding'] = padding + self.param_dict['kernel_size'] = kernel_size + self.param_dict.update(kwargs) + nn.modules.pooling.MaxUnpool3d.__init__(self, **self.param_dict) + + +class _AdaptiveAvgPoolNd(nn.modules.pooling._AdaptiveAvgPoolNd, TorchModule): + + def __init__(self, output_size, **kwargs): + TorchModule.__init__(self) + self.param_dict['output_size'] = output_size + self.param_dict.update(kwargs) + nn.modules.pooling._AdaptiveAvgPoolNd.__init__(self, **self.param_dict) + + +class _AdaptiveMaxPoolNd(nn.modules.pooling._AdaptiveMaxPoolNd, TorchModule): + + def __init__(self, output_size, return_indices=False, **kwargs): + TorchModule.__init__(self) + self.param_dict['return_indices'] = return_indices + self.param_dict['output_size'] = output_size + self.param_dict.update(kwargs) + nn.modules.pooling._AdaptiveMaxPoolNd.__init__(self, **self.param_dict) + + +class _AvgPoolNd(nn.modules.pooling._AvgPoolNd, TorchModule): + + def __init__(self, **kwargs): + TorchModule.__init__(self) + self.param_dict.update(kwargs) + nn.modules.pooling._AvgPoolNd.__init__(self, **self.param_dict) + + +class _LPPoolNd(nn.modules.pooling._LPPoolNd, TorchModule): + + def __init__( + self, + norm_type, + kernel_size, + stride=None, + ceil_mode=False, + **kwargs): + TorchModule.__init__(self) + self.param_dict['stride'] = stride + self.param_dict['ceil_mode'] = ceil_mode + self.param_dict['norm_type'] = norm_type + self.param_dict['kernel_size'] = kernel_size + self.param_dict.update(kwargs) + nn.modules.pooling._LPPoolNd.__init__(self, **self.param_dict) + + +class _MaxPoolNd(nn.modules.pooling._MaxPoolNd, TorchModule): + + def __init__( + self, + kernel_size, + stride=None, + padding=0, + dilation=1, + return_indices=False, + ceil_mode=False, + **kwargs): + TorchModule.__init__(self) + self.param_dict['stride'] = stride + self.param_dict['padding'] = padding + self.param_dict['dilation'] = dilation + self.param_dict['return_indices'] = return_indices + self.param_dict['ceil_mode'] = ceil_mode + self.param_dict['kernel_size'] = kernel_size + self.param_dict.update(kwargs) + nn.modules.pooling._MaxPoolNd.__init__(self, **self.param_dict) + + +class _MaxUnpoolNd(nn.modules.pooling._MaxUnpoolNd, TorchModule): + + def __init__(self, **kwargs): + TorchModule.__init__(self) + self.param_dict.update(kwargs) + nn.modules.pooling._MaxUnpoolNd.__init__(self, **self.param_dict) + + +class BatchNorm1d(nn.modules.batchnorm.BatchNorm1d, TorchModule): + + def __init__( + self, + num_features, + eps=1e-05, + momentum=0.1, + affine=True, + track_running_stats=True, + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['eps'] = eps + self.param_dict['momentum'] = momentum + self.param_dict['affine'] = affine + self.param_dict['track_running_stats'] = track_running_stats + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['num_features'] = num_features + self.param_dict.update(kwargs) + nn.modules.batchnorm.BatchNorm1d.__init__(self, **self.param_dict) + + +class BatchNorm2d(nn.modules.batchnorm.BatchNorm2d, TorchModule): + + def __init__( + self, + num_features, + eps=1e-05, + momentum=0.1, + affine=True, + track_running_stats=True, + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['eps'] = eps + self.param_dict['momentum'] = momentum + self.param_dict['affine'] = affine + self.param_dict['track_running_stats'] = track_running_stats + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['num_features'] = num_features + self.param_dict.update(kwargs) + nn.modules.batchnorm.BatchNorm2d.__init__(self, **self.param_dict) + + +class BatchNorm3d(nn.modules.batchnorm.BatchNorm3d, TorchModule): + + def __init__( + self, + num_features, + eps=1e-05, + momentum=0.1, + affine=True, + track_running_stats=True, + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['eps'] = eps + self.param_dict['momentum'] = momentum + self.param_dict['affine'] = affine + self.param_dict['track_running_stats'] = track_running_stats + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['num_features'] = num_features + self.param_dict.update(kwargs) + nn.modules.batchnorm.BatchNorm3d.__init__(self, **self.param_dict) + + +class LazyBatchNorm1d(nn.modules.batchnorm.LazyBatchNorm1d, TorchModule): + + def __init__( + self, + eps=1e-05, + momentum=0.1, + affine=True, + track_running_stats=True, + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['eps'] = eps + self.param_dict['momentum'] = momentum + self.param_dict['affine'] = affine + self.param_dict['track_running_stats'] = track_running_stats + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict.update(kwargs) + nn.modules.batchnorm.LazyBatchNorm1d.__init__(self, **self.param_dict) + + +class LazyBatchNorm2d(nn.modules.batchnorm.LazyBatchNorm2d, TorchModule): + + def __init__( + self, + eps=1e-05, + momentum=0.1, + affine=True, + track_running_stats=True, + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['eps'] = eps + self.param_dict['momentum'] = momentum + self.param_dict['affine'] = affine + self.param_dict['track_running_stats'] = track_running_stats + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict.update(kwargs) + nn.modules.batchnorm.LazyBatchNorm2d.__init__(self, **self.param_dict) + + +class LazyBatchNorm3d(nn.modules.batchnorm.LazyBatchNorm3d, TorchModule): + + def __init__( + self, + eps=1e-05, + momentum=0.1, + affine=True, + track_running_stats=True, + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['eps'] = eps + self.param_dict['momentum'] = momentum + self.param_dict['affine'] = affine + self.param_dict['track_running_stats'] = track_running_stats + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict.update(kwargs) + nn.modules.batchnorm.LazyBatchNorm3d.__init__(self, **self.param_dict) + + +class SyncBatchNorm(nn.modules.batchnorm.SyncBatchNorm, TorchModule): + + def __init__( + self, + num_features, + eps=1e-05, + momentum=0.1, + affine=True, + track_running_stats=True, + process_group=None, + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['eps'] = eps + self.param_dict['momentum'] = momentum + self.param_dict['affine'] = affine + self.param_dict['track_running_stats'] = track_running_stats + self.param_dict['process_group'] = process_group + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['num_features'] = num_features + self.param_dict.update(kwargs) + nn.modules.batchnorm.SyncBatchNorm.__init__(self, **self.param_dict) + + +class _BatchNorm(nn.modules.batchnorm._BatchNorm, TorchModule): + + def __init__( + self, + num_features, + eps=1e-05, + momentum=0.1, + affine=True, + track_running_stats=True, + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['eps'] = eps + self.param_dict['momentum'] = momentum + self.param_dict['affine'] = affine + self.param_dict['track_running_stats'] = track_running_stats + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['num_features'] = num_features + self.param_dict.update(kwargs) + nn.modules.batchnorm._BatchNorm.__init__(self, **self.param_dict) + + +class _LazyNormBase(nn.modules.batchnorm._LazyNormBase, TorchModule): + + def __init__( + self, + eps=1e-05, + momentum=0.1, + affine=True, + track_running_stats=True, + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['eps'] = eps + self.param_dict['momentum'] = momentum + self.param_dict['affine'] = affine + self.param_dict['track_running_stats'] = track_running_stats + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict.update(kwargs) + nn.modules.batchnorm._LazyNormBase.__init__(self, **self.param_dict) + + +class _NormBase(nn.modules.batchnorm._NormBase, TorchModule): + + def __init__( + self, + num_features, + eps=1e-05, + momentum=0.1, + affine=True, + track_running_stats=True, + device=None, + dtype=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['eps'] = eps + self.param_dict['momentum'] = momentum + self.param_dict['affine'] = affine + self.param_dict['track_running_stats'] = track_running_stats + self.param_dict['device'] = device + self.param_dict['dtype'] = dtype + self.param_dict['num_features'] = num_features + self.param_dict.update(kwargs) + nn.modules.batchnorm._NormBase.__init__(self, **self.param_dict) + + +class ConstantPad1d(nn.modules.padding.ConstantPad1d, TorchModule): + + def __init__(self, padding, value, **kwargs): + TorchModule.__init__(self) + self.param_dict['padding'] = padding + self.param_dict['value'] = value + self.param_dict.update(kwargs) + nn.modules.padding.ConstantPad1d.__init__(self, **self.param_dict) + + +class ConstantPad2d(nn.modules.padding.ConstantPad2d, TorchModule): + + def __init__(self, padding, value, **kwargs): + TorchModule.__init__(self) + self.param_dict['padding'] = padding + self.param_dict['value'] = value + self.param_dict.update(kwargs) + nn.modules.padding.ConstantPad2d.__init__(self, **self.param_dict) + + +class ConstantPad3d(nn.modules.padding.ConstantPad3d, TorchModule): + + def __init__(self, padding, value, **kwargs): + TorchModule.__init__(self) + self.param_dict['padding'] = padding + self.param_dict['value'] = value + self.param_dict.update(kwargs) + nn.modules.padding.ConstantPad3d.__init__(self, **self.param_dict) + + +class ReflectionPad1d(nn.modules.padding.ReflectionPad1d, TorchModule): + + def __init__(self, padding, **kwargs): + TorchModule.__init__(self) + self.param_dict['padding'] = padding + self.param_dict.update(kwargs) + nn.modules.padding.ReflectionPad1d.__init__(self, **self.param_dict) + + +class ReflectionPad2d(nn.modules.padding.ReflectionPad2d, TorchModule): + + def __init__(self, padding, **kwargs): + TorchModule.__init__(self) + self.param_dict['padding'] = padding + self.param_dict.update(kwargs) + nn.modules.padding.ReflectionPad2d.__init__(self, **self.param_dict) + + +class ReflectionPad3d(nn.modules.padding.ReflectionPad3d, TorchModule): + + def __init__(self, padding, **kwargs): + TorchModule.__init__(self) + self.param_dict['padding'] = padding + self.param_dict.update(kwargs) + nn.modules.padding.ReflectionPad3d.__init__(self, **self.param_dict) + + +class ReplicationPad1d(nn.modules.padding.ReplicationPad1d, TorchModule): + + def __init__(self, padding, **kwargs): + TorchModule.__init__(self) + self.param_dict['padding'] = padding + self.param_dict.update(kwargs) + nn.modules.padding.ReplicationPad1d.__init__(self, **self.param_dict) + + +class ReplicationPad2d(nn.modules.padding.ReplicationPad2d, TorchModule): + + def __init__(self, padding, **kwargs): + TorchModule.__init__(self) + self.param_dict['padding'] = padding + self.param_dict.update(kwargs) + nn.modules.padding.ReplicationPad2d.__init__(self, **self.param_dict) + + +class ReplicationPad3d(nn.modules.padding.ReplicationPad3d, TorchModule): + + def __init__(self, padding, **kwargs): + TorchModule.__init__(self) + self.param_dict['padding'] = padding + self.param_dict.update(kwargs) + nn.modules.padding.ReplicationPad3d.__init__(self, **self.param_dict) + + +class ZeroPad2d(nn.modules.padding.ZeroPad2d, TorchModule): + + def __init__(self, padding, **kwargs): + TorchModule.__init__(self) + self.param_dict['padding'] = padding + self.param_dict.update(kwargs) + nn.modules.padding.ZeroPad2d.__init__(self, **self.param_dict) + + +class _ConstantPadNd(nn.modules.padding._ConstantPadNd, TorchModule): + + def __init__(self, value, **kwargs): + TorchModule.__init__(self) + self.param_dict['value'] = value + self.param_dict.update(kwargs) + nn.modules.padding._ConstantPadNd.__init__(self, **self.param_dict) + + +class _ReflectionPadNd(nn.modules.padding._ReflectionPadNd, TorchModule): + + def __init__(self, **kwargs): + TorchModule.__init__(self) + self.param_dict.update(kwargs) + nn.modules.padding._ReflectionPadNd.__init__(self, **self.param_dict) + + +class _ReplicationPadNd(nn.modules.padding._ReplicationPadNd, TorchModule): + + def __init__(self, **kwargs): + TorchModule.__init__(self) + self.param_dict.update(kwargs) + nn.modules.padding._ReplicationPadNd.__init__(self, **self.param_dict) + + +class BCELoss(nn.modules.loss.BCELoss, TorchModule): + + def __init__( + self, + weight=None, + size_average=None, + reduce=None, + reduction='mean', + **kwargs): + TorchModule.__init__(self) + self.param_dict['weight'] = weight + self.param_dict['size_average'] = size_average + self.param_dict['reduce'] = reduce + self.param_dict['reduction'] = reduction + self.param_dict.update(kwargs) + nn.modules.loss.BCELoss.__init__(self, **self.param_dict) + + +class BCEWithLogitsLoss(nn.modules.loss.BCEWithLogitsLoss, TorchModule): + + def __init__( + self, + weight=None, + size_average=None, + reduce=None, + reduction='mean', + pos_weight=None, + **kwargs): + TorchModule.__init__(self) + self.param_dict['weight'] = weight + self.param_dict['size_average'] = size_average + self.param_dict['reduce'] = reduce + self.param_dict['reduction'] = reduction + self.param_dict['pos_weight'] = pos_weight + self.param_dict.update(kwargs) + nn.modules.loss.BCEWithLogitsLoss.__init__(self, **self.param_dict) + + +class CTCLoss(nn.modules.loss.CTCLoss, TorchModule): + + def __init__( + self, + blank=0, + reduction='mean', + zero_infinity=False, + **kwargs): + TorchModule.__init__(self) + self.param_dict['blank'] = blank + self.param_dict['reduction'] = reduction + self.param_dict['zero_infinity'] = zero_infinity + self.param_dict.update(kwargs) + nn.modules.loss.CTCLoss.__init__(self, **self.param_dict) + + +class CosineEmbeddingLoss(nn.modules.loss.CosineEmbeddingLoss, TorchModule): + + def __init__( + self, + margin=0.0, + size_average=None, + reduce=None, + reduction='mean', + **kwargs): + TorchModule.__init__(self) + self.param_dict['margin'] = margin + self.param_dict['size_average'] = size_average + self.param_dict['reduce'] = reduce + self.param_dict['reduction'] = reduction + self.param_dict.update(kwargs) + nn.modules.loss.CosineEmbeddingLoss.__init__(self, **self.param_dict) + + +class CrossEntropyLoss(nn.modules.loss.CrossEntropyLoss, TorchModule): + + def __init__( + self, + weight=None, + size_average=None, + ignore_index=-100, + reduce=None, + reduction='mean', + label_smoothing=0.0, + **kwargs): + TorchModule.__init__(self) + self.param_dict['weight'] = weight + self.param_dict['size_average'] = size_average + self.param_dict['ignore_index'] = ignore_index + self.param_dict['reduce'] = reduce + self.param_dict['reduction'] = reduction + self.param_dict['label_smoothing'] = label_smoothing + self.param_dict.update(kwargs) + nn.modules.loss.CrossEntropyLoss.__init__(self, **self.param_dict) + + +class GaussianNLLLoss(nn.modules.loss.GaussianNLLLoss, TorchModule): + + def __init__(self, **kwargs): + TorchModule.__init__(self) + self.param_dict.update(kwargs) + nn.modules.loss.GaussianNLLLoss.__init__(self, **self.param_dict) + + +class HingeEmbeddingLoss(nn.modules.loss.HingeEmbeddingLoss, TorchModule): + + def __init__( + self, + margin=1.0, + size_average=None, + reduce=None, + reduction='mean', + **kwargs): + TorchModule.__init__(self) + self.param_dict['margin'] = margin + self.param_dict['size_average'] = size_average + self.param_dict['reduce'] = reduce + self.param_dict['reduction'] = reduction + self.param_dict.update(kwargs) + nn.modules.loss.HingeEmbeddingLoss.__init__(self, **self.param_dict) + + +class HuberLoss(nn.modules.loss.HuberLoss, TorchModule): + + def __init__(self, reduction='mean', delta=1.0, **kwargs): + TorchModule.__init__(self) + self.param_dict['reduction'] = reduction + self.param_dict['delta'] = delta + self.param_dict.update(kwargs) + nn.modules.loss.HuberLoss.__init__(self, **self.param_dict) + + +class KLDivLoss(nn.modules.loss.KLDivLoss, TorchModule): + + def __init__( + self, + size_average=None, + reduce=None, + reduction='mean', + log_target=False, + **kwargs): + TorchModule.__init__(self) + self.param_dict['size_average'] = size_average + self.param_dict['reduce'] = reduce + self.param_dict['reduction'] = reduction + self.param_dict['log_target'] = log_target + self.param_dict.update(kwargs) + nn.modules.loss.KLDivLoss.__init__(self, **self.param_dict) + + +class L1Loss(nn.modules.loss.L1Loss, TorchModule): + + def __init__( + self, + size_average=None, + reduce=None, + reduction='mean', + **kwargs): + TorchModule.__init__(self) + self.param_dict['size_average'] = size_average + self.param_dict['reduce'] = reduce + self.param_dict['reduction'] = reduction + self.param_dict.update(kwargs) + nn.modules.loss.L1Loss.__init__(self, **self.param_dict) + + +class MSELoss(nn.modules.loss.MSELoss, TorchModule): + + def __init__( + self, + size_average=None, + reduce=None, + reduction='mean', + **kwargs): + TorchModule.__init__(self) + self.param_dict['size_average'] = size_average + self.param_dict['reduce'] = reduce + self.param_dict['reduction'] = reduction + self.param_dict.update(kwargs) + nn.modules.loss.MSELoss.__init__(self, **self.param_dict) + + +class MarginRankingLoss(nn.modules.loss.MarginRankingLoss, TorchModule): + + def __init__( + self, + margin=0.0, + size_average=None, + reduce=None, + reduction='mean', + **kwargs): + TorchModule.__init__(self) + self.param_dict['margin'] = margin + self.param_dict['size_average'] = size_average + self.param_dict['reduce'] = reduce + self.param_dict['reduction'] = reduction + self.param_dict.update(kwargs) + nn.modules.loss.MarginRankingLoss.__init__(self, **self.param_dict) + + +class MultiLabelMarginLoss(nn.modules.loss.MultiLabelMarginLoss, TorchModule): + + def __init__( + self, + size_average=None, + reduce=None, + reduction='mean', + **kwargs): + TorchModule.__init__(self) + self.param_dict['size_average'] = size_average + self.param_dict['reduce'] = reduce + self.param_dict['reduction'] = reduction + self.param_dict.update(kwargs) + nn.modules.loss.MultiLabelMarginLoss.__init__(self, **self.param_dict) + + +class MultiLabelSoftMarginLoss( + nn.modules.loss.MultiLabelSoftMarginLoss, + TorchModule): + + def __init__( + self, + weight=None, + size_average=None, + reduce=None, + reduction='mean', + **kwargs): + TorchModule.__init__(self) + self.param_dict['weight'] = weight + self.param_dict['size_average'] = size_average + self.param_dict['reduce'] = reduce + self.param_dict['reduction'] = reduction + self.param_dict.update(kwargs) + nn.modules.loss.MultiLabelSoftMarginLoss.__init__( + self, **self.param_dict) + + +class MultiMarginLoss(nn.modules.loss.MultiMarginLoss, TorchModule): + + def __init__( + self, + p=1, + margin=1.0, + weight=None, + size_average=None, + reduce=None, + reduction='mean', + **kwargs): + TorchModule.__init__(self) + self.param_dict['p'] = p + self.param_dict['margin'] = margin + self.param_dict['weight'] = weight + self.param_dict['size_average'] = size_average + self.param_dict['reduce'] = reduce + self.param_dict['reduction'] = reduction + self.param_dict.update(kwargs) + nn.modules.loss.MultiMarginLoss.__init__(self, **self.param_dict) + + +class NLLLoss(nn.modules.loss.NLLLoss, TorchModule): + + def __init__( + self, + weight=None, + size_average=None, + ignore_index=-100, + reduce=None, + reduction='mean', + **kwargs): + TorchModule.__init__(self) + self.param_dict['weight'] = weight + self.param_dict['size_average'] = size_average + self.param_dict['ignore_index'] = ignore_index + self.param_dict['reduce'] = reduce + self.param_dict['reduction'] = reduction + self.param_dict.update(kwargs) + nn.modules.loss.NLLLoss.__init__(self, **self.param_dict) + + +class NLLLoss2d(nn.modules.loss.NLLLoss2d, TorchModule): + + def __init__( + self, + weight=None, + size_average=None, + ignore_index=-100, + reduce=None, + reduction='mean', + **kwargs): + TorchModule.__init__(self) + self.param_dict['weight'] = weight + self.param_dict['size_average'] = size_average + self.param_dict['ignore_index'] = ignore_index + self.param_dict['reduce'] = reduce + self.param_dict['reduction'] = reduction + self.param_dict.update(kwargs) + nn.modules.loss.NLLLoss2d.__init__(self, **self.param_dict) + + +class PoissonNLLLoss(nn.modules.loss.PoissonNLLLoss, TorchModule): + + def __init__( + self, + log_input=True, + full=False, + size_average=None, + eps=1e-08, + reduce=None, + reduction='mean', + **kwargs): + TorchModule.__init__(self) + self.param_dict['log_input'] = log_input + self.param_dict['full'] = full + self.param_dict['size_average'] = size_average + self.param_dict['eps'] = eps + self.param_dict['reduce'] = reduce + self.param_dict['reduction'] = reduction + self.param_dict.update(kwargs) + nn.modules.loss.PoissonNLLLoss.__init__(self, **self.param_dict) + + +class SmoothL1Loss(nn.modules.loss.SmoothL1Loss, TorchModule): + + def __init__( + self, + size_average=None, + reduce=None, + reduction='mean', + beta=1.0, + **kwargs): + TorchModule.__init__(self) + self.param_dict['size_average'] = size_average + self.param_dict['reduce'] = reduce + self.param_dict['reduction'] = reduction + self.param_dict['beta'] = beta + self.param_dict.update(kwargs) + nn.modules.loss.SmoothL1Loss.__init__(self, **self.param_dict) + + +class SoftMarginLoss(nn.modules.loss.SoftMarginLoss, TorchModule): + + def __init__( + self, + size_average=None, + reduce=None, + reduction='mean', + **kwargs): + TorchModule.__init__(self) + self.param_dict['size_average'] = size_average + self.param_dict['reduce'] = reduce + self.param_dict['reduction'] = reduction + self.param_dict.update(kwargs) + nn.modules.loss.SoftMarginLoss.__init__(self, **self.param_dict) + + +class TripletMarginLoss(nn.modules.loss.TripletMarginLoss, TorchModule): + + def __init__( + self, + margin=1.0, + p=2.0, + eps=1e-06, + swap=False, + size_average=None, + reduce=None, + reduction='mean', + **kwargs): + TorchModule.__init__(self) + self.param_dict['margin'] = margin + self.param_dict['p'] = p + self.param_dict['eps'] = eps + self.param_dict['swap'] = swap + self.param_dict['size_average'] = size_average + self.param_dict['reduce'] = reduce + self.param_dict['reduction'] = reduction + self.param_dict.update(kwargs) + nn.modules.loss.TripletMarginLoss.__init__(self, **self.param_dict) + + +class TripletMarginWithDistanceLoss( + nn.modules.loss.TripletMarginWithDistanceLoss, + TorchModule): + + def __init__(self, **kwargs): + TorchModule.__init__(self) + self.param_dict.update(kwargs) + nn.modules.loss.TripletMarginWithDistanceLoss.__init__( + self, **self.param_dict) + + +class _Loss(nn.modules.loss._Loss, TorchModule): + + def __init__( + self, + size_average=None, + reduce=None, + reduction='mean', + **kwargs): + TorchModule.__init__(self) + self.param_dict['size_average'] = size_average + self.param_dict['reduce'] = reduce + self.param_dict['reduction'] = reduction + self.param_dict.update(kwargs) + nn.modules.loss._Loss.__init__(self, **self.param_dict) + + +class _WeightedLoss(nn.modules.loss._WeightedLoss, TorchModule): + + def __init__( + self, + weight=None, + size_average=None, + reduce=None, + reduction='mean', + **kwargs): + TorchModule.__init__(self) + self.param_dict['weight'] = weight + self.param_dict['size_average'] = size_average + self.param_dict['reduce'] = reduce + self.param_dict['reduction'] = reduction + self.param_dict.update(kwargs) + nn.modules.loss._WeightedLoss.__init__(self, **self.param_dict) diff --git a/python/fate/components/components/nn/torch/optim.py b/python/fate/components/components/nn/torch/optim.py new file mode 100644 index 0000000000..7069af916a --- /dev/null +++ b/python/fate/components/components/nn/torch/optim.py @@ -0,0 +1,495 @@ +from torch import optim +from fate.components.components.nn.torch.base import TorchOptimizer + + +class ASGD(optim.ASGD, TorchOptimizer): + + def __init__( + self, + params=None, + lr=0.01, + lambd=0.0001, + alpha=0.75, + t0=1000000.0, + weight_decay=0, + foreach=None, + maximize=False, + ): + TorchOptimizer.__init__(self) + self.param_dict['lr'] = lr + self.param_dict['lambd'] = lambd + self.param_dict['alpha'] = alpha + self.param_dict['t0'] = t0 + self.param_dict['weight_decay'] = weight_decay + self.param_dict['foreach'] = foreach + self.param_dict['maximize'] = maximize + self.torch_class = type(self).__bases__[0] + + if params is None: + return + + params = self.check_params(params) + + self.torch_class.__init__(self, params, **self.param_dict) + + # optim.ASGD.__init__(self, **self.param_dict) + + def __repr__(self): + try: + return type(self).__bases__[0].__repr__(self) + except BaseException: + return 'Optimizer ASGD without initiated parameters'.format( + type(self).__name__) + + +class Adadelta(optim.Adadelta, TorchOptimizer): + + def __init__( + self, + params=None, + lr=1.0, + rho=0.9, + eps=1e-06, + weight_decay=0, + foreach=None, + ): + TorchOptimizer.__init__(self) + self.param_dict['lr'] = lr + self.param_dict['rho'] = rho + self.param_dict['eps'] = eps + self.param_dict['weight_decay'] = weight_decay + self.param_dict['foreach'] = foreach + self.torch_class = type(self).__bases__[0] + + if params is None: + return + + params = self.check_params(params) + + self.torch_class.__init__(self, params, **self.param_dict) + + # optim.Adadelta.__init__(self, **self.param_dict) + + def __repr__(self): + try: + return type(self).__bases__[0].__repr__(self) + except BaseException: + return 'Optimizer Adadelta without initiated parameters'.format( + type(self).__name__) + + +class Adagrad(optim.Adagrad, TorchOptimizer): + + def __init__( + self, + params=None, + lr=0.01, + lr_decay=0, + weight_decay=0, + initial_accumulator_value=0, + eps=1e-10, + foreach=None, + ): + TorchOptimizer.__init__(self) + self.param_dict['lr'] = lr + self.param_dict['lr_decay'] = lr_decay + self.param_dict['weight_decay'] = weight_decay + self.param_dict['initial_accumulator_value'] = initial_accumulator_value + self.param_dict['eps'] = eps + self.param_dict['foreach'] = foreach + self.torch_class = type(self).__bases__[0] + + if params is None: + return + + params = self.check_params(params) + + self.torch_class.__init__(self, params, **self.param_dict) + + # optim.Adagrad.__init__(self, **self.param_dict) + + def __repr__(self): + try: + return type(self).__bases__[0].__repr__(self) + except BaseException: + return 'Optimizer Adagrad without initiated parameters'.format( + type(self).__name__) + + +class Adam(optim.Adam, TorchOptimizer): + + def __init__( + self, + params=None, + lr=0.001, + betas=( + 0.9, + 0.999), + eps=1e-08, + weight_decay=0, + amsgrad=False, + ): + TorchOptimizer.__init__(self) + self.param_dict['lr'] = lr + self.param_dict['betas'] = betas + self.param_dict['eps'] = eps + self.param_dict['weight_decay'] = weight_decay + self.param_dict['amsgrad'] = amsgrad + self.torch_class = type(self).__bases__[0] + + if params is None: + return + + params = self.check_params(params) + + self.torch_class.__init__(self, params, **self.param_dict) + + # optim.Adam.__init__(self, **self.param_dict) + + def __repr__(self): + try: + return type(self).__bases__[0].__repr__(self) + except BaseException: + return 'Optimizer Adam without initiated parameters'.format( + type(self).__name__) + + +class AdamW(optim.AdamW, TorchOptimizer): + + def __init__( + self, + params=None, + lr=0.001, + betas=( + 0.9, + 0.999), + eps=1e-08, + weight_decay=0.01, + amsgrad=False, + ): + TorchOptimizer.__init__(self) + self.param_dict['lr'] = lr + self.param_dict['betas'] = betas + self.param_dict['eps'] = eps + self.param_dict['weight_decay'] = weight_decay + self.param_dict['amsgrad'] = amsgrad + self.torch_class = type(self).__bases__[0] + + if params is None: + return + + params = self.check_params(params) + + self.torch_class.__init__(self, params, **self.param_dict) + + # optim.AdamW.__init__(self, **self.param_dict) + + def __repr__(self): + try: + return type(self).__bases__[0].__repr__(self) + except BaseException: + return 'Optimizer AdamW without initiated parameters'.format( + type(self).__name__) + + +class Adamax(optim.Adamax, TorchOptimizer): + + def __init__( + self, + params=None, + lr=0.002, + betas=( + 0.9, + 0.999), + eps=1e-08, + weight_decay=0, + foreach=None, + ): + TorchOptimizer.__init__(self) + self.param_dict['lr'] = lr + self.param_dict['betas'] = betas + self.param_dict['eps'] = eps + self.param_dict['weight_decay'] = weight_decay + self.param_dict['foreach'] = foreach + self.torch_class = type(self).__bases__[0] + + if params is None: + return + + params = self.check_params(params) + + self.torch_class.__init__(self, params, **self.param_dict) + + # optim.Adamax.__init__(self, **self.param_dict) + + def __repr__(self): + try: + return type(self).__bases__[0].__repr__(self) + except BaseException: + return 'Optimizer Adamax without initiated parameters'.format( + type(self).__name__) + + +class LBFGS(optim.LBFGS, TorchOptimizer): + + def __init__( + self, + params=None, + lr=1, + max_iter=20, + max_eval=None, + tolerance_grad=1e-07, + tolerance_change=1e-09, + history_size=100, + line_search_fn=None, + ): + TorchOptimizer.__init__(self) + self.param_dict['lr'] = lr + self.param_dict['max_iter'] = max_iter + self.param_dict['max_eval'] = max_eval + self.param_dict['tolerance_grad'] = tolerance_grad + self.param_dict['tolerance_change'] = tolerance_change + self.param_dict['history_size'] = history_size + self.param_dict['line_search_fn'] = line_search_fn + self.torch_class = type(self).__bases__[0] + + if params is None: + return + + params = self.check_params(params) + + self.torch_class.__init__(self, params, **self.param_dict) + + # optim.LBFGS.__init__(self, **self.param_dict) + + def __repr__(self): + try: + return type(self).__bases__[0].__repr__(self) + except BaseException: + return 'Optimizer LBFGS without initiated parameters'.format( + type(self).__name__) + + +class NAdam(optim.NAdam, TorchOptimizer): + + def __init__( + self, + params=None, + lr=0.002, + betas=( + 0.9, + 0.999), + eps=1e-08, + weight_decay=0, + momentum_decay=0.004, + foreach=None, + ): + TorchOptimizer.__init__(self) + self.param_dict['lr'] = lr + self.param_dict['betas'] = betas + self.param_dict['eps'] = eps + self.param_dict['weight_decay'] = weight_decay + self.param_dict['momentum_decay'] = momentum_decay + self.param_dict['foreach'] = foreach + self.torch_class = type(self).__bases__[0] + + if params is None: + return + + params = self.check_params(params) + + self.torch_class.__init__(self, params, **self.param_dict) + + # optim.NAdam.__init__(self, **self.param_dict) + + def __repr__(self): + try: + return type(self).__bases__[0].__repr__(self) + except BaseException: + return 'Optimizer NAdam without initiated parameters'.format( + type(self).__name__) + + +class RAdam(optim.RAdam, TorchOptimizer): + + def __init__( + self, + params=None, + lr=0.001, + betas=( + 0.9, + 0.999), + eps=1e-08, + weight_decay=0, + foreach=None, + ): + TorchOptimizer.__init__(self) + self.param_dict['lr'] = lr + self.param_dict['betas'] = betas + self.param_dict['eps'] = eps + self.param_dict['weight_decay'] = weight_decay + self.param_dict['foreach'] = foreach + self.torch_class = type(self).__bases__[0] + + if params is None: + return + + params = self.check_params(params) + + self.torch_class.__init__(self, params, **self.param_dict) + + # optim.RAdam.__init__(self, **self.param_dict) + + def __repr__(self): + try: + return type(self).__bases__[0].__repr__(self) + except BaseException: + return 'Optimizer RAdam without initiated parameters'.format( + type(self).__name__) + + +class RMSprop(optim.RMSprop, TorchOptimizer): + + def __init__( + self, + params=None, + lr=0.01, + alpha=0.99, + eps=1e-08, + weight_decay=0, + momentum=0, + centered=False, + foreach=None, + maximize=False, + differentiable=False, + ): + TorchOptimizer.__init__(self) + self.param_dict['lr'] = lr + self.param_dict['alpha'] = alpha + self.param_dict['eps'] = eps + self.param_dict['weight_decay'] = weight_decay + self.param_dict['momentum'] = momentum + self.param_dict['centered'] = centered + self.param_dict['foreach'] = foreach + self.param_dict['maximize'] = maximize + self.param_dict['differentiable'] = differentiable + self.torch_class = type(self).__bases__[0] + + if params is None: + return + + params = self.check_params(params) + + self.torch_class.__init__(self, params, **self.param_dict) + + # optim.RMSprop.__init__(self, **self.param_dict) + + def __repr__(self): + try: + return type(self).__bases__[0].__repr__(self) + except BaseException: + return 'Optimizer RMSprop without initiated parameters'.format( + type(self).__name__) + + +class Rprop(optim.Rprop, TorchOptimizer): + + def __init__( + self, params=None, lr=0.01, etas=( + 0.5, 1.2), step_sizes=( + 1e-06, 50), foreach=None, maximize=False, ): + TorchOptimizer.__init__(self) + self.param_dict['lr'] = lr + self.param_dict['etas'] = etas + self.param_dict['step_sizes'] = step_sizes + self.param_dict['foreach'] = foreach + self.param_dict['maximize'] = maximize + self.torch_class = type(self).__bases__[0] + + if params is None: + return + + params = self.check_params(params) + + self.torch_class.__init__(self, params, **self.param_dict) + + # optim.Rprop.__init__(self, **self.param_dict) + + def __repr__(self): + try: + return type(self).__bases__[0].__repr__(self) + except BaseException: + return 'Optimizer Rprop without initiated parameters'.format( + type(self).__name__) + + +class SGD(optim.SGD, TorchOptimizer): + + def __init__( + self, + lr, + params=None, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + ): + TorchOptimizer.__init__(self) + self.param_dict['lr'] = lr + self.param_dict['momentum'] = momentum + self.param_dict['dampening'] = dampening + self.param_dict['weight_decay'] = weight_decay + self.param_dict['nesterov'] = nesterov + self.torch_class = type(self).__bases__[0] + + if params is None: + return + + params = self.check_params(params) + + self.torch_class.__init__(self, params, **self.param_dict) + + # optim.SGD.__init__(self, **self.param_dict) + + def __repr__(self): + try: + return type(self).__bases__[0].__repr__(self) + except BaseException: + return 'Optimizer SGD without initiated parameters'.format( + type(self).__name__) + + +class SparseAdam(optim.SparseAdam, TorchOptimizer): + + def __init__( + self, + params=None, + lr=0.001, + betas=( + 0.9, + 0.999), + eps=1e-08, + maximize=False, + ): + TorchOptimizer.__init__(self) + self.param_dict['lr'] = lr + self.param_dict['betas'] = betas + self.param_dict['eps'] = eps + self.param_dict['maximize'] = maximize + self.torch_class = type(self).__bases__[0] + + if params is None: + return + + params = self.check_params(params) + + self.torch_class.__init__(self, params, **self.param_dict) + + # optim.SparseAdam.__init__(self, **self.param_dict) + + def __repr__(self): + try: + return type(self).__bases__[0].__repr__(self) + except BaseException: + return 'Optimizer SparseAdam without initiated parameters'.format( + type(self).__name__) diff --git a/python/fate/arch/context/io/data/__init__.py b/python/fate/components/components/nn/utils/__init__.py similarity index 100% rename from python/fate/arch/context/io/data/__init__.py rename to python/fate/components/components/nn/utils/__init__.py diff --git a/python/fate/components/components/nn/utils/extract_pytorch_optim.py b/python/fate/components/components/nn/utils/extract_pytorch_optim.py new file mode 100644 index 0000000000..93e238b0db --- /dev/null +++ b/python/fate/components/components/nn/utils/extract_pytorch_optim.py @@ -0,0 +1,78 @@ +import inspect +from torch import optim +from fate.components.components.nn.utils.extract_torch_modules import extract_init_param, Required +from torch.optim.optimizer import required + + +def code_assembly(param, nn_class): + para_str = "" + non_default_param = "" + init_str = """""" + special_param = '' + for k, v in param.items(): + + if k == 'params': + k = 'params' + v = None + special_param = k + '=' + str(v) + ', ' + continue + else: + new_para = "\n self.param_dict['{}'] = {}".format(k, k) + init_str += new_para + + if isinstance(v, Required) or v == required: + non_default_param += str(k) + non_default_param += ', ' + continue + + para_str += str(k) + if isinstance(v, str): + para_str += "='{}'".format(v) + else: + para_str += "={}".format(str(v)) + para_str += ', ' + + para_str = non_default_param + special_param + para_str + + init_ = """ + def __init__(self, {}): + FateTorchOptimizer.__init__(self){} + self.torch_class = type(self).__bases__[0] + + if params is None: + return + + params = self.check_params(params) + + self.torch_class.__init__(self, params, **self.param_dict) + + # optim.{}.__init__(self, **self.param_dict) + + def __repr__(self): + try: + return type(self).__bases__[0].__repr__(self) + except: + return 'Optimizer {} without initiated parameters'.format(type(self).__name__) + + """.format(para_str, init_str, nn_class, nn_class) + + code = """ +class {}(optim.{}, FateTorchOptimizer): + {} + """.format(nn_class, nn_class, init_) + + return code + + +if __name__ == '__main__': + + memb = inspect.getmembers(optim) + + module_str = """""" + for k, v in memb: + if inspect.isclass(v) and k != 'Optimizer': + param = extract_init_param(v) + code = code_assembly(param, k) + module_str += code + + open('../torch/optim.py', 'w').write(module_str) diff --git a/python/fate/components/components/nn/utils/extract_torch_modules.py b/python/fate/components/components/nn/utils/extract_torch_modules.py new file mode 100644 index 0000000000..023a7e8583 --- /dev/null +++ b/python/fate/components/components/nn/utils/extract_torch_modules.py @@ -0,0 +1,123 @@ +import inspect +from torch.nn.modules import linear, activation, rnn, dropout, sparse, pooling, conv, transformer, batchnorm +from torch.nn.modules import padding, pixelshuffle +from torch.nn.modules import loss + + +class Required(object): + + def __init__(self): + pass + + def __repr__(self): + return '(Required Parameter)' + + +def get_all_class_obj(module, key_word=''): + members = inspect.getmembers(module) + rs = [] + module_name = None + for name, obj in members: + if inspect.isclass(obj): + if 'modules.' + key_word in obj.__module__: + rs.append(obj) + # print(obj) + module_name = obj.__module__.split('.')[-1] + + return rs, module_name + + +def extract_init_param(class_): + args = inspect.getfullargspec(class_.__init__) + print(class_) + print(args) + keys = args[0][1:] + if len(keys) == 0: + return {} + defaults = args[3] + args_map = {} + print(keys) + print(defaults) + if defaults is not None: + for idx, i in enumerate(keys[-len(defaults):]): + print(args_map) + print(defaults) + args_map[i] = defaults[idx] + + for i in keys: + if i not in args_map: + args_map[i] = Required() + + return args_map + + +def code_assembly(param, nn_class, module_name): + if module_name == 'loss': + parent_class = 'FateTorch' + else: + parent_class = 'FateTorch' + + para_str = "" + non_default_param = "" + init_str = """""" + for k, v in param.items(): + + new_para = "\n self.param_dict['{}'] = {}".format(k, k) + init_str += new_para + if isinstance(v, Required): + non_default_param += str(k) + non_default_param += ', ' + continue + + para_str += str(k) + if isinstance(v, str): + para_str += "='{}'".format(v) + else: + para_str += "={}".format(str(v)) + para_str += ', ' + + para_str = non_default_param + para_str + + init_ = """ + def __init__(self, {}**kwargs): + {}.__init__(self){} + self.param_dict.update(kwargs) + nn.modules.{}.{}.__init__(self, **self.param_dict) + """.format(para_str, parent_class, init_str, module_name, nn_class) + + code = """ +class {}({}, {}): + {} + """.format(nn_class, 'nn.modules.{}.{}'.format(module_name, nn_class), parent_class, init_) + + return code + + +if __name__ == '__main__': + + rs1 = get_all_class_obj(linear, 'linear') + rs2 = get_all_class_obj(rnn, 'rnn') + rs3 = get_all_class_obj(sparse, 'sparse') + rs4 = get_all_class_obj(dropout, 'dropout') + rs5 = get_all_class_obj(activation, 'activation') + rs6 = get_all_class_obj(conv, 'conv') + rs7 = get_all_class_obj(transformer, 'transformer') + rs8 = get_all_class_obj(pooling, 'pooling') + rs9 = get_all_class_obj(batchnorm, 'batchnorm') + rs10 = get_all_class_obj(padding, 'padding') + rs11 = get_all_class_obj(pixelshuffle, 'pixielshuffle') + rs12 = get_all_class_obj(loss, 'loss') + + module_str = """""" + module_str += "from torch import nn\n\n" + for rs in [rs1, rs2, rs3, rs4, rs5, rs6, rs7, rs8, rs9, rs10, rs11, rs12]: + module_name = rs[1] + for i in rs[0]: + # print(i) + param = extract_init_param(i) + class_str = code_assembly(param, i.__name__, module_name) + module_str += class_str + + module_str = module_str + + open('../torch/nn.py', 'w').write(module_str) diff --git a/python/fate/arch/context/metric/_metrics.py b/python/fate/components/components/psi.py similarity index 51% rename from python/fate/arch/context/metric/_metrics.py rename to python/fate/components/components/psi.py index f29c642a1f..af270d7b3b 100644 --- a/python/fate/arch/context/metric/_metrics.py +++ b/python/fate/components/components/psi.py @@ -12,26 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional +from fate.components.core import GUEST, HOST, Role, cpn +from fate.arch.protocol.psi import psi_run -from ._type import Metrics +@cpn.component(roles=[GUEST, HOST], provider="fate") +def psi( + ctx, + role: Role, + input_data: cpn.dataframe_input(roles=[GUEST, HOST]), + protocol: cpn.parameter(type=str, default="ecdh_psi", optional=True), + curve_type: cpn.parameter(type=str, default="curve25519", optional=True), + output_data: cpn.dataframe_output(roles=[GUEST, HOST]), +): -class ROCMetrics(Metrics): - type = "roc" - - def __init__(self, name, data) -> None: - self.name = name - self.data = data - self.nemaspace: Optional[str] = None - self.groups: Dict[str, str] = {} - - def dict(self) -> dict: - return dict( - name=self.name, - namespace=self.nemaspace, - groups=self.groups, - type=self.type, - metadata={}, - data=self.data, - ) + intersect_data = psi_run(ctx, input_data.read(), protocol, curve_type) + output_data.write(intersect_data) diff --git a/python/fate/components/components/reader.py b/python/fate/components/components/reader.py index 60ced177fa..c35e377ffd 100644 --- a/python/fate/components/components/reader.py +++ b/python/fate/components/components/reader.py @@ -12,42 +12,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from fate.arch.unify import URI -from fate.components import GUEST, HOST, DatasetArtifact, Output, Role, cpn +from fate.components.core import GUEST, HOST, Role, cpn @cpn.component(roles=[GUEST, HOST]) -@cpn.parameter("path", type=str, default=None, optional=False) -@cpn.parameter("format", type=str, default="csv", optional=False) -@cpn.parameter("id_name", type=str, default="id", optional=True) -@cpn.parameter("delimiter", type=str, default=",", optional=True) -@cpn.parameter("label_name", type=str, default=None, optional=True) -@cpn.parameter("label_type", type=str, default="float32", optional=True) -@cpn.parameter("dtype", type=str, default="float32", optional=True) -@cpn.artifact("output_data", type=Output[DatasetArtifact], roles=[GUEST, HOST]) def reader( ctx, role: Role, - path, - format, - id_name, - delimiter, - label_name, - label_type, - dtype, - output_data, + path: cpn.parameter(type=str, default=None, optional=False), + format: cpn.parameter(type=str, default="csv", optional=False), + sample_id_name: cpn.parameter(type=str, default=None, optional=True), + match_id_name: cpn.parameter(type=str, default=None, optional=True), + delimiter: cpn.parameter(type=str, default=",", optional=True), + label_name: cpn.parameter(type=str, default=None, optional=True), + label_type: cpn.parameter(type=str, default="float32", optional=True), + dtype: cpn.parameter(type=str, default="float32", optional=True), + output_data: cpn.dataframe_output(roles=[GUEST, HOST]), ): - read_data(ctx, path, format, id_name, delimiter, label_name, label_type, dtype, output_data) - - -def read_data(ctx, path, format, id_name, delimiter, label_name, label_type, dtype, output_data): if format == "csv": - data_meta = DatasetArtifact( + data_meta = DataframeArtifact( uri=path, name="data", metadata=dict( format=format, - id_name=id_name, + sample_id_name=sample_id_name, + match_id_name=match_id_name, delimiter=delimiter, label_name=label_name, label_type=label_type, @@ -55,7 +44,7 @@ def read_data(ctx, path, format, id_name, delimiter, label_name, label_type, dty ), ) elif format == "raw_table": - data_meta = DatasetArtifact(uri=path, name="data", metadata=dict(format=format)) + data_meta = DataframeArtifact(uri=path, name="data", metadata=dict(format=format)) else: raise ValueError(f"Reader does not support format={format}") diff --git a/python/fate/components/components/sample.py b/python/fate/components/components/sample.py new file mode 100644 index 0000000000..50f46704ee --- /dev/null +++ b/python/fate/components/components/sample.py @@ -0,0 +1,76 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union, Mapping + +from fate.arch import Context +from fate.components.core import GUEST, HOST, Role, cpn, params +from fate.ml.model_selection.sample import SampleModuleGuest, SampleModuleHost + + +@cpn.component(roles=[GUEST, HOST], provider="fate") +def sample( + ctx: Context, + role: Role, + input_data: cpn.dataframe_input(roles=[GUEST, HOST]), + replace: cpn.parameter(type=bool, default=False, + desc="whether allow sampling with replacement, default False"), + frac: cpn.parameter(type=Union[params.confloat(gt=0.0), + Mapping[Union[params.conint(), params.confloat()], params.confloat(gt=0.0)]], + default=None, optional=True, + desc="if mode equals to random, it should be a float number greater than 0," + "otherwise a dict of pairs like [label_i, sample_rate_i]," + "e.g. {0: 0.5, 1: 0.8, 2: 0.3}, any label unspecified in dict will not be sampled," + "default: 1.0, cannot be used with n"), + n: cpn.parameter(type=params.conint(gt=0), default=None, optional=True, + desc="exact sample size, it should be an int greater than 0, " + "default: None, cannot be used with frac"), + random_state: cpn.parameter(type=params.conint(ge=0), default=None, + desc="random state"), + hetero_sync: cpn.parameter(type=bool, default=True, + desc="whether guest sync sampled data sids with host, " + "default True for hetero scenario, " + "should set to False for local and homo scenario"), + output_data: cpn.dataframe_output(roles=[GUEST, HOST]) +): + if frac is not None and n is not None: + raise ValueError(f"n and frac cannot be used at the same time") + if frac is not None: + if isinstance(frac, float): + if frac > 1 and not replace: + raise ValueError(f"replace has to be set to True when sampling frac greater than 1.") + elif isinstance(frac, dict): + for v in frac.values(): + if v > 1 and not replace: + raise ValueError(f"replace has to be set to True when sampling frac greater than 1.") + if n is None and frac is None: + frac = 1.0 + # check if local but federated sample + if hetero_sync and len(ctx.parties.ranks) < 2: + raise ValueError(f"federated sample can only be called when both 'guest' and 'host' present. Please check") + sub_ctx = ctx.sub_ctx("train") + if role.is_guest: + module = SampleModuleGuest(replace=replace, frac=frac, n=n, + random_state=random_state, hetero_sync=hetero_sync) + elif role.is_host: + module = SampleModuleHost(replace=replace, frac=frac, n=n, + random_state=random_state, hetero_sync=hetero_sync) + else: + raise ValueError(f"unknown role") + input_data = input_data.read() + + sampled_data = module.fit(sub_ctx, input_data) + + output_data.write(sampled_data) diff --git a/python/fate/components/components/statistics.py b/python/fate/components/components/statistics.py new file mode 100644 index 0000000000..5224b01609 --- /dev/null +++ b/python/fate/components/components/statistics.py @@ -0,0 +1,82 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +from fate.arch import Context +from fate.components.core import GUEST, HOST, Role, cpn, params + + +@cpn.component(roles=[GUEST, HOST], provider="fate") +def statistics( + ctx: Context, + role: Role, + input_data: cpn.dataframe_input(roles=[GUEST, HOST]), + metrics: cpn.parameter( + type=Union[List[Union[params.statistic_metrics_param(), params.legal_percentile()]], + params.statistic_metrics_param(), params.legal_percentile()], + default=["mean", "std", "min", "max"], + desc="metrics to be computed, default ['count', 'mean', 'std', 'min', 'max']", + ), + ddof: cpn.parameter( + type=params.conint(ge=0), default=1, desc="Delta Degrees of Freedom for std and var, default 1" + ), + bias: cpn.parameter( + type=bool, + default=True, + desc="If False, the calculations of skewness and kurtosis are corrected for statistical bias.", + ), + relative_error: cpn.parameter(type=params.confloat(gt=0, le=1), default=1e-3, + desc="float, error rate for quantile"), + skip_col: cpn.parameter( + type=List[str], + default=None, + optional=True, + desc="columns to be skipped, default None; if None, statistics will be computed over all columns", + ), + use_anonymous: cpn.parameter( + type=bool, default=False, desc="bool, whether interpret `skip_col` as anonymous column names" + ), + output_model: cpn.json_model_output(roles=[GUEST, HOST]), +): + from fate.ml.statistics.statistics import FeatureStatistics + sub_ctx = ctx.sub_ctx("train") + input_data = input_data.read() + select_cols = get_to_compute_cols( + input_data.schema.columns, input_data.schema.anonymous_columns, skip_col, use_anonymous + ) + if isinstance(metrics, str): + metrics = [metrics] + if len(metrics) > 1: + for metric in metrics: + if metric == "describe": + raise ValueError(f"'describe' should not be combined with additional metric names.") + stat_computer = FeatureStatistics(list(set(metrics)), ddof, bias, relative_error) + input_data = input_data[select_cols] + stat_computer.fit(sub_ctx, input_data) + + model = stat_computer.get_model() + output_model.write(model, metadata={}) + + +def get_to_compute_cols(columns, anonymous_columns, skip_columns, use_anonymous): + if skip_columns is None: + skip_columns = [] + if use_anonymous and skip_columns is not None: + skip_columns = [anonymous_columns[columns.index(col)] for col in skip_columns] + skip_col_set = set(skip_columns) + select_columns = [col for col in columns if col not in skip_col_set] + + return select_columns diff --git a/python/fate/components/components/toy_example.py b/python/fate/components/components/toy_example.py new file mode 100644 index 0000000000..cdb4faf5de --- /dev/null +++ b/python/fate/components/components/toy_example.py @@ -0,0 +1,49 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Union + +import pandas as pd +from fate.arch import Context +from fate.arch.dataframe import PandasReader +from fate.components.core import GUEST, HOST, Role, cpn, params + + +@cpn.component(roles=[GUEST, HOST]) +def toy_example( + ctx: Context, + role: Role, + output_data: cpn.dataframe_output(roles=[GUEST, HOST]), + json_model_output: cpn.json_model_output(roles=[GUEST, HOST]), + data_num: cpn.parameter(type=params.conint(gt=1), desc="data_num", optional=False), + partition: cpn.parameter(type=params.conint(gt=1), desc="data_partition", optional=False), +): + pd_df = pd.DataFrame([[str(i), str(i), i] for i in range(data_num)], columns=["sample_id", "match_id", "x0"]) + reader = PandasReader(sample_id_name="sample_id", match_id_name="match_id", dtype="float64", partition=partition) + df = reader.to_frame(ctx, pd_df) + + if role == "guest": + ctx.hosts.put("guest_index", df.get_indexer(target="sample_id")) + host_indexes = ctx.hosts[0].get("host_index") + final_df = df.loc(host_indexes, preserve_order=True) + else: + guest_indexes = ctx.guest.get("guest_index") + final_df = df.loc(guest_indexes) + ctx.guest.put("host_index", final_df.get_indexer(target="sample_id")) + + assert final_df.shape[0] == data_num, f"data num should be {data_num} instead of {final_df}" + + output_data.write(final_df) + + json_model_output.write({"test_role": role}) diff --git a/python/fate/components/components/union.py b/python/fate/components/components/union.py new file mode 100644 index 0000000000..5e19717e88 --- /dev/null +++ b/python/fate/components/components/union.py @@ -0,0 +1,36 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from fate.arch import Context +from fate.components.core import GUEST, HOST, Role, cpn + + +@cpn.component(roles=[GUEST, HOST], provider="fate") +def union( + ctx: Context, + role: Role, + input_data_list: cpn.dataframe_inputs(roles=[GUEST, HOST]), + output_data: cpn.dataframe_output(roles=[GUEST, HOST]) +): + from fate.ml.preprocessing import Union + data_list = [] + for data in input_data_list: + data = data.read() + data_list.append(data) + + sub_ctx = ctx.sub_ctx("train") + union_obj = Union() + output_df = union_obj.fit(sub_ctx, data_list) + output_data.write(output_df) diff --git a/python/fate/arch/context/io/metric/__init__.py b/python/fate/components/components/utils/__init__.py similarity index 100% rename from python/fate/arch/context/io/metric/__init__.py rename to python/fate/components/components/utils/__init__.py diff --git a/python/fate/components/components/utils/consts.py b/python/fate/components/components/utils/consts.py new file mode 100644 index 0000000000..e315dd3290 --- /dev/null +++ b/python/fate/components/components/utils/consts.py @@ -0,0 +1,20 @@ +# stage +TRAIN = 'train' +PREDICT = 'predict' + +# DATASET TYPE +TRAIN_SET = 'train_set' +VALIDATE_SET = 'validate_set' +TEST_SET = 'test_set' + +# eval type +BINARY = 'binary' +MULTI = 'multi' +REGRESSION = 'regression' +OTHER = 'other' + +# Task Type +CLASSIFICATION = 'classification' +REGRESSION = 'regression' +CLUSTERING = 'clustering' +OTHER_TASK = 'other_task' \ No newline at end of file diff --git a/python/fate/components/components/utils/tools.py b/python/fate/components/components/utils/tools.py new file mode 100644 index 0000000000..d1e9fb92de --- /dev/null +++ b/python/fate/components/components/utils/tools.py @@ -0,0 +1,20 @@ +from fate.arch.dataframe import DataFrame +from .consts import TRAIN_SET, VALIDATE_SET, TEST_SET + + +TYPE = 'type' + + +def cat_train_and_validate_df(train_df: DataFrame, val_df: DataFrame): + """ + Concatenate train and validate dataframe + """ + return train_df.vstack(val_df) + + +def add_dataset_type(df: DataFrame, dataset_type): + assert dataset_type in [TRAIN_SET, VALIDATE_SET, TEST_SET], f"dataset_type must be one of {TRAIN_SET}, {VALIDATE_SET}, {TEST_SET}" + df[TYPE] = dataset_type + return df + + diff --git a/python/fate/components/core/__init__.py b/python/fate/components/core/__init__.py new file mode 100644 index 0000000000..d94395e763 --- /dev/null +++ b/python/fate/components/core/__init__.py @@ -0,0 +1,27 @@ +from . import _cpn_reexport as cpn +from ._cpn_search import list_components, load_component +from ._load_computing import load_computing +from ._load_device import load_device +from ._load_federation import load_federation +from ._load_metric_handler import load_metric_handler +from .component_desc import Component, ComponentExecutionIO +from .essential import ARBITER, GUEST, HOST, LOCAL, Label, Role, Stage + +__all__ = [ + "Component", + "ComponentExecutionIO", + "cpn", + "load_component", + "list_components", + "load_device", + "load_computing", + "load_federation", + "load_metric_handler", + "Role", + "Stage", + "ARBITER", + "GUEST", + "HOST", + "LOCAL", + "Label", +] diff --git a/python/fate/components/core/_cpn_reexport.py b/python/fate/components/core/_cpn_reexport.py new file mode 100644 index 0000000000..6ed3882141 --- /dev/null +++ b/python/fate/components/core/_cpn_reexport.py @@ -0,0 +1,74 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, List, Optional, Type, TypeVar + +# re-export +from .component_desc import ( + component, + data_directory_input, + data_directory_inputs, + data_directory_output, + data_directory_outputs, + dataframe_input, + dataframe_inputs, + dataframe_output, + dataframe_outputs, + json_model_input, + json_model_inputs, + json_model_output, + json_model_outputs, + model_directory_input, + model_directory_inputs, + model_directory_output, + model_directory_outputs, + parameter, + table_input, + table_inputs, +) +from .essential import Role + +T1 = TypeVar("T1") +T2 = TypeVar("T2") + + +def union(f1: Callable[..., Type[T1]], f2: Callable[..., Type[T2]]): + def wrapper(roles: Optional[List[Role]] = None, desc="", optional=False) -> "Type[T1] | Type[T2]": + return f1(roles, desc, optional) | f2(optional=optional) + + return wrapper + + +__all__ = [ + "component", + "parameter", + "dataframe_input", + "dataframe_output", + "dataframe_inputs", + "dataframe_outputs", + "table_input", + "table_inputs", + "data_directory_input", + "data_directory_output", + "data_directory_outputs", + "data_directory_inputs", + "json_model_output", + "json_model_outputs", + "json_model_input", + "json_model_inputs", + "model_directory_inputs", + "model_directory_outputs", + "model_directory_output", + "model_directory_input", +] diff --git a/python/fate/components/core/_cpn_search.py b/python/fate/components/core/_cpn_search.py new file mode 100644 index 0000000000..614ae72ae0 --- /dev/null +++ b/python/fate/components/core/_cpn_search.py @@ -0,0 +1,87 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import typing +from typing import Optional + +if typing.TYPE_CHECKING: + from fate.components.core import Component, Stage + +logger = logging.getLogger(__name__) + + +def load_component(cpn_name: str, stage: Optional["Stage"] = None): + from fate.components.components import LazyBuildInComponentsLoader + + # from build in + cpn = None + lazy_build_in_components_loader = LazyBuildInComponentsLoader() + if lazy_build_in_components_loader.contains(cpn_name): + cpn = lazy_build_in_components_loader.load_cpn(cpn_name) + else: + # from entrypoint + import pkg_resources + + for cpn_ep in pkg_resources.iter_entry_points(group="fate.ext.component_desc"): + try: + candidate_cpn: "Component" = cpn_ep.load() + candidate_cpn_name = candidate_cpn.name + except Exception as e: + logger.warning( + f"register cpn from entrypoint(named={cpn_ep.name}, module={cpn_ep.module_name}) failed: {e}" + ) + continue + if candidate_cpn_name == cpn_name: + cpn = candidate_cpn + break + if cpn is None: + raise RuntimeError(f"could not find registered cpn named `{cpn_name}`") + if stage is not None: + cpn = load_stage_component(cpn, stage) + return cpn + + +def load_stage_component(cpn, stage: "Stage"): + if not stage.is_default: + for stage_component in cpn.stage_components: + if stage_component.name == stage.name: + cpn = stage_component + break + else: + supported_stage_names = [stage_component.name for stage_component in cpn.stage_components] + raise ValueError( + f"stage `{stage.name}` not supported for component `{cpn.name}`, use one listed in: {supported_stage_names}" + ) + return cpn + + +def list_components(): + import pkg_resources + from fate.components.components import LazyBuildInComponentsLoader + + build_in_components = LazyBuildInComponentsLoader().list() + third_parties_components = [] + + for cpn_ep in pkg_resources.iter_entry_points(group="fate.ext.component_desc"): + try: + candidate_cpn = cpn_ep.load() + candidate_cpn_name = candidate_cpn.name + third_parties_components.append([candidate_cpn_name]) + except Exception as e: + logger.warning( + f"register cpn from entrypoint(named={cpn_ep.name}, module={cpn_ep.module_name}) failed: {e}" + ) + continue + return dict(buildin=build_in_components, thirdparty=third_parties_components) diff --git a/python/fate/components/loader/computing.py b/python/fate/components/core/_load_computing.py similarity index 83% rename from python/fate/components/loader/computing.py rename to python/fate/components/core/_load_computing.py index 7646ba72d6..3a78acc51f 100644 --- a/python/fate/components/loader/computing.py +++ b/python/fate/components/core/_load_computing.py @@ -12,9 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -def load_computing(computing): - from fate.components.spec.computing import ( - CustomComputingSpec, +def load_computing(computing, logger_config=None): + from fate.components.core.spec.computing import ( EggrollComputingSpec, SparkComputingSpec, StandaloneComputingSpec, @@ -23,7 +22,9 @@ def load_computing(computing): if isinstance(computing, StandaloneComputingSpec): from fate.arch.computing.standalone import CSession - return CSession(computing.metadata.computing_id, options=computing.metadata.options) + return CSession( + computing.metadata.computing_id, logger_config=logger_config, options=computing.metadata.options + ) if isinstance(computing, EggrollComputingSpec): from fate.arch.computing.eggroll import CSession diff --git a/python/fate/components/loader/device.py b/python/fate/components/core/_load_device.py similarity index 85% rename from python/fate/components/loader/device.py rename to python/fate/components/core/_load_device.py index d55782eabb..5b45a19d56 100644 --- a/python/fate/components/loader/device.py +++ b/python/fate/components/core/_load_device.py @@ -14,11 +14,11 @@ # limitations under the License. def load_device(device_spec): from fate.arch.unify import device - from fate.components.spec.device import CPUSpec, GPUSpec + from fate.components.core.spec.device import CPUSpec, GPUSpec if isinstance(device_spec, CPUSpec): return device.CPU if isinstance(device_spec, GPUSpec): return device.CUDA - raise ValueError(f"device `{device_spec}` not implemeted yet") + raise ValueError(f"device `{device_spec}` not implemented yet") diff --git a/python/fate/components/loader/federation.py b/python/fate/components/core/_load_federation.py similarity index 98% rename from python/fate/components/loader/federation.py rename to python/fate/components/core/_load_federation.py index fe2f89aac6..4419b36996 100644 --- a/python/fate/components/loader/federation.py +++ b/python/fate/components/core/_load_federation.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. def load_federation(federation, computing): - from fate.components.spec.federation import ( + from fate.components.core.spec.federation import ( OSXFederationSpec, PulsarFederationSpec, RabbitMQFederationSpec, diff --git a/python/fate/arch/tensor/_exception.py b/python/fate/components/core/_load_metric_handler.py similarity index 53% rename from python/fate/arch/tensor/_exception.py rename to python/fate/components/core/_load_metric_handler.py index 13ea8599c7..2ce73c672f 100644 --- a/python/fate/arch/tensor/_exception.py +++ b/python/fate/components/core/_load_metric_handler.py @@ -12,22 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -class OpsDispatchException(Exception): - ... -class OpsDispatchBadSignatureError(OpsDispatchException): - ... - - -class OpsDispatchBadDtypeError(OpsDispatchException): - ... - - -class OpsDispatchUnsupportedError(OpsDispatchException): - def __init__(self, method, distributed, device, dtype) -> None: - super().__init__(f"method={method}, distributed={distributed}, device={device}, dtype={dtype}") - - -class OpDispatchInvalidDevice(OpsDispatchException): - ... +def load_metric_handler(writer): + from fate.components.core.component_desc._metric import ( + ComponentMetricsFileHandler, + ComponentMetricsRestfulHandler, + JsonMetricFileWriter, + JsonMetricRestfulWriter, + ) + + if isinstance(writer, JsonMetricRestfulWriter): + return ComponentMetricsRestfulHandler(writer=writer) + elif isinstance(writer, JsonMetricFileWriter): + return ComponentMetricsFileHandler(writer=writer) + else: + raise ValueError(f"writer `{writer}` not allowed") diff --git a/python/fate/components/core/component_desc/__init__.py b/python/fate/components/core/component_desc/__init__.py new file mode 100644 index 0000000000..c1cec6a28c --- /dev/null +++ b/python/fate/components/core/component_desc/__init__.py @@ -0,0 +1,68 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# re-export +from ._component import Component, component +from ._component_io import ComponentExecutionIO +from ._parameter import parameter +from .artifacts import ( + data_directory_input, + data_directory_inputs, + data_directory_output, + data_directory_outputs, + dataframe_input, + dataframe_inputs, + dataframe_output, + dataframe_outputs, + json_metric_output, + json_metric_outputs, + json_model_input, + json_model_inputs, + json_model_output, + json_model_outputs, + model_directory_input, + model_directory_inputs, + model_directory_output, + model_directory_outputs, + table_input, + table_inputs, +) + +__all__ = [ + "component", + "Component", + "ComponentExecutionIO", + "parameter", + "dataframe_input", + "dataframe_output", + "dataframe_inputs", + "dataframe_outputs", + "table_input", + "table_inputs", + "data_directory_input", + "data_directory_output", + "data_directory_outputs", + "data_directory_inputs", + "json_model_output", + "json_model_outputs", + "json_model_input", + "json_model_inputs", + "model_directory_inputs", + "model_directory_outputs", + "model_directory_output", + "model_directory_input", + "json_metric_output", + "json_metric_outputs", +] diff --git a/python/fate/components/core/component_desc/_component.py b/python/fate/components/core/component_desc/_component.py new file mode 100644 index 0000000000..0651339757 --- /dev/null +++ b/python/fate/components/core/component_desc/_component.py @@ -0,0 +1,386 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Copyright 2014 Pallets + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: + +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +# TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +""" +use decorators to define component_desc for FATE. +flowing codes modified from [click](https://github.com/pallets/click) project +""" + +import inspect +import logging +from typing import List, Optional + +from fate.components.core.essential import ( + CROSS_VALIDATION, + DEFAULT, + PREDICT, + TRAIN, + Role, + Stage, +) + +from ._component_artifact import ArtifactDescribeAnnotation, ComponentArtifactDescribes +from ._parameter import ComponentParameterDescribes, ParameterDescribeAnnotation + +logger = logging.getLogger(__name__) + + +class Component: + def __init__( + self, + name: str, + roles: List[Role], + provider, + version, + description, + callback, + parameters: ComponentParameterDescribes, + artifacts: ComponentArtifactDescribes, + is_subcomponent: bool = False, + ) -> None: + self.is_subcomponent = is_subcomponent + self.name = name + self.roles = roles + self.provider = provider + self.version = version + self.description = description + self.callback = callback + self.parameters = parameters + if not self.description: + self.description = "" + self.artifacts = artifacts + self.func_args = list(inspect.signature(self.callback).parameters.keys()) + self.stage_components: List[Component] = [] + + def execute(self, ctx, role, **kwargs): + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"execution arguments: {kwargs}") + return self.callback(ctx, role, **kwargs) + + def dict(self): + return self._flatten_stages()._dict() + + def _flatten_stages(self) -> "Component": + merged_parameters = self.parameters + merged_artifacts = self.artifacts + for stage_cpn in self.stage_components: + stage_cpn = stage_cpn._flatten_stages() + merged_parameters = merged_parameters.merge(stage_cpn.parameters) + merged_artifacts = merged_artifacts.merge(stage_cpn.artifacts) + + return Component( + name=self.name, + roles=self.roles, + provider=self.provider, + version=self.version, + description=self.description, + callback=self.callback, + parameters=merged_parameters, + artifacts=merged_artifacts, + is_subcomponent=self.is_subcomponent, + ) + + def _dict(self): + from fate.components.core.spec.component import ComponentSpec, ComponentSpecV1 + + return ComponentSpecV1( + component=ComponentSpec( + name=self.name, + description=self.description, + provider=self.provider, + version=self.version, + labels=[], + roles=self.roles, + parameters=self.parameters.get_parameters_spec(), + input_artifacts=self.artifacts.get_inputs_spec(), + output_artifacts=self.artifacts.get_outputs_spec(), + ) + ) + + def _runtime_io_dict(self, runtime_role: Role, runtime_stage: Stage): + from fate.components.core.spec.component import ( + ArtifactTypeSpec, + ComponentIOArtifactsTypeSpec, + ComponentIOArtifactTypeSpec, + ComponentIOInputsArtifactsTypeSpec, + ComponentIOOutputsArtifactsTypeSpec, + ) + + def _get_io_artifact_type_spec(v): + return ComponentIOArtifactTypeSpec( + name=v.name, + is_multi=v.is_multi, + optional=v.optional, + types=[ + ArtifactTypeSpec( + type_name=v.get_type().type_name, + path_type=v.get_type().path_type, + uri_types=v.get_type().uri_types, + ) + for v in v.types + ], + ) + + return ComponentIOArtifactsTypeSpec( + inputs=ComponentIOInputsArtifactsTypeSpec( + data=[ + _get_io_artifact_type_spec(v) + for v in self.artifacts.data_inputs.values() + if v.is_active_for(runtime_stage, runtime_role) + ], + model=[ + _get_io_artifact_type_spec(v) + for v in self.artifacts.model_inputs.values() + if v.is_active_for(runtime_stage, runtime_role) + ], + ), + outputs=ComponentIOOutputsArtifactsTypeSpec( + data=[ + _get_io_artifact_type_spec(v) + for v in self.artifacts.data_outputs.values() + if v.is_active_for(runtime_stage, runtime_role) + ], + model=[ + _get_io_artifact_type_spec(v) + for v in self.artifacts.model_outputs.values() + if v.is_active_for(runtime_stage, runtime_role) + ], + metric=[ + _get_io_artifact_type_spec(v) + for v in self.artifacts.metric_outputs.values() + if v.is_active_for(runtime_stage, runtime_role) + ], + ), + ) + + def dump_runtime_io_yaml(self, role: Role, stage: Stage, stream=None): + from io import StringIO + + import ruamel.yaml + + inefficient = False + if stream is None: + inefficient = True + stream = StringIO() + yaml = ruamel.yaml.YAML() + yaml.indent(mapping=2, sequence=4, offset=2) + yaml.dump( + self._flatten_stages()._runtime_io_dict(runtime_role=role, runtime_stage=stage).dict(), stream=stream + ) + if inefficient: + return stream.getvalue() + + def dump_yaml(self, stream=None): + from io import StringIO + + import ruamel.yaml + + spec = self.dict() + inefficient = False + if stream is None: + inefficient = True + stream = StringIO() + yaml = ruamel.yaml.YAML() + yaml.indent(mapping=2, sequence=4, offset=2) + yaml.dump(spec.dict(), stream=stream) + if inefficient: + return stream.getvalue() + + def predict( + self, roles: List = None, provider: Optional[str] = None, version: Optional[str] = None, description=None + ): + + if roles is None: + roles = [] + + return self.stage(roles=roles, name=PREDICT.name, provider=provider, version=version, description=description) + + def train( + self, roles: List = None, provider: Optional[str] = None, version: Optional[str] = None, description=None + ): + + if roles is None: + roles = [] + + return self.stage(roles=roles, name=TRAIN.name, provider=provider, version=version, description=description) + + def cross_validation( + self, roles: List = None, provider: Optional[str] = None, version: Optional[str] = None, description=None + ): + + if roles is None: + roles = [] + + return self.stage( + roles=roles, name=CROSS_VALIDATION.name, provider=provider, version=version, description=description + ) + + def stage( + self, + roles: List = None, + name=None, + provider: Optional[str] = None, + version: Optional[str] = None, + description=None, + ): + r"""Creates a new stage component_desc with :class:`_Component` and uses the decorated function as + callback. This will also automatically attach all decorated + :func:`artifact`\s and :func:`parameter`\s as parameters to the component_desc execution. + + The stage name of the component_desc defaults to the name of the function. + If you want to change that, you can + pass the intended name as the first argument. + + Once decorated the function turns into a :class:`Component` instance + that can be invoked as a component_desc execution. + """ + if roles is None: + roles = [] + + def wrap(f): + sub_cpn = _component( + name, roles or self.roles, provider or self.provider, version or self.version, description, True + )(f) + self.stage_components.append(sub_cpn) + return sub_cpn + + return wrap + + +def component( + roles: List[Role], + name: Optional[str] = None, + provider: Optional[str] = None, + version: Optional[str] = None, + description: Optional[str] = None, +): + r"""Creates a new :class:`_Component` and uses the decorated function as + callback. This will also automatically attach all decorated + :func:`artifact`\s and :func:`parameter`\s as parameters to the component_desc execution. + + The name of the component_desc defaults to the name of the function. + If you want to change that, you can + pass the intended name as the first argument. + + Once decorated the function turns into a :class:`Component` instance + that can be invoked as a component_desc execution. + """ + from fate import __provider__, __version__ + + if version is None: + version = __version__ + if provider is None: + provider = __provider__ + return _component( + name=name, + roles=roles, + provider=provider, + version=version, + description=description, + is_subcomponent=False, + ) + + +def _component(name, roles, provider, version, description, is_subcomponent): + def decorator(f): + + cpn_name = name or f.__name__.lower() + if isinstance(f, Component): + raise TypeError("Attempted to convert a callback into a component_desc twice.") + parameters = ComponentParameterDescribes() + artifacts = ComponentArtifactDescribes() + signatures = list(inspect.signature(f).parameters.items()) + # first two arguments are ctx and role + if signatures[0][0] != "ctx": + raise ComponentDeclareError("bad component_desc definition, first argument should be `ctx`") + if signatures[1][0] != "role": + raise ComponentDeclareError("bad component_desc definition, second argument should be `role`") + + # check if all arguments are annotated + for k, v in signatures[2:]: + if isinstance(annotation := v.annotation, ArtifactDescribeAnnotation): + artifacts.add(annotation, k) + elif isinstance(annotation, ParameterDescribeAnnotation): + parameters.add_parameter( + name=k, + type=annotation.type, + default=annotation.default, + desc=annotation.desc, + optional=annotation.optional, + ) + else: + raise ComponentDeclareError(f"bad component_desc definition, argument {k} is not annotated") + + if is_subcomponent: + artifacts.update_roles_and_stages(stages=[Stage.from_str(cpn_name)], roles=roles) + else: + artifacts.update_roles_and_stages(stages=[DEFAULT], roles=roles) + desc = description + if desc is None: + desc = inspect.getdoc(f) + if isinstance(desc, bytes): + desc = desc.decode("utf-8") + else: + desc = inspect.cleandoc(desc) + cpn = Component( + name=cpn_name, + roles=roles, + provider=provider, + version=version, + description=desc, + callback=f, + parameters=parameters, + artifacts=artifacts, + is_subcomponent=is_subcomponent, + ) + cpn.__doc__ = f.__doc__ + return cpn + + return decorator + + +class ComponentDeclareError(Exception): + ... diff --git a/python/fate/components/core/component_desc/_component_artifact.py b/python/fate/components/core/component_desc/_component_artifact.py new file mode 100644 index 0000000000..5358d0e85f --- /dev/null +++ b/python/fate/components/core/component_desc/_component_artifact.py @@ -0,0 +1,259 @@ +import typing +from typing import Dict, List, Type, Union + +if typing.TYPE_CHECKING: + from fate.components.core import Role, Stage + + from .artifacts import ArtifactDescribe + from .artifacts.data import DataDirectoryArtifactDescribe, DataframeArtifactDescribe + from .artifacts.metric import JsonMetricArtifactDescribe + from .artifacts.model import ( + JsonModelArtifactDescribe, + ModelDirectoryArtifactDescribe, + ) + +T = typing.TypeVar("T") + + +class AllowArtifactDescribes(typing.Generic[T]): + def __init__(self, name, types: List[Type["ArtifactDescribe"]], roles, stages, desc, is_multi, optional): + self.name = name + self.types = types + self.roles = roles + self.stages = stages + self.desc = desc + self.is_multi = is_multi + self.optional = optional + + def update_roles(self, roles: List["Role"]): + if not self.roles: + self.roles = roles + + def update_stages(self, stages: List["Stage"]): + self.stages = stages + + def is_active_for(self, stage: "Stage", role: "Role"): + return stage in self.stages and role in self.roles + + def get_correct_arti(self, apply_spec) -> T: + for t in self.types: + if apply_spec.type_name is None or t.get_type().type_name == apply_spec.type_name: + return t( + name=self.name, + roles=self.roles, + stages=self.stages, + desc=self.desc, + multi=self.is_multi, + optional=self.optional, + ) + raise ValueError(f"no artifact describe for {apply_spec}") + + def dict(self): + from fate.components.core.spec.component import ArtifactSpec + + return ArtifactSpec( + types=[t.get_type().type_name for t in self.types], + optional=self.optional, + roles=self.roles, + stages=self.stages, + description=self.desc, + is_multi=self.is_multi, + ) + + def merge(self, a: "AllowArtifactDescribes"): + if len(self.types) != len(set(self.types).union(a.types)): + raise ValueError( + f"artifact {self.name} declare multiple times with different types: `{self.types}` vs `{a.types}`" + ) + if set(self.roles) != set(a.roles): + raise ValueError( + f"artifact {self.name} declare multiple times with different roles: `{self.roles}` vs `{a.roles}`" + ) + if self.optional != a.optional: + raise ValueError( + f"artifact {self.name} declare multiple times with different optional: `{self.optional}` vs `{a.optional}`" + ) + stages = set(self.stages) + stages.update(a.stages) + stages = list(stages) + return AllowArtifactDescribes( + name=self.name, + types=self.types, + roles=self.roles, + stages=stages, + desc=self.desc, + optional=self.optional, + is_multi=self.is_multi, + ) + + def __str__(self): + return f"AllowArtifactDescribes(name={self.name}, types={self.types}, roles={self.roles}, stages={self.stages}, desc={self.desc}, optional={self.optional}, is_multi={self.is_multi})" + + +class ArtifactDescribeAnnotation: + def __init__( + self, + describe_type: Type["ArtifactDescribe"], + describe_type_kind: str, + is_input: bool, + roles, + stages, + desc, + optional, + multi, + ): + self.is_input = is_input + self.describe_types = [describe_type] + self.describe_type_kind = describe_type_kind + self.roles = roles + self.stages = stages + self.desc = desc + self.optional = optional + self.multi = multi + + def __or__(self, other: "ArtifactDescribeAnnotation"): + if self.is_input != other.is_input: + raise ValueError("input and output can't be mixed") + if other.roles: + raise ValueError("second annotation should not provide roles") + if other.stages: + raise ValueError("second annotation should not provide stages") + if other.desc: + raise ValueError("second annotation should not provide desc") + if other.optional != self.optional: + raise ValueError("optional and non-optional can't be mixed") + if self.multi != other.multi: + raise ValueError("multi and non-multi can't be mixed") + if self.describe_type_kind != other.describe_type_kind: + raise ValueError(f"{self.describe_type_kind} and {other.describe_type_kind} can't be mixed") + self.describe_types.extend(other.describe_types) + return self + + def apply(self, name): + return AllowArtifactDescribes( + name=name, + types=self.describe_types, + roles=self.roles, + stages=self.stages, + desc=self.desc, + is_multi=self.multi, + optional=self.optional, + ) + + +class ComponentArtifactDescribes: + def __init__( + self, + data_inputs: Dict[ + str, AllowArtifactDescribes[Union["DataframeArtifactDescribe", "DataDirectoryArtifactDescribe"]] + ] = None, + model_inputs: Dict[ + str, AllowArtifactDescribes[Union["JsonModelArtifactDescribe", "ModelDirectoryArtifactDescribe"]] + ] = None, + data_outputs: Dict[ + str, AllowArtifactDescribes[Union["DataframeArtifactDescribe", "DataDirectoryArtifactDescribe"]] + ] = None, + model_outputs: Dict[ + str, AllowArtifactDescribes[Union["JsonModelArtifactDescribe", "ModelDirectoryArtifactDescribe"]] + ] = None, + ): + if data_inputs is None: + data_inputs = {} + if model_inputs is None: + model_inputs = {} + if data_outputs is None: + data_outputs = {} + if model_outputs is None: + model_outputs = {} + self.data_inputs = data_inputs + self.model_inputs = model_inputs + self.data_outputs = data_outputs + self.model_outputs = model_outputs + self.metric_outputs = {} + self._keys = ( + self.data_outputs.keys() + | self.model_outputs.keys() + | self.metric_outputs.keys() + | self.data_inputs.keys() + | self.model_inputs.keys() + ) + + # invisible artifact: metrics + from .artifacts import json_metric_output + + self.add(name="metric", annotation=json_metric_output([], desc="metric, invisible for user", optional=False)) + + def keys(self): + return self._keys + + def add(self, annotation: ArtifactDescribeAnnotation, name: str): + if name in self._keys: + raise ValueError(f"artifact {name} already exists") + self._keys.add(name) + if annotation.is_input: + if annotation.describe_type_kind == "data": + self.data_inputs[name] = annotation.apply(name) + elif annotation.describe_type_kind == "model": + self.model_inputs[name] = annotation.apply(name) + else: + raise ValueError(f"unknown artifact type {annotation.describe_type_kind}") + else: + if annotation.describe_type_kind == "data": + self.data_outputs[name] = annotation.apply(name) + elif annotation.describe_type_kind == "model": + self.model_outputs[name] = annotation.apply(name) + elif annotation.describe_type_kind == "metric": + self.metric_outputs[name] = annotation.apply(name) + else: + raise ValueError(f"unknown artifact type {annotation.describe_type_kind}") + + def update_roles_and_stages(self, stages, roles): + def _set_all(artifacts: Dict[str, "AllowArtifactDescribes"]): + for _, artifact in artifacts.items(): + artifact.update_stages(stages) + artifact.update_roles(roles) + + _set_all(self.data_inputs) + _set_all(self.model_inputs) + _set_all(self.data_outputs) + _set_all(self.model_outputs) + _set_all(self.metric_outputs) + + def merge(self, stage_artifacts: "ComponentArtifactDescribes"): + def _merge(a: Dict[str, "AllowArtifactDescribes"], b: Dict[str, "AllowArtifactDescribes"]): + result = {} + result.update(a) + for k, v in b.items(): + if k not in result: + result[k] = v + else: + result[k] = result[k].merge(v) + return result + + return ComponentArtifactDescribes( + data_inputs=_merge(self.data_inputs, stage_artifacts.data_inputs), + model_inputs=_merge(self.model_inputs, stage_artifacts.model_inputs), + data_outputs=_merge(self.data_outputs, stage_artifacts.data_outputs), + model_outputs=_merge(self.model_outputs, stage_artifacts.model_outputs), + ) + + def get_inputs_spec(self): + from fate.components.core.spec.component import InputDefinitionsSpec + + return InputDefinitionsSpec( + data={k: v.dict() for k, v in self.data_inputs.items()}, + model={k: v.dict() for k, v in self.model_inputs.items()}, + ) + + def get_outputs_spec(self): + from fate.components.core.spec.component import OutputDefinitionsSpec + + return OutputDefinitionsSpec( + data={k: v.dict() for k, v in self.data_outputs.items()}, + model={k: v.dict() for k, v in self.model_outputs.items()}, + metric={k: v.dict() for k, v in self.metric_outputs.items()}, + ) + + +class ComponentArtifactApplyError(RuntimeError): + ... diff --git a/python/fate/components/core/component_desc/_component_io.py b/python/fate/components/core/component_desc/_component_io.py new file mode 100644 index 0000000000..4c614fe34c --- /dev/null +++ b/python/fate/components/core/component_desc/_component_io.py @@ -0,0 +1,306 @@ +import logging +import typing +from typing import Dict, Generic, List, Optional, Union + +from fate.components.core.essential import Role, Stage + +from .artifacts._base_type import ( + AT, + MM, + ArtifactDescribe, + M, + _ArtifactsType, + _ArtifactType, +) + +if typing.TYPE_CHECKING: + from fate.arch import Context + + from ..spec.artifact import ( + ArtifactInputApplySpec, + ArtifactOutputApplySpec, + DataOutputMetadata, + Metadata, + MetricOutputMetadata, + ModelOutputMetadata, + ) + from ..spec.task import TaskConfigSpec + from ._component import Component + +logger = logging.getLogger(__name__) + + +class ComponentExecutionIO: + class InputPair(Generic[MM]): + def __init__(self, artifact: Optional[Union[_ArtifactsType[MM], _ArtifactType[MM]]], reader): + self.artifact = artifact + self.reader = reader + + class OutputPair(Generic[MM]): + def __init__(self, artifact: Optional[Union[_ArtifactsType[MM], _ArtifactType[MM]]], writer): + self.artifact = artifact + self.writer = writer + + def __init__(self, ctx: "Context", component: "Component", role: Role, stage: Stage, config): + self.cpn = component + self.parameter_artifacts_desc = {} + self.parameter_artifacts_apply = {} + self.input_data: Dict[str, ComponentExecutionIO.InputPair[Metadata]] = {} + self.input_model: Dict[str, ComponentExecutionIO.InputPair[Metadata]] = {} + self.output_data: Dict[str, ComponentExecutionIO.OutputPair[DataOutputMetadata]] = {} + self.output_model: Dict[str, ComponentExecutionIO.OutputPair[ModelOutputMetadata]] = {} + self.output_metric: Dict[str, ComponentExecutionIO.OutputPair[MetricOutputMetadata]] = {} + + logging.debug(f"parse and apply component artifacts") + + for arg in component.func_args[2:]: + if not ( + self._handle_parameter(component, arg, config) + or self._handle_input(ctx, component, arg, stage, role, config) + or self._handle_output(ctx, component, arg, stage, role, config) + ): + raise ValueError(f"args `{arg}` not provided") + + self._handle_output(ctx, component, "metric", stage, role, config) + + def _handle_parameter(self, component, arg, config): + if parameter := component.parameters.mapping.get(arg): + apply_spec: ArtifactInputApplySpec = config.parameters.get(arg) + applied_parameter = parameter.apply(apply_spec) + logging.debug(f"apply parameter `{parameter.name}`: {parameter} -> {applied_parameter}") + self.parameter_artifacts_apply[parameter.name] = applied_parameter + return True + return False + + def _handle_input(self, ctx, component, arg, stage, role, config): + from fate.arch import URI + + for input_pair_dict, artifacts in [ + (self.input_data, component.artifacts.data_inputs), + (self.input_model, component.artifacts.model_inputs), + ]: + if allow_artifacts := artifacts.get(arg): + if allow_artifacts.is_active_for(stage, role): + apply_spec: Union[ + ArtifactInputApplySpec, List[ArtifactInputApplySpec] + ] = config.input_artifacts.get(arg) + if apply_spec is not None: + try: + if allow_artifacts.is_multi: + if not isinstance(apply_spec, list): + raise ComponentArtifactApplyError( + f"`{arg}` expected list of artifact, but single artifact get" + ) + readers = [] + for c in apply_spec: + uri = URI.from_string(c.uri) + arti = allow_artifacts.get_correct_arti(c) + readers.append(arti.get_reader(ctx, uri, c.metadata, arti.get_type().type_name)) + input_pair_dict[arg] = ComponentExecutionIO.InputPair( + artifact=_ArtifactsType([r.artifact for r in readers]), reader=readers + ) + else: + uri = URI.from_string(apply_spec.uri) + arti = allow_artifacts.get_correct_arti(apply_spec) + reader = arti.get_reader(ctx, uri, apply_spec.metadata, arti.get_type().type_name) + input_pair_dict[arg] = ComponentExecutionIO.InputPair( + artifact=reader.artifact, reader=reader + ) + except Exception as e: + raise ComponentArtifactApplyError( + f"load as input artifact({allow_artifacts}) error: {e}" + ) from e + elif allow_artifacts.optional: + input_pair_dict[arg] = ComponentExecutionIO.InputPair(artifact=None, reader=None) + else: + raise ComponentArtifactApplyError( + f"load as input artifact({allow_artifacts}) error: `{arg}` is not optional but None got" + ) + logger.debug( + f"apply artifact `{allow_artifacts.name}`: {apply_spec} -> {input_pair_dict[arg].reader}" + ) + return True + else: + logger.debug(f"skip artifact `{allow_artifacts.name}` for stage `{stage}` and role `{role}`") + input_pair_dict[arg] = ComponentExecutionIO.InputPair(artifact=None, reader=None) + return True + return False + + def _handle_output(self, ctx, component, arg, stage, role, config): + from fate.arch import URI + + for output_pair_dict, artifacts in [ + (self.output_data, component.artifacts.data_outputs), + (self.output_model, component.artifacts.model_outputs), + (self.output_metric, component.artifacts.metric_outputs), + ]: + + if allowed_artifacts := artifacts.get(arg): + if allowed_artifacts.is_active_for(stage, role): + apply_spec: ArtifactOutputApplySpec = config.output_artifacts.get(arg) + if apply_spec is not None: + try: + if allowed_artifacts.is_multi: + if not apply_spec.is_template(): + raise ComponentArtifactApplyError( + "template uri required for multiple output artifact" + ) + arti = allowed_artifacts.get_correct_arti(apply_spec) + writers = WriterGenerator(component, arg, config, ctx, arti, apply_spec) + output_pair_dict[arg] = ComponentExecutionIO.OutputPair( + artifact=writers.recorder, writer=writers + ) + + else: + if apply_spec.is_template(): + raise ComponentArtifactApplyError( + "template uri is not supported for non-multiple output artifact" + ) + arti = allowed_artifacts.get_correct_arti(apply_spec) + writer = arti.get_writer( + config, ctx, URI.from_string(apply_spec.uri), arti.get_type().type_name + ) + _update_source_meta(writer.artifact.metadata, config, arg) + _maybe_update_model_overview_meta(writer.artifact.metadata, self.cpn, config) + output_pair_dict[arg] = ComponentExecutionIO.OutputPair( + artifact=writer.artifact, writer=writer + ) + except Exception as e: + raise ComponentArtifactApplyError( + f"load as output artifact({allowed_artifacts}) error: {e}" + ) from e + elif allowed_artifacts.optional: + output_pair_dict[arg] = ComponentExecutionIO.OutputPair(artifact=None, writer=None) + else: + raise ComponentArtifactApplyError( + f"load as output artifact({allowed_artifacts}) error: apply_config is None but not optional" + ) + logger.debug( + f"apply artifact `{allowed_artifacts.name}`: {apply_spec} -> {output_pair_dict[arg].writer}" + ) + return True + else: + logger.debug(f"skip artifact `{allowed_artifacts.name}` for stage `{stage}` and role `{role}`") + output_pair_dict[arg] = ComponentExecutionIO.OutputPair(artifact=None, writer=None) + return True + return False + + def get_kwargs(self, with_metrics=False): + kwargs = {**self.parameter_artifacts_apply} + kwargs.update({k: v.reader for k, v in self.input_data.items()}) + kwargs.update({k: v.reader for k, v in self.input_model.items()}) + kwargs.update({k: v.writer for k, v in self.output_data.items()}) + kwargs.update({k: v.writer for k, v in self.output_model.items()}) + if with_metrics: + kwargs.update({k: v.writer for k, v in self.output_metric.items()}) + return kwargs + + def get_metric_writer(self): + return self.output_metric["metric"].writer + + def dump_io_meta(self) -> dict: + from fate.components.core.spec.artifact import IOArtifactMeta + + return IOArtifactMeta( + inputs=IOArtifactMeta.InputMeta( + data={k: v.artifact.dict() for k, v in self.input_data.items() if v.artifact is not None}, + model={k: v.artifact.dict() for k, v in self.input_model.items() if v.artifact is not None}, + ), + outputs=IOArtifactMeta.OutputMeta( + data={k: v.artifact.dict() for k, v in self.output_data.items() if v.artifact is not None}, + model={k: v.artifact.dict() for k, v in self.output_model.items() if v.artifact is not None}, + metric={k: v.artifact.dict() for k, v in self.output_metric.items() if v.artifact is not None}, + ), + ).dict(exclude_none=True) + + +class WriterGenerator: + def __init__( + self, + cpn, + name: str, + config, + ctx: "Context", + artifact_describe: "ArtifactDescribe[AT, M]", + apply_config: "ArtifactOutputApplySpec", + ): + self.name = name + self.cpn = cpn + self.config = config + self.ctx = ctx + self.artifact_describe = artifact_describe + self.apply_config = apply_config + + self.recorder = _ArtifactsType([]) + self.current = 0 + + def get_recorder(self): + return self.recorder + + def __iter__(self): + return self + + def __next__(self): + from fate.arch import URI + + uri = URI.from_string(self.apply_config.uri.format(index=self.current)) + writer = self.artifact_describe.get_writer( + self.config, self.ctx, uri, self.artifact_describe.get_type().type_name + ) + _update_source_meta(writer.artifact.metadata, self.config, self.name, self.current) + _maybe_update_model_overview_meta(writer.artifact.metadata, self.cpn, self.config) + self.recorder.artifacts.append(writer.artifact) + self.current += 1 + return writer + + def __str__(self): + return f"{self.__class__.__name__}({self.artifact_describe}, index={self.current}>" + + def __repr__(self): + return str(self) + + +class ComponentArtifactApplyError(RuntimeError): + ... + + +def _update_source_meta(metadata, config: "TaskConfigSpec", output_artifact_key, output_index=None): + from fate.components.core.spec.artifact import ArtifactSource + + metadata.source = ArtifactSource( + task_id=config.task_id, + party_task_id=config.party_task_id, + task_name=config.task_name, + component=config.component, + output_artifact_key=output_artifact_key, + output_index=output_index, + ) + + +def _maybe_update_model_overview_meta(metadata, cpn, config: "TaskConfigSpec"): + from fate.components.core.spec.artifact import ModelOutputMetadata + from fate.components.core.spec.model import ( + MLModelComponentSpec, + MLModelFederatedSpec, + MLModelPartiesSpec, + MLModelPartySpec, + MLModelSpec, + ) + + if not isinstance(metadata, ModelOutputMetadata): + return + + metadata.model_overview = MLModelSpec( + federated=MLModelFederatedSpec( + task_id=config.task_id, + parties=MLModelPartiesSpec( + guest=[p.partyid for p in config.conf.federation.metadata.parties.parties if p.role == "guest"], + host=[p.partyid for p in config.conf.federation.metadata.parties.parties if p.role == "host"], + arbiter=[p.partyid for p in config.conf.federation.metadata.parties.parties if p.role == "arbiter"], + ), + component=MLModelComponentSpec(name=cpn.name, provider=cpn.provider, version=cpn.version, metadata={}), + ), + party=MLModelPartySpec( + party_task_id=config.party_task_id, role=config.role, partyid=config.party_id, models=[] + ), + ) diff --git a/python/fate/components/core/component_desc/_metric.py b/python/fate/components/core/component_desc/_metric.py new file mode 100644 index 0000000000..3bf32251e6 --- /dev/null +++ b/python/fate/components/core/component_desc/_metric.py @@ -0,0 +1,33 @@ +from fate.arch.context._metrics import ( + BaseMetricsHandler, + InMemoryMetricsHandler, + OneTimeMetrics, + StepMetrics, +) + +from .artifacts.metric import JsonMetricFileWriter, JsonMetricRestfulWriter + + +class ComponentMetricsFileHandler(InMemoryMetricsHandler): + def __init__(self, writer: JsonMetricFileWriter) -> None: + self._writer = writer + super().__init__() + + def finalize(self): + self._writer.write(self.get_metrics()) + + +class ComponentMetricsRestfulHandler(BaseMetricsHandler): + def __init__(self, writer: JsonMetricRestfulWriter) -> None: + self._writer = writer + + def _log_step_metrics(self, metrics: "StepMetrics"): + record = metrics.to_record() + self._writer.write(record.dict()) + + def _log_one_time_metrics(self, metrics: "OneTimeMetrics"): + record = metrics.to_record() + self._writer.write(record.dict()) + + def finalize(self): + self._writer.close() diff --git a/python/fate/components/core/component_desc/_parameter.py b/python/fate/components/core/component_desc/_parameter.py new file mode 100644 index 0000000000..7300292db9 --- /dev/null +++ b/python/fate/components/core/component_desc/_parameter.py @@ -0,0 +1,121 @@ +import typing +from typing import Dict, TypeVar + +import pydantic + + +class ParameterDescribe: + def __init__(self, name, type, default, optional, desc) -> None: + self.name = name + self.type = type + self.default = default + self.optional = optional + self.desc = desc + + def __str__(self) -> str: + return f"Parameter" + + def merge(self, p: "ParameterDescribe"): + if self.default != p.default: + raise ComponentParameterDuplicateError( + f"parameter {p.name} declare multiple times with different default: `{self.default}` vs `{p.default}`" + ) + if self.optional != p.optional: + raise ComponentParameterDuplicateError( + f"parameter {p.name} declare multiple times with different optional: `{self.optional}` vs `{p.optional}`" + ) + # if str(self.type) != str(p.type) or self.type.__dict__ != p.type.__dict__: + if str(self.type) != str(p.type): + raise ComponentParameterDuplicateError( + f"parameter {p.name} declare multiple times with different type: `{self.type}({self.type.__dict__})` vs `{self.type}({self.type.__dict__})`" + ) + return self + + def get_parameter_spec(self): + from fate.components.core.params import Parameter + from fate.components.core.spec.component import ParameterSpec + + default = self.default if self.default is not ... else None + if not typing.get_origin(self.type) and issubclass(self.type, Parameter): # recommended + type_name = type(self.type).__name__ + if (schema := self.type.schema()) != NotImplemented: + type_meta = schema + else: + type_meta = pydantic.schema_of(self.type, title=type_name) + else: + type_name = getattr(self.type, "__name__", None) + if type_name is None: + type_name = str(self.type) + type_meta = pydantic.schema_of(self.type, title=type_name) + if self.default is not ...: + type_meta["default"] = self.default + type_meta["description"] = self.desc + + return ParameterSpec( + type=type_name, + type_meta=type_meta, + default=default, + optional=self.optional, + description=self.desc, + ) + + def apply(self, parameter_config): + from fate.components.core import params + + if parameter_config is not None: + try: + return params.parse(self.type, parameter_config) + except Exception as e: + raise ComponentParameterApplyError( + f"apply value `{parameter_config}` to parameter `{self.name}` failed: {e}" + ) from e + else: + if not self.optional: + raise ComponentParameterApplyError(f"parameter `{self.name}` required, declare: `{parameter_config}`") + else: + return self.default + + +class ComponentParameterDescribes: + def __init__(self, mapping: Dict[str, ParameterDescribe] = None) -> None: + self.mapping = mapping or {} + + def add_parameter(self, name, type, default, optional, desc): + if name in self.mapping: + raise ComponentParameterDuplicateError(f"parameter {name} declare multiple times") + self.mapping[name] = ParameterDescribe(name, type, default, optional, desc) + + def merge(self, pd: "ComponentParameterDescribes"): + parameter_mapping = self.mapping.copy() + for name, p in pd.mapping.items(): + if name not in parameter_mapping: + parameter_mapping[name] = p + else: + parameter_mapping[name].merge(p) + return ComponentParameterDescribes(parameter_mapping) + + def get_parameters_spec(self): + return {name: p.get_parameter_spec() for name, p in self.mapping.items()} + + +class ParameterDescribeAnnotation: + def __init__(self, type, default, optional, desc) -> None: + self.type = type + self.default = default + self.optional = optional + self.desc = desc + + +T = TypeVar("T") + + +def parameter(type: T, default=..., optional=True, desc="") -> T: + return ParameterDescribeAnnotation(type, default, optional, desc) + + +class ComponentParameterApplyError(RuntimeError): + ... + + +class ComponentParameterDuplicateError(RuntimeError): + ... diff --git a/python/fate/components/core/component_desc/artifacts/__init__.py b/python/fate/components/core/component_desc/artifacts/__init__.py new file mode 100644 index 0000000000..0e3ff1f9c6 --- /dev/null +++ b/python/fate/components/core/component_desc/artifacts/__init__.py @@ -0,0 +1,55 @@ +from ._base_type import ( + ArtifactDescribe, + DataArtifactDescribe, + MetricArtifactDescribe, + ModelArtifactDescribe, + _ArtifactType, +) +from .data import ( + data_directory_input, + data_directory_inputs, + data_directory_output, + data_directory_outputs, + dataframe_input, + dataframe_inputs, + dataframe_output, + dataframe_outputs, + table_input, + table_inputs, +) +from .metric import json_metric_output, json_metric_outputs +from .model import ( + json_model_input, + json_model_inputs, + json_model_output, + json_model_outputs, + model_directory_input, + model_directory_inputs, + model_directory_output, + model_directory_outputs, +) + +__all__ = [ + "_ArtifactType", + "ArtifactDescribe", + "json_model_input", + "json_model_inputs", + "json_model_output", + "json_model_outputs", + "model_directory_input", + "model_directory_inputs", + "model_directory_output", + "model_directory_outputs", + "dataframe_input", + "dataframe_inputs", + "dataframe_output", + "dataframe_outputs", + "table_input", + "table_inputs", + "data_directory_input", + "data_directory_inputs", + "data_directory_output", + "data_directory_outputs", + "json_metric_output", + "json_metric_outputs", +] diff --git a/python/fate/components/core/component_desc/artifacts/_base_type.py b/python/fate/components/core/component_desc/artifacts/_base_type.py new file mode 100644 index 0000000000..86dd2c56ce --- /dev/null +++ b/python/fate/components/core/component_desc/artifacts/_base_type.py @@ -0,0 +1,162 @@ +import inspect +import typing +from typing import Generic, List, Optional, Type, TypeVar, Union + +from fate.arch import URI +from fate.components.core.essential import Role, Stage +from fate.components.core.spec.artifact import ( + DataOutputMetadata, + Metadata, + MetricOutputMetadata, + ModelOutputMetadata, +) +from fate.components.core.spec.component import ArtifactSpec + +if typing.TYPE_CHECKING: + from fate.arch import Context + +M = typing.TypeVar("M", bound=Union[DataOutputMetadata, ModelOutputMetadata, MetricOutputMetadata]) + + +class _ArtifactTypeWriter(Generic[M]): + def __init__(self, ctx: "Context", artifact: "_ArtifactType[M]") -> None: + self.ctx = ctx + self.artifact = artifact + + def __str__(self): + return f"{self.__class__.__name__}({self.artifact})" + + def __repr__(self): + return self.__str__() + + +class _ArtifactTypeReader: + def __init__(self, ctx: "Context", artifact: "_ArtifactType[Metadata]") -> None: + self.ctx = ctx + self.artifact = artifact + + def __str__(self): + return f"{self.__class__.__name__}({self.artifact})" + + def __repr__(self): + return self.__str__() + + +MM = TypeVar("MM", bound=Union[Metadata, DataOutputMetadata, ModelOutputMetadata, MetricOutputMetadata]) + + +class _ArtifactType(Generic[MM]): + def __init__(self, uri: "URI", metadata: MM, type_name) -> None: + self.uri = uri + self.metadata = metadata + self.type_name = type_name + self._consumed = False + + def __str__(self): + return f"{self.__class__.__name__}(uri={self.uri}, metadata={self.metadata}, type_name={self.type_name})" + + def __repr__(self): + return self.__str__() + + def consumed(self): + self._consumed = True + return self + + def dict(self): + return { + "metadata": self.metadata, + "uri": self.uri.to_string(), + "type_name": self.type_name, + "consumed": self._consumed, + } + + +class _ArtifactsType(Generic[MM]): + def __init__(self, artifacts: List[_ArtifactType[MM]]): + self.artifacts = artifacts + + def __str__(self): + return f"{self.__class__.__name__}(artifacts={self.artifacts})" + + def __repr__(self): + return self.__str__() + + def dict(self): + return [artifact.dict() for artifact in self.artifacts] + + +AT = TypeVar("AT") + + +class ArtifactDescribe(Generic[AT, M]): + def __init__(self, name: str, roles: List[Role], stages: List[Stage], desc: str, optional: bool, multi: bool): + if roles is None: + roles = [] + if desc: + desc = inspect.cleandoc(desc) + + self.name = name + self.roles = roles + self.stages = stages + self.desc = desc + self.optional = optional + self.multi = multi + + def __str__(self) -> str: + return f"{self.__class__.__name__}(name={self.name}, type={self.get_type()}, roles={self.roles}, stages={self.stages}, optional={self.optional})" + + def dict(self): + return ArtifactSpec( + types=self.get_type().type_name, + optional=self.optional, + roles=self.roles, + stages=self.stages, + description=self.desc, + is_multi=self.multi, + ) + + @classmethod + def get_type(cls) -> AT: + raise NotImplementedError() + + def get_writer(self, config, ctx: "Context", uri: "URI", type_name: str) -> _ArtifactTypeWriter[M]: + raise NotImplementedError() + + def get_reader(self, ctx: "Context", uri: URI, metadata: Metadata, type_name: str) -> _ArtifactTypeReader: + raise NotImplementedError() + + +class DataArtifactDescribe(ArtifactDescribe[AT, M]): + ... + + +class ModelArtifactDescribe(ArtifactDescribe[AT, M]): + ... + + +class MetricArtifactDescribe(ArtifactDescribe[AT, M]): + ... + + +def _create_artifact_annotation( + is_input: bool, is_multi: bool, describe_type: Type[ArtifactDescribe], describe_type_kind: str +): + def f(roles: Optional[List[Role]] = None, desc="", optional=False): + from .._component_artifact import ArtifactDescribeAnnotation + + return ArtifactDescribeAnnotation( + describe_type=describe_type, + describe_type_kind=describe_type_kind, + is_input=is_input, + roles=roles, + stages=[], + desc=desc, + optional=optional, + multi=is_multi, + ) + + return f + + +class ComponentArtifactApplyError(RuntimeError): + ... diff --git a/python/fate/components/core/component_desc/artifacts/data/__init__.py b/python/fate/components/core/component_desc/artifacts/data/__init__.py new file mode 100644 index 0000000000..9cf11ce444 --- /dev/null +++ b/python/fate/components/core/component_desc/artifacts/data/__init__.py @@ -0,0 +1,62 @@ +from typing import Iterator, List, Optional, Type + +from .._base_type import Role, _create_artifact_annotation +from ._dataframe import DataframeArtifactDescribe, DataframeReader, DataframeWriter +from ._directory import ( + DataDirectoryArtifactDescribe, + DataDirectoryReader, + DataDirectoryWriter, +) +from ._table import TableArtifactDescribe, TableReader, TableWriter + + +def dataframe_input(roles: Optional[List[Role]] = None, desc="", optional=False) -> Type[DataframeReader]: + return _create_artifact_annotation(True, False, DataframeArtifactDescribe, "data")(roles, desc, optional) + + +def dataframe_inputs(roles: Optional[List[Role]] = None, desc="", optional=False) -> Type[List[DataframeReader]]: + return _create_artifact_annotation(True, True, DataframeArtifactDescribe, "data")(roles, desc, optional) + + +def dataframe_output(roles: Optional[List[Role]] = None, desc="", optional=False) -> Type[DataframeWriter]: + return _create_artifact_annotation(False, False, DataframeArtifactDescribe, "data")(roles, desc, optional) + + +def dataframe_outputs(roles: Optional[List[Role]] = None, desc="", optional=False) -> Type[Iterator[DataframeWriter]]: + return _create_artifact_annotation(False, True, DataframeArtifactDescribe, "data")(roles, desc, optional) + + +def table_input(roles: Optional[List[Role]] = None, desc="", optional=False) -> Type[TableReader]: + return _create_artifact_annotation(True, False, TableArtifactDescribe, "data")(roles, desc, optional) + + +def table_inputs(roles: Optional[List[Role]] = None, desc="", optional=False) -> Type[List[TableReader]]: + return _create_artifact_annotation(True, True, TableArtifactDescribe, "data")(roles, desc, optional) + + +def table_output(roles: Optional[List[Role]] = None, desc="", optional=False) -> Type[TableWriter]: + return _create_artifact_annotation(False, False, TableArtifactDescribe, "data")(roles, desc, optional) + + +def table_outputs(roles: Optional[List[Role]] = None, desc="", optional=False) -> Type[Iterator[TableWriter]]: + return _create_artifact_annotation(False, True, TableArtifactDescribe, "data")(roles, desc, optional) + + +def data_directory_input(roles: Optional[List[Role]] = None, desc="", optional=False) -> Type[DataDirectoryReader]: + return _create_artifact_annotation(True, False, DataDirectoryArtifactDescribe, "data")(roles, desc, optional) + + +def data_directory_inputs( + roles: Optional[List[Role]] = None, desc="", optional=False +) -> Type[List[DataDirectoryReader]]: + return _create_artifact_annotation(True, True, DataDirectoryArtifactDescribe, "data")(roles, desc, optional) + + +def data_directory_output(roles: Optional[List[Role]] = None, desc="", optional=False) -> Type[DataDirectoryWriter]: + return _create_artifact_annotation(False, False, DataDirectoryArtifactDescribe, "data")(roles, desc, optional) + + +def data_directory_outputs( + roles: Optional[List[Role]] = None, desc="", optional=False +) -> Type[Iterator[DataDirectoryWriter]]: + return _create_artifact_annotation(False, True, DataDirectoryArtifactDescribe, "data")(roles, desc, optional) diff --git a/python/fate/components/core/component_desc/artifacts/data/_dataframe.py b/python/fate/components/core/component_desc/artifacts/data/_dataframe.py new file mode 100644 index 0000000000..463936f53d --- /dev/null +++ b/python/fate/components/core/component_desc/artifacts/data/_dataframe.py @@ -0,0 +1,92 @@ +import logging +import typing + +from fate.components.core.essential import DataframeArtifactType + +from .._base_type import ( + ArtifactDescribe, + DataOutputMetadata, + Metadata, + _ArtifactType, + _ArtifactTypeReader, + _ArtifactTypeWriter, +) + +if typing.TYPE_CHECKING: + from fate.arch import URI + from fate.arch.dataframe import DataFrame + +logger = logging.getLogger(__name__) + + +class DataframeWriter(_ArtifactTypeWriter[DataOutputMetadata]): + def write(self, df: "DataFrame", name=None, namespace=None): + self.artifact.consumed() + logger.debug(f"start writing dataframe to artifact: {self.artifact}, name={name}, namespace={namespace}") + from fate.arch import dataframe + + if name is not None: + self.artifact.metadata.name = name + if namespace is not None: + self.artifact.metadata.namespace = namespace + + table = dataframe.serialize(self.ctx, df) + if "schema" not in self.artifact.metadata.metadata: + self.artifact.metadata.metadata["schema"] = {} + table.save( + uri=self.artifact.uri, + schema=self.artifact.metadata.metadata["schema"], + options=self.artifact.metadata.metadata.get("options", None), + ) + # save data overview + count = df.count() + samples = df.data_overview() + from fate.components.core.spec.artifact import DataOverview + + self.artifact.metadata.data_overview = DataOverview(count=count, samples=samples) + + logger.debug(f"write dataframe to artifact: {self.artifact}") + + +class DataframeReader(_ArtifactTypeReader): + def read(self) -> "DataFrame": + self.artifact.consumed() + logger.debug(f"start reading dataframe from artifact: {self.artifact}") + # if self.artifact.uri.scheme == "file": + # import inspect + # + # from fate.arch import dataframe + # + # kwargs = {} + # p = inspect.signature(dataframe.CSVReader.__init__).parameters + # parameter_keys = p.keys() + # for k, v in self.artifact.metadata.metadata.items(): + # if k in parameter_keys: + # kwargs[k] = v + # + # return dataframe.CSVReader(**kwargs).to_frame(self.ctx, self.artifact.uri.path) + + from fate.arch import dataframe + + table = self.ctx.computing.load( + uri=self.artifact.uri, + schema=self.artifact.metadata.metadata.get("schema", None), + options=self.artifact.metadata.metadata.get("options", None), + ) + df = dataframe.deserialize(self.ctx, table) + logger.debug(f"read dataframe from artifact: {self.artifact}") + return df + + +class DataframeArtifactDescribe(ArtifactDescribe[DataframeArtifactType, DataOutputMetadata]): + @classmethod + def get_type(cls): + return DataframeArtifactType + + def get_writer(self, config, ctx, uri: "URI", type_name: str) -> DataframeWriter: + from fate.components.core.spec.artifact import DataOutputMetadata + + return DataframeWriter(ctx, _ArtifactType(uri=uri, metadata=DataOutputMetadata(), type_name=type_name)) + + def get_reader(self, ctx, uri: "URI", metadata: "Metadata", type_name: str) -> DataframeReader: + return DataframeReader(ctx, _ArtifactType(uri, metadata=metadata, type_name=type_name)) diff --git a/python/fate/components/core/component_desc/artifacts/data/_directory.py b/python/fate/components/core/component_desc/artifacts/data/_directory.py new file mode 100644 index 0000000000..511a97ce93 --- /dev/null +++ b/python/fate/components/core/component_desc/artifacts/data/_directory.py @@ -0,0 +1,50 @@ +from pathlib import Path + +from fate.components.core.essential import DataDirectoryArtifactType + +from .._base_type import ( + URI, + ArtifactDescribe, + DataOutputMetadata, + Metadata, + _ArtifactType, + _ArtifactTypeReader, + _ArtifactTypeWriter, +) + + +class DataDirectoryWriter(_ArtifactTypeWriter[DataDirectoryArtifactType]): + def get_directory(self) -> Path: + self.artifact.consumed() + path = Path(self.artifact.uri.path) + path.mkdir(parents=True, exist_ok=True) + return path + + def write_metadata(self, metadata: dict, name=None, namespace=None): + self.artifact.metadata.metadata.update(metadata) + if name is not None: + self.artifact.metadata.name = name + if namespace is not None: + self.artifact.metadata.namespace = namespace + + +class DataDirectoryReader(_ArtifactTypeReader): + def get_directory(self) -> Path: + self.artifact.consumed() + path = Path(self.artifact.uri.path) + return path + + def get_metadata(self): + return self.artifact.metadata.metadata + + +class DataDirectoryArtifactDescribe(ArtifactDescribe[DataDirectoryArtifactType, DataOutputMetadata]): + @classmethod + def get_type(cls): + return DataDirectoryArtifactType + + def get_writer(self, config, ctx, uri: URI, type_name: str) -> DataDirectoryWriter: + return DataDirectoryWriter(ctx, _ArtifactType(uri=uri, metadata=DataOutputMetadata(), type_name=type_name)) + + def get_reader(self, ctx, uri: "URI", metadata: "Metadata", type_name: str) -> DataDirectoryReader: + return DataDirectoryReader(ctx, _ArtifactType(uri=uri, metadata=metadata, type_name=type_name)) diff --git a/python/fate/components/core/component_desc/artifacts/data/_table.py b/python/fate/components/core/component_desc/artifacts/data/_table.py new file mode 100644 index 0000000000..d7894fc256 --- /dev/null +++ b/python/fate/components/core/component_desc/artifacts/data/_table.py @@ -0,0 +1,50 @@ +import typing + +from fate.components.core.essential import TableArtifactType + +from .._base_type import ( + URI, + ArtifactDescribe, + DataOutputMetadata, + Metadata, + _ArtifactType, + _ArtifactTypeReader, + _ArtifactTypeWriter, +) + +if typing.TYPE_CHECKING: + from fate.arch import Context + + +class TableWriter(_ArtifactTypeWriter[DataOutputMetadata]): + def write(self, table): + self.artifact.consumed() + if "schema" not in self.artifact.metadata.metadata: + self.artifact.metadata.metadata["schema"] = {} + table.save( + uri=self.artifact.uri, + schema=self.artifact.metadata.metadata["schema"], + options=self.artifact.metadata.metadata.get("options", None), + ) + + +class TableReader(_ArtifactTypeReader): + def read(self): + self.artifact.consumed() + return self.ctx.computing.load( + uri=self.artifact.uri, + schema=self.artifact.metadata.metadata.get("schema", {}), + options=self.artifact.metadata.metadata.get("options", None), + ) + + +class TableArtifactDescribe(ArtifactDescribe[TableArtifactType, DataOutputMetadata]): + @classmethod + def get_type(cls): + return TableArtifactType + + def get_writer(self, config, ctx: "Context", uri: URI, type_name: str) -> TableWriter: + return TableWriter(ctx, _ArtifactType(uri, DataOutputMetadata(), type_name)) + + def get_reader(self, ctx: "Context", uri: "URI", metadata: "Metadata", type_name: str) -> TableReader: + return TableReader(ctx, _ArtifactType(uri, metadata, type_name)) diff --git a/python/fate/components/core/component_desc/artifacts/metric/__init__.py b/python/fate/components/core/component_desc/artifacts/metric/__init__.py new file mode 100644 index 0000000000..f913bff5a1 --- /dev/null +++ b/python/fate/components/core/component_desc/artifacts/metric/__init__.py @@ -0,0 +1,16 @@ +from typing import List, Optional, Type + +from .._base_type import Role, _create_artifact_annotation +from ._json import ( + JsonMetricArtifactDescribe, + JsonMetricFileWriter, + JsonMetricRestfulWriter, +) + + +def json_metric_output(roles: Optional[List[Role]] = None, desc="", optional=False) -> Type[JsonMetricFileWriter]: + return _create_artifact_annotation(False, False, JsonMetricArtifactDescribe, "metric")(roles, desc, optional) + + +def json_metric_outputs(roles: Optional[List[Role]] = None, desc="", optional=False) -> Type[JsonMetricFileWriter]: + return _create_artifact_annotation(False, True, JsonMetricArtifactDescribe, "metric")(roles, desc, optional) diff --git a/python/fate/components/core/component_desc/artifacts/metric/_json.py b/python/fate/components/core/component_desc/artifacts/metric/_json.py new file mode 100644 index 0000000000..a735b89cd5 --- /dev/null +++ b/python/fate/components/core/component_desc/artifacts/metric/_json.py @@ -0,0 +1,74 @@ +import json +import logging +import typing +from pathlib import Path +from typing import Dict, Optional, Union + +import requests +from fate.components.core.essential import JsonMetricArtifactType + +logger = logging.getLogger(__name__) + +from .._base_type import ( + URI, + ArtifactDescribe, + Metadata, + MetricOutputMetadata, + _ArtifactType, + _ArtifactTypeWriter, +) + +if typing.TYPE_CHECKING: + from fate.arch import Context + + +class JsonMetricFileWriter(_ArtifactTypeWriter[MetricOutputMetadata]): + def write(self, data, metadata: Optional[Dict] = None): + self.artifact.consumed() + path = Path(self.artifact.uri.path) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w") as fw: + json.dump(data, fw) + + if metadata is not None: + self.artifact.metadata.metadata = metadata + + +class JsonMetricRestfulWriter(_ArtifactTypeWriter[MetricOutputMetadata]): + def write(self, data): + self.artifact.consumed() + try: + output = requests.post(url=self.artifact.uri.original_uri, json=dict(data=[data])) + except Exception as e: + logger.error(f"write data `{data}` to {self.artifact.uri.original_uri} failed, error: {e}") + else: + logger.debug(f"write data `{data}` to {self.artifact.uri.original_uri} success, output: {output}") + + def write_metadata(self, metadata: Dict): + self.artifact.metadata.metadata = metadata + + def close(self): + pass + + +class JsonMetricArtifactDescribe(ArtifactDescribe[JsonMetricArtifactType, MetricOutputMetadata]): + @classmethod + def get_type(cls): + return JsonMetricArtifactType + + def get_writer( + self, config, ctx: "Context", uri: URI, type_name: str + ) -> Union[JsonMetricFileWriter, JsonMetricRestfulWriter]: + if uri.scheme == "http" or uri.scheme == "https": + return JsonMetricRestfulWriter( + ctx, _ArtifactType(uri=uri, metadata=MetricOutputMetadata(), type_name=type_name) + ) + elif uri.scheme == "file": + return JsonMetricFileWriter( + ctx, _ArtifactType(uri=uri, metadata=MetricOutputMetadata(), type_name=type_name) + ) + else: + raise ValueError(f"unsupported uri scheme: {uri.scheme}") + + def get_reader(self, ctx: "Context", uri: URI, metadata: Metadata, type_name: str): + raise NotImplementedError() diff --git a/python/fate/components/core/component_desc/artifacts/model/__init__.py b/python/fate/components/core/component_desc/artifacts/model/__init__.py new file mode 100644 index 0000000000..f7ac839f39 --- /dev/null +++ b/python/fate/components/core/component_desc/artifacts/model/__init__.py @@ -0,0 +1,45 @@ +from typing import Iterator, List, Optional, Type + +from .._base_type import Role, _create_artifact_annotation +from ._directory import ( + ModelDirectoryArtifactDescribe, + ModelDirectoryReader, + ModelDirectoryWriter, +) +from ._json import JsonModelArtifactDescribe, JsonModelReader, JsonModelWriter + + +def json_model_input(roles: Optional[List[Role]] = None, desc="", optional=False) -> Type[JsonModelReader]: + return _create_artifact_annotation(True, False, JsonModelArtifactDescribe, "model")(roles, desc, optional) + + +def json_model_inputs(roles: Optional[List[Role]] = None, desc="", optional=False) -> Type[List[JsonModelReader]]: + return _create_artifact_annotation(True, True, JsonModelArtifactDescribe, "model")(roles, desc, optional) + + +def json_model_output(roles: Optional[List[Role]] = None, desc="", optional=False) -> Type[JsonModelWriter]: + return _create_artifact_annotation(False, False, JsonModelArtifactDescribe, "model")(roles, desc, optional) + + +def json_model_outputs(roles: Optional[List[Role]] = None, desc="", optional=False) -> Type[Iterator[JsonModelWriter]]: + return _create_artifact_annotation(False, True, JsonModelArtifactDescribe, "model")(roles, desc, optional) + + +def model_directory_input(roles: Optional[List[Role]] = None, desc="", optional=False) -> Type[ModelDirectoryReader]: + return _create_artifact_annotation(True, False, ModelDirectoryArtifactDescribe, "model")(roles, desc, optional) + + +def model_directory_inputs( + roles: Optional[List[Role]] = None, desc="", optional=False +) -> Type[List[ModelDirectoryReader]]: + return _create_artifact_annotation(True, True, ModelDirectoryArtifactDescribe, "model")(roles, desc, optional) + + +def model_directory_output(roles: Optional[List[Role]] = None, desc="", optional=False) -> Type[ModelDirectoryWriter]: + return _create_artifact_annotation(False, False, ModelDirectoryArtifactDescribe, "model")(roles, desc, optional) + + +def model_directory_outputs( + roles: Optional[List[Role]] = None, desc="", optional=False +) -> Type[Iterator[ModelDirectoryWriter]]: + return _create_artifact_annotation(False, True, ModelDirectoryArtifactDescribe, "model")(roles, desc, optional) diff --git a/python/fate/components/core/component_desc/artifacts/model/_directory.py b/python/fate/components/core/component_desc/artifacts/model/_directory.py new file mode 100644 index 0000000000..aea1498dfe --- /dev/null +++ b/python/fate/components/core/component_desc/artifacts/model/_directory.py @@ -0,0 +1,64 @@ +import datetime +import typing +from pathlib import Path + +from fate.components.core.essential import ModelDirectoryArtifactType + +from .._base_type import ( + URI, + ArtifactDescribe, + Metadata, + ModelOutputMetadata, + _ArtifactType, + _ArtifactTypeReader, + _ArtifactTypeWriter, +) + +if typing.TYPE_CHECKING: + from fate.arch import Context + + +class ModelDirectoryWriter(_ArtifactTypeWriter[ModelOutputMetadata]): + def get_directory(self): + self.artifact.consumed() + path = Path(self.artifact.uri.path) + path.mkdir(parents=True, exist_ok=True) + + # update model overview + from fate.components.core.spec.model import MLModelModelSpec + + model_overview = self.artifact.metadata.model_overview + model_overview.party.models.append( + MLModelModelSpec( + name="", + created_time=datetime.datetime.now().isoformat(), + file_format=ModelDirectoryArtifactType.type_name, + metadata={}, + ) + ) + return self.artifact.uri.path + + def write_metadata(self, metadata: dict): + self.artifact.metadata.metadata = metadata + + +class ModelDirectoryReader(_ArtifactTypeReader): + def get_directory(self): + self.artifact.consumed() + path = Path(self.artifact.uri.path) + return path + + def get_metadata(self): + return self.artifact.metadata.metadata + + +class ModelDirectoryArtifactDescribe(ArtifactDescribe[ModelDirectoryArtifactType, ModelOutputMetadata]): + @classmethod + def get_type(cls): + return ModelDirectoryArtifactType + + def get_writer(self, config, ctx: "Context", uri: URI, type_name: str) -> ModelDirectoryWriter: + return ModelDirectoryWriter(ctx, _ArtifactType(uri=uri, metadata=ModelOutputMetadata(), type_name=type_name)) + + def get_reader(self, ctx: "Context", uri: URI, metadata: Metadata, type_name: str) -> ModelDirectoryReader: + return ModelDirectoryReader(ctx, _ArtifactType(uri=uri, metadata=metadata, type_name=type_name)) diff --git a/python/fate/components/core/component_desc/artifacts/model/_json.py b/python/fate/components/core/component_desc/artifacts/model/_json.py new file mode 100644 index 0000000000..d5fb69a07a --- /dev/null +++ b/python/fate/components/core/component_desc/artifacts/model/_json.py @@ -0,0 +1,71 @@ +import datetime +import json +import typing +from pathlib import Path + +from fate.components.core.essential import JsonModelArtifactType + +from .._base_type import ( + URI, + ArtifactDescribe, + Metadata, + ModelOutputMetadata, + _ArtifactType, + _ArtifactTypeReader, + _ArtifactTypeWriter, +) + +if typing.TYPE_CHECKING: + from fate.arch import Context + + +class JsonModelWriter(_ArtifactTypeWriter[ModelOutputMetadata]): + def write(self, data, metadata: dict = None): + self.artifact.consumed() + if not hasattr(self, "_has_write"): + setattr(self, "_has_write", True) + else: + raise RuntimeError(f"json model writer {self.artifact} has been written, cannot write again") + + path = Path(self.artifact.uri.path) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w") as fw: + json.dump(data, fw) + if metadata is None: + metadata = {} + self.artifact.metadata.metadata = metadata + + # update model overview + from fate.components.core.spec.model import MLModelModelSpec + + model_overview = self.artifact.metadata.model_overview + model_overview.party.models.append( + MLModelModelSpec( + name="", + created_time=datetime.datetime.now().isoformat(), + file_format="json", + metadata=metadata, + ) + ) + + +class JsonModelReader(_ArtifactTypeReader): + def read(self): + self.artifact.consumed() + try: + with open(self.artifact.uri.path, "r") as fr: + return json.load(fr) + except Exception as e: + raise RuntimeError(f"load json model named from {self.artifact} failed: {e}") + + +class JsonModelArtifactDescribe(ArtifactDescribe[JsonModelArtifactType, ModelOutputMetadata]): + @classmethod + def get_type(cls): + return JsonModelArtifactType + + def get_writer(self, config, ctx: "Context", uri: URI, type_name: str) -> JsonModelWriter: + return JsonModelWriter(ctx, _ArtifactType(uri=uri, metadata=ModelOutputMetadata(), type_name=type_name)) + + def get_reader(self, ctx: "Context", uri: URI, metadata: Metadata, type_name: str) -> JsonModelReader: + return JsonModelReader(ctx, _ArtifactType(uri=uri, metadata=metadata, type_name=type_name)) diff --git a/python/fate/components/core/essential/__init__.py b/python/fate/components/core/essential/__init__.py new file mode 100644 index 0000000000..5918878659 --- /dev/null +++ b/python/fate/components/core/essential/__init__.py @@ -0,0 +1,12 @@ +from ._artifact_type import ( + ArtifactType, + DataDirectoryArtifactType, + DataframeArtifactType, + JsonMetricArtifactType, + JsonModelArtifactType, + ModelDirectoryArtifactType, + TableArtifactType, +) +from ._label import Label +from ._role import ARBITER, GUEST, HOST, LOCAL, Role +from ._stage import CROSS_VALIDATION, DEFAULT, PREDICT, TRAIN, Stage diff --git a/python/fate/components/core/essential/_artifact_type.py b/python/fate/components/core/essential/_artifact_type.py new file mode 100644 index 0000000000..488032d476 --- /dev/null +++ b/python/fate/components/core/essential/_artifact_type.py @@ -0,0 +1,43 @@ +from typing import List + + +class ArtifactType: + type_name: str + path_type: str + uri_types: List[str] + + +class DataframeArtifactType(ArtifactType): + type_name = "dataframe" + path_type = "distributed" + uri_types = ["eggroll", "hdfs"] + + +class TableArtifactType(ArtifactType): + type_name = "table" + path_type = "distributed" + uri_types = ["eggroll", "hdfs"] + + +class DataDirectoryArtifactType(ArtifactType): + type_name = "data_directory" + path_type = "directory" + uri_types = ["file"] + + +class ModelDirectoryArtifactType(ArtifactType): + type_name = "model_directory" + path_type = "directory" + uri_types = ["file"] + + +class JsonModelArtifactType(ArtifactType): + type_name = "json_model" + path_type = "file" + uri_types = ["file"] + + +class JsonMetricArtifactType(ArtifactType): + type_name = "json_metric" + path_type = "file" + uri_types = ["file"] diff --git a/python/fate/components/loader/other.py b/python/fate/components/core/essential/_label.py similarity index 73% rename from python/fate/components/loader/other.py rename to python/fate/components/core/essential/_label.py index 248ee38040..1eba17662e 100644 --- a/python/fate/components/loader/other.py +++ b/python/fate/components/core/essential/_label.py @@ -12,13 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -def load_role(role: str): - from fate.components import Role - return Role(role) +class Label: + def __init__(self, name): + self.name = name -def load_stage(stage: str): - from fate.components import Stage + def __str__(self): + return f"Label<{self.name}>" - return Stage(stage) + def __repr__(self): + return f"Label<{self.name}>" + + +TRAINABLE = Label("trainable") diff --git a/python/fate/components/core/essential/_role.py b/python/fate/components/core/essential/_role.py new file mode 100644 index 0000000000..00b7b45de2 --- /dev/null +++ b/python/fate/components/core/essential/_role.py @@ -0,0 +1,76 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Union + +import pydantic + + +class Role(str): + def __init__(self, name) -> None: + self.name = name + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + field_schema.update(type="string", format="role") + + @classmethod + def __get_validators__(cls) -> "CallableGenerator": + yield cls.validate + + @classmethod + def validate(cls, value: Union[str]) -> str: + return value + + @property + def is_guest(self) -> bool: + return self.name == "guest" + + @property + def is_host(self) -> bool: + return self.name == "host" + + @property + def is_arbiter(self) -> bool: + return self.name == "arbiter" + + @property + def local(self) -> bool: + return self.name == "local" + + @classmethod + def from_str(cls, role: str): + if role == "local": + return LOCAL + if role == "guest": + return GUEST + elif role == "host": + return HOST + elif role == "arbiter": + return ARBITER + else: + raise ValueError(f"role {role} is not supported") + + def __str__(self): + return f"Role<{self.name}>" + + def __repr__(self): + return f"Role<{self.name}>" + + +GUEST = Role("guest") +HOST = Role("host") +ARBITER = Role("arbiter") +LOCAL = Role("local") diff --git a/python/fate/components/core/essential/_stage.py b/python/fate/components/core/essential/_stage.py new file mode 100644 index 0000000000..97bc43cef2 --- /dev/null +++ b/python/fate/components/core/essential/_stage.py @@ -0,0 +1,63 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class Stage: + def __init__(self, name: str) -> None: + self.name = name + + @property + def is_train(self): + return self.name == TRAIN.name + + @property + def is_predict(self): + return self.name == PREDICT.name + + @property + def is_train_eval(self): + return self.name == CROSS_VALIDATION.name + + @property + def is_default(self): + return self.name == DEFAULT.name + + def is_cross_validation(self): + return self.name == CROSS_VALIDATION.name + + @classmethod + def from_str(cls, stage: str): + if stage == "train": + return TRAIN + elif stage == "predict": + return PREDICT + elif stage == "cross_validation": + return CROSS_VALIDATION + elif stage == "default": + return DEFAULT + else: + raise ValueError(f"stage {stage} is not supported") + + def __str__(self): + return f"Stage<{self.name}>" + + def __repr__(self): + return f"Stage<{self.name}>" + + +TRAIN = Stage("train") +PREDICT = Stage("predict") +CROSS_VALIDATION = Stage("cross_validation") +DEFAULT = Stage("default") diff --git a/python/fate/components/core/params/__init__.py b/python/fate/components/core/params/__init__.py new file mode 100644 index 0000000000..a7654d88e2 --- /dev/null +++ b/python/fate/components/core/params/__init__.py @@ -0,0 +1,33 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from ._cipher import CipherParamType, PaillierCipherParam +from ._cv_param import CVParam, cv_param +from ._fields import Parameter, confloat, conint, jsonschema, parse, string_choice +from ._filter_param import ( + IVFilterParam, + ManualFilterParam, + StatisticFilterParam, + iv_filter_param, + manual_filter_param, + statistic_filter_param, +) +from ._he_param import he_param, HEParam +from ._init_param import InitParam, init_param +from ._learning_rate import LRSchedulerParam, lr_scheduler_param +from ._metrics import metrics_param, statistic_metrics_param, legal_percentile +from ._optimizer import OptimizerParam, optimizer_param +from ._penalty import penalty_param diff --git a/python/fate/components/core/params/_cipher.py b/python/fate/components/core/params/_cipher.py new file mode 100644 index 0000000000..601f5789c7 --- /dev/null +++ b/python/fate/components/core/params/_cipher.py @@ -0,0 +1,15 @@ +from typing import Literal, Union + +import pydantic + + +class PaillierCipherParam(pydantic.BaseModel): + method: Literal["paillier"] = "paillier" + key_length: pydantic.conint(gt=1024) = 1024 + + +class NoopCipher(pydantic.BaseModel): + method: Literal[None] + + +CipherParamType = Union[PaillierCipherParam, NoopCipher] diff --git a/python/fate/components/core/params/_cv_param.py b/python/fate/components/core/params/_cv_param.py new file mode 100644 index 0000000000..92b4c093b4 --- /dev/null +++ b/python/fate/components/core/params/_cv_param.py @@ -0,0 +1,14 @@ +import pydantic + +from ._fields import conint + + +class CVParam(pydantic.BaseModel): + n_splits: conint(gt=1) + shuffle: bool = False + random_state: int = None + + +def cv_param(): + namespace = {} + return type("CVParam", (CVParam,), namespace) diff --git a/python/fate/components/core/params/_fields.py b/python/fate/components/core/params/_fields.py new file mode 100644 index 0000000000..472ccc9b83 --- /dev/null +++ b/python/fate/components/core/params/_fields.py @@ -0,0 +1,102 @@ +import typing +from typing import Any, Optional, Type, TypeVar + +import pydantic + + +class Parameter: + @classmethod + def parse(cls, obj: Any): + return pydantic.parse_obj_as(cls, obj) + + @classmethod + def schema(cls): + return NotImplemented + + +T = TypeVar("T") + + +class _SmartUnion(pydantic.BaseModel.Config): + smart_union = True + + +def parse(type_: Type[T], obj: Any) -> T: + if not isinstance(type_, typing._GenericAlias) and issubclass(type_, Parameter): + return type_.parse(obj) + else: + # create_model to inject config + model = pydantic.create_model("parameter", __config__=_SmartUnion, p=(type_, ...)) + return pydantic.parse_obj_as(model, {"p": obj}).p + + +def jsonschema(type_: Type[T]): + return pydantic.schema_json_of(type_, indent=2) + + +class ConstrainedInt(pydantic.ConstrainedInt, Parameter): + ... + + +def conint( + *, + strict: bool = False, + gt: int = None, + ge: int = None, + lt: int = None, + le: int = None, + multiple_of: int = None, +) -> Type[int]: + namespace = dict(strict=strict, gt=gt, ge=ge, lt=lt, le=le, multiple_of=multiple_of) + return type("ConstrainedIntValue", (ConstrainedInt,), namespace) + + +class ConstrainedFloat(pydantic.ConstrainedFloat, Parameter): + ... + + +def confloat( + *, + strict: bool = False, + gt: float = None, + ge: float = None, + lt: float = None, + le: float = None, + multiple_of: float = None, + allow_inf_nan: Optional[bool] = None, +) -> Type[float]: + namespace = dict( + strict=strict, + gt=gt, + ge=ge, + lt=lt, + le=le, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + ) + return type("ConstrainedFloatValue", (ConstrainedFloat,), namespace) + + +class StringChoice(str, Parameter): + choice = set() + lower = True + + @classmethod + def __get_validators__(cls): + yield cls.string_choice_validator + + @classmethod + def string_choice_validator(cls, v): + allowed = {c.lower() for c in cls.choice} if cls.lower else cls.choice + provided = v.lower() if cls.lower else v + if provided in allowed: + return provided + raise ValueError(f"provided `{provided}` not in `{allowed}`") + + +def string_choice(choice, lower=True) -> Type[str]: + namespace = dict( + choice=choice, + lower=lower, + ) + return type("StringChoice", (StringChoice,), namespace) diff --git a/python/fate/components/core/params/_filter_param.py b/python/fate/components/core/params/_filter_param.py new file mode 100644 index 0000000000..e5174de3ef --- /dev/null +++ b/python/fate/components/core/params/_filter_param.py @@ -0,0 +1,103 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union, List + +import pydantic + +from ._fields import string_choice, Parameter, conint, confloat +from ._metrics import statistic_metrics_param, legal_percentile + + +class StandardFilterParam(pydantic.BaseModel, Parameter): + metrics: List[str] + + filter_type: List[string_choice({'threshold', 'top_k', 'top_percentile'})] = ['threshold'] + threshold: List[Union[confloat(ge=0.0, le=1.0), conint(ge=1)]] = [1.0] + take_high: List[bool] = [True] + + @pydantic.validator('metrics', 'filter_type', 'threshold', 'take_high', pre=True, allow_reuse=True) + def to_list(cls, v): + return v if isinstance(v, list) else [v] + + @pydantic.root_validator(pre=False) + def check_filter_param_length(cls, values): + max_length = max([len(x) for k, x in values.items()]) + for k, v in values.items(): + if len(v) == 1: + v *= max_length + assert len(v) == max_length, f"Length of {k}: {v} does not match " \ + f"max length {max_length} of (metrics, filter_type, threshold, take_high)." + return values + + +class FederatedStandardFilterParam(StandardFilterParam, Parameter): + host_filter_type: List[string_choice({'threshold', 'top_k', 'top_percentile'})] = ['threshold'] + host_threshold: List[Union[confloat(ge=0.0, le=1.0), conint(ge=1)]] = [1.0] + host_take_high: List[bool] = [True] + + select_federated: bool = True + + @pydantic.validator('host_filter_type', 'host_threshold', 'host_take_high', pre=True, allow_reuse=True) + def to_list(cls, v): + return v if isinstance(v, list) else [v] + + @pydantic.root_validator(pre=False) + def check_filter_param_length(cls, values): + select_values = {k: v for k, v in values.items() if k != 'select_federated'} + max_length = max([len(x) for k, x in select_values.items()]) + for k, v in select_values.items(): + if len(v) == 1: + v *= max_length + assert len(v) == max_length, f"Length of {k}: {v} does not match " \ + f"max length {max_length} of (metrics, filter_type, threshold, take_high)." + return values + + +class IVFilterParam(FederatedStandardFilterParam, Parameter): + metrics: List[string_choice({'iv'})] = ['iv'] + + +class StatisticFilterParam(StandardFilterParam, Parameter): + metrics: List[Union[statistic_metrics_param(), legal_percentile()]] = ["mean"] + + +class ManualFilterParam(pydantic.BaseModel, Parameter): + keep_col: List[str] = [] + filter_out_col: List[str] = [] + + @pydantic.root_validator(pre=False) + def no_intersection(cls, values): + filter_out_col = values.get('filter_out_col', []) + keep_col = values.get('keep_col', []) + intersection = set(filter_out_col).intersection(set(keep_col)) + if intersection: + raise ValueError(f"`keep_col` and `filter_out_col` share common elements: {intersection}") + return values + + +def iv_filter_param(): + namespace = {} + return type("IVFilterParam", (IVFilterParam,), namespace) + + +def statistic_filter_param(): + namespace = {} + return type("StatisticFilterParam", (StatisticFilterParam,), namespace) + + +def manual_filter_param(): + namespace = {} + return type("ManualFilterParam", (ManualFilterParam,), namespace) diff --git a/python/fate/components/core/params/_he_param.py b/python/fate/components/core/params/_he_param.py new file mode 100644 index 0000000000..d7abd1cebb --- /dev/null +++ b/python/fate/components/core/params/_he_param.py @@ -0,0 +1,29 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pydantic + +from ._fields import string_choice + + +class HEParam(pydantic.BaseModel): + kind: string_choice(["paillier", "ou", "mock"]) + key_length: int = 1024 + + +def he_param(): + namespace = {} + return type("HEParam", (HEParam,), namespace) diff --git a/python/fate/components/core/params/_init_param.py b/python/fate/components/core/params/_init_param.py new file mode 100644 index 0000000000..fbbf81770f --- /dev/null +++ b/python/fate/components/core/params/_init_param.py @@ -0,0 +1,39 @@ +from typing import Union + +import pydantic + +from ._fields import string_choice + + +class InitParam(pydantic.BaseModel): + method: string_choice(['zeros', 'ones', 'consts', 'random', 'random_uniform']) = 'zeros' + fill_val: Union[int, float] = 0.0 + fit_intercept: bool = True + random_state: int = None + + +def init_param(): + namespace = {} + return type("InitParam", (InitParam,), namespace) + + +""" +class OnesInitParam(pydantic.BaseModel): + method: Literal['ones'] + fit_intercept: bool = True + + +class ConstsInitParam(pydantic.BaseModel): + method: Literal['consts'] + fill_val: Union[int, float] + fit_intercept: bool = True + + +class RandomInitParam(pydantic.BaseModel): + method: Literal['random'] + fit_intercept: bool = True + + +InitParam = Union[ZerosInitParam, OnesInitParam, ConstsInitParam, RandomInitParam] + +""" diff --git a/python/fate/components/core/params/_learning_rate.py b/python/fate/components/core/params/_learning_rate.py new file mode 100644 index 0000000000..45b32b969f --- /dev/null +++ b/python/fate/components/core/params/_learning_rate.py @@ -0,0 +1,17 @@ +import pydantic + +from ._fields import StringChoice + + +class LRSchedulerType(StringChoice): + choice = {'constant', 'linear', 'step'} + + +class LRSchedulerParam(pydantic.BaseModel): + method: LRSchedulerType = 'constant' + scheduler_params: dict = None + + +def lr_scheduler_param(): + namespace = {} + return type("LRSchedulerParam", (LRSchedulerParam,), namespace) diff --git a/python/fate/components/core/params/_metrics.py b/python/fate/components/core/params/_metrics.py new file mode 100644 index 0000000000..0ce7d006e3 --- /dev/null +++ b/python/fate/components/core/params/_metrics.py @@ -0,0 +1,92 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Type + +from ._fields import StringChoice, Parameter + + +class Metrics(StringChoice): + choice = {} + + +class StatisticMetrics(StringChoice): + choice = {} + + +def statistic_metrics_param( + count=True, + sum=True, + min=True, + max=True, + mean=True, + median=True, + std=True, + var=True, + coe=True, + missing_count=True, + missing_ratio=True, + skewness=True, + kurtosis=True +) -> Type[str]: + choice = { + "count": count, + "sum": sum, + "max": max, + "min": min, + "mean": mean, + "median": median, + "std": std, + "var": var, + "coefficient_of_variation": coe, + "missing_count": missing_count, + "missing_ratio": missing_ratio, + "skewness": skewness, + "kurtosis": kurtosis, + } + namespace = dict( + choice={k for k, v in choice.items() if v}, + ) + return type("StatisticMetrics", (StatisticMetrics,), namespace) + + +def metrics_param(auc=True, ks=True, accuracy=True, mse=True) -> Type[str]: + choice = {"auc": auc, "ks": ks, "accuracy": accuracy, "mse": mse} + namespace = dict( + choice={k for k, v in choice.items() if v}, + ) + return type("Metrics", (Metrics,), namespace) + + +class LegalPercentile(str, Parameter): + legal_percentile = r"^(100|\d{1,2})%$" + + @classmethod + def __get_validators__(cls): + yield cls.percentile_validator + + @classmethod + def percentile_validator(cls, v): + if re.match(cls.legal_percentile, v): + return v + raise ValueError(f"provided `{v}` not in legal percentile format") + + +def legal_percentile() -> Type[str]: + namespace = dict( + legal_percentile=LegalPercentile.legal_percentile, + ) + return type("LegalPercentile", (LegalPercentile,), namespace) diff --git a/python/fate/components/core/params/_optimizer.py b/python/fate/components/core/params/_optimizer.py new file mode 100644 index 0000000000..6c8a7aec9e --- /dev/null +++ b/python/fate/components/core/params/_optimizer.py @@ -0,0 +1,20 @@ +from typing import Type + +import pydantic + +from ._fields import string_choice +from ._penalty import penalty_param + + +class OptimizerParam(pydantic.BaseModel): + method: string_choice( + ["sgd", "adadelta", "adagrad", "adam", "adamax", "adamw", "asgd", "nadam", "radam", "rmsprop", "rprop"] + ) = "sgd" + penalty: penalty_param(l1=True, l2=True, none=True) = "l2" + alpha: float = 1.0 + optimizer_params: dict + + +def optimizer_param() -> Type[OptimizerParam]: + namespace = {} + return type("OptimizerParam", (OptimizerParam,), namespace) diff --git a/python/fate/components/core/params/_penalty.py b/python/fate/components/core/params/_penalty.py new file mode 100644 index 0000000000..72b3b98a5f --- /dev/null +++ b/python/fate/components/core/params/_penalty.py @@ -0,0 +1,15 @@ +from typing import Type + +from ._fields import StringChoice + + +class Penalty(StringChoice): + choice = {} + + +def penalty_param(l1=True, l2=True, none=True) -> Type[str]: + choice = {"L1": l1, "L2": l2, "none": none} + namespace = dict( + choice={k for k, v in choice.items() if v}, + ) + return type("PenaltyValue", (Penalty,), namespace) diff --git a/python/fate/arch/context/io/model/__init__.py b/python/fate/components/core/spec/__init__.py similarity index 100% rename from python/fate/arch/context/io/model/__init__.py rename to python/fate/components/core/spec/__init__.py diff --git a/python/fate/components/core/spec/artifact.py b/python/fate/components/core/spec/artifact.py new file mode 100644 index 0000000000..426b2c0643 --- /dev/null +++ b/python/fate/components/core/spec/artifact.py @@ -0,0 +1,127 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime +import re +from typing import Dict, List, Optional, Union + +import pydantic + +from .model import ( + MLModelComponentSpec, + MLModelFederatedSpec, + MLModelModelSpec, + MLModelPartiesSpec, + MLModelPartySpec, + MLModelSpec, +) + +# see https://www.rfc-editor.org/rfc/rfc3986#appendix-B +# scheme = $2 +# authority = $4 +# path = $5 +# query = $7 +# fragment = $9 +_uri_regex = re.compile(r"^(([^:/?#]+):)?(//([^/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?") + + +class DataOverview(pydantic.BaseModel): + count: Optional[int] = None + samples: Optional[List] = None + + +class ArtifactSource(pydantic.BaseModel): + task_id: str + party_task_id: str + task_name: str + component: str + output_artifact_key: str + output_index: Optional[int] = None + + +class Metadata(pydantic.BaseModel): + metadata: dict = pydantic.Field(default_factory=dict) + name: Optional[str] = None + namespace: Optional[str] = None + source: Optional[ArtifactSource] = None + + +class ModelOutputMetadata(pydantic.BaseModel): + metadata: dict = pydantic.Field(default_factory=dict) + name: Optional[str] = None + namespace: Optional[str] = None + source: Optional[ArtifactSource] = None + model_overview: MLModelSpec = None + + class Config: + extra = "forbid" + + +class DataOutputMetadata(pydantic.BaseModel): + metadata: dict = pydantic.Field(default_factory=dict) + name: Optional[str] = None + namespace: Optional[str] = None + source: Optional[ArtifactSource] = None + data_overview: Optional[DataOverview] = None + + class Config: + extra = "forbid" + + +class MetricOutputMetadata(pydantic.BaseModel): + metadata: dict = pydantic.Field(default_factory=dict) + name: Optional[str] = None + namespace: Optional[str] = None + source: Optional[ArtifactSource] = None + + class Config: + extra = "forbid" + + +class ArtifactInputApplySpec(pydantic.BaseModel): + uri: str + metadata: Metadata + type_name: Optional[str] = None + + +class ArtifactOutputApplySpec(pydantic.BaseModel): + uri: str + _is_template: Optional[bool] = None + type_name: Optional[str] = None + + def is_template(self) -> bool: + return "{index}" in self.uri + + def _check_is_template(self) -> bool: + return "{index}" in self.uri + + @pydantic.validator("uri") + def _check_uri(cls, v, values) -> str: + if not _uri_regex.match(v): + raise pydantic.ValidationError(f"`{v}` is not valid uri") + return v + + +class IOArtifactMeta(pydantic.BaseModel): + class InputMeta(pydantic.BaseModel): + data: Dict[str, Union[List[Dict], Dict]] + model: Dict[str, Union[List[Dict], Dict]] + + class OutputMeta(pydantic.BaseModel): + data: Dict[str, Union[List[Dict], Dict]] + model: Dict[str, Union[List[Dict], Dict]] + metric: Dict[str, Union[List[Dict], Dict]] + + inputs: InputMeta + outputs: OutputMeta diff --git a/python/fate/components/core/spec/component.py b/python/fate/components/core/spec/component.py new file mode 100644 index 0000000000..c802833bbc --- /dev/null +++ b/python/fate/components/core/spec/component.py @@ -0,0 +1,110 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List, Literal, Optional + +from fate.components.core.essential import Label, Role, Stage +from pydantic import BaseModel + + +class ParameterSpec(BaseModel): + type: str + default: Optional[Any] + optional: bool + description: str = "" + type_meta: dict = {} + + +class ArtifactSpec(BaseModel): + types: List[str] + optional: bool + stages: Optional[List[Stage]] + roles: List[Role] + description: str = "" + is_multi: bool + + class Config: + arbitrary_types_allowed = True + + def dict(self, *args, **kwargs): + object_dict = super().dict(*args, **kwargs) + object_dict["roles"] = [r.name for r in self.roles] + object_dict["stages"] = [s.name for s in self.stages] + return object_dict + + +class InputDefinitionsSpec(BaseModel): + data: Dict[str, ArtifactSpec] + model: Dict[str, ArtifactSpec] + + +class OutputDefinitionsSpec(BaseModel): + data: Dict[str, ArtifactSpec] + model: Dict[str, ArtifactSpec] + metric: Dict[str, ArtifactSpec] + + +class ComponentSpec(BaseModel): + class Config: + arbitrary_types_allowed = True + + name: str + description: str + provider: str + version: str + labels: List[Label] + roles: List[Role] + parameters: Dict[str, ParameterSpec] + input_artifacts: InputDefinitionsSpec + output_artifacts: OutputDefinitionsSpec + + def dict(self, *args, **kwargs): + object_dict = super().dict(*args, **kwargs) + object_dict["roles"] = [r.name for r in self.roles] + object_dict["labels"] = [l.name for l in self.labels] + return object_dict + + +class ComponentSpecV1(BaseModel): + component: ComponentSpec + schema_version: str = "v1" + + +class ArtifactTypeSpec(BaseModel): + type_name: str + uri_types: List[str] + path_type: Literal["file", "directory", "distributed"] + + +class ComponentIOArtifactTypeSpec(BaseModel): + name: str + is_multi: bool + optional: bool + types: List[ArtifactTypeSpec] + + +class ComponentIOInputsArtifactsTypeSpec(BaseModel): + data: List[ComponentIOArtifactTypeSpec] + model: List[ComponentIOArtifactTypeSpec] + + +class ComponentIOOutputsArtifactsTypeSpec(BaseModel): + data: List[ComponentIOArtifactTypeSpec] + model: List[ComponentIOArtifactTypeSpec] + metric: List[ComponentIOArtifactTypeSpec] + + +class ComponentIOArtifactsTypeSpec(BaseModel): + inputs: ComponentIOInputsArtifactsTypeSpec + outputs: ComponentIOOutputsArtifactsTypeSpec diff --git a/python/fate/components/spec/computing.py b/python/fate/components/core/spec/computing.py similarity index 100% rename from python/fate/components/spec/computing.py rename to python/fate/components/core/spec/computing.py diff --git a/python/fate/components/spec/device.py b/python/fate/components/core/spec/device.py similarity index 100% rename from python/fate/components/spec/device.py rename to python/fate/components/core/spec/device.py diff --git a/python/fate/components/spec/federation.py b/python/fate/components/core/spec/federation.py similarity index 98% rename from python/fate/components/spec/federation.py rename to python/fate/components/core/spec/federation.py index 46e5c221d9..0eee05a157 100644 --- a/python/fate/components/spec/federation.py +++ b/python/fate/components/core/spec/federation.py @@ -18,7 +18,7 @@ class PartySpec(pydantic.BaseModel): - role: Literal["guest", "host", "arbiter"] + role: Literal["guest", "host", "arbiter", "local"] partyid: str def tuple(self): diff --git a/python/fate/components/core/spec/logger.py b/python/fate/components/core/spec/logger.py new file mode 100644 index 0000000000..b6608e6d5b --- /dev/null +++ b/python/fate/components/core/spec/logger.py @@ -0,0 +1,65 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import logging.config +import os +from typing import Optional + +import pydantic + + +class LoggerConfig(pydantic.BaseModel): + config: Optional[dict] = None + + def install(self, debug=False): + if debug or self.config is None: + level = os.getenv("DEBUG_MODE_LOG_LEVEL", "DEBUG") + try: + import rich.logging + + logging_class = "rich.logging.RichHandler" + logging_formatters = {} + handlers = { + "console": { + "class": logging_class, + "level": level, + "filters": [], + } + } + except ImportError: + logging_class = "logging.StreamHandler" + logging_formatters = { + "console": { + "format": "[%(levelname)s][%(asctime)-8s][%(process)s][%(module)s.%(funcName)s][line:%(lineno)d]: %(message)s" + } + } + handlers = { + "console": { + "class": logging_class, + "level": level, + "formatter": "console", + } + } + self.config = dict( + version=1, + formatters=logging_formatters, + handlers=handlers, + filters={}, + loggers={}, + root=dict(handlers=["console"], level="DEBUG"), + disable_existing_loggers=False, + ) + logging.config.dictConfig(self.config) diff --git a/python/fate/components/core/spec/metric.py b/python/fate/components/core/spec/metric.py new file mode 100644 index 0000000000..24e236fce6 --- /dev/null +++ b/python/fate/components/core/spec/metric.py @@ -0,0 +1,11 @@ +from typing import Dict, List, Optional, Union + +import pydantic + + +class Metric(pydantic.BaseModel): + name: str + type: Optional[str] + groups: List[Dict[str, int]] + step_axis: Optional[str] + data: Union[Dict, List] diff --git a/python/fate/components/spec/model.py b/python/fate/components/core/spec/model.py similarity index 98% rename from python/fate/components/spec/model.py rename to python/fate/components/core/spec/model.py index 29882d833b..e60cc28265 100644 --- a/python/fate/components/spec/model.py +++ b/python/fate/components/core/spec/model.py @@ -40,7 +40,7 @@ class MLModelFederatedSpec(pydantic.BaseModel): class MLModelModelSpec(pydantic.BaseModel): name: str - created_time: datetime + created_time: str file_format: str metadata: dict diff --git a/python/fate/components/spec/task.py b/python/fate/components/core/spec/task.py similarity index 63% rename from python/fate/components/spec/task.py rename to python/fate/components/core/spec/task.py index 48f36e1141..5b349f7678 100644 --- a/python/fate/components/spec/task.py +++ b/python/fate/components/core/spec/task.py @@ -12,11 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Union +import os +from typing import Any, Dict, List, Optional, Union import pydantic -from .artifact import ArtifactSpec +from .artifact import ArtifactInputApplySpec, ArtifactOutputApplySpec from .computing import EggrollComputingSpec, SparkComputingSpec, StandaloneComputingSpec from .device import CPUSpec, GPUSpec from .federation import ( @@ -26,16 +27,10 @@ RollSiteFederationSpec, StandaloneFederationSpec, ) -from .logger import CustomLogger, FlowLogger, PipelineLogger -from .mlmd import CustomMLMDSpec, FlowMLMDSpec, NoopMLMDSpec, PipelineMLMDSpec -from .output import OutputPoolConf +from .logger import LoggerConfig class TaskConfigSpec(pydantic.BaseModel): - class TaskInputsSpec(pydantic.BaseModel): - parameters: Dict[str, Any] = {} - artifacts: Dict[str, Union[ArtifactSpec, List[ArtifactSpec]]] = {} - class TaskConfSpec(pydantic.BaseModel): device: Union[CPUSpec, GPUSpec] computing: Union[StandaloneComputingSpec, EggrollComputingSpec, SparkComputingSpec] @@ -46,15 +41,28 @@ class TaskConfSpec(pydantic.BaseModel): PulsarFederationSpec, OSXFederationSpec, ] - logger: Union[PipelineLogger, FlowLogger, CustomLogger] - mlmd: Union[PipelineMLMDSpec, FlowMLMDSpec, NoopMLMDSpec, CustomMLMDSpec] - output: OutputPoolConf + logger: LoggerConfig + task_final_meta_path: pydantic.FilePath = pydantic.Field(default_factory=lambda: os.path.abspath(os.getcwd())) task_id: str party_task_id: str + task_name: str component: str role: str party_id: str stage: str = "default" - inputs: TaskInputsSpec = TaskInputsSpec(parameters={}, artifacts={}) + parameters: Dict[str, Any] = {} + input_artifacts: Dict[str, Optional[Union[List[ArtifactInputApplySpec], ArtifactInputApplySpec]]] = {} + output_artifacts: Dict[str, Optional[ArtifactOutputApplySpec]] = {} conf: TaskConfSpec + + +class TaskCleanupConfigSpec(pydantic.BaseModel): + computing: Union[StandaloneComputingSpec, EggrollComputingSpec, SparkComputingSpec] + federation: Union[ + StandaloneFederationSpec, + RollSiteFederationSpec, + RabbitMQFederationSpec, + PulsarFederationSpec, + OSXFederationSpec, + ] diff --git a/python/fate/components/cpn.py b/python/fate/components/cpn.py deleted file mode 100644 index 3cfceb61c6..0000000000 --- a/python/fate/components/cpn.py +++ /dev/null @@ -1,570 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Copyright 2014 Pallets - -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: - -# 1. Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. - -# 2. Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. - -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. - -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A -# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED -# TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING -# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# - -""" -use decorators to define component for FATE. -flowing codes modified from [click](https://github.com/pallets/click) project -""" - -import inspect -import logging -from typing import Any, Dict, List, Optional - -import pydantic -from fate.components import T_ROLE, T_STAGE, MetricArtifact, Role, Stage - - -class ComponentDeclarError(Exception): - ... - - -class ComponentApplyError(Exception): - ... - - -logger = logging.getLogger(__name__) - - -class _Component: - def __init__( - self, - name: str, - roles: List[T_ROLE], - provider, - version, - description, - callback, - parameters: List["_ParameterDeclareClass"], - artifacts: "_ComponentArtifacts", - is_subcomponent: bool = False, - ) -> None: - self.is_subcomponent = is_subcomponent - self.name = name - self.roles = roles - self.provider = provider - self.version = version - self.description = description - self.callback = callback - self.parameters = parameters - if not self.description: - self.description = "" - self.artifacts = artifacts - self.func_args = list(inspect.signature(self.callback).parameters.keys()) - self.stage_components: List[_Component] = [] - - def validate_declare(self): - # validate - if self.func_args[0] != "ctx": - raise ComponentDeclarError("bad component definition, first argument should be `ctx`") - if self.func_args[1] != "role": - raise ComponentDeclarError("bad component definition, second argument should be `role`") - - # assert parameters defined once - _defined = set() - for p in self.parameters: - if p.name in _defined: - raise ComponentDeclarError(f"parameter named `{p.name}` declared multiple times") - _defined.add(p.name) - - # validate func arguments - undeclared_func_parameters = set(self.func_args[2:]) - - def _check_and_remove(name, arg_type): - if name not in undeclared_func_parameters: - raise ComponentDeclarError( - f"{arg_type} named `{name}` declar in decorator, but not found in function's argument" - ) - undeclared_func_parameters.remove(name) - - for parameter in self.parameters: - _check_and_remove(parameter.name, "parameter") - for name in self.artifacts.get_artifacts(): - _check_and_remove(name, "artifact") - if undeclared_func_parameters: - raise ComponentDeclarError( - f"function's arguments `{undeclared_func_parameters}` lack of corresponding parameter or artifact decorator" - ) - - def execute(self, ctx, role, **kwargs): - if logger.isEnabledFor(logging.DEBUG): - logger.debug(f"execution arguments: {kwargs}") - return self.callback(ctx, role, **kwargs) - - def dict(self): - return self._flatten_stages()._dict() - - def _flatten_stages(self) -> "_Component": - parameter_mapping = {parameter.name: parameter for parameter in self.parameters} - merged_artifacts = self.artifacts - for stage_cpn in self.stage_components: - stage_cpn = stage_cpn._flatten_stages() - # merge parameters - for parameter in stage_cpn.parameters: - # update or error - if parameter.name not in parameter_mapping: - parameter_mapping[parameter.name] = parameter - else: - parameter_mapping[parameter.name].merge(parameter) - merged_artifacts = merged_artifacts.merge(stage_cpn.artifacts) - - return _Component( - name=self.name, - roles=self.roles, - provider=self.provider, - version=self.version, - description=self.description, - callback=self.callback, - parameters=list(parameter_mapping.values()), - artifacts=merged_artifacts, - is_subcomponent=self.is_subcomponent, - ) - - def _dict(self): - from fate.components import InputAnnotated, OutputAnnotated - from fate.components.spec.component import ( - ArtifactSpec, - ComponentSpec, - ComponentSpecV1, - InputDefinitionsSpec, - OutputDefinitionsSpec, - ParameterSpec, - ) - - input_artifacts = {} - output_artifacts = {} - for _, artifact in self.artifacts.get_artifacts().items(): - annotated = getattr(artifact.type, "__metadata__", [None])[0] - roles = artifact.roles or self.roles - if annotated == OutputAnnotated: - output_artifacts[artifact.name] = ArtifactSpec( - type=artifact.type.type, - optional=artifact.optional, - roles=roles, - stages=artifact.stages, - description=artifact.desc, - ) - elif annotated == InputAnnotated: - input_artifacts[artifact.name] = ArtifactSpec( - type=artifact.type.type, - optional=artifact.optional, - roles=roles, - stages=artifact.stages, - description=artifact.desc, - ) - else: - raise ValueError(f"bad artifact: {artifact}") - - input_parameters = {} - from fate.components.params import Parameter - - for parameter in self.parameters: - if isinstance(parameter.type, Parameter): # recomanded - type_name = type(parameter.type).__name__ - type_meta = parameter.type.dict() - else: - type_name = parameter.type.__name__ - type_meta = {} - - input_parameters[parameter.name] = ParameterSpec( - type=type_name, - type_meta=type_meta, - default=parameter.default, - optional=parameter.optional, - description=parameter.desc, - ) - - input_definition = InputDefinitionsSpec(parameters=input_parameters, artifacts=input_artifacts) - output_definition = OutputDefinitionsSpec(artifacts=output_artifacts) - component = ComponentSpec( - name=self.name, - description=self.description, - provider=self.provider, - version=self.version, - labels=[], - roles=self.roles, - input_definitions=input_definition, - output_definitions=output_definition, - ) - return ComponentSpecV1(component=component) - - def dump_yaml(self, stream=None): - from io import StringIO - - import ruamel.yaml - - spec = self.dict() - inefficient = False - if stream is None: - inefficient = True - stream = StringIO() - yaml = ruamel.yaml.YAML() - yaml.indent(mapping=2, sequence=4, offset=2) - yaml.dump(spec.dict(), stream=stream) - if inefficient: - return stream.getvalue() - - def predict(self, roles=[], provider: Optional[str] = None, version: Optional[str] = None, description=None): - from fate.components import PREDICT - - return self.stage(roles=roles, name=PREDICT.name, provider=provider, version=version, description=description) - - def train(self, roles=[], provider: Optional[str] = None, version: Optional[str] = None, description=None): - from fate.components import TRAIN - - return self.stage(roles=roles, name=TRAIN.name, provider=provider, version=version, description=description) - - def stage( - self, roles=[], name=None, provider: Optional[str] = None, version: Optional[str] = None, description=None - ): - r"""Creates a new stage component with :class:`_Component` and uses the decorated function as - callback. This will also automatically attach all decorated - :func:`artifact`\s and :func:`parameter`\s as parameters to the component execution. - - The stage name of the component defaults to the name of the function. - If you want to change that, you can - pass the intended name as the first argument. - - Once decorated the function turns into a :class:`Component` instance - that can be invoked as a component execution. - - :param name: the name of the component. This defaults to the function - name. - """ - - def wrap(f): - sub_cpn = _component( - name, roles or self.roles, provider or self.provider, version or self.version, description, True - )(f) - self.stage_components.append(sub_cpn) - return sub_cpn - - return wrap - - -def component( - roles: List[Role], - name: Optional[str] = None, - provider: Optional[str] = None, - version: Optional[str] = None, - description: Optional[str] = None, -): - r"""Creates a new :class:`_Component` and uses the decorated function as - callback. This will also automatically attach all decorated - :func:`artifact`\s and :func:`parameter`\s as parameters to the component execution. - - The name of the component defaults to the name of the function. - If you want to change that, you can - pass the intended name as the first argument. - - Once decorated the function turns into a :class:`Component` instance - that can be invoked as a component execution. - - :param name: the name of the component. This defaults to the function - name. - """ - from fate import __provider__, __version__ - - if version is None: - version = __version__ - if provider is None: - provider = __provider__ - component_roles = [r.name for r in roles] - return _component( - name=name, - roles=component_roles, - provider=provider, - version=version, - description=description, - is_subcomponent=False, - ) - - -def _component(name, roles, provider, version, description, is_subcomponent): - from fate.components import DEFAULT - - def decorator(f): - cpn_name = name or f.__name__.lower() - if isinstance(f, _Component): - raise TypeError("Attempted to convert a callback into a component twice.") - try: - parameters = f.__component_parameters__ - parameters.reverse() - del f.__component_parameters__ - except AttributeError: - parameters = [] - try: - artifacts = f.__component_artifacts__ - del f.__component_artifacts__ - except AttributeError: - artifacts = _ComponentArtifacts() - - if is_subcomponent: - artifacts.set_stages([cpn_name]) - else: - artifacts.set_stages([DEFAULT.name]) - desc = description - if desc is None: - desc = inspect.getdoc(f) - if isinstance(desc, bytes): - desc = desc.decode("utf-8") - else: - desc = inspect.cleandoc(desc) - cpn = _Component( - name=cpn_name, - roles=roles, - provider=provider, - version=version, - description=desc, - callback=f, - parameters=parameters, - artifacts=artifacts, - is_subcomponent=is_subcomponent, - ) - cpn.__doc__ = f.__doc__ - cpn.validate_declare() - return cpn - - return decorator - - -class _ArtifactDeclareClass(pydantic.BaseModel): - name: str - type: Any - roles: List[T_ROLE] - stages: List[T_STAGE] - desc: str - optional: bool - - def is_active_for(self, stage: Stage, role: Role): - if self.stages is not None and stage.name not in self.stages: - return False - if self.roles and role.name not in self.roles: - return False - return True - - def __str__(self) -> str: - return f"ArtifactDeclare" - - def merge(self, a: "_ArtifactDeclareClass"): - if set(self.roles) != set(a.roles): - raise ComponentDeclarError( - f"artifact {self.name} declare multiple times with different roles: `{self.roles}` vs `{a.roles}`" - ) - if self.optional != a.optional: - raise ComponentDeclarError( - f"artifact {self.name} declare multiple times with different optional: `{self.optional}` vs `{a.optional}`" - ) - if self.type != a.type: - raise ComponentDeclarError( - f"artifact {self.name} declare multiple times with different optional: `{self.type}` vs `{a.type}`" - ) - stages = set(self.stages) - stages.update(a.stages) - stages = list(stages) - return _ArtifactDeclareClass( - name=self.name, type=self.type, roles=self.roles, stages=stages, desc=self.desc, optional=self.optional - ) - - -class _ComponentArtifacts(pydantic.BaseModel): - class Artifacts(pydantic.BaseModel): - data_artifact: Dict[str, _ArtifactDeclareClass] = pydantic.Field(default_factory=dict) - model_artifact: Dict[str, _ArtifactDeclareClass] = pydantic.Field(default_factory=dict) - metric_artifact: Dict[str, _ArtifactDeclareClass] = pydantic.Field(default_factory=dict) - - def add_data(self, artifact): - self.data_artifact[artifact.name] = artifact - - def add_model(self, artifact): - self.model_artifact[artifact.name] = artifact - - def add_metric(self, artifact): - self.metric_artifact[artifact.name] = artifact - - def get_artifact(self, name): - return self.data_artifact.get(name) or self.model_artifact.get(name) or self.metric_artifact.get(name) - - def merge(self, stage_artifacts): - def _merge(a, b): - result = {} - result.update(a) - for k, v in b.items(): - if k not in result: - result[k] = v - else: - result[k] = result[k].merge(v) - return result - - data_artifact = _merge(self.data_artifact, stage_artifacts.data_artifact) - model_artifact = _merge(self.model_artifact, stage_artifacts.model_artifact) - metric_artifact = _merge(self.metric_artifact, stage_artifacts.metric_artifact) - return _ComponentArtifacts.Artifacts( - data_artifact=data_artifact, model_artifact=model_artifact, metric_artifact=metric_artifact - ) - - inputs: Artifacts = pydantic.Field(default_factory=Artifacts) - outputs: Artifacts = pydantic.Field(default_factory=Artifacts) - - def set_stages(self, stages): - def _set_all(artifacts: Dict[str, _ArtifactDeclareClass]): - for _, artifact in artifacts.items(): - artifact.stages = stages - - _set_all(self.inputs.data_artifact) - _set_all(self.inputs.model_artifact) - _set_all(self.inputs.metric_artifact) - _set_all(self.outputs.data_artifact) - _set_all(self.outputs.model_artifact) - _set_all(self.outputs.metric_artifact) - - def get_artifacts(self) -> Dict[str, _ArtifactDeclareClass]: - artifacts = {} - artifacts.update(self.inputs.data_artifact) - artifacts.update(self.inputs.model_artifact) - artifacts.update(self.inputs.metric_artifact) - artifacts.update(self.outputs.data_artifact) - artifacts.update(self.outputs.model_artifact) - artifacts.update(self.outputs.metric_artifact) - return artifacts - - def merge(self, stage_artifacts: "_ComponentArtifacts"): - return _ComponentArtifacts( - inputs=self.inputs.merge(stage_artifacts.inputs), outputs=self.outputs.merge(stage_artifacts.outputs) - ) - - -def artifact(name, type, roles: Optional[List[Role]] = None, desc="", optional=False): - """attaches an artifact to the component.""" - if roles is None: - artifact_roles = [] - else: - artifact_roles = [r.name for r in roles] - - def decorator(f): - description = desc - if description: - description = inspect.cleandoc(description) - if not hasattr(f, "__component_artifacts__"): - f.__component_artifacts__ = _ComponentArtifacts() - - from fate.components import ( - DatasetArtifact, - InputAnnotated, - ModelArtifact, - OutputAnnotated, - ) - - annotates = getattr(type, "__metadata__", [None]) - origin_type = getattr(type, "__origin__") - artifact_dec = _ArtifactDeclareClass( - name=name, type=type, roles=artifact_roles, stages=[], desc=description, optional=optional - ) - if InputAnnotated in annotates: - if issubclass(origin_type, DatasetArtifact): - f.__component_artifacts__.inputs.add_data(artifact_dec) - elif issubclass(origin_type, ModelArtifact): - f.__component_artifacts__.inputs.add_model(artifact_dec) - elif issubclass(origin_type, MetricArtifact): - f.__component_artifacts__.inputs.add_metric(artifact_dec) - else: - raise ValueError(f"bad artifact, name: `{name}`, type: `{type}`") - - elif OutputAnnotated in annotates: - if issubclass(origin_type, DatasetArtifact): - f.__component_artifacts__.outputs.add_data(artifact_dec) - elif issubclass(origin_type, ModelArtifact): - f.__component_artifacts__.outputs.add_model(artifact_dec) - elif issubclass(origin_type, MetricArtifact): - f.__component_artifacts__.outputs.add_metric(artifact_dec) - else: - raise ValueError(f"bad artifact, name: `{name}`, type: `{type}`") - else: - raise ValueError(f"bad artifact, name: `{name}`, type: `{type}`") - return f - - return decorator - - -class _ParameterDeclareClass: - def __init__(self, name, type, default, optional, desc) -> None: - self.name = name - self.type = type - self.default = default - self.optional = optional - self.desc = desc - - def __str__(self) -> str: - return f"Parameter" - - def merge(self, p: "_ParameterDeclareClass"): - if self.default != p.default: - raise ComponentDeclarError( - f"parameter {p.name} declare multiple times with different default: `{self.default}` vs `{p.default}`" - ) - if self.optional != p.optional: - raise ComponentDeclarError( - f"parameter {parameter.name} declare multiple times with different optional: `{self.optional}` vs `{p.optional}`" - ) - if self.type != p.type: - raise ComponentDeclarError( - f"parameter {parameter.name} declare multiple times with different type: `{self.type}` vs `{self.type}`" - ) - return self - - -def parameter(name, type, default=None, optional=True, desc=""): - """attaches an parameter to the component.""" - - def decorator(f): - description = desc - if description is not None: - description = inspect.cleandoc(description) - if not hasattr(f, "__component_parameters__"): - f.__component_parameters__ = [] - f.__component_parameters__.append(_ParameterDeclareClass(name, type, default, optional, desc)) - return f - - return decorator diff --git a/python/fate/components/entrypoint/cli/__init__.py b/python/fate/components/entrypoint/cli/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/fate/components/entrypoint/cli/component/__init__.py b/python/fate/components/entrypoint/cli/component/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/fate/components/entrypoint/cli/component/__main__.py b/python/fate/components/entrypoint/cli/component/__main__.py new file mode 100644 index 0000000000..f292387e94 --- /dev/null +++ b/python/fate/components/entrypoint/cli/component/__main__.py @@ -0,0 +1,20 @@ +import click +from fate.components.entrypoint.cli.component import ( + artifact_type_cli, + cleanup_cli, + desc_cli, + execute_cli, + list_cli, + task_schema_cli, +) + +component = click.Group(name="component") +component.add_command(execute_cli.execute) +component.add_command(cleanup_cli.cleanup) +component.add_command(desc_cli.desc) +component.add_command(list_cli.list) +component.add_command(artifact_type_cli.artifact_type) +component.add_command(task_schema_cli.task_schema) + +if __name__ == "__main__": + component(prog_name="python -m fate.components.entrypoint.cli.component") diff --git a/python/fate/components/entrypoint/cli/component/artifact_type_cli.py b/python/fate/components/entrypoint/cli/component/artifact_type_cli.py new file mode 100644 index 0000000000..f66743c092 --- /dev/null +++ b/python/fate/components/entrypoint/cli/component/artifact_type_cli.py @@ -0,0 +1,22 @@ +import click + + +@click.command() +@click.option("--name", type=str, required=True, help="component name") +@click.option("--role", type=str, required=True, help="component name") +@click.option("--stage", type=str, required=True, help="component name") +@click.option("--output-path", type=click.File("w", lazy=True), help="output path") +def artifact_type(name, role, stage, output_path): + from fate.components.core import Role, Stage, load_component + + role = Role.from_str(role) + stage = Stage.from_str(stage) + cpn = load_component(name, stage=stage) + if output_path: + cpn.dump_runtime_io_yaml(role, stage, output_path) + else: + print(cpn.dump_runtime_io_yaml(role, stage, output_path)) + + +if __name__ == "__main__": + artifact_type() diff --git a/python/fate/components/entrypoint/cli/component/cleanup_cli.py b/python/fate/components/entrypoint/cli/component/cleanup_cli.py new file mode 100644 index 0000000000..0a15bd890c --- /dev/null +++ b/python/fate/components/entrypoint/cli/component/cleanup_cli.py @@ -0,0 +1,41 @@ +import click + + +@click.command() +@click.option("--process-tag", required=False, help="unique id to identify this execution process") +@click.option("--config", required=False, type=click.File(), help="config path") +@click.option("--env-name", required=False, type=str, help="env name for config") +def cleanup(process_tag, config, env_name): + """cleanup""" + import traceback + + from fate.arch import Context + from fate.components.core import load_computing, load_federation + from fate.components.core.spec.task import TaskCleanupConfigSpec + from fate.components.entrypoint.utils import ( + load_config_from_env, + load_config_from_file, + ) + + configs = {} + configs = load_config_from_env(configs, env_name) + load_config_from_file(configs, config) + config = TaskCleanupConfigSpec.parse_obj(configs) + + try: + print("start cleanup") + computing = load_computing(config.computing) + federation = load_federation(config.federation, computing) + ctx = Context( + computing=computing, + federation=federation, + ) + ctx.destroy() + print("cleanup done") + except Exception as e: + traceback.print_exc() + raise e + + +if __name__ == "__main__": + cleanup() diff --git a/python/fate/components/entrypoint/cli/component/desc_cli.py b/python/fate/components/entrypoint/cli/component/desc_cli.py new file mode 100644 index 0000000000..d43443ce5c --- /dev/null +++ b/python/fate/components/entrypoint/cli/component/desc_cli.py @@ -0,0 +1,19 @@ +import click + + +@click.command() +@click.option("--name", required=True, help="name of component_desc") +@click.option("--save", type=click.File(mode="w", lazy=True), help="save desc output to specified file in yaml format") +def desc(name, save): + "generate component_desc describe config" + from fate.components.core import load_component + + cpn = load_component(name) + if save: + cpn.dump_yaml(save) + else: + print(cpn.dump_yaml()) + + +if __name__ == "__main__": + desc() diff --git a/python/fate/components/entrypoint/cli/component/execute_cli.py b/python/fate/components/entrypoint/cli/component/execute_cli.py new file mode 100644 index 0000000000..7dcf16fc6a --- /dev/null +++ b/python/fate/components/entrypoint/cli/component/execute_cli.py @@ -0,0 +1,161 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import typing + +import click + +if typing.TYPE_CHECKING: + from fate.components.core.spec.task import TaskConfigSpec + + +@click.command() +@click.option("--process-tag", required=False, help="unique id to identify this execution process") +@click.option("--config", required=False, type=click.File(), help="config path") +@click.option("--config-entrypoint", required=False, help="entrypoint to get config") +@click.option("--properties", "-p", multiple=True, help="properties config") +@click.option("--env-prefix", "-e", type=str, default="runtime.component_desc.", help="prefix for env config") +@click.option("--env-name", required=False, type=str, help="env name for config") +@click.option( + "--execution-final-meta-path", + type=click.Path(exists=False, dir_okay=False, writable=True, resolve_path=True), + default=os.path.join(os.getcwd(), "execution_final_meta.json"), + show_default=True, + help="path for execution meta generated by component when execution finished", +) +@click.option("--debug", is_flag=True, help="enable debug mode") +def execute( + process_tag, config, config_entrypoint, properties, env_prefix, env_name, execution_final_meta_path, debug +): + """ + execute component + """ + import logging + + from fate.components.core.spec.task import TaskConfigSpec + from fate.components.entrypoint.utils import ( + load_config_from_entrypoint, + load_config_from_env, + load_config_from_file, + load_config_from_properties, + load_properties, + load_properties_from_env, + ) + + "execute component_desc" + if config is None and config_entrypoint is None and not properties and env_name is None: + raise click.UsageError("at least one of config, config-entrypoint, properties, env-name should be provided") + + # parse properties + properties_items = {} + properties_items.update(load_properties(properties)) + properties_items.update(load_properties_from_env(env_prefix)) + + # parse config + configs = {} + load_config_from_env(configs, env_name) + load_config_from_entrypoint(configs, config_entrypoint) + load_config_from_file(configs, config) + load_config_from_properties(configs, properties_items) + + task_config = TaskConfigSpec.parse_obj(configs) + + # install logger + task_config.conf.logger.install(debug=debug) + logger = logging.getLogger(__name__) + logger.debug("logger installed") + logger.debug(f"task config: {task_config}") + + os.makedirs(os.path.dirname(execution_final_meta_path), exist_ok=True) + execute_component_from_config(task_config, execution_final_meta_path) + + +def execute_component_from_config(config: "TaskConfigSpec", output_path): + import json + import logging + import traceback + + from fate.arch import CipherKit, Context + from fate.arch.computing import profile_ends, profile_start + from fate.components.core import ( + ComponentExecutionIO, + Role, + Stage, + load_component, + load_computing, + load_device, + load_federation, + load_metric_handler, + ) + + logger = logging.getLogger(__name__) + logger.debug(f"logging final status to `{output_path}`") + try: + party_task_id = config.party_task_id + device = load_device(config.conf.device) + computing = load_computing(config.conf.computing, config.conf.logger.config) + federation = load_federation(config.conf.federation, computing) + cipher = CipherKit(device=device) + + ctx = Context( + device=device, + computing=computing, + federation=federation, + cipher=cipher, + ) + role = Role.from_str(config.role) + stage = Stage.from_str(config.stage) + logger.debug(f"component={config.component}, context={ctx}") + logger.debug("running...") + + # get correct component_desc/subcomponent handle stage + component = load_component(config.component, stage) + + # enable profiling + profile_start() + + # prepare + execution_io = ComponentExecutionIO(ctx, component, role, stage, config) + + # register metric handler + metrics_handler = load_metric_handler(execution_io.get_metric_writer()) + ctx.set_metric_handler(metrics_handler) + + # execute + component.execute(ctx, role, **execution_io.get_kwargs()) + + # finalize metric handler + metrics_handler.finalize() + # final execution io meta + execution_io_meta = execution_io.dump_io_meta() + try: + with open(output_path, "w") as fw: + json.dump(dict(status=dict(code=0), io_meta=execution_io_meta), fw, indent=4) + except Exception as e: + raise RuntimeError(f"failed to dump execution io meta to `{output_path}`: meta={execution_io_meta}") from e + + profile_ends() + logger.debug("done without error, waiting signal to terminate") + logger.debug("terminating, bye~") + + except Exception as e: + logger.error(e, exc_info=True) + with open(output_path, "w") as fw: + json.dump(dict(status=dict(code=-1, exceptions=traceback.format_exc())), fw) + raise e + + +if __name__ == "__main__": + execute() diff --git a/python/fate/components/entrypoint/cli/component/list_cli.py b/python/fate/components/entrypoint/cli/component/list_cli.py new file mode 100644 index 0000000000..c977df7201 --- /dev/null +++ b/python/fate/components/entrypoint/cli/component/list_cli.py @@ -0,0 +1,19 @@ +import click + + +@click.command() +@click.option("--save", type=click.File(mode="w", lazy=True), help="save list output to specified file in json format") +def list(save): + "list all components" + from fate.components.core import list_components + + if save: + import json + + json.dump(list_components(), save) + else: + print(list_components()) + + +if __name__ == "__main__": + list() diff --git a/python/fate/components/entrypoint/cli/component/task_schema_cli.py b/python/fate/components/entrypoint/cli/component/task_schema_cli.py new file mode 100644 index 0000000000..1c15e70f00 --- /dev/null +++ b/python/fate/components/entrypoint/cli/component/task_schema_cli.py @@ -0,0 +1,17 @@ +import click + + +@click.command() +@click.option("--save", type=click.File(mode="w", lazy=True), help="save desc output to specified file in yaml format") +def task_schema(save): + "generate component_desc task config json schema" + from fate.components.core.spec.task import TaskConfigSpec + + if save: + save.write(TaskConfigSpec.schema_json()) + else: + print(TaskConfigSpec.schema_json()) + + +if __name__ == "__main__": + task_schema() diff --git a/python/fate/components/entrypoint/cli/test/__init__.py b/python/fate/components/entrypoint/cli/test/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/fate/components/entrypoint/cli/test/__main__.py b/python/fate/components/entrypoint/cli/test/__main__.py new file mode 100644 index 0000000000..9de67136a6 --- /dev/null +++ b/python/fate/components/entrypoint/cli/test/__main__.py @@ -0,0 +1,8 @@ +import click +from fate.components.entrypoint.cli.test import execute + +test = click.Group(name="test") +test.add_command(execute.execute) + +if __name__ == "__main__": + test(prog_name="python -m fate.components.entrypoint.cli.test") diff --git a/python/fate/components/entrypoint/cli/test/execute.py b/python/fate/components/entrypoint/cli/test/execute.py new file mode 100644 index 0000000000..5a4c402297 --- /dev/null +++ b/python/fate/components/entrypoint/cli/test/execute.py @@ -0,0 +1,29 @@ +import os +import sys + +import click + + +@click.command() +@click.option("--config-path", type=click.Path(exists=True), required=True) +@click.option("--data-path", type=click.Path(exists=True), required=True) +@click.option("--properties", "-p", multiple=True, help="properties config") +def execute(config_path, data_path, properties): + """ + execute component from existing config file and data path, for debug purpose + Args: + config_path: + data_path: + + Returns: + + """ + os.environ["STANDALONE_DATA_PATH"] = str(data_path) + sys.argv = [__name__, "--config", f"{config_path}", "--debug"] + [f"--properties={p}" for p in properties] + from fate.components.entrypoint.cli.component.execute_cli import execute + + execute() + + +if __name__ == "__main__": + execute() diff --git a/python/fate/components/entrypoint/component.py b/python/fate/components/entrypoint/component.py deleted file mode 100644 index 53feb7fc29..0000000000 --- a/python/fate/components/entrypoint/component.py +++ /dev/null @@ -1,295 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import signal -import time -import traceback -from typing import Any, Dict - -from fate.arch.context import Context -from fate.components import params -from fate.components.cpn import ComponentApplyError, _Component -from fate.components.loader.artifact import load_artifact -from fate.components.loader.component import load_component -from fate.components.loader.computing import load_computing -from fate.components.loader.device import load_device -from fate.components.loader.federation import load_federation -from fate.components.loader.metric import load_metrics_handler -from fate.components.loader.mlmd import MLMD, load_mlmd -from fate.components.loader.model import ( - load_input_model_wrapper, - load_output_model_wrapper, -) -from fate.components.loader.other import load_role, load_stage -from fate.components.loader.output import OutputPool, load_pool -from fate.components.spec.task import TaskConfigSpec - -logger = logging.getLogger(__name__) - - -def execute_component(config: TaskConfigSpec): - party_task_id = config.party_task_id - mlmd = load_mlmd(config.conf.mlmd, party_task_id) - device = load_device(config.conf.device) - role = load_role(config.role) - stage = load_stage(config.stage) - metrics_handler = load_metrics_handler() - output_pool = load_pool(config.conf.output) - computing = load_computing(config.conf.computing) - federation = load_federation(config.conf.federation, computing) - ctx = Context( - context_name=party_task_id, - device=device, - computing=computing, - federation=federation, - metrics_handler=metrics_handler, - ) - logger.debug(f"component={config.component}, context={ctx}") - - # registe signal to handle sigterm - def gracefully_stop(signum, frame): - logger.debug(f"gracefully stop: signum={signum}") - try: - ctx.destroy() - except: - logger.debug(f"context destroy failed, skip") - finally: - import os - - os._exit(0) - - signal.signal(signal.SIGTERM, gracefully_stop) - - try: - logger.debug("running...") - mlmd.execution_status.log_excution_start() - component = load_component(config.component) - try: - if not stage.is_default: - # use sub component to handle stage - for stage_component in component.stage_components: - if stage_component.name == stage.name: - component = stage_component - break - else: - raise ValueError(f"stage `{stage.name}` for component `{component.name}` not supported") - - # load model wrapper - output_model_wrapper = load_output_model_wrapper( - config.task_id, config.party_task_id, component, config.role, config.party_id, config.conf.federation - ) - input_model_wrapper = load_input_model_wrapper() - # parse and validate parameters - input_parameters = parse_input_parameters(mlmd, component, config.inputs.parameters) - # parse and validate inputs - input_data_artifacts = parse_input_data(component, stage, role, config.inputs.artifacts) - input_model_artifacts = parse_input_model(component, stage, role, config.inputs.artifacts) - input_metric_artifacts = parse_input_metric(component, stage, role, config.inputs.artifacts) - # log output artifacts - for name, artifact in input_data_artifacts.items(): - if artifact is not None: - mlmd.io.log_input_artifact(name, artifact) - for name, artifact in input_metric_artifacts.items(): - if artifact is not None: - mlmd.io.log_input_artifact(name, artifact) - - # wrap model artifact - input_model_artifacts = { - key: value if value is None else input_model_wrapper.wrap(value, mlmd.io) - for key, value in input_model_artifacts.items() - } - - # fill in outputs - output_data_artifacts = parse_output_data(component, stage, role, output_pool) - output_model_artifacts = parse_output_model(component, stage, role, output_pool) - output_metric_artifacts = parse_output_metric(component, stage, role, output_pool) - - # wrap model artifact - output_model_artifacts = { - key: value if value is None else output_model_wrapper.wrap(value, mlmd.io) - for key, value in output_model_artifacts.items() - } - - # get execute key-word arguments - execute_kwargs = {} - execute_kwargs.update(input_parameters) - execute_kwargs.update(input_data_artifacts) - execute_kwargs.update(input_model_artifacts) - execute_kwargs.update(input_metric_artifacts) - execute_kwargs.update(output_data_artifacts) - execute_kwargs.update(output_model_artifacts) - execute_kwargs.update(output_metric_artifacts) - - # execute - component.execute(ctx, role, **execute_kwargs) - - # log output artifacts - for name, artifact in output_data_artifacts.items(): - if artifact is not None: - mlmd.io.log_output_data(name, artifact) - for name, artifact in output_metric_artifacts.items(): - if artifact is not None: - mlmd.io.log_output_metric(name, artifact) - - except Exception as e: - tb = traceback.format_exc() - mlmd.execution_status.log_excution_exception(dict(exception=str(e.args), traceback=tb)) - raise e - else: - mlmd.execution_status.log_excution_end() - except Exception as e: - logger.error(e, exc_info=True) - raise e - else: - logger.debug("done without error, waiting signal to terminate") - while not mlmd.execution_status.safe_terminate(): - time.sleep(0.5) - logger.debug("terminating, bye~") - finally: - - # protect process from `sigterm` when context destroying - def drop_sigterm(signum, frame): - logger.warning( - "component is cleaning, will stop in few seconds. Terminate now may cause some process not stop properly, please wait." - ) - - signal.signal(signal.SIGTERM, drop_sigterm) - logger.debug("stop and cleaning...") - ctx.destroy() - logger.debug("stop and clean finished") - - -def parse_input_parameters(mlmd: MLMD, cpn: _Component, input_parameters: Dict[str, Any]) -> dict: - execute_parameters = {} - name_parameter_mapping = {parameter.name: parameter for parameter in cpn.parameters} - for arg in cpn.func_args[2:]: - if parameter := name_parameter_mapping.get(arg): - parameter_apply = input_parameters.get(arg) - if parameter_apply is None: - if not parameter.optional: - raise ComponentApplyError(f"parameter `{arg}` required, declare: `{parameter}`") - else: - execute_parameters[parameter.name] = parameter.default - mlmd.io.log_input_parameter(parameter.name, parameter.default) - else: - try: - value = params.parse(parameter.type, parameter_apply) - except Exception as e: - raise ComponentApplyError(f"apply value `{parameter_apply}` to parameter `{arg}` failed:\n{e}") - execute_parameters[parameter.name] = value - mlmd.io.log_input_parameter(parameter.name, parameter_apply) - return execute_parameters - - -def parse_input_data(cpn: _Component, stage, role, input_artifacts) -> dict: - - execute_input_data = {} - for arg in cpn.func_args[2:]: - if arti := cpn.artifacts.inputs.data_artifact.get(arg): - execute_input_data[arg] = None - if arti.is_active_for(stage, role): - artifact_apply = input_artifacts.get(arg) - if artifact_apply is not None: - # try apply - try: - execute_input_data[arg] = load_artifact(artifact_apply, arti.type) - except Exception as e: - raise ComponentApplyError( - f"artifact `{arg}` with applying config `{artifact_apply}` can't apply to `{arti}`" - ) from e - continue - else: - if not arti.optional: - raise ComponentApplyError(f"artifact `{arg}` required, declare: `{arti}`") - return execute_input_data - - -def parse_input_model(cpn: _Component, stage, role, input_artifacts) -> dict: - - execute_input_model = {} - for arg in cpn.func_args[2:]: - if arti := cpn.artifacts.inputs.model_artifact.get(arg): - execute_input_model[arg] = None - if arti.is_active_for(stage, role): - artifact_apply = input_artifacts.get(arg) - if artifact_apply is not None: - # try apply - try: - execute_input_model[arg] = load_artifact(artifact_apply, arti.type) - except Exception as e: - raise ComponentApplyError( - f"artifact `{arg}` with applying config `{artifact_apply}` can't apply to `{arti}`" - ) from e - continue - else: - if not arti.optional: - raise ComponentApplyError(f"artifact `{arg}` required, declare: `{arti}`") - return execute_input_model - - -def parse_input_metric(cpn: _Component, stage, role, input_artifacts) -> dict: - - execute_input_metric = {} - for arg in cpn.func_args[2:]: - if arti := cpn.artifacts.inputs.metric_artifact.get(arg): - execute_input_metric[arg] = None - if arti.is_active_for(stage, role): - artifact_apply = input_artifacts.get(arg) - if artifact_apply is not None: - # try apply - try: - execute_input_metric[arg] = load_artifact(artifact_apply, arti.type) - except Exception as e: - raise ComponentApplyError( - f"artifact `{arg}` with applying config `{artifact_apply}` can't apply to `{arti}`" - ) from e - continue - else: - if not arti.optional: - raise ComponentApplyError(f"artifact `{arg}` required, declare: `{arti}`") - return execute_input_metric - - -def parse_output_data(cpn: _Component, stage, role, output_pool: OutputPool) -> dict: - - execute_output_data = {} - for arg in cpn.func_args[2:]: - if arti := cpn.artifacts.outputs.data_artifact.get(arg): - execute_output_data[arg] = None - if arti.is_active_for(stage, role): - execute_output_data[arg] = output_pool.create_data_artifact(arti.name) - return execute_output_data - - -def parse_output_model(cpn: _Component, stage, role, output_pool: OutputPool) -> dict: - - execute_output_model = {} - for arg in cpn.func_args[2:]: - if arti := cpn.artifacts.outputs.model_artifact.get(arg): - execute_output_model[arg] = None - if arti.is_active_for(stage, role): - execute_output_model[arg] = output_pool.create_model_artifact(arti.name) - return execute_output_model - - -def parse_output_metric(cpn: _Component, stage, role, output_pool: OutputPool) -> dict: - - execute_output_metrics = {} - for arg in cpn.func_args[2:]: - if arti := cpn.artifacts.outputs.metric_artifact.get(arg): - execute_output_metrics[arg] = None - if arti.is_active_for(stage, role): - execute_output_metrics[arg] = output_pool.create_metric_artifact(arti.name) - return execute_output_metrics diff --git a/python/fate/components/entrypoint/component_cli.py b/python/fate/components/entrypoint/component_cli.py deleted file mode 100644 index 8eed50aaf3..0000000000 --- a/python/fate/components/entrypoint/component_cli.py +++ /dev/null @@ -1,195 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import click - - -@click.group() -def component(): - """ - Manipulate components: execute, list, generate describe file - """ - - -@component.command() -@click.option("--process-tag", required=True, help="unique id to identify this execution process") -@click.option("--config", required=False, type=click.File(), help="config path") -@click.option("--config-entrypoint", required=False, help="enctypoint to get config") -@click.option("--properties", "-p", multiple=True, help="properties config") -@click.option("--env-prefix", "-e", type=str, default="runtime.component.", help="prefix for env config") -@click.option("--env-name", required=False, type=str, help="env name for config") -def execute(process_tag, config, config_entrypoint, properties, env_prefix, env_name): - "execute component" - import logging - - from fate.components.spec.task import TaskConfigSpec - - # parse properties - properties_items = {} - properties_items.update(load_properties(properties)) - properties_items.update(load_properties_from_env(env_prefix)) - - # parse config - configs = {} - load_config_from_env(configs, env_name) - load_config_from_entrypoint(configs, config_entrypoint) - load_config_from_file(configs, config) - load_config_from_properties(configs, properties_items) - - task_config = TaskConfigSpec.parse_obj(configs) - - # install logger - task_config.conf.logger.install() - logger = logging.getLogger(__name__) - logger.debug("logger installed") - logger.debug(f"task config: {task_config}") - - from fate.components.entrypoint.component import execute_component - - execute_component(task_config) - - -def load_properties(properties) -> dict: - properties_dict = {} - for property_item in properties: - k, v = property_item.split("=") - k = k.strip() - v = v.strip() - properties_dict[k] = v - return properties_dict - - -def load_properties_from_env(env_filter_prefix): - import os - - properties_dict = {} - if env_filter_prefix: - env_prefix_size = len(env_filter_prefix) - for k, v in os.environ.items(): - if k.startswith(env_filter_prefix): - property_key = k[env_prefix_size:] - if property_key: - properties_dict[property_key] = v - return properties_dict - - -def load_config_from_properties(configs, properties_dict): - for k, v in properties_dict.items(): - lens_and_setter = configs, None - - def _setter(d, k): - def _set(v): - d[k] = v - - return _set - - for s in k.split("."): - lens, _ = lens_and_setter - if not s.endswith("]"): - print("in", lens) - if lens.get(s) is None: - lens[s] = {} - lens_and_setter = lens[s], _setter(lens, s) - else: - name, index = s.rstrip("]").split("[") - index = int(index) - if lens.get(name) is None: - lens[name] = [] - lens = lens[name] - if (short_size := index + 1 - len(lens)) > 0: - lens.extend([None] * short_size) - lens[index] = {} - lens_and_setter = lens[index], _setter(lens, index) - _, setter = lens_and_setter - if setter is not None: - setter(v) - - -def load_config_from_file(configs, config_file): - from ruamel import yaml - - if config_file is not None: - configs.update(yaml.safe_load(config_file)) - return configs - - -def load_config_from_entrypoint(configs, config_entrypoint): - import requests - - if config_entrypoint is not None: - try: - resp = requests.get(config_entrypoint).json() - configs.update(resp["config"]) - except: - pass - return configs - - -def load_config_from_env(configs, env_name): - import os - from ruamel import yaml - - if env_name is not None and os.environ.get(env_name): - configs.update(yaml.safe_load(os.environ[env_name])) - return configs - - -@component.command() -@click.option("--name", required=True, help="name of component") -@click.option("--save", type=click.File(mode="w", lazy=True), help="save desc output to specified file in yaml format") -def desc(name, save): - "generate component describe config" - from fate.components.loader.component import load_component - - cpn = load_component(name) - if save: - cpn.dump_yaml(save) - else: - print(cpn.dump_yaml()) - - -@component.command() -@click.option("--save", type=click.File(mode="w", lazy=True), help="save desc output to specified file in yaml format") -def task_schema(save): - "generate component task config json schema" - from fate.components.spec.task import TaskConfigSpec - - if save: - save.write(TaskConfigSpec.schema_json()) - else: - print(TaskConfigSpec.schema_json()) - - -@component.command() -@click.option("--save", type=click.File(mode="w", lazy=True), help="save list output to specified file in json format") -def list(save): - "list all components" - from fate.components.loader.component import list_components - - if save: - import json - - json.dump(list_components(), save) - else: - print(list_components()) - - -@component.command() -@click.option("--db", required=True, type=str, help="mlmd db") -@click.option("--taskid", required=True, type=str, help="taskid") -def set_mlmd_finish(db, taskid): - from fate.arch.context._mlmd import MachineLearningMetadata - - mlmd = MachineLearningMetadata(metadata={"filename_uri": db}) - mlmd.set_task_safe_terminate_flag(taskid) diff --git a/python/fate/components/entrypoint/utils.py b/python/fate/components/entrypoint/utils.py new file mode 100644 index 0000000000..9a926dd329 --- /dev/null +++ b/python/fate/components/entrypoint/utils.py @@ -0,0 +1,83 @@ +def load_properties(properties) -> dict: + properties_dict = {} + for property_item in properties: + k, v = property_item.split("=") + k = k.strip() + v = v.strip() + properties_dict[k] = v + return properties_dict + + +def load_properties_from_env(env_filter_prefix): + import os + + properties_dict = {} + if env_filter_prefix: + env_prefix_size = len(env_filter_prefix) + for k, v in os.environ.items(): + if k.startswith(env_filter_prefix): + property_key = k[env_prefix_size:] + if property_key: + properties_dict[property_key] = v + return properties_dict + + +def load_config_from_properties(configs, properties_dict): + for k, v in properties_dict.items(): + lens_and_setter = configs, None + + def _setter(d, k): + def _set(v): + d[k] = v + + return _set + + for s in k.split("."): + lens, _ = lens_and_setter + if not s.endswith("]"): + if lens.get(s) is None: + lens[s] = {} + lens_and_setter = lens[s], _setter(lens, s) + else: + name, index = s.rstrip("]").split("[") + index = int(index) + if lens.get(name) is None: + lens[name] = [] + lens = lens[name] + if (short_size := index + 1 - len(lens)) > 0: + lens.extend([None] * short_size) + lens[index] = {} + lens_and_setter = lens[index], _setter(lens, index) + _, setter = lens_and_setter + if setter is not None: + setter(v) + + +def load_config_from_file(configs, config_file): + from ruamel import yaml + + if config_file is not None: + configs.update(yaml.safe_load(config_file)) + return configs + + +def load_config_from_entrypoint(configs, config_entrypoint): + import requests + + if config_entrypoint is not None: + try: + resp = requests.get(config_entrypoint).json() + configs.update(resp["config"]) + except: + pass + return configs + + +def load_config_from_env(configs, env_name): + import os + + from ruamel import yaml + + if env_name is not None and os.environ.get(env_name): + configs.update(yaml.safe_load(os.environ[env_name])) + return configs diff --git a/python/fate/components/loader/artifact.py b/python/fate/components/loader/artifact.py deleted file mode 100644 index c96504ecc1..0000000000 --- a/python/fate/components/loader/artifact.py +++ /dev/null @@ -1,41 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -def load_artifact(data, artifact_type): - from fate.components import ( - Artifact, - Artifacts, - DatasetArtifact, - DatasetArtifacts, - MetricArtifact, - ModelArtifact, - ModelArtifacts, - ) - - if hasattr(artifact_type, "__origin__"): - artifact_type = artifact_type.__origin__ - if isinstance(data, list): - if artifact_type.__origin__ == DatasetArtifacts: - return DatasetArtifacts([DatasetArtifact(name=d.name, uri=d.uri, metadata=d.metadata) for d in data]) - if artifact_type == ModelArtifacts: - return ModelArtifacts([ModelArtifact(name=d.name, uri=d.uri, metadata=d.metadata) for d in data]) - return Artifacts([Artifact(name=d.name, uri=d.uri, metadata=d.metadata) for d in data]) - else: - if artifact_type == DatasetArtifact: - return DatasetArtifact(name=data.name, uri=data.uri, metadata=data.metadata) - if artifact_type == ModelArtifact: - return ModelArtifact(name=data.name, uri=data.uri, metadata=data.metadata) - if artifact_type == MetricArtifact: - return MetricArtifact(name=data.name, uri=data.uri, metadata=data.metadata) - return Artifact(name=data.name, uri=data.uri, metadata=data.metadata) diff --git a/python/fate/components/loader/component.py b/python/fate/components/loader/component.py deleted file mode 100644 index 0ec43bc606..0000000000 --- a/python/fate/components/loader/component.py +++ /dev/null @@ -1,63 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging - -logger = logging.getLogger(__name__) - - -def load_component(cpn_name: str): - from fate.components.components import BUILDIN_COMPONENTS - from fate.components.cpn import _Component - - # from buildin - for cpn in BUILDIN_COMPONENTS: - if cpn.name == cpn_name: - return cpn - - # from entrypoint - import pkg_resources - - for cpn_ep in pkg_resources.iter_entry_points(group="fate.ext.component"): - try: - candidate_cpn: _Component = cpn_ep.load() - candidate_cpn_name = candidate_cpn.name - except Exception as e: - logger.warning( - f"register cpn from entrypoint(named={cpn_ep.name}, module={cpn_ep.module_name}) failed: {e}" - ) - continue - if candidate_cpn_name == cpn_name: - return candidate_cpn - raise RuntimeError(f"could not find registerd cpn named `{cpn_name}`") - - -def list_components(): - import pkg_resources - from fate.components.components import BUILDIN_COMPONENTS - - buildin_components = [c.name for c in BUILDIN_COMPONENTS] - third_parties_components = [] - - for cpn_ep in pkg_resources.iter_entry_points(group="fate.ext.component"): - try: - candidate_cpn = cpn_ep.load() - candidate_cpn_name = candidate_cpn.name - third_parties_components.append([candidate_cpn_name]) - except Exception as e: - logger.warning( - f"register cpn from entrypoint(named={cpn_ep.name}, module={cpn_ep.module_name}) failed: {e}" - ) - continue - return dict(buildin=buildin_components, thirdparty=third_parties_components) diff --git a/python/fate/components/loader/metric.py b/python/fate/components/loader/metric.py deleted file mode 100644 index 067b03ca02..0000000000 --- a/python/fate/components/loader/metric.py +++ /dev/null @@ -1,40 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Union - -from fate.interface import InCompleteMetrics, Metrics, MetricsHandler - - -def load_metrics_handler(): - return ComponentMetricsHandler() - - -class ComponentMetricsHandler(MetricsHandler): - """ - this implement use ctx.writer(artifact).write_metric() as metric output sink - """ - - def __init__(self) -> None: - self._metric_handlers = {} - - def register_metrics(self, **kwargs): - for name, handler in kwargs.items(): - self._metric_handlers[name] = handler - - def log_metrics(self, metrics: Union[Metrics, InCompleteMetrics]): - if metrics.name not in self._metric_handlers: - raise ValueError(f"metric named `{metrics.name}` not registered") - handler = self._metric_handlers[metrics.name] - handler.write_metric(metrics) diff --git a/python/fate/components/loader/mlmd/__init__.py b/python/fate/components/loader/mlmd/__init__.py deleted file mode 100644 index 73e0f0abe5..0000000000 --- a/python/fate/components/loader/mlmd/__init__.py +++ /dev/null @@ -1,62 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging - -from fate.components.spec.mlmd import ( - CustomMLMDSpec, - FlowMLMDSpec, - NoopMLMDSpec, - PipelineMLMDSpec, -) - -from .protocol import MLMD - -logger = logging.getLogger(__name__) - - -def load_mlmd(mlmd, taskid) -> MLMD: - # from buildin - if isinstance(mlmd, PipelineMLMDSpec): - from .pipeline import PipelineMLMD - - return PipelineMLMD(mlmd, taskid) - - if isinstance(mlmd, FlowMLMDSpec): - from .flow import FlowMLMD - - return FlowMLMD(mlmd, taskid) - - if isinstance(mlmd, NoopMLMDSpec): - from .noop import NoopMLMD - - return NoopMLMD(mlmd, taskid) - # from entrypoint - if isinstance(mlmd, CustomMLMDSpec): - import pkg_resources - - for mlmd_ep in pkg_resources.iter_entry_points(group="fate.ext.mlmd"): - try: - mlmd_register = mlmd_ep.load() - mlmd_registered_name = mlmd_register.registered_name() - except Exception as e: - logger.warning( - f"register cpn from entrypoint(named={mlmd_ep.name}, module={mlmd_ep.module_name}) failed: {e}" - ) - continue - if mlmd_registered_name == mlmd.name: - return mlmd_register - raise RuntimeError(f"could not find registerd mlmd named `{mlmd.name}`") - - raise ValueError(f"unknown mlmd spec: `{mlmd}`") diff --git a/python/fate/components/loader/mlmd/flow.py b/python/fate/components/loader/mlmd/flow.py deleted file mode 100644 index 7408d6802e..0000000000 --- a/python/fate/components/loader/mlmd/flow.py +++ /dev/null @@ -1,162 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from typing import Optional - -from fate.components.loader.mlmd import MLMD -from fate.components.loader.mlmd.protocol import IOManagerProtocol -from fate.components.spec.mlmd import FlowMLMDSpec -from pydantic import BaseModel - - -class ExecutionStatus: - class StateData(BaseModel): - execution_id: str - status: Optional[str] - error: Optional[str] - - class Status: - WAITING = 'waiting' - READY = 'ready' - RUNNING = "running" - CANCELED = "canceled" - TIMEOUT = "timeout" - FAILED = "failed" - PASS = "pass" - SUCCESS = "success" - - class EndStatus: - CANCELED = "canceled" - TIMEOUT = "timeout" - FAILED = "failed" - PASS = "pass" - SUCCESS = "success" - - @classmethod - def status_list(cls): - return [cls.__dict__[k] for k in cls.__dict__.keys() if - not callable(getattr(cls, k)) and not k.startswith("__")] - - def __init__(self, mlmd: FlowMLMDSpec, taskid) -> None: - self._mlmd = mlmd - self._taskid = taskid - - def log_excution_start(self): - return self._log_state(self.Status.RUNNING) - - def log_excution_end(self): - return self._log_state(self.Status.SUCCESS) - - def log_excution_exception(self, message: dict): - return self._log_state(self.Status.FAILED, message) - - def _log_state(self, state, message=None): - error = "" - if message: - error = message.get("exception") - import requests - - logging.info(self._mlmd.metadata.statu_uri) - data = self.StateData(execution_id=self._taskid, status=state, error=error).dict() - logging.debug(f"request flow uri: {self._mlmd.metadata.statu_uri}") - response = requests.post(self._mlmd.metadata.statu_uri, json=data) - logging.debug(f"response: {response.text}") - - def _get_state(self): - import requests - logging.info(self._mlmd.metadata.statu_uri) - data = self.StateData(execution_id=self._taskid).dict() - logging.debug(f"wzh test request flow uri: {self._mlmd.metadata.statu_uri}") - response = requests.get(self._mlmd.metadata.statu_uri, params=data) - logging.debug(f"response: {response.text}") - status = False - try: - task_status = response.json().get("data").get("status") - if task_status in ExecutionStatus.EndStatus.status_list(): - status = True - except Exception as e: - logging.exception(e) - status = True - return status - - def safe_terminate(self): - return self._get_state() - - -class IOManager(IOManagerProtocol): - def __init__(self, mlmd, task_id): - self.mlmd = mlmd - self.task_id = task_id - - def log_output_artifact(self, key, value): - if value is None: - return - from fate.components import DatasetArtifact, MetricArtifact, ModelArtifact - - if isinstance(value, DatasetArtifact): - self.log_output_data(key, value) - elif isinstance(value, ModelArtifact): - self.log_output_model(key, value) - elif isinstance(value, MetricArtifact): - self.log_output_metric(key, value) - else: - raise RuntimeError(f"not supported input artifact `name={key}, value={value}`") - - def log_output_data(self, key, value): - import requests - - logging.debug(f"request flow uri: {self.mlmd.metadata.tracking_uri}") - response = requests.post( - self.mlmd.metadata.tracking_uri, - json={ - "output_key": value.name, - "meta_data": value.metadata, - "execution_id": self.task_id, - "uri": value.uri, - "type": "data", - }, - ) - logging.debug(f"response: {response.text}") - - def log_output_model(self, key, value, metadata={}): - import requests - - data = { - "output_key": value.name, - "meta_data": value.metadata, - "execution_id": self.task_id, - "uri": value.uri, - "type": "model", - } - logging.debug(f"request flow uri: {self.mlmd.metadata.tracking_uri}, data: {data}") - response = requests.post( - self.mlmd.metadata.tracking_uri, - json=data, - ) - logging.debug(response.text) - logging.debug(value) - - def log_output_metric(self, key, value): - logging.debug(value) - - def safe_terminate(self): - pass - - -class FlowMLMD(MLMD): - def __init__(self, mlmd: FlowMLMDSpec, taskid) -> None: - self._taskid = taskid - self.execution_status = ExecutionStatus(mlmd, taskid) - self.io = IOManager(mlmd=mlmd, task_id=taskid) diff --git a/python/fate/components/loader/mlmd/noop.py b/python/fate/components/loader/mlmd/noop.py deleted file mode 100644 index c80dcf8b58..0000000000 --- a/python/fate/components/loader/mlmd/noop.py +++ /dev/null @@ -1,75 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from fate.components.spec.mlmd import NoopMLMDSpec - -from .protocol import MLMD -from .protocol import ExecutionStatus as ExecutionStatusProtocol -from .protocol import IOManagerProtocol - - -class NoopMLMD(MLMD): - def __init__(self, mlmd: NoopMLMDSpec, taskid) -> None: - self._taskid = taskid - self.execution_status = ExecutionStatus() - self.io = IOManager() - - -class IOManager(IOManagerProtocol): - def __init__(self) -> None: - ... - - def log_input_artifact(self, key, value): - print(f"log input artifact: {key}, {value}") - - def log_output_artifact(self, key, value): - print(f"log output artifact: {key}, {value}") - - def log_input_parameter(self, key, value): - print(f"log input parameter: {key}, {value}") - - def log_input_data(self, key, value): - print(f"log input data: {key}, {value}") - - def log_input_model(self, key, value): - print(f"log input model: {key}, {value}") - - def log_input_metric(self, key, value): - print(f"log input metric: {key}, {value}") - - def log_output_data(self, key, value): - print(f"log output data: {key}, {value}") - - def log_output_model(self, key, value, metadata={}): - print(f"log output model: {key}, {value}, {metadata}") - - def log_output_metric(self, key, value): - print(f"log output metric: {key}, {value}") - - -class ExecutionStatus(ExecutionStatusProtocol): - def __init__(self) -> None: - ... - - def log_excution_start(self): - print(f"running") - - def log_excution_end(self): - print(f"end") - - def log_excution_exception(self, message: dict): - print(f"exception: {message}") - - def safe_terminate(self): - return True diff --git a/python/fate/components/loader/mlmd/pipeline.py b/python/fate/components/loader/mlmd/pipeline.py deleted file mode 100644 index 982e2aee58..0000000000 --- a/python/fate/components/loader/mlmd/pipeline.py +++ /dev/null @@ -1,139 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from fate.arch.context._mlmd import MachineLearningMetadata -from fate.components.spec.mlmd import PipelineMLMDSpec - -from .protocol import MLMD -from .protocol import ExecutionStatus as ExecutionStatusProtocol -from .protocol import IOManagerProtocol - - -class PipelineMLMD(MLMD): - def __init__(self, mlmd: PipelineMLMDSpec, taskid) -> None: - self._mlmd = MachineLearningMetadata(metadata=dict(filename_uri=mlmd.metadata.db)) - self._taskid = taskid - self.execution_status = ExecutionStatus(self._mlmd, self._taskid) - self.io = IOManager(self._mlmd, self._taskid) - - -class IOManager(IOManagerProtocol): - def __init__(self, mlmd: MachineLearningMetadata, taskid) -> None: - self._mlmd = mlmd - self._taskid = taskid - - def log_input_artifact(self, key, value): - if value is None: - return - from fate.components import DatasetArtifact, MetricArtifact, ModelArtifact - - if isinstance(value, DatasetArtifact): - self.log_input_data(key, value) - elif isinstance(value, ModelArtifact): - self.log_input_model(key, value) - elif isinstance(value, MetricArtifact): - self.log_input_metric(key, value) - else: - raise RuntimeError(f"not supported input artifact `name={key}, value={value}`") - - def log_output_artifact(self, key, value): - if value is None: - return - from fate.components import DatasetArtifact, MetricArtifact, ModelArtifact - - if isinstance(value, DatasetArtifact): - self.log_output_data(key, value) - elif isinstance(value, ModelArtifact): - self.log_output_model(key, value) - elif isinstance(value, MetricArtifact): - self.log_output_metric(key, value) - else: - raise RuntimeError(f"not supported input artifact `name={key}, value={value}`") - - def log_input_parameter(self, key, value): - artifact_id = self._mlmd.add_parameter(name=key, value=value) - execution_id = self._mlmd.get_or_create_task(self._taskid).id - self._mlmd.record_input_event(execution_id=execution_id, artifact_id=artifact_id) - self._mlmd.put_artifact_to_task_context(self._taskid, artifact_id) - - def log_input_data(self, key, value): - artifact_id = self._mlmd.add_data_artifact( - name=value.name, uri=value.uri, metadata=value.metadata, is_input=True - ) - execution_id = self._mlmd.get_or_create_task(self._taskid).id - self._mlmd.record_input_event(execution_id=execution_id, artifact_id=artifact_id) - self._mlmd.put_artifact_to_task_context(self._taskid, artifact_id) - - def log_input_model(self, key, value): - artifact_id = self._mlmd.add_model_artifact( - name=value.name, uri=value.uri, metadata=value.metadata, is_input=True - ) - execution_id = self._mlmd.get_or_create_task(self._taskid).id - self._mlmd.record_input_event(execution_id=execution_id, artifact_id=artifact_id) - self._mlmd.put_artifact_to_task_context(self._taskid, artifact_id) - - def log_input_metric(self, key, value): - artifact_id = self._mlmd.add_metric_artifact( - name=value.name, uri=value.uri, metadata=value.metadata, is_input=True - ) - execution_id = self._mlmd.get_or_create_task(self._taskid).id - self._mlmd.record_input_event(execution_id=execution_id, artifact_id=artifact_id) - self._mlmd.put_artifact_to_task_context(self._taskid, artifact_id) - - def log_output_data(self, key, value): - artifact_id = self._mlmd.add_data_artifact( - name=value.name, uri=value.uri, metadata=value.metadata, is_input=False - ) - execution_id = self._mlmd.get_or_create_task(self._taskid).id - self._mlmd.record_output_event(execution_id=execution_id, artifact_id=artifact_id) - self._mlmd.put_artifact_to_task_context(self._taskid, artifact_id) - - def log_output_model(self, key, value, metadata={}): - artifact_id = self._mlmd.add_model_artifact( - name=value.name, uri=value.uri, metadata=value.metadata, is_input=False - ) - execution_id = self._mlmd.get_or_create_task(self._taskid).id - self._mlmd.record_output_event(execution_id=execution_id, artifact_id=artifact_id) - self._mlmd.put_artifact_to_task_context(self._taskid, artifact_id) - - def log_output_metric(self, key, value): - artifact_id = self._mlmd.add_metric_artifact( - name=value.name, uri=value.uri, metadata=value.metadata, is_input=False - ) - execution_id = self._mlmd.get_or_create_task(self._taskid).id - self._mlmd.record_output_event(execution_id=execution_id, artifact_id=artifact_id) - self._mlmd.put_artifact_to_task_context(self._taskid, artifact_id) - - -class ExecutionStatus(ExecutionStatusProtocol): - def __init__(self, mlmd: MachineLearningMetadata, taskid) -> None: - self._mlmd = mlmd - self._taskid = taskid - - def log_excution_start(self): - return self._log_state("running") - - def log_excution_end(self): - return self._log_state("finish") - - def log_excution_exception(self, message: dict): - import json - - self._log_state("exception", json.dumps(message)) - - def _log_state(self, state, message=None): - self._mlmd.update_task_state(self._taskid, state, message) - - def safe_terminate(self): - return self._mlmd.get_task_safe_terminate_flag(self._taskid) diff --git a/python/fate/components/loader/mlmd/protocol.py b/python/fate/components/loader/mlmd/protocol.py deleted file mode 100644 index 13f90cf549..0000000000 --- a/python/fate/components/loader/mlmd/protocol.py +++ /dev/null @@ -1,54 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Protocol - - -class ExecutionStatus(Protocol): - def log_excution_start(self): - ... - - def log_excution_end(self): - ... - - def log_excution_exception(self, message: dict): - ... - - def safe_terminate(self): - ... - - -class IOManagerProtocol: - def log_input_parameter(self, key, value): - ... - - def log_input_artifact(self, key, value): - ... - - def log_output_artifact(self, key, value): - ... - - def log_output_data(self, key, value): - ... - - def log_output_model(self, key, value, metadata={}): - ... - - def log_output_metric(self, key, value): - ... - - -class MLMD(Protocol): - execution_status: ExecutionStatus - io: IOManagerProtocol diff --git a/python/fate/components/loader/model.py b/python/fate/components/loader/model.py deleted file mode 100644 index 2b398451b2..0000000000 --- a/python/fate/components/loader/model.py +++ /dev/null @@ -1,292 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -import logging -import tarfile -import tempfile -from datetime import datetime - -from ruamel import yaml - -from fate.components.spec.model import ( - MLModelComponentSpec, - MLModelFederatedSpec, - MLModelModelSpec, - MLModelPartiesSpec, - MLModelPartySpec, - MLModelSpec, -) - - -def load_output_model_wrapper(task_id, party_task_id, cpn, role, partyid, federation): - return ComponentModelWriterWrapper(cpn, federation, task_id, party_task_id, role, partyid) - - -def load_input_model_wrapper(): - return ComponentModelLoaderWrapper() - - -_MODEL_META_NAME = "FMLModel.yaml" - - -class ComponentModelWriterWrapper: - def __init__(self, cpn, federation, task_id, party_task_id, role, party_id) -> None: - self.task_id = task_id - self.party_task_id = party_task_id - self.role = role - self.party_id = party_id - self.cpn_spec = MLModelComponentSpec(name=cpn.name, provider=cpn.provider, version=cpn.version, metadata={}) - guest = [] - host = [] - arbiter = [] - for party in federation.metadata.parties.parties: - if party.role == "guest": - guest.append(party.partyid) - if party.role == "host": - host.append(party.partyid) - if party.role == "arbiter": - arbiter.append(party.partyid) - self.parties_spec = MLModelPartiesSpec(guest=guest, host=host, arbiter=arbiter) - - def wrap(self, artifact, io_mlmd): - return ComponentModelWriter(self, artifact, io_mlmd) - - -class ComponentModelLoaderWrapper: - def wrap(self, artifact, io_mlmd): - return ComponentModelLoader(artifact, io_mlmd) - - -class ModelTarWriteHandler: - def __init__(self, tar) -> None: - self.tar = tar - - def add_model(self, name, model, file_format): - if file_format == "json": - self.add_json_model(name, model) - else: - raise NotImplementedError(f"file_format={file_format} not support") - - def add_json_model(self, name, model): - with tempfile.NamedTemporaryFile("w") as f: - json.dump(model, f) - f.flush() - self.tar.add(f.name, name) - - def add_meta(self, meta): - with tempfile.NamedTemporaryFile("w") as f: - yaml.safe_dump(meta, f) - f.flush() - self.tar.add(f.name, _MODEL_META_NAME) - - -class FileModelTarWriteHandler(ModelTarWriteHandler): - def __init__(self, uri) -> None: - super().__init__(tarfile.open(uri.path, "w")) - - def close(self): - self.tar.close() - - def mlmd_send(self, mlmd, artifact, metadata): - mlmd.log_output_model(artifact.name, artifact, metadata=metadata) - - -class HttpModelTarWriteTarHandler(ModelTarWriteHandler): - def __init__(self, uri) -> None: - self.uri = uri - import io - - self.memory_file = io.BytesIO() - super().__init__(tarfile.open(fileobj=self.memory_file, mode="w")) - - def close(self): - self.tar.close() - - def mlmd_send(self, mlmd, artifact, metadata): - import requests - - logging.info(f"mlmd send uri: {self.uri.to_string()}") - self.memory_file.seek(0) - response = requests.post(url=self.uri.to_string(), files={"file": self.memory_file}) - logging.info(f"response: {response.text}") - mlmd.log_output_model(artifact.name, artifact, metadata=metadata) - - -class ComponentModelWriter: - def __init__(self, info: ComponentModelWriterWrapper, artifact, mlmd) -> None: - self.info = info - self.models = [] - - from fate.arch.unify import URI - - self.artifact = artifact - self.uri = URI.from_string(artifact.uri).to_schema() - self.mlmd = mlmd - - self._tar = None - - def __enter__(self): - from fate.arch.unify import FileURI, HttpsURI, HttpURI - - if isinstance(self.uri, FileURI): - self._tar = FileModelTarWriteHandler(self.uri) - elif isinstance(self.uri, (HttpURI, HttpsURI)): - self._tar = HttpModelTarWriteTarHandler(self.uri) - else: - raise NotImplementedError(f"model writer not support uri: {self.uri}") - return self - - def __exit__(self, type, value, trace): - self._write_meta() - self._get_tar().mlmd_send(self.mlmd, self.artifact, self._get_meta().dict()) - self._get_tar().close() - - def _get_tar(self): - if self._tar is None: - raise ValueError(f"should open first") - return self._tar - - def _get_meta(self): - return MLModelSpec( - federated=MLModelFederatedSpec( - task_id=self.info.task_id, parties=self.info.parties_spec, component=self.info.cpn_spec - ), - party=MLModelPartySpec( - party_task_id=self.info.party_task_id, - role=self.info.role, - partyid=self.info.party_id, - models=self.models, - ), - ) - - def _write_meta(self): - self._get_tar().add_meta(self._get_meta().dict()) - - def write_model(self, name, model, metadata, file_format="json", created_time=None): - if created_time is None: - created_time = datetime.now() - self._get_tar().add_model(name, model, file_format=file_format) - self.models.append( - MLModelModelSpec(name=name, created_time=created_time, file_format=file_format, metadata=metadata) - ) - - -class ModelTarReadHandler: - def __init__(self, tar) -> None: - self.tar = tar - self.meta = None - - def add_model(self, name, model): - with tempfile.NamedTemporaryFile("w") as f: - json.dump(model, f) - f.flush() - self.tar.add(f.name, name) - - def add_meta(self, meta): - with tempfile.NamedTemporaryFile("w") as f: - yaml.safe_dump(meta, f) - f.flush() - self.tar.add(f.name, _MODEL_META_NAME) - - def get_meta(self): - if self.meta is None: - with tempfile.TemporaryDirectory() as d: - path = f"{d}/{_MODEL_META_NAME}" - self.tar.extract(_MODEL_META_NAME, d) - with open(path, "r") as f: - meta = yaml.safe_load(f) - - self.meta = MLModelSpec.parse_obj(meta) - return self.meta - - def read_model(self, **kwargs): - # return first for now, TODO: extend this - model_info = self.get_meta().party.models[0] - model_name = model_info.name - file_format = model_info.file_format - if file_format == "json": - return self.read_json_model(model_name) - else: - raise NotImplementedError(f"file_format={file_format} not supported") - - def read_json_model(self, model_name): - with tempfile.TemporaryDirectory() as d: - path = f"{d}/{model_name}" - self.tar.extract(model_name, d) - with open(path, "r") as f: - return json.load(f) - - -class FileModelTarReadHandler(ModelTarReadHandler): - def __init__(self, uri) -> None: - super().__init__(tarfile.open(uri.path, "r")) - - def close(self): - self.tar.close() - - -class HttpModelTarReadTarHandler(ModelTarReadHandler): - def __init__(self, uri) -> None: - import io - from contextlib import closing - - import requests - - memory_file = io.BytesIO() - logging.debug(f"read model from: {uri.to_string()}") - with closing(requests.get(url=uri.to_string(), stream=True)) as response: - for chunk in response.iter_content(1024): - if chunk: - memory_file.write(chunk) - memory_file.seek(0) - tar = tarfile.open(fileobj=memory_file, mode="r") - logging.debug(f"read model success") - super().__init__(tar) - - def close(self): - self.tar.close() - - -class ComponentModelLoader: - def __init__(self, artifact, mlmd) -> None: - self.artifact = artifact - from fate.arch.unify import URI - - self.uri = URI.from_string(artifact.uri).to_schema() - self.mlmd = mlmd - self._tar = None - self._meta = None - - def __enter__(self): - from fate.arch.unify import FileURI, HttpsURI, HttpURI - - if isinstance(self.uri, FileURI): - self._tar = FileModelTarReadHandler(self.uri) - elif isinstance(self.uri, (HttpURI, HttpsURI)): - self._tar = HttpModelTarReadTarHandler(self.uri) - else: - raise NotImplementedError(f"model writer not support uri: {self.uri}") - return self - - def __exit__(self, type, value, trace): - self._get_tar().close() - - def _get_tar(self): - if self._tar is None: - raise ValueError(f"should open first") - return self._tar - - def read_model(self, **kwargs): - return self._get_tar().read_model(**kwargs) diff --git a/python/fate/components/loader/output.py b/python/fate/components/loader/output.py deleted file mode 100644 index 4c5c99660d..0000000000 --- a/python/fate/components/loader/output.py +++ /dev/null @@ -1,122 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import uuid - - -class OutputPool: - def __init__(self, data, model, metric) -> None: - self.data = data - self.model = model - self.metric = metric - - def create_data_artifact(self, name: str): - return self.data.create_artifact(name) - - def create_model_artifact(self, name: str): - return self.model.create_artifact(name) - - def create_metric_artifact(self, name: str): - return self.metric.create_artifact(name) - - -def load_pool(output_pool_conf): - data = _load_data_pool(output_pool_conf.data) - model = _load_model_pool(output_pool_conf.model) - metric = _load_metric_pool(output_pool_conf.metric) - return OutputPool(data, model, metric) - - -def _load_data_pool(data_pool): - from fate.arch.unify import URI - from fate.components.spec.output import DirectoryDataPool - - if isinstance(data_pool, DirectoryDataPool): - return DataPool( - base_uri=URI.from_string(data_pool.metadata.uri).to_schema(), - format=data_pool.metadata.format, - name_template=data_pool.metadata.name_template, - ) - raise RuntimeError(f"load data pool failed: {data_pool}") - - -def _load_model_pool(model_pool): - from fate.arch.unify import URI - from fate.components.spec.output import DirectoryModelPool - - if isinstance(model_pool, DirectoryModelPool): - return ModelPool( - base_uri=URI.from_string(model_pool.metadata.uri).to_schema(), - format=model_pool.metadata.format, - name_template=model_pool.metadata.name_template, - ) - raise RuntimeError(f"load data pool failed: {model_pool}") - - -def _load_metric_pool(metric_pool): - from fate.arch.unify import URI - from fate.components.spec.output import DirectoryMetricPool - - if isinstance(metric_pool, DirectoryMetricPool): - return MetricPool( - base_uri=URI.from_string(metric_pool.metadata.uri).to_schema(), - format=metric_pool.metadata.format, - name_template=metric_pool.metadata.name_template, - ) - raise RuntimeError(f"load data pool failed: {metric_pool}") - - -class DataPool: - def __init__(self, base_uri, format, name_template) -> None: - self.format = format - self.base_uri = base_uri - self.name_template = name_template - - def create_artifact(self, name): - from fate.components import DatasetArtifact - - file_name = self.name_template.format(name=name, uuid=uuid.uuid1()) - uri = self.base_uri.create_file(file_name) - metadata = dict(format=self.format) - return DatasetArtifact(name=name, uri=uri.to_string(), metadata=metadata) - - -class ModelPool: - def __init__(self, base_uri, format, name_template) -> None: - self.format = format - self.base_uri = base_uri - self.name_template = name_template - - def create_artifact(self, name): - from fate.components import ModelArtifact - - file_name = self.name_template.format(name=name, uuid=uuid.uuid1()) - uri = self.base_uri.create_file(file_name) - metadata = dict(format=self.format) - return ModelArtifact(name=name, uri=uri.to_string(), metadata=metadata) - - -class MetricPool: - def __init__(self, base_uri, format, name_template) -> None: - self.format = format - self.base_uri = base_uri - self.name_template = name_template - - def create_artifact(self, name): - from fate.components import MetricArtifact - - file_name = self.name_template.format(name=name, uuid=uuid.uuid1()) - uri = self.base_uri.create_file(file_name) - metadata = dict(format=self.format) - return MetricArtifact(name=name, uri=uri.to_string(), metadata=metadata) diff --git a/python/fate/components/params/__init__.py b/python/fate/components/params/__init__.py deleted file mode 100644 index 7e5bc9d0fb..0000000000 --- a/python/fate/components/params/__init__.py +++ /dev/null @@ -1,76 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import pydantic - - -class Parameter: - def parse(self, obj): - raise NotImplementedError() - - def dict(self): - raise NotImplementedError() - - -class ConInt(Parameter): - def __init__(self, gt: int = None, ge: int = None, lt: int = None, le: int = None) -> None: - self.gt = gt - self.ge = ge - self.lt = lt - self.le = le - - def parse(self, obj): - return pydantic.parse_obj_as(pydantic.conint(gt=self.gt, ge=self.ge, lt=self.lt, le=self.le), obj) - - def dict(self): - meta = {} - if self.gt is not None: - meta["gt"] = self.gt - if self.ge is not None: - meta["ge"] = self.ge - if self.lt is not None: - meta["lt"] = self.lt - if self.le is not None: - meta["le"] = self.le - return meta - - -class ConFloat(Parameter): - def __init__(self, gt: float = None, ge: float = None, lt: float = None, le: float = None) -> None: - self.gt = gt - self.ge = ge - self.lt = lt - self.le = le - - def parse(self, obj): - return pydantic.parse_obj_as(pydantic.confloat(gt=self.gt, ge=self.ge, lt=self.lt, le=self.le), obj) - - def dict(self): - meta = {} - if self.gt is not None: - meta["gt"] = self.gt - if self.ge is not None: - meta["ge"] = self.ge - if self.lt is not None: - meta["lt"] = self.lt - if self.le is not None: - meta["le"] = self.le - return meta - - -def parse(parameter_type, obj): - if isinstance(parameter_type, Parameter): - return parameter_type.parse(obj) - else: - return pydantic.parse_obj_as(parameter_type, obj) diff --git a/python/fate/components/spec/component.py b/python/fate/components/spec/component.py deleted file mode 100644 index ff8a8ef677..0000000000 --- a/python/fate/components/spec/component.py +++ /dev/null @@ -1,59 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Dict, List, Optional - -from fate.components import T_LABEL, T_ROLE, T_STAGE -from pydantic import BaseModel - - -class ParameterSpec(BaseModel): - type: str - default: Any - optional: bool - description: str = "" - type_meta: dict = {} - - -class ArtifactSpec(BaseModel): - type: str - optional: bool - stages: Optional[List[T_STAGE]] - roles: List[T_ROLE] - description: str = "" - - -class InputDefinitionsSpec(BaseModel): - parameters: Dict[str, ParameterSpec] - artifacts: Dict[str, ArtifactSpec] - - -class OutputDefinitionsSpec(BaseModel): - artifacts: Dict[str, ArtifactSpec] - - -class ComponentSpec(BaseModel): - name: str - description: str - provider: str - version: str - labels: List[T_LABEL] - roles: List[T_ROLE] - input_definitions: InputDefinitionsSpec - output_definitions: OutputDefinitionsSpec - - -class ComponentSpecV1(BaseModel): - component: ComponentSpec - schema_version: str = "v1" diff --git a/python/fate/components/spec/logger.py b/python/fate/components/spec/logger.py deleted file mode 100644 index 7a45fb7e63..0000000000 --- a/python/fate/components/spec/logger.py +++ /dev/null @@ -1,221 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import logging.config -import pathlib -from typing import Literal - -import pydantic - - -class PipelineLogger(pydantic.BaseModel): - class PipelineLoggerMetadata(pydantic.BaseModel): - basepath: pydantic.DirectoryPath - level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] - debug_mode: bool = False - - @pydantic.validator("basepath", pre=True) - def create_basepath(cls, value): - pathlib.Path(value).mkdir(parents=True, exist_ok=True) - return value - - type: Literal["pipeline"] - metadata: PipelineLoggerMetadata - - def install(self): - self.metadata.basepath.mkdir(parents=True, exist_ok=True) - levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] - formatters = {"brief": {"format": "'%(asctime)s %(levelname)-8s %(name)s:%(lineno)s %(message)s'"}} - handlers = {} - filters = {} - - def add_file_handler( - name, - filename, - level, - formater="brief", - filters=[], - ): - handlers[name] = { - "class": "logging.FileHandler", - "level": level, - "formatter": formater, - "filters": filters, - "filename": filename - } - - # add root logger - root_handlers = [] - root_base_path = self.metadata.basepath.joinpath("root") - root_base_path.mkdir(parents=True, exist_ok=True) - for level in levels: - handler_name = f"root_{level.lower()}" - add_file_handler( - name=handler_name, - filename=root_base_path.joinpath(level), - level=level, - ) - root_handlers.append(handler_name) - - # add console logger - if self.metadata.debug_mode: - handler_name = f"root_console_{self.metadata.level.lower()}" - handlers[handler_name] = { - # "class": "logging.StreamHandler", - "class": "rich.logging.RichHandler", - # "formatter": "brief", - "level": self.metadata.level, - "filters": [], - # "stream": "ext://sys.stdout", - } - root_handlers.append(handler_name) - - # add component logger - component_handlers = [] - component_base_path = self.metadata.basepath.joinpath("component") - component_base_path.mkdir(parents=True, exist_ok=True) - filters["components"] = {"name": "fate.components"} - filters["ml"] = {"name": "fate.ml"} - for level in levels: - handler_name = f"component_{level.lower()}" - add_file_handler( - name=handler_name, - filename=component_base_path.joinpath(level), - level=level, - ) - component_handlers.append(handler_name) - component_loggers = { - "fate.components": dict( - handlers=component_handlers, - filters=["components"], - level=self.metadata.level, - ), - "fate.ml": dict( - handlers=component_handlers, - filters=["ml"], - level=self.metadata.level, - ), - } - - logging.config.dictConfig( - dict( - version=1, - formatters=formatters, - handlers=handlers, - filters=filters, - loggers=component_loggers, - root=dict(handlers=root_handlers, level=self.metadata.level), - disable_existing_loggers=False, - ) - ) - - -class FlowLogger(pydantic.BaseModel): - class FlowLoggerMetadata(pydantic.BaseModel): - basepath: pydantic.DirectoryPath - level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] - - @pydantic.validator("basepath", pre=True) - def create_basepath(cls, value): - pathlib.Path(value).mkdir(parents=True, exist_ok=True) - return value - - type: Literal["flow"] - metadata: FlowLoggerMetadata - - def install(self): - self.metadata.basepath.mkdir(parents=True, exist_ok=True) - levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] - formatters = {"brief": {"format": "'%(asctime)s %(levelname)-8s %(name)s:%(lineno)s %(message)s'"}} - handlers = {} - filters = {} - - def add_file_handler( - name, - filename, - level, - formater="brief", - filters=[] - ): - handlers[name] = { - "class": "logging.FileHandler", - "level": level, - "formatter": formater, - "filters": filters, - "filename": filename - } - - # add root logger - root_handlers = [] - root_base_path = self.metadata.basepath.joinpath("root") - root_base_path.mkdir(parents=True, exist_ok=True) - for level in levels: - handler_name = f"root_{level.lower()}" - add_file_handler( - name=handler_name, - filename=root_base_path.joinpath(level), - level=level, - ) - root_handlers.append(handler_name) - - # add component logger - component_handlers = [] - component_base_path = self.metadata.basepath.joinpath("component") - component_base_path.mkdir(parents=True, exist_ok=True) - filters["components"] = {"name": "fate.components"} - filters["ml"] = {"name": "fate.ml"} - for level in levels: - handler_name = f"component_{level.lower()}" - add_file_handler( - name=handler_name, - filename=component_base_path.joinpath(level), - level=level, - ) - component_handlers.append(handler_name) - component_loggers = { - "fate.components": dict( - handlers=component_handlers, - filters=["components"], - level=self.metadata.level, - ), - "fate.ml": dict( - handlers=component_handlers, - filters=["ml"], - level=self.metadata.level, - ), - } - - logging.config.dictConfig( - dict( - version=1, - formatters=formatters, - handlers=handlers, - filters=filters, - loggers=component_loggers, - root=dict(handlers=root_handlers, level=self.metadata.level), - disable_existing_loggers=False, - ) - ) - - -class CustomLogger(pydantic.BaseModel): - class CustomLoggerMetadata(pydantic.BaseModel): - config_dict: dict - - type: Literal["custom"] - metadata: CustomLoggerMetadata - - def install(self): - logging.config.dictConfig(self.metadata.config_dict) diff --git a/python/fate/components/spec/mlmd.py b/python/fate/components/spec/mlmd.py deleted file mode 100644 index da4a09e69e..0000000000 --- a/python/fate/components/spec/mlmd.py +++ /dev/null @@ -1,47 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Literal - -from pydantic import BaseModel - - -class PipelineMLMDSpec(BaseModel): - class PipelineMLMDMetaData(BaseModel): - db: str - - type: Literal["pipeline"] - metadata: PipelineMLMDMetaData - - -class FlowMLMDSpec(BaseModel): - class FlowMLMDMetaData(BaseModel): - statu_uri: str - tracking_uri: str - - type: Literal["flow"] - metadata: FlowMLMDMetaData - - -class NoopMLMDSpec(BaseModel): - type: Literal["noop"] - - -class CustomMLMDSpec(BaseModel): - class CustomMLMDMetaData(BaseModel): - entrypoint: str - - type: Literal["custom"] - name: str - metadata: CustomMLMDMetaData diff --git a/python/fate/components/spec/output.py b/python/fate/components/spec/output.py deleted file mode 100644 index ad06a339d4..0000000000 --- a/python/fate/components/spec/output.py +++ /dev/null @@ -1,68 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Literal, Union - -import pydantic - - -class DirectoryDataPool(pydantic.BaseModel): - class DirectoryDataPoolMetadata(pydantic.BaseModel): - uri: str - format: str = "csv" - name_template: str = "{name}" # `name` and `uuid` allowed in template - - type: Literal["directory"] - metadata: DirectoryDataPoolMetadata - - -class CustomDataPool(pydantic.BaseModel): - type: Literal["custom"] - metadata: dict - - -class DirectoryModelPool(pydantic.BaseModel): - class DirectoryDataPoolMetadata(pydantic.BaseModel): - uri: str - format: str = "json" - name_template: str = "{name}" # `name` and `uuid` allowed in template - - type: Literal["directory"] - metadata: DirectoryDataPoolMetadata - - -class CustomModelPool(pydantic.BaseModel): - type: Literal["custom"] - metadata: dict - - -class DirectoryMetricPool(pydantic.BaseModel): - class DirectoryDataPoolMetadata(pydantic.BaseModel): - uri: str - format: str = "json" - name_template: str = "{name}" # `name` and `uuid` allowed in template - - type: Literal["directory"] - metadata: DirectoryDataPoolMetadata - - -class CustomMetricPool(pydantic.BaseModel): - type: Literal["custom"] - metadata: dict - - -class OutputPoolConf(pydantic.BaseModel): - data: Union[DirectoryDataPool, CustomDataPool] - model: Union[DirectoryModelPool, CustomModelPool] - metric: Union[DirectoryMetricPool, CustomMetricPool] diff --git a/python/fate/interface/__init__.py b/python/fate/interface/__init__.py deleted file mode 100644 index dfa48998c1..0000000000 --- a/python/fate/interface/__init__.py +++ /dev/null @@ -1,46 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from ._cipher import CipherKit, PHECipher -from ._computing import ComputingEngine -from ._consts import T_ARBITER, T_GUEST, T_HOST, T_ROLE -from ._context import Context -from ._data_io import Dataframe -from ._federation import FederationEngine, FederationWrapper -from ._gc import GarbageCollector -from ._metric import InCompleteMetrics, Metric, Metrics, MetricsHandler, MetricsWrap -from ._party import Parties, Party, PartyMeta - -__all__ = [ - "Context", - "Dataframe", - "MetricsHandler", - "MetricsWrap", - "Metrics", - "InCompleteMetrics", - "Metric", - "Party", - "Parties", - "PartyMeta", - "FederationWrapper", - "ComputingEngine", - "CipherKit", - "PHECipher", - "FederationEngine", - "GarbageCollector", - "T_GUEST", - "T_HOST", - "T_ARBITER", - "T_ROLE", -] diff --git a/python/fate/interface/_cipher.py b/python/fate/interface/_cipher.py deleted file mode 100644 index 8920d5c9ab..0000000000 --- a/python/fate/interface/_cipher.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from enum import Enum -from typing import Protocol, Tuple - -from ._tensor import Tensor - - -class PHEKind(Enum): - AUTO = "auto" - PAILLIER = "Paillier" - RUST_PAILLIER = "rust_paillier" - INTEL_PAILLIER = "intel_paillier" - - -class PHECipher(Protocol): - def keygen(self, kind: PHEKind = PHEKind.AUTO, options={}) -> Tuple["PHEEncryptor", "PHEDecryptor"]: - ... - - -class CipherKit(Protocol): - phe: PHECipher - - -class PHEEncryptor(Protocol): - def encrypt(self, tensor) -> Tensor: - ... - - -class PHEDecryptor(Protocol): - def decrypt(self, tensor) -> Tensor: - ... diff --git a/python/fate/interface/_consts.py b/python/fate/interface/_consts.py deleted file mode 100644 index 443728a291..0000000000 --- a/python/fate/interface/_consts.py +++ /dev/null @@ -1,20 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Literal - -T_GUEST = Literal["guest"] -T_HOST = Literal["host"] -T_ARBITER = Literal["arbiter"] -T_ROLE = Literal[T_GUEST, T_HOST, T_ARBITER] diff --git a/python/fate/interface/_context.py b/python/fate/interface/_context.py deleted file mode 100644 index ca5d158fef..0000000000 --- a/python/fate/interface/_context.py +++ /dev/null @@ -1,48 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from contextlib import contextmanager -from typing import Iterable, Iterator, Protocol, Tuple, TypeVar - -from ._cipher import CipherKit -from ._computing import ComputingEngine -from ._federation import FederationEngine -from ._metric import MetricsWrap -from ._party import Parties, Party - -T = TypeVar("T") - - -class Context(Protocol): - metrics: MetricsWrap - guest: Party - hosts: Parties - arbiter: Party - parties: Parties - cipher: CipherKit - computing: ComputingEngine - federation: FederationEngine - - @contextmanager - def sub_ctx(self, namespace) -> Iterator["Context"]: - ... - - def range(self, end) -> Iterator[Tuple[int, "Context"]]: - ... - - def iter(self, iterable: Iterable[T]) -> Iterator[Tuple["Context", T]]: - ... - - def destroy(self): - ... diff --git a/python/fate/interface/_metric.py b/python/fate/interface/_metric.py deleted file mode 100644 index 673ca4648a..0000000000 --- a/python/fate/interface/_metric.py +++ /dev/null @@ -1,76 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import List, Protocol, Tuple, Union - - -class Metric(Protocol): - type: str - - def dict(self) -> dict: - ... - - -class Metrics(Protocol): - name: str - type: str - - def dict(self) -> dict: - ... - - -class InCompleteMetrics(Protocol): - name: str - type: str - - def dict(self) -> dict: - ... - - def merge(self, metrics: "InCompleteMetrics"): - ... - - @classmethod - def from_dict(cls, d) -> "InCompleteMetrics": - ... - - -class MetricsHandler(Protocol): - def log_metrics(self, metrics: Union[Metrics, InCompleteMetrics]): - ... - - -class MetricsWrap(Protocol): - def into_group(self, group_name: str, group_id: str) -> "MetricsWrap": - ... - - def log_metrics(self, metrics: Metrics): - ... - - def log_meta(self, meta): - ... - - def log_metric(self, name: str, metric: Metric, step=None, timestamp=None): - ... - - def log_scalar(self, name: str, metric: float, step=None, timestamp=None): - ... - - def log_loss(self, name: str, loss: float, step, timestamp=None): - ... - - def log_accuracy(self, name: str, accuracy: float, step, timestamp=None): - ... - - def log_roc(self, name: str, data: List[Tuple[float, float]]): - ... diff --git a/python/fate/interface/_party.py b/python/fate/interface/_party.py deleted file mode 100644 index 33026c3016..0000000000 --- a/python/fate/interface/_party.py +++ /dev/null @@ -1,74 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, List, Optional, Protocol, Tuple, TypeVar, overload - -from ._consts import T_ROLE - -T = TypeVar("T") - - -class _KeyedParty(Protocol): - def put(self, value): - ... - - def get(self) -> Any: - ... - - -class Party(Protocol): - def get(self, name: str) -> Any: - ... - - @overload - def put(self, name: str, value): - ... - - @overload - def put(self, **kwargs): - ... - - def __call__(self, key: str) -> _KeyedParty: - ... - - -class Parties(Protocol): - def get(self, name: str) -> List: - ... - - @overload - def put(self, name: str, value): - ... - - @overload - def put(self, **kwargs): - ... - - def __getitem__(self, key: int) -> Party: - ... - - def get_neighbor(self, shift: int, module: bool = False) -> Party: - ... - - def get_neighbors(self) -> "Parties": - ... - - def get_local_index(self) -> Optional[int]: - ... - - def __call__(self, key: str) -> _KeyedParty: - ... - - -PartyMeta = Tuple[T_ROLE, str] diff --git a/python/fate/ml/abc/module.py b/python/fate/ml/abc/module.py index 9d3a853742..17f30a8236 100644 --- a/python/fate/ml/abc/module.py +++ b/python/fate/ml/abc/module.py @@ -14,7 +14,8 @@ # limitations under the License. from typing import Optional, Union -from fate.interface import Context, Dataframe +from fate.arch import Context +from fate.arch.dataframe import DataFrame class Model: @@ -27,19 +28,18 @@ class Module: def fit( self, ctx: Context, - train_data: Dataframe, - validate_data: Optional[Dataframe] = None, + train_data: DataFrame, + validate_data: Optional[DataFrame] = None, ) -> None: ... - def transform(self, ctx: Context, transform_data: Dataframe) -> Dataframe: + def transform(self, ctx: Context, transform_data: DataFrame) -> DataFrame: ... - def predict(self, ctx: Context, predict_data: Dataframe) -> Dataframe: + def predict(self, ctx: Context, predict_data: DataFrame) -> DataFrame: ... - @classmethod - def from_model(cls, model: Union[dict, Model]) -> "Module": + def from_model(cls, model: Union[dict, Model]): ... def get_model(self) -> Union[dict, Model]: diff --git a/python/fate/ml/aggregator/__init__.py b/python/fate/ml/aggregator/__init__.py new file mode 100644 index 0000000000..4909859034 --- /dev/null +++ b/python/fate/ml/aggregator/__init__.py @@ -0,0 +1,17 @@ +from fate.ml.aggregator.plaintext_aggregator import PlainTextAggregatorClient, PlainTextAggregatorServer +from fate.ml.aggregator.secure_aggregator import SecureAggregatorClient, SecureAggregatorServer +import enum + + +class AggregatorType(enum.Enum): + PLAINTEXT = 'plaintext' + SECURE_AGGREGATE = 'secure_aggregate' + + +aggregator_map = { + AggregatorType.PLAINTEXT.value: (PlainTextAggregatorClient, PlainTextAggregatorServer), + AggregatorType.SECURE_AGGREGATE.value: (SecureAggregatorClient, SecureAggregatorServer) +} + + +__all__ = ['PlainTextAggregatorClient', 'PlainTextAggregatorServer', 'SecureAggregatorClient', 'SecureAggregatorServer'] diff --git a/python/fate/ml/aggregator/base.py b/python/fate/ml/aggregator/base.py new file mode 100644 index 0000000000..44afc1c7b5 --- /dev/null +++ b/python/fate/ml/aggregator/base.py @@ -0,0 +1,243 @@ +import logging +from typing import Optional + +import numpy as np +import torch as t +from fate.arch import Context +from fate.arch.protocol.secure_aggregation._secure_aggregation import ( + SecureAggregatorClient as sa_client, +) +from fate.arch.protocol.secure_aggregation._secure_aggregation import ( + SecureAggregatorServer as sa_server, +) + +logger = logging.getLogger(__name__) + + +AGGREGATE_TYPE = ["mean", "sum", "weighted_mean"] +TORCH_TENSOR_PRECISION = ["float32", "float64"] + + +class AutoSuffix(object): + + """ + A auto suffix that will auto increase count + """ + + def __init__(self, suffix_str=""): + self._count = 0 + self.suffix_str = suffix_str + + def __call__(self): + concat_suffix = self.suffix_str + "_" + str(self._count) + self._count += 1 + return concat_suffix + + +class Aggregator: + def __init__(self, ctx: Context, aggregator_name: Optional[str] = None): + + if aggregator_name is not None: + agg_name = "_" + aggregator_name + else: + agg_name = "" + self.suffix = { + "local_loss": AutoSuffix("local_loss" + agg_name), + "agg_loss": AutoSuffix("agg_loss" + agg_name), + "local_model": AutoSuffix("local_model" + agg_name), + "agg_model": AutoSuffix("agg_model" + agg_name), + "converge_status": AutoSuffix("converge_status" + agg_name), + "local_weight": AutoSuffix("local_weight" + agg_name), + "computed_weight": AutoSuffix("agg_weight" + agg_name), + } + + def model_aggregation(self, *args, **kwargs): + raise NotImplementedError("model_aggregation should be implemented in subclass") + + def loss_aggregation(self, *args, **kwargs): + raise NotImplementedError("loss_aggregation should be implemented in subclass") + + +class BaseAggregatorClient(Aggregator): + def __init__( + self, + ctx: Context, + aggregator_name: str = None, + aggregate_type="mean", + sample_num=1, + is_mock=True, + require_grad=True, + float_p="float64", + ) -> None: + + super().__init__(ctx, aggregator_name) + self._weight = 1.0 + self.aggregator_name = "default" if aggregator_name is None else aggregator_name + self.require_grad = require_grad + + assert float_p in TORCH_TENSOR_PRECISION, "float_p should be one of {}".format(TORCH_TENSOR_PRECISION) + self.float_p = float_p + + if sample_num <= 0 and not isinstance(sample_num, int): + raise ValueError("sample_num should be int greater than 0") + + logger.info("computing weights") + if aggregate_type not in AGGREGATE_TYPE: + raise ValueError("aggregate_type should be one of {}".format(AGGREGATE_TYPE)) + elif aggregate_type == "mean": + ctx.arbiter.put(self.suffix["local_weight"](), 1.0) + self._weight = ctx.arbiter.get(self.suffix["computed_weight"]()) + elif aggregate_type == "sum": + ctx.arbiter.put(self.suffix["local_weight"](), sample_num) + self._weight = 1.0 + elif aggregate_type == "weighted_mean": + if sample_num <= 0 or sample_num is None: + raise ValueError("sample_num should be int greater than 0") + ctx.arbiter.put(self.suffix["local_weight"](), sample_num) + self._weight = ctx.arbiter.get(self.suffix["computed_weight"]()) + + logger.info("aggregate weight is {}".format(self._weight)) + + self.model_aggregator = sa_client(prefix=self.aggregator_name + "_model", is_mock=is_mock) + self.model_aggregator.dh_exchange(ctx, [ctx.guest.rank, *ctx.hosts.ranks]) + self.loss_aggregator = sa_client(prefix=self.aggregator_name + "_loss", is_mock=is_mock) + self.loss_aggregator.dh_exchange(ctx, [ctx.guest.rank, *ctx.hosts.ranks]) + + def _convert_type(self, data, dtype="float32"): + + if isinstance(data, t.Tensor): + if dtype == "float32": + data = data.float() + elif dtype == "float64": + data = data.double() + else: + raise ValueError("Invalid dtype. Choose either 'float32' or 'float64'") + + numpy_array = data.detach().cpu().numpy() + + elif isinstance(data, np.ndarray): + if dtype == "float32": + numpy_array = data.astype(np.float32) + elif dtype == "float64": + numpy_array = data.astype(np.float64) + else: + raise ValueError("Invalid dtype. Choose either 'float32' or 'float64'") + else: + raise ValueError("Invalid data type. Only numpy ndarray and PyTorch tensor are supported.") + + return numpy_array + + def _process_model(self, model): + + to_agg = None + if isinstance(model, np.ndarray) or isinstance(model, t.Tensor): + to_agg = self._convert_type(model, self.float_p) + return [to_agg] + + if isinstance(model, t.nn.Module): + parameters = list(model.parameters()) + if self.require_grad: + agg_list = [ + self._convert_type(p.cpu().detach().numpy(), self.float_p) for p in parameters if p.requires_grad + ] + else: + agg_list = [self._convert_type(p.cpu().detach().numpy(), self.float_p) for p in parameters] + + elif isinstance(model, list): + to_agg = [] + for p in model: + to_agg.append(self._convert_type(p, self.float_p)) + agg_list = to_agg + + return agg_list + + def _recover_model(self, model, agg_model): + + if isinstance(model, np.ndarray) or isinstance(model, t.Tensor): + return agg_model + elif isinstance(model, t.nn.Module): + if self.require_grad: + for agg_p, p in zip(agg_model, [p for p in model.parameters() if p.requires_grad]): + p.data.copy_(t.Tensor(agg_p)) + else: + for agg_p, p in zip(agg_model, model.parameters()): + p.data.copy_(t.Tensor(agg_p)) + return model + else: + return agg_model + + """ + User API + """ + + def model_aggregation(self, ctx, model): + + to_send = self._process_model(model) + agg_model = self.model_aggregator.secure_aggregate(ctx, to_send, self._weight) + return self._recover_model(model, agg_model) + + def loss_aggregation(self, ctx, loss): + if isinstance(loss, t.Tensor): + loss = loss.detach.cpu().numpy() + else: + loss = np.array(loss) + loss = [loss] + agg_loss = self.loss_aggregator.secure_aggregate(ctx, loss, self._weight) + return agg_loss + + +class BaseAggregatorServer(Aggregator): + def __init__(self, ctx: Context, aggregator_name: str = None, is_mock=True) -> None: + + super().__init__(ctx, aggregator_name) + + weight_list = self._collect(ctx, self.suffix["local_weight"]()) + weight_sum = sum(weight_list) + ret_weight = [] + for w in weight_list: + ret_weight.append(w / weight_sum) + + ret_suffix = self.suffix["computed_weight"]() + for idx, w in enumerate(ret_weight): + self._broadcast(ctx, w, ret_suffix, idx) + + self.aggregator_name = "default" if aggregator_name is None else aggregator_name + self.model_aggregator = sa_server( + prefix=self.aggregator_name + "_model", is_mock=is_mock, ranks=[ctx.guest.rank, *ctx.hosts.ranks] + ) + self.loss_aggregator = sa_server( + prefix=self.aggregator_name + "_loss", is_mock=is_mock, ranks=[ctx.guest.rank, *ctx.hosts.ranks] + ) + + def _check_party_id(self, party_id): + # party idx >= -1, int + if not isinstance(party_id, int): + raise ValueError("party_id should be int") + if party_id < -1: + raise ValueError("party_id should be greater than -1") + + def _collect(self, ctx, suffix): + guest_item = [ctx.guest.get(suffix)] + host_item = ctx.hosts.get(suffix) + combine_list = guest_item + host_item + return combine_list + + def _broadcast(self, ctx, data, suffix, party_idx=-1): + self._check_party_id(party_idx) + if party_idx == -1: + ctx.guest.put(suffix, data) + ctx.hosts.put(suffix, data) + elif party_idx == 0: + ctx.guest.put(suffix, data) + else: + ctx.hosts[party_idx - 1].put(suffix, data) + + """ + User API + """ + + def model_aggregation(self, ctx, ranks=None): + self.model_aggregator.secure_aggregate(ctx, ranks=ranks) + + def loss_aggregation(self, ctx, ranks=None): + self.loss_aggregator.secure_aggregate(ctx, ranks=ranks) diff --git a/python/fate/ml/aggregator/plaintext_aggregator.py b/python/fate/ml/aggregator/plaintext_aggregator.py new file mode 100644 index 0000000000..81fc85d5d2 --- /dev/null +++ b/python/fate/ml/aggregator/plaintext_aggregator.py @@ -0,0 +1,14 @@ +from fate.arch import Context +from fate.ml.aggregator.base import BaseAggregatorClient, BaseAggregatorServer + + +class PlainTextAggregatorClient(BaseAggregatorClient): + + def __init__(self, ctx: Context, aggregator_name: str = None, aggregate_type='mean', sample_num=1) -> None: + super().__init__(ctx, aggregator_name, aggregate_type, sample_num, is_mock=True) + + +class PlainTextAggregatorServer(BaseAggregatorServer): + + def __init__(self, ctx: Context, aggregator_name: str = None) -> None: + super().__init__(ctx, aggregator_name, is_mock=True) \ No newline at end of file diff --git a/python/fate/ml/aggregator/secure_aggregator.py b/python/fate/ml/aggregator/secure_aggregator.py new file mode 100644 index 0000000000..0ec0ad37dd --- /dev/null +++ b/python/fate/ml/aggregator/secure_aggregator.py @@ -0,0 +1,14 @@ +from fate.arch import Context +from fate.ml.aggregator.base import BaseAggregatorClient, BaseAggregatorServer + + +class SecureAggregatorClient(BaseAggregatorClient): + + def __init__(self, ctx: Context, aggregator_name: str = None, aggregate_type='mean', sample_num=1) -> None: + super().__init__(ctx, aggregator_name, aggregate_type, sample_num, is_mock=False) + + +class SecureAggregatorServer(BaseAggregatorServer): + + def __init__(self, ctx: Context, aggregator_name: str = None) -> None: + super().__init__(ctx, aggregator_name, is_mock=False) diff --git a/python/fate/ml/aggregator/test/test_aggregator.py b/python/fate/ml/aggregator/test/test_aggregator.py new file mode 100644 index 0000000000..ba74c743c6 --- /dev/null +++ b/python/fate/ml/aggregator/test/test_aggregator.py @@ -0,0 +1,73 @@ +import sys +import torch as t + + +arbiter = ("arbiter", 10000) +guest = ("guest", 10000) +host = ("host", 9999) +name = "fed" + + +def create_ctx(local): + from fate.arch import Context + from fate.arch.computing.standalone import CSession + from fate.arch.federation.standalone import StandaloneFederation + import logging + logger = logging.getLogger() + logger.setLevel(logging.DEBUG) + + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.DEBUG) + + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + console_handler.setFormatter(formatter) + + logger.addHandler(console_handler) + computing = CSession() + return Context(computing=computing, + federation=StandaloneFederation(computing, name, local, [guest, host, arbiter])) + + +if __name__ == "__main__": + + epoch = 10 + + if sys.argv[1] == "guest": + from fate.ml.aggregator.plaintext_aggregator import PlainTextAggregatorClient + + ctx = create_ctx(guest) + client = PlainTextAggregatorClient(ctx, sample_num=100, aggregate_type='weighted_mean') + model = t.nn.Sequential( + t.nn.Linear(10, 10), + t.nn.ReLU(), + t.nn.Linear(10, 1), + t.nn.Sigmoid() + ) + + for i, iter_ctx in ctx.on_iterations.ctxs_range(epoch): + client.model_aggregation(iter_ctx, model) + + elif sys.argv[1] == "host": + from fate.ml.aggregator.plaintext_aggregator import PlainTextAggregatorClient + + ctx = create_ctx(host) + client = PlainTextAggregatorClient(ctx, sample_num=100, aggregate_type='weighted_mean') + model = t.nn.Sequential( + t.nn.Linear(10, 10), + t.nn.ReLU(), + t.nn.Linear(10, 1), + t.nn.Sigmoid() + ) + + for i, iter_ctx in ctx.on_iterations.ctxs_range(epoch): + client.model_aggregation(iter_ctx, model) + + else: + + from fate.ml.aggregator.plaintext_aggregator import PlainTextAggregatorServer + ctx = create_ctx(arbiter) + server = PlainTextAggregatorServer(ctx) + + for i, iter_ctx in ctx.on_iterations.ctxs_range(epoch): + server.model_aggregation(iter_ctx) + diff --git a/python/fate/ml/aggregator/test/test_fate_utils.py b/python/fate/ml/aggregator/test/test_fate_utils.py new file mode 100644 index 0000000000..51ddd34311 --- /dev/null +++ b/python/fate/ml/aggregator/test/test_fate_utils.py @@ -0,0 +1,52 @@ +import sys + +arbiter = ("arbiter", 10000) +guest = ("guest", 10000) +host = ("host", 9999) +name = "fed" + + +def create_ctx(local): + from fate.arch import Context + from fate.arch.computing.standalone import CSession + from fate.arch.federation.standalone import StandaloneFederation + import logging + logger = logging.getLogger() + logger.setLevel(logging.DEBUG) + + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.DEBUG) + + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + console_handler.setFormatter(formatter) + + logger.addHandler(console_handler) + computing = CSession() + return Context(computing=computing, + federation=StandaloneFederation(computing, name, local, [guest, host, arbiter])) + + +if __name__ == "__main__": + if sys.argv[1] == "guest": + from fate.arch.protocol import SecureAggregatorClient + import numpy as np + + ctx = create_ctx(guest) + client = SecureAggregatorClient(is_mock=True) + client.dh_exchange(ctx, [ctx.guest.rank, *ctx.hosts.ranks]) + print('ranks are {}'.format([ctx.guest.rank, *ctx.hosts.ranks])) + print(client.secure_aggregate(ctx, [np.zeros((3, 4)), np.ones((2, 3))])) + elif sys.argv[1] == "host": + from fate.arch.protocol import SecureAggregatorClient + import numpy as np + + ctx = create_ctx(host) + client = SecureAggregatorClient(is_mock=True) + client.dh_exchange(ctx, [ctx.guest.rank, *ctx.hosts.ranks]) + print(client.secure_aggregate(ctx, [np.zeros((3, 4)), np.ones((2, 3))])) + else: + from fate.arch.protocol import SecureAggregatorServer + + ctx = create_ctx(arbiter) + server = SecureAggregatorServer([ctx.guest.rank, *ctx.hosts.ranks], is_mock=True) + server.secure_aggregate(ctx) \ No newline at end of file diff --git a/python/fate/ml/ensemble/__init__.py b/python/fate/ml/ensemble/__init__.py new file mode 100644 index 0000000000..51ee710519 --- /dev/null +++ b/python/fate/ml/ensemble/__init__.py @@ -0,0 +1,5 @@ +from fate.ml.ensemble.algo.secureboost.hetero.guest import HeteroSecureBoostGuest +from fate.ml.ensemble.algo.secureboost.hetero.host import HeteroSecureBoostHost +from fate.ml.ensemble.learner.decision_tree.tree_core.loss import BINARY_BCE, MULTI_CE, REGRESSION_L2 + +__all__ = ["HeteroSecureBoostGuest", "HeteroSecureBoostHost", "BINARY_BCE", "MULTI_CE", "REGRESSION_L2"] diff --git a/python/fate/arch/tensor/storage/__init__.py b/python/fate/ml/ensemble/algo/__init__.py similarity index 100% rename from python/fate/arch/tensor/storage/__init__.py rename to python/fate/ml/ensemble/algo/__init__.py diff --git a/python/fate/arch/tensor/storage/distributed/__init__.py b/python/fate/ml/ensemble/algo/secureboost/__init__.py similarity index 100% rename from python/fate/arch/tensor/storage/distributed/__init__.py rename to python/fate/ml/ensemble/algo/secureboost/__init__.py diff --git a/python/fate/arch/tensor/storage/local/__init__.py b/python/fate/ml/ensemble/algo/secureboost/common/__init__.py similarity index 100% rename from python/fate/arch/tensor/storage/local/__init__.py rename to python/fate/ml/ensemble/algo/secureboost/common/__init__.py diff --git a/python/fate/ml/ensemble/algo/secureboost/common/predict.py b/python/fate/ml/ensemble/algo/secureboost/common/predict.py new file mode 100644 index 0000000000..bffc451ddd --- /dev/null +++ b/python/fate/ml/ensemble/algo/secureboost/common/predict.py @@ -0,0 +1,195 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pandas as pd +import numpy as np +from typing import List +from fate.arch import Context +from fate.arch.dataframe import DataFrame +import copy +from fate.ml.ensemble.learner.decision_tree.tree_core.decision_tree import DecisionTree, _make_decision, Node +import functools +from logging import getLogger + +logger = getLogger(__name__) + + +def get_dtype(max_int): + if max_int < (2**8) / 2: + return np.int8 + elif max_int < (2**16) / 2: + return np.int16 + else: + return np.int64 + + +def all_reach_leaf(pos: np.array): + if isinstance(pos, list): + pos = np.array(pos) + return np.all(pos < 0) + + +def not_finished(pos: np.array): + if isinstance(pos, list): + pos = np.array(pos) + return not np.all(pos < 0) + + +def generate_pos_array(tree_num, max_node_num): + dtype = get_dtype(max_node_num) + # return list as a column of dataframe + return [np.zeros(tree_num, dtype=dtype)] + + +def go_deep(s: pd.Series, tree: List[Node], sitename, cur_node_id, tree_idx=None): + node: Node = tree[cur_node_id] + while True: + if node.is_leaf: + return -(node.nid + 1) + elif node.sitename != sitename: + return node.nid + else: + fid = node.fid + split_val = node.bid + sample_feat_val = s[fid] + is_left = _make_decision(sample_feat_val, split_val) + if is_left: + node = tree[node.l] + else: + node = tree[node.r] + + +def traverse_tree(s: pd.Series, trees: List[List[Node]], sitename: str): + sample_pos = s["sample_pos"] + new_sample_pos = np.copy(sample_pos) # deepcopy to avoid inplace modification, for spark + + tree_idx = 0 + for node_pos, tree in zip(sample_pos, trees): + if node_pos < 0: # sample already reaches leaf node in this tree + tree_idx += 1 + continue + + cur_node_id = node_pos + end_node_id = go_deep(s, tree, sitename, cur_node_id, tree_idx=tree_idx) + new_sample_pos[tree_idx] = end_node_id + tree_idx += 1 + + return [new_sample_pos] + + +def _merge_pos_arr(s: pd.Series): + arr_1 = s["sample_pos"] + arr_2 = s["host_sample_pos"] + arr_1 = np.array(arr_1) + arr_2 = np.array(arr_2) + assert len(arr_1) == len(arr_2) + merge_rs = np.copy(arr_1) + on_leaf = arr_2 < 0 + updated = on_leaf | (arr_2 > arr_1) + merge_rs[updated] = arr_2[updated] + return [merge_rs] + + +def _merge_pos(guest_pos: DataFrame, host_pos: List[DataFrame]): + for host_df in host_pos: + # assert alignment + indexer = guest_pos.get_indexer(target="sample_id") + host_df = host_df.loc(indexer=indexer, preserve_order=True) + stack_df = DataFrame.hstack([guest_pos, host_df]) + guest_pos["sample_pos"] = stack_df.apply_row(_merge_pos_arr) + + return guest_pos + + +def predict_leaf_guest(ctx: Context, trees: List[DecisionTree], data: DataFrame): + predict_data = data + tree_list = [tree.get_nodes() for tree in trees] + max_node_num = max([len(tree) for tree in tree_list]) + map_func = functools.partial(generate_pos_array, tree_num=len(trees), max_node_num=max_node_num) + + sample_pos = data.create_frame() + sample_pos["sample_pos"] = data.apply_row(lambda x: map_func()) + result_sample_pos = sample_pos.empty_frame() + + sitename = ctx.local.name + + # start loop here + comm_round = 0 + + while True: + sub_ctx = ctx.sub_ctx("predict_round").indexed_ctx(comm_round) + + if comm_round: + predict_data = predict_data.loc(indexer=sample_pos.get_indexer(target="sample_id"), preserve_order=True) + sample_with_pos = DataFrame.hstack([predict_data, sample_pos]) + logger.info("predict round {} has {} samples to predict".format(comm_round, len(sample_with_pos))) + map_func = functools.partial(traverse_tree, trees=tree_list, sitename=sitename) + new_pos = sample_with_pos.create_frame() + new_pos["sample_pos"] = sample_with_pos.apply_row(map_func) + done_sample_idx = new_pos.apply_row( + lambda x: all_reach_leaf(x["sample_pos"]) + ) # samples that reach leaf node in all trees + # not_finished_sample_idx = ~done_sample_idx + not_finished_sample_idx = new_pos.apply_row( + lambda x: not_finished(x["sample_pos"]) + ) # samples that not reach leaf node in all trees + + done_sample = new_pos.iloc(done_sample_idx) + result_sample_pos = DataFrame.vstack([result_sample_pos, done_sample]) + if len(result_sample_pos) == len(data): + sub_ctx.hosts.put("need_stop", True) + break + + sub_ctx.hosts.put("need_stop", False) + pending_samples = new_pos.iloc(not_finished_sample_idx) + + # send not-finished samples to host + sub_ctx.hosts.put("pending_samples", (pending_samples)) + # get result from host and merge + updated_pos = sub_ctx.hosts.get("updated_pos") + sample_pos = _merge_pos(pending_samples, updated_pos) + comm_round += 1 + + logger.info("predict done") + + assert len(result_sample_pos) == len(data), "result sample pos length not equal to data length, {} vs {}".format( + len(result_sample_pos), len(data) + ) + return result_sample_pos + + +def predict_leaf_host(ctx: Context, trees: List[DecisionTree], data: DataFrame): + tree_list = [tree.get_nodes() for tree in trees] + sitename = ctx.local.name + map_func = functools.partial(traverse_tree, trees=tree_list, sitename=sitename) + + # help guest to traverse tree + comm_round = 0 + + # start loop here + while True: + sub_ctx = ctx.sub_ctx("predict_round").indexed_ctx(comm_round) + need_stop = sub_ctx.guest.get("need_stop") + if need_stop: + break + pending_samples = sub_ctx.guest.get("pending_samples") + logger.info("got {} pending samples".format(len(pending_samples))) + sample_features = data.loc(pending_samples.get_indexer("sample_id"), preserve_order=True) + sample_with_pos = DataFrame.hstack([sample_features, pending_samples]) + new_pos = sample_with_pos.create_frame() + new_pos["host_sample_pos"] = sample_with_pos.apply_row(map_func) + sub_ctx.guest.put("updated_pos", (new_pos)) + comm_round += 1 + + logger.info("predict done") diff --git a/python/fate/arch/tensor/storage/local/device/cpu/__init__.py b/python/fate/ml/ensemble/algo/secureboost/hetero/__init__.py similarity index 100% rename from python/fate/arch/tensor/storage/local/device/cpu/__init__.py rename to python/fate/ml/ensemble/algo/secureboost/hetero/__init__.py diff --git a/python/fate/ml/ensemble/algo/secureboost/hetero/_base.py b/python/fate/ml/ensemble/algo/secureboost/hetero/_base.py new file mode 100644 index 0000000000..e5a1173741 --- /dev/null +++ b/python/fate/ml/ensemble/algo/secureboost/hetero/_base.py @@ -0,0 +1,77 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List +import functools +from fate.arch.dataframe import DataFrame +from fate.ml.abc.module import HeteroModule, Model +from fate.ml.ensemble.learner.decision_tree.tree_core.decision_tree import FeatureImportance, Node +from typing import Dict +import numpy as np + + +class HeteroBoostingTree(HeteroModule): + def __init__(self) -> None: + super().__init__() + self._global_feature_importance = {} + self._trees = [] + self._saved_tree = [] + + def _update_feature_importance(self, fi_dict: Dict[int, FeatureImportance]): + for fid, fi in fi_dict.items(): + if fid not in self._global_feature_importance: + self._global_feature_importance[fid] = fi + else: + self._global_feature_importance[fid] = self._global_feature_importance[fid] + fi + + def _sum_leaf_weights(self, leaf_pos: DataFrame, trees, learing_rate: float, loss_func): + def _compute_score(leaf_pos: np.array, trees: List[List[Node]], learning_rate: float): + score = 0 + leaf_pos = leaf_pos["sample_pos"] + for node_idx, tree in zip(leaf_pos, trees): + recovered_idx = -(node_idx + 1) + score += tree[recovered_idx].weight * learning_rate + return score + + tree_list = [tree.get_nodes() for tree in trees] + apply_func = functools.partial(_compute_score, trees=tree_list, learning_rate=learing_rate) + predict_score = leaf_pos.create_frame() + predict_score["score"] = leaf_pos.apply_row(apply_func) + return loss_func.predict(predict_score) + + def get_trees(self): + return self._trees + + def get_feature_importance(self): + return self._global_feature_importance + + def print_forest(self): + idx = 0 + for tree in self._trees: + print("tree {}: ".format(idx)) + idx += 1 + tree.print_tree() + print() + + def _get_hyper_param(self) -> dict: + pass + + def get_model(self) -> dict: + import copy + + hyper_param = self._get_hyper_param() + result = {} + result["hyper_param"] = hyper_param + result["trees"] = copy.deepcopy(self._saved_tree) + return result diff --git a/python/fate/ml/ensemble/algo/secureboost/hetero/guest.py b/python/fate/ml/ensemble/algo/secureboost/hetero/guest.py new file mode 100644 index 0000000000..273f5b144b --- /dev/null +++ b/python/fate/ml/ensemble/algo/secureboost/hetero/guest.py @@ -0,0 +1,325 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +from typing import Optional +from fate.arch import Context +from fate.arch.dataframe import DataFrame +from fate.ml.ensemble.algo.secureboost.hetero._base import HeteroBoostingTree +from fate.ml.ensemble.learner.decision_tree.hetero.guest import HeteroDecisionTreeGuest +from fate.ml.ensemble.utils.binning import binning +from fate.ml.ensemble.learner.decision_tree.tree_core.loss import OBJECTIVE, get_task_info, MULTI_CE +from fate.ml.ensemble.algo.secureboost.common.predict import predict_leaf_guest +from fate.ml.utils.predict_tools import compute_predict_details, PREDICT_SCORE, LABEL, BINARY, MULTI, REGRESSION +import logging + + +logger = logging.getLogger(__name__) + + +class HeteroSecureBoostGuest(HeteroBoostingTree): + def __init__( + self, + num_trees=3, + learning_rate=0.3, + max_depth=3, + objective="binary:bce", + num_class=3, + max_bin=32, + l2=0.1, + l1=0, + min_impurity_split=1e-2, + min_sample_split=2, + min_leaf_node=1, + min_child_weight=1, + gh_pack=True, + split_info_pack=True, + hist_sub=True, + ): + super().__init__() + self.num_trees = num_trees + self.learning_rate = learning_rate + self.max_depth = max_depth + self.objective = objective + self.max_bin = max_bin + + # regularization + self.l2 = l2 + self.l1 = l1 + self.min_impurity_split = min_impurity_split + self.min_sample_split = min_sample_split + self.min_leaf_node = min_leaf_node + self.min_child_weight = min_child_weight + + # running var + self.num_class = num_class + self._accumulate_scores = None + self._tree_dim = 1 # tree dimension, if is multilcass task, tree dim > 1 + self._loss_func = None + self._train_predict = None + self._hist_sub = hist_sub + + # encryption + self._encrypt_kit = None + self._gh_pack = gh_pack + self._split_info_pack = split_info_pack + + # reg score + self._init_score = None + + # model loaded + self._model_loaded = False + + def _prepare_parameter(self): + self._tree_dim = self.num_class if self.objective == "multiclass:ce" else 1 + + def _get_loss_func(self, objective: str) -> Optional[object]: + # to lowercase + objective = objective.lower() + if objective == MULTI_CE: + raise ValueError( + "multi:ce objective is not supported in the beta version, will be added in the next version" + ) + assert ( + objective in OBJECTIVE + ), f"objective {objective} not found, supported objective: {list(OBJECTIVE.keys())}" + obj_class = OBJECTIVE[objective] + loss_func = obj_class() + return loss_func + + def _compute_gh(self, data: DataFrame, scores: DataFrame, loss_func): + label = data.label + predict = loss_func.predict(scores) + gh = data.create_frame() + loss_func.compute_grad(gh, label, predict) + loss_func.compute_hess(gh, label, predict) + return gh + + def _check_encrypt_kit(self, ctx: Context): + if self._encrypt_kit is None: + # make sure cipher is initialized + kit = ctx.cipher.phe.setup() + self._encrypt_kit = kit + + if not self._encrypt_kit.can_support_negative_number: + self._gh_pack = True + logger.info("current encrypt method cannot support neg num, gh pack is forced to be True") + if not self._encrypt_kit.can_support_squeeze: + self._split_info_pack = False + logger.info("current encrypt method cannot support compress, split info pack is forced to be False") + if not self._encrypt_kit.can_support_pack: + self._gh_pack = False + self._split_info_pack = False + logger.info("current encrypt method cannot support pack, gh pack is forced to be False") + return kit + + def get_train_predict(self): + return self._train_predict + + def get_tree(self, idx): + return self._trees[idx] + + def _init_sample_scores(self, ctx: Context, label, train_data: DataFrame): + task_type = self.objective.split(":")[0] + pred_ctx = ctx.sub_ctx("warmstart_predict") + if self._model_loaded: + logger.info("prepare warmstarting score") + self._accumulate_scores = self.predict(pred_ctx, train_data, ret_std_format=False) + self._accumulate_scores = self._accumulate_scores.loc( + train_data.get_indexer(target="sample_id"), preserve_order=True + ) + else: + if task_type == REGRESSION: + self._accumulate_scores, avg_score = self._loss_func.initialize(label) + if self._init_score is None: + self._init_score = avg_score + elif task_type == MULTI: + self._accumulate_scores = self._loss_func.initialize(label, self.num_class) + else: + self._accumulate_scores = self._loss_func.initialize(label) + + def _check_label(self, label: DataFrame): + label_df = label.as_pd_df()[label.schema.label_name] + if self.objective == "multi:ce": + if self.num_class is None or self.num_class <= 2: + raise ValueError( + f"num_class should be set and greater than 2 for multi:ce objective, but got {self.num_class}" + ) + label_set = set(np.unique(label_df)) + if len(label_set) > self.num_class: + raise ValueError( + f"num_class should be greater than or equal to the number of unique label in provided train data, but got {self.num_class} and {len(label_set)}" + ) + if max(label_set) - 1 > self.num_class: + raise ValueError( + f"the max label index in the provided train data should be less than or equal to num_class - 1, but got index {max(label_set)} which is > {self.num_class}" + ) + + elif self.objective == "binary:bce": + label_set = set(np.unique(label_df)) + assert len(label_set) == 2, f"binary classification task should have 2 unique label, but got {label_set}" + assert ( + 0 in label_set and 1 in label_set + ), f"binary classification task should have label 0 and 1, but got {label_set}" + self.num_class = 2 + else: + self.num_class = None + + def get_task_info(self): + task_type = get_task_info(self.objective) + if task_type == BINARY: + classes = [0, 1] + elif task_type == REGRESSION: + classes = None + return task_type, classes + + def fit(self, ctx: Context, train_data: DataFrame, validate_data: DataFrame = None) -> None: + """ + Train model with train data and validate data. + + Parameters + ---------- + ctx: Context + FATE Context object + train_data: DataFrame + Train data used to fit model. + validate_data: DataFrame, optional + Validate data used to evaluate model performance during training process. + """ + + # data binning + bin_info = binning(train_data, max_bin=self.max_bin) + bin_data: DataFrame = train_data.bucketize(boundaries=bin_info) + self._check_label(bin_data.label) + + # init loss func & scores + self._loss_func = self._get_loss_func(self.objective) + label = bin_data.label + self._init_sample_scores(ctx, label, train_data) + + # init encryption kit + self._encrypt_kit = self._check_encrypt_kit(ctx) + + # start tree fittingf + for tree_idx, tree_ctx in ctx.on_iterations.ctxs_range(len(self._trees), len(self._trees) + self.num_trees): + # compute gh of current iter + logger.info("start to fit a guest tree") + gh = self._compute_gh(bin_data, self._accumulate_scores, self._loss_func) + tree = HeteroDecisionTreeGuest( + max_depth=self.max_depth, + l2=self.l2, + l1=self.l1, + min_impurity_split=self.min_impurity_split, + min_sample_split=self.min_sample_split, + min_leaf_node=self.min_leaf_node, + min_child_weight=self.min_child_weight, + objective=self.objective, + gh_pack=self._gh_pack, + split_info_pack=self._split_info_pack, + hist_sub=self._hist_sub, + ) + tree.set_encrypt_kit(self._encrypt_kit) + tree.booster_fit(tree_ctx, bin_data, gh, bin_info) + # accumulate scores of cur boosting round + scores = tree.get_sample_predict_weights() + assert len(scores) == len( + self._accumulate_scores + ), f"tree predict scores length {len(scores)} not equal to accumulate scores length {len(self._accumulate_scores)}." + scores = scores.loc(self._accumulate_scores.get_indexer(target="sample_id"), preserve_order=True) + self._accumulate_scores = self._accumulate_scores + scores * self.learning_rate + self._trees.append(tree) + self._saved_tree.append(tree.get_model()) + self._update_feature_importance(tree.get_feature_importance()) + logger.info("fitting guest decision tree {} done".format(tree_idx)) + + # compute train predict using cache scores + train_predict: DataFrame = self._loss_func.predict(self._accumulate_scores) + train_predict = train_predict.loc(train_data.get_indexer(target="sample_id"), preserve_order=True) + train_predict.label = train_data.label + task_type, classes = self.get_task_info() + train_predict.rename(columns={"score": PREDICT_SCORE}) + self._train_predict = compute_predict_details(train_predict, task_type, classes) + + def predict(self, ctx: Context, predict_data: DataFrame, predict_leaf=False, ret_std_format=True) -> DataFrame: + """ + predict function + + Parameters + ---------- + ctx: Context + FATE Context object + predict_data: DataFrame + Data used to predict. + predict_leaf: bool, optional + Whether to predict and return leaf index. + ret_std_format: bool, optional + Whether to return result in a FATE standard format which contains more details. + """ + + task_type, classes = self.get_task_info() + leaf_pos = predict_leaf_guest(ctx, self._trees, predict_data) + if predict_leaf: + return leaf_pos + result = self._sum_leaf_weights(leaf_pos, self._trees, self.learning_rate, self._loss_func) + + if task_type == REGRESSION: + logger.debug("regression task, add init score") + result = result + self._init_score + + if ret_std_format: + # align table + result: DataFrame = result.loc(predict_data.get_indexer(target="sample_id"), preserve_order=True) + ret_frame = result.create_frame() + if predict_data.schema.label_name is not None: + ret_frame.label = predict_data.label + ret_frame[PREDICT_SCORE] = result["score"] + + return compute_predict_details(ret_frame, task_type, classes) + else: + return result + + def _get_hyper_param(self) -> dict: + return { + "num_trees": self.num_trees, + "learning_rate": self.learning_rate, + "max_depth": self.max_depth, + "objective": self.objective, + "max_bin": self.max_bin, + "l2": self.l2, + "num_class": self.num_class, + } + + def get_model(self) -> dict: + ret_dict = super().get_model() + ret_dict["init_score"] = self._init_score + return ret_dict + + def from_model(self, model: dict): + trees = model["trees"] + self._saved_tree = trees + self._trees = [HeteroDecisionTreeGuest.from_model(tree) for tree in trees] + hyper_parameter = model["hyper_param"] + + # these parameter are related to predict + self.learning_rate = hyper_parameter["learning_rate"] + self.num_class = hyper_parameter["num_class"] + self.objective = hyper_parameter["objective"] + self._init_score = float(model["init_score"]) if model["init_score"] is not None else None + # initialize + self._prepare_parameter() + self._loss_func = self._get_loss_func(self.objective) + # for warmstart + self._model_loaded = True + + return self diff --git a/python/fate/ml/ensemble/algo/secureboost/hetero/host.py b/python/fate/ml/ensemble/algo/secureboost/hetero/host.py new file mode 100644 index 0000000000..e1102f9837 --- /dev/null +++ b/python/fate/ml/ensemble/algo/secureboost/hetero/host.py @@ -0,0 +1,95 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random +import os +from fate.ml.ensemble.learner.decision_tree.hetero.host import HeteroDecisionTreeHost +from fate.ml.ensemble.algo.secureboost.hetero._base import HeteroBoostingTree +from fate.arch import Context +from fate.arch.dataframe import DataFrame +from fate.ml.ensemble.utils.binning import binning +from fate.ml.ensemble.algo.secureboost.common.predict import predict_leaf_host +import logging + +logger = logging.getLogger(__name__) + + +class HeteroSecureBoostHost(HeteroBoostingTree): + def __init__(self, num_trees=3, learning_rate=0.3, max_depth=3, max_bin=32, hist_sub=True) -> None: + super().__init__() + self.num_trees = num_trees + self.learning_rate = learning_rate + self.max_depth = max_depth + self.max_bin = max_bin + self._model_loaded = False + self._hist_sub = hist_sub + + def get_tree(self, idx): + return self._trees[idx] + + def _get_seeds(self, ctx: Context): + if ctx.cipher.allow_custom_random_seed: + seed = ctx.cipher.get_custom_random_seed() + random_state = random.Random(seed) + yield random_state.getrandbits(64) + while True: + yield random_state.getrandbits(64) + else: + while True: + random_seed = os.urandom(8) + yield int.from_bytes(random_seed, byteorder="big") + + def fit(self, ctx: Context, train_data: DataFrame, validate_data: DataFrame = None) -> None: + # data binning + bin_info = binning(train_data, max_bin=self.max_bin) + bin_data: DataFrame = train_data.bucketize(boundaries=bin_info) + logger.info("data binning done") + # predict to help guest to get the warmstart scores + if self._model_loaded: + pred_ctx = ctx.sub_ctx("warmstart_predict") + self.predict(pred_ctx, train_data) + + random_seeds = self._get_seeds(ctx) + global_random_seed = next(random_seeds) + for tree_idx, tree_ctx in ctx.on_iterations.ctxs_range(len(self._trees), len(self._trees) + self.num_trees): + logger.info("start to fit a host tree") + tree = HeteroDecisionTreeHost( + max_depth=self.max_depth, + hist_sub=self._hist_sub, + global_random_seed=global_random_seed, + random_seed=next(random_seeds), + ) + tree.booster_fit(tree_ctx, bin_data, bin_info) + self._trees.append(tree) + self._saved_tree.append(tree.get_model()) + self._update_feature_importance(tree.get_feature_importance()) + logger.info("fitting host decision tree {} done".format(tree_idx)) + + def predict(self, ctx: Context, predict_data: DataFrame) -> None: + predict_leaf_host(ctx, self._trees, predict_data) + + def _get_hyper_param(self) -> dict: + return { + "num_trees": self.num_trees, + "learning_rate": self.learning_rate, + "max_depth": self.max_depth, + "max_bin": self.max_bin, + } + + def from_model(self, model: dict): + trees = model["trees"] + self._saved_tree = trees + self._trees = [HeteroDecisionTreeHost.from_model(tree) for tree in trees] + self._model_loaded = True + return self diff --git a/python/fate/ml/ensemble/algo/secureboost/test/test_hetero_sbt_binary.py b/python/fate/ml/ensemble/algo/secureboost/test/test_hetero_sbt_binary.py new file mode 100644 index 0000000000..e3035b77ec --- /dev/null +++ b/python/fate/ml/ensemble/algo/secureboost/test/test_hetero_sbt_binary.py @@ -0,0 +1,69 @@ +import pandas as pd +from fate.arch.dataframe import PandasReader +import sys +from fate.ml.ensemble.algo.secureboost.hetero.guest import HeteroSecureBoostGuest +from fate.ml.ensemble.algo.secureboost.hetero.host import HeteroSecureBoostHost +from datetime import datetime + + +def get_current_datetime_str(): + return datetime.now().strftime("%Y-%m-%d-%H-%M") + +guest = ("guest", "10000") +host = ("host", "9999") +name = get_current_datetime_str() + + +def create_ctx(local, context_name): + from fate.arch import Context + from fate.arch.computing.standalone import CSession + from fate.arch.federation.standalone import StandaloneFederation + import logging + + # prepare log + logger = logging.getLogger() + logger.setLevel(logging.INFO) + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + # init fate context + computing = CSession() + return Context( + computing=computing, federation=StandaloneFederation(computing, context_name, local, [guest, host]) + ) + + +if __name__ == "__main__": + + party = sys.argv[1] + max_depth = 3 + num_tree = 1 + + if party == "guest": + + ctx = create_ctx(guest, get_current_datetime_str()) + df = pd.read_csv("./../../../../../../../examples/data/breast_hetero_guest.csv") + df["sample_id"] = [i for i in range(len(df))] + + reader = PandasReader(sample_id_name="sample_id", match_id_name="id", label_name="y", dtype="float32") + + data_guest = reader.to_frame(ctx, df) + + trees = HeteroSecureBoostGuest(num_tree, max_depth=max_depth) + trees.fit(ctx, data_guest) + pred = trees.get_train_predict().as_pd_df() + + elif party == "host": + + ctx = create_ctx(host, get_current_datetime_str()) + df_host = pd.read_csv("./../../../../../../../examples/data/breast_hetero_host.csv") + df_host["sample_id"] = [i for i in range(len(df_host))] + + reader_host = PandasReader(sample_id_name="sample_id", match_id_name="id", dtype="float32") + + data_host = reader_host.to_frame(ctx, df_host) + + trees = HeteroSecureBoostHost(num_tree, max_depth=max_depth) + trees.fit(ctx, data_host) \ No newline at end of file diff --git a/python/fate/ml/ensemble/algo/secureboost/test/test_hetero_sbt_multi.py b/python/fate/ml/ensemble/algo/secureboost/test/test_hetero_sbt_multi.py new file mode 100644 index 0000000000..9ee7992754 --- /dev/null +++ b/python/fate/ml/ensemble/algo/secureboost/test/test_hetero_sbt_multi.py @@ -0,0 +1,100 @@ +import pandas as pd +from fate.arch.dataframe import PandasReader, DataFrame +from fate.arch import Context +import sys +from fate.ml.ensemble.algo.secureboost.hetero.guest import HeteroSecureBoostGuest +from fate.ml.ensemble.algo.secureboost.hetero.host import HeteroSecureBoostHost +from datetime import datetime + + +def get_current_datetime_str(): + return datetime.now().strftime("%Y-%m-%d-%H-%M") + + +arbiter = ("arbiter", "10000") +guest = ("guest", "10000") +host = ("host", "9999") +name = get_current_datetime_str() + + +def create_ctx(local): + from fate.arch import Context + from fate.arch.computing.standalone import CSession + from fate.arch.federation.standalone import StandaloneFederation + import logging + + logger = logging.getLogger() + logger.setLevel(logging.DEBUG) + + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.DEBUG) + + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + console_handler.setFormatter(formatter) + + logger.addHandler(console_handler) + computing = CSession() + return Context( + computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]) + ) + + +if __name__ == "__main__": + party = sys.argv[1] + max_depth = 3 + num_tree = 1 + from sklearn.metrics import roc_auc_score as auc + + if party == "guest": + ctx = create_ctx(guest) + df = pd.read_csv("./../../../../../../../examples/data/vehicle_scale_hetero_guest.csv") + df["sample_id"] = [i for i in range(len(df))] + + reader = PandasReader(sample_id_name="sample_id", match_id_name="id", label_name="y", dtype="float32") + + data_guest = reader.to_frame(ctx, df) + + trees = HeteroSecureBoostGuest( + num_tree, max_depth=max_depth, l2=0.5, min_impurity_split=200, num_class=4, objective="multi:ce" + ) + gh = trees.fit(ctx, data_guest) + # pred = trees.get_train_predict().as_pd_df() + # pred['sample_id'] = pred.sample_id.astype(int) + # df = pd.merge(df, pred, on='sample_id') + + # load tree + # tree_dict = pickle.load(open('guest_tree.pkl', 'rb')) + # trees.from_model(tree_dict) + # pred_ = trees.predict(ctx, data_guest).as_pd_df() + # print(auc(df.y, df.score)) + # print(auc(pred_.label, pred_.predict_score)) + # pred_.sample_id = pred_.sample_id.astype(int) + # merge_df = pd.merge(pred, pred_, on='sample_id') + + # print('fitting again, warm start') + # # fit again + # new_tree = HeteroSecureBoostGuest(1, max_depth=3) + # new_tree.from_model(trees.get_model()) + # new_tree.fit(ctx, data_guest) + + elif party == "host": + ctx = create_ctx(host) + + df_host = pd.read_csv("./../../../../../../../examples/data/vehicle_scale_hetero_host.csv") + df_host["sample_id"] = [i for i in range(len(df_host))] + + reader_host = PandasReader(sample_id_name="sample_id", match_id_name="id", dtype="float32") + + data_host = reader_host.to_frame(ctx, df_host) + + trees = HeteroSecureBoostHost(num_tree, max_depth=max_depth) + trees.fit(ctx, data_host) + # load tree + # tree_dict = pickle.load(open('host_tree.pkl', 'rb')) + # trees.from_model(tree_dict) + # trees.predict(ctx, data_host) + + # fit again + new_tree = HeteroSecureBoostHost(1, max_depth=3) + new_tree.from_model(trees.get_model()) + new_tree.fit(ctx, data_host) diff --git a/python/fate/ml/ensemble/algo/secureboost/test/test_hetero_sbt_regression.py b/python/fate/ml/ensemble/algo/secureboost/test/test_hetero_sbt_regression.py new file mode 100644 index 0000000000..d8b28ebafb --- /dev/null +++ b/python/fate/ml/ensemble/algo/secureboost/test/test_hetero_sbt_regression.py @@ -0,0 +1,88 @@ +import pandas as pd +from fate.arch.dataframe import PandasReader, DataFrame +from fate.arch import Context +import sys +from fate.ml.ensemble.algo.secureboost.hetero.guest import HeteroSecureBoostGuest +from fate.ml.ensemble.algo.secureboost.hetero.host import HeteroSecureBoostHost +from datetime import datetime + + +def get_current_datetime_str(): + return datetime.now().strftime("%Y-%m-%d-%H-%M") + + +arbiter = ("arbiter", "10000") +guest = ("guest", "10000") +host = ("host", "9999") +name = get_current_datetime_str() + + +def create_ctx(local): + from fate.arch import Context + from fate.arch.computing.standalone import CSession + from fate.arch.federation.standalone import StandaloneFederation + import logging + + logger = logging.getLogger() + logger.setLevel(logging.DEBUG) + + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.DEBUG) + + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + console_handler.setFormatter(formatter) + + logger.addHandler(console_handler) + computing = CSession() + return Context( + computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]) + ) + + +if __name__ == "__main__": + party = sys.argv[1] + max_depth = 3 + num_tree = 20 + from sklearn.metrics import roc_auc_score as auc + + if party == "guest": + ctx = create_ctx(guest) + df = pd.read_csv("./../../../../../../../examples/data/student_hetero_guest.csv") + df["sample_id"] = [i for i in range(len(df))] + + reader = PandasReader(sample_id_name="sample_id", match_id_name="id", label_name="y", dtype="float32") + + data_guest = reader.to_frame(ctx, df) + + trees = HeteroSecureBoostGuest(num_tree, max_depth=max_depth, objective="regression:l2") + trees.fit(ctx, data_guest) + pred = trees.get_train_predict().as_pd_df() + pred["sample_id"] = pred.sample_id.astype(int) + df = pd.merge(df, pred, on="sample_id") + + # load tree + # tree_dict = pickle.load(open('guest_tree.pkl', 'rb')) + # trees.from_model(tree_dict) + pred_ = trees.predict(ctx, data_guest).as_pd_df() + pred_.sample_id = pred_.sample_id.astype(int) + merge_df = pd.merge(pred, pred_, on="sample_id") + from sklearn.metrics import mean_squared_error + + print(mean_squared_error(pred.predict_score, pred.label)) + + elif party == "host": + ctx = create_ctx(host) + + df_host = pd.read_csv("./../../../../../../../examples/data/student_hetero_host.csv") + df_host["sample_id"] = [i for i in range(len(df_host))] + + reader_host = PandasReader(sample_id_name="sample_id", match_id_name="id", dtype="float32") + + data_host = reader_host.to_frame(ctx, df_host) + + trees = HeteroSecureBoostHost(num_tree, max_depth=max_depth) + trees.fit(ctx, data_host) + # load tree + # tree_dict = pickle.load(open('host_tree.pkl', 'rb')) + # trees.from_model(tree_dict) + trees.predict(ctx, data_host) diff --git a/python/fate/components/loader/__init__.py b/python/fate/ml/ensemble/learner/__init__.py similarity index 100% rename from python/fate/components/loader/__init__.py rename to python/fate/ml/ensemble/learner/__init__.py diff --git a/python/fate/ml/ensemble/learner/decision_tree/__init__.py b/python/fate/ml/ensemble/learner/decision_tree/__init__.py new file mode 100644 index 0000000000..b41822b468 --- /dev/null +++ b/python/fate/ml/ensemble/learner/decision_tree/__init__.py @@ -0,0 +1,3 @@ +from fate.ml.ensemble.learner.decision_tree.tree_core.decision_tree import DecisionTree, Node, FeatureImportance + +__all__ = ["DecisionTree", "Node", "FeatureImportance"] diff --git a/python/fate/components/spec/__init__.py b/python/fate/ml/ensemble/learner/decision_tree/hetero/__init__.py similarity index 100% rename from python/fate/components/spec/__init__.py rename to python/fate/ml/ensemble/learner/decision_tree/hetero/__init__.py diff --git a/python/fate/ml/ensemble/learner/decision_tree/hetero/guest.py b/python/fate/ml/ensemble/learner/decision_tree/hetero/guest.py new file mode 100644 index 0000000000..bac4fffc2a --- /dev/null +++ b/python/fate/ml/ensemble/learner/decision_tree/hetero/guest.py @@ -0,0 +1,402 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from fate.ml.ensemble.learner.decision_tree.tree_core.decision_tree import ( + DecisionTree, + Node, + _update_sample_pos_on_local_nodes, + _merge_sample_pos, +) +from fate.ml.ensemble.learner.decision_tree.tree_core.hist import SBTHistogramBuilder +from fate.ml.ensemble.learner.decision_tree.tree_core.splitter import FedSBTSplitter +from fate.ml.ensemble.learner.decision_tree.tree_core.loss import get_task_info +from fate.ml.utils.predict_tools import BINARY, MULTI, REGRESSION +from fate.arch import Context +from fate.arch.dataframe import DataFrame +from typing import List +import functools +import logging +import pandas as pd +import torch as t +import numpy as np +import math + + +logger = logging.getLogger(__name__) + +FIX_POINT_PRECISION = 52 + + +class HeteroDecisionTreeGuest(DecisionTree): + def __init__( + self, + max_depth=3, + valid_features=None, + use_missing=False, + zero_as_missing=False, + goss=False, + l1=0.1, + l2=0, + min_impurity_split=1e-2, + min_sample_split=2, + min_leaf_node=1, + min_child_weight=1, + objective=None, + gh_pack=True, + split_info_pack=True, + hist_sub=True, + ): + super().__init__( + max_depth, use_missing=use_missing, zero_as_missing=zero_as_missing, valid_features=valid_features + ) + self.host_sitenames = None + self._tree_node_num = 0 + self.hist_builder = None + self.splitter = None + + # regularization + self.l1 = l1 + self.l2 = l2 + self.min_impurity_split = min_impurity_split + self.min_sample_split = min_sample_split + self.min_leaf_node = min_leaf_node + self.min_child_weight = min_child_weight + + # goss + self.goss = goss + + # other + self._valid_features = valid_features + self._hist_sub = hist_sub + + # homographic encryption + self._encrypt_kit = None + self._sk = None + self._pk = None + self._coder = None + self._evaluator = None + self._encryptor = None + self._decryptor = None + + # for g, h packing + self._en_key_length = None + self._gh_pack = gh_pack + self._split_info_pack = split_info_pack + self._g_offset = 0 + self._g_abs_max = 0 + self._h_abs_max = 0 + self._objective = objective + if gh_pack: + if objective is None: + raise ValueError("objective must be specified when gh_pack is True") + self._pack_info = {} + + def set_encrypt_kit(self, kit): + self._encrypt_kit = kit + self._en_key_length = kit.key_size + self._sk, self._pk, self._coder, self._evaluator, self._encryptor = ( + kit.sk, + kit.pk, + kit.coder, + kit.evaluator, + kit.get_tensor_encryptor(), + ) + self._decryptor = kit.get_tensor_decryptor() + logger.info("encrypt kit setup through setter") + + def _init_encrypt_kit(self, ctx: Context): + kit = ctx.cipher.phe.setup(options={"kind": "paillier", "key_length": 1024}) + self._en_key_length = kit.key_size + self._sk, self._pk, self._coder, self._evaluator, self._encryptor = ( + kit.sk, + kit.pk, + kit.coder, + kit.evaluator, + kit.get_tensor_encryptor(), + ) + self._decryptor = kit.get_tensor_decryptor() + logger.info("encrypt kit is not setup, auto initializing") + + def _get_column_max_bin(self, result_dict): + bin_len = {} + + for column, values in result_dict.items(): + bin_num = len(values) + bin_len[column] = bin_num + + max_max_value = max(bin_len.values()) + + return bin_len, max_max_value + + def _update_sample_pos( + self, ctx: Context, cur_layer_nodes: List[Node], sample_pos: DataFrame, data: DataFrame, node_map: dict + ): + sitename = ctx.local.name + data_with_pos = DataFrame.hstack([data, sample_pos]) + map_func = functools.partial( + _update_sample_pos_on_local_nodes, cur_layer_node=cur_layer_nodes, node_map=node_map, sitename=sitename + ) + updated_sample_pos = data_with_pos.apply_row(map_func, columns=["g_on_local", "g_node_idx"]) + + # synchronize sample pos + host_update_sample_pos = ctx.hosts.get("updated_data") + + merge_func = functools.partial(_merge_sample_pos) + for host_data in host_update_sample_pos: + updated_sample_pos = DataFrame.hstack([updated_sample_pos, host_data]).apply_row( + merge_func, columns=["g_on_local", "g_node_idx"] + ) + + new_sample_pos = updated_sample_pos.create_frame(columns=["g_node_idx"]) + new_sample_pos.rename(columns={"g_node_idx": "node_idx"}) + ctx.hosts.put("new_sample_pos", new_sample_pos) + self.sample_pos = new_sample_pos + + return new_sample_pos + + def _g_h_process(self, grad_and_hess: DataFrame): + en_grad_hess = grad_and_hess.create_frame() + + def make_long_tensor(s: pd.Series, coder, pk, offset, shift_bit, precision, encryptor, pack_num=2): + pack_tensor = t.Tensor(s.values) + pack_tensor[0] = pack_tensor[0] + offset + pack_vec = coder.pack_floats(pack_tensor, shift_bit, pack_num, precision) + en = pk.encrypt_encoded(pack_vec, obfuscate=True) + ret = encryptor.lift(en, (len(en), 1), pack_tensor.dtype, pack_tensor.device) + return ret + + def compute_offset_bit(sample_num, g_max, h_max): + g_bit = int(math.log2(2**FIX_POINT_PRECISION * sample_num * g_max) + 1) # add 1 more bit for safety + h_bit = int(math.log2(2**FIX_POINT_PRECISION * sample_num * h_max) + 1) + return max(g_bit, h_bit) + + if self._gh_pack: + task_type = get_task_info(self._objective) + + if task_type == BINARY or task_type == MULTI: + self._g_offset = 1 + self._g_abs_max = 2 + self._h_abs_max = 1 + + elif task_type == REGRESSION: + self._g_offset = abs(float(grad_and_hess["g"].min()["g"])) + self._g_abs_max = abs(float(grad_and_hess["g"].max()["g"])) + self._g_offset + self._h_abs_max = 2 + + pack_num = 2 + shift_bit = compute_offset_bit(len(grad_and_hess), self._g_abs_max, self._h_abs_max) + total_pack_num = (self._en_key_length - 2) // (shift_bit * pack_num) # -2 in case overflow + partial_func = functools.partial( + make_long_tensor, + coder=self._coder, + offset=self._g_offset, + pk=self._pk, + shift_bit=shift_bit, + pack_num=2, + precision=FIX_POINT_PRECISION, + encryptor=self._encryptor, + ) + en_grad_hess["gh"] = grad_and_hess.apply_row(partial_func) + + # record pack info + self._pack_info["g_offset"] = self._g_offset + self._pack_info["shift_bit"] = shift_bit + self._pack_info["precision"] = FIX_POINT_PRECISION + self._pack_info["pack_num"] = pack_num + self._pack_info["total_pack_num"] = total_pack_num + self._pack_info["split_point_shift_bit"] = shift_bit * pack_num + logger.info("gh are packed") + else: + en_grad_hess["g"] = self._encryptor.encrypt_tensor(grad_and_hess["g"].as_tensor()) + en_grad_hess["h"] = self._encryptor.encrypt_tensor(grad_and_hess["h"].as_tensor()) + logger.info("not using gh pack") + + return en_grad_hess + + def _send_gh(self, ctx: Context, grad_and_hess: DataFrame): + # encrypt g & h + en_grad_hess = self._g_h_process(grad_and_hess) + ctx.hosts.put("en_gh", en_grad_hess) + ctx.hosts.put("en_kit", [self._pk, self._evaluator]) + + def _mask_node(self, ctx: Context, nodes: List[Node]): + new_nodes = [] + for n in nodes: + new_nodes.append( + Node( + nid=n.nid, + is_leaf=n.is_leaf, + l=n.l, + r=n.r, + is_left_node=n.is_left_node, + split_id=n.split_id, + sitename=n.sitename, + sibling_nodeid=n.sibling_nodeid, + parent_nodeid=n.parent_nodeid, + sample_num=n.sample_num, + ) + ) + return new_nodes + + def _check_assign_result(self, sample_pos: DataFrame, cur_layer_node: List[Node]): + # debugging function + sample_pos_df = sample_pos.as_pd_df() + sample_pos_count = sample_pos_df.groupby("node_idx").count().to_dict()["sample_id"] + for node in cur_layer_node: + nid = node.nid + sample_count_0 = node.sample_num + sample_count_1 = sample_pos_count[nid] + if sample_count_0 != sample_count_1: + parent_nid = node.parent_nodeid + for i in self._nodes: + if i.nid == parent_nid: + logger.info("parent node {}".format(i)) + raise ValueError( + "node {} sample count not match, {} vs {}, node details {}".format( + nid, sample_count_0, sample_count_1, node + ) + ) + + def _sync_nodes(self, ctx: Context, cur_layer_nodes: List[Node], next_layer_nodes: List[Node]): + mask_cur_layer = self._mask_node(ctx, cur_layer_nodes) + mask_next_layer = self._mask_node(ctx, next_layer_nodes) + ctx.hosts.put("sync_nodes", [mask_cur_layer, mask_next_layer]) + + def booster_fit(self, ctx: Context, bin_train_data: DataFrame, grad_and_hess: DataFrame, binning_dict: dict): + logger.info + # Initialization + train_df = bin_train_data + sample_pos = self._init_sample_pos(train_df) + self._sample_on_leaves = sample_pos.empty_frame() + root_node = self._initialize_root_node(ctx, train_df, grad_and_hess) + + # initialize homographic encryption + if self._encrypt_kit is None: + self._init_encrypt_kit(ctx) + # Send Encrypted Grad and Hess + self._send_gh(ctx, grad_and_hess) + + # send pack info + send_pack_info = ( + { + "total_pack_num": self._pack_info["total_pack_num"], + "split_point_shift_bit": self._pack_info["split_point_shift_bit"], + "split_info_pack": self._split_info_pack, + } + if self._gh_pack + else {} + ) + ctx.hosts.put("pack_info", send_pack_info) + + # init histogram builder + self.hist_builder = SBTHistogramBuilder(bin_train_data, binning_dict, None, None, hist_sub=self._hist_sub) + + # init splitter + self.splitter = FedSBTSplitter( + bin_train_data, + binning_dict, + l2=self.l2, + l1=self.l1, + min_sample_split=self.min_sample_split, + min_impurity_split=self.min_impurity_split, + min_child_weight=self.min_child_weight, + min_leaf_node=self.min_leaf_node, + ) + + # Prepare for training + node_map = {} + cur_layer_node = [root_node] + grad_and_hess["cnt"] = 1 + + for cur_depth, sub_ctx in ctx.on_iterations.ctxs_range(self.max_depth): + if len(cur_layer_node) == 0: + logger.info("no nodes to split, stop training") + break + + assert len(sample_pos) == len(train_df), "sample pos len not match train data len, {} vs {}".format( + len(sample_pos), len(train_df) + ) + + # debug checking code + # self._check_assign_result(sample_pos, cur_layer_node) + # initialize node map + node_map = {n.nid: idx for idx, n in enumerate(cur_layer_node)} + # compute histogram + hist_inst, statistic_result = self.hist_builder.compute_hist( + sub_ctx, cur_layer_node, train_df, grad_and_hess, sample_pos, node_map + ) + # compute best splits + split_info = self.splitter.split( + sub_ctx, + statistic_result, + cur_layer_node, + node_map, + self._sk, + self._coder, + self._gh_pack, + self._pack_info, + ) + # update tree with best splits + next_layer_nodes = self._update_tree(sub_ctx, cur_layer_node, split_info, train_df) + # update feature importance + self._update_feature_importance(sub_ctx, split_info, train_df) + # sync nodes + self._sync_nodes(sub_ctx, cur_layer_node, next_layer_nodes) + # update sample positions + sample_pos = self._update_sample_pos(sub_ctx, cur_layer_node, sample_pos, train_df, node_map) + # if sample reaches leaf nodes, drop them + sample_on_leaves = self._get_samples_on_leaves(sample_pos) + train_df, sample_pos, grad_and_hess = self._drop_samples_on_leaves(sample_pos, train_df, grad_and_hess) + self._sample_on_leaves = DataFrame.vstack([self._sample_on_leaves, sample_on_leaves]) + # next layer nodes + cur_layer_node = next_layer_nodes + logger.info( + "layer {} done: next layer will split {} nodes, active samples num {}".format( + cur_depth, len(cur_layer_node), len(sample_pos) + ) + ) + self.next_layer_node = next_layer_nodes + + # handle final leaves + if len(cur_layer_node) != 0: + for node in cur_layer_node: + node.is_leaf = True + node.sitename = ctx.local.name # leaf always on guest + self._nodes.append(node) + self._sample_on_leaves = DataFrame.vstack([self._sample_on_leaves, sample_pos]) + + # when training is done, all samples must be on leaves + assert len(self._sample_on_leaves) == len(bin_train_data), "sample on leaves num not match, {} vs {}".format( + len(self._sample_on_leaves), len(bin_train_data) + ) + # convert sample pos to weights + self._sample_weights = self._convert_sample_pos_to_weight(self._sample_on_leaves, self._nodes) + # convert bid to split value + self._nodes = self._convert_bin_idx_to_split_val(ctx, self._nodes, binning_dict, bin_train_data.schema) + + def get_hyper_param(self): + param = { + "max_depth": self.max_depth, + "valid_features": self._valid_features, + "l1": self.l1, + "l2": self.l2, + "use_missing": self.use_missing, + "zero_as_missing": self.zero_as_missing, + "objective": self._objective, + } + return param + + @staticmethod + def from_model(model_dict): + return HeteroDecisionTreeGuest._from_model(model_dict, HeteroDecisionTreeGuest) diff --git a/python/fate/ml/ensemble/learner/decision_tree/hetero/host.py b/python/fate/ml/ensemble/learner/decision_tree/hetero/host.py new file mode 100644 index 0000000000..cbc64bc402 --- /dev/null +++ b/python/fate/ml/ensemble/learner/decision_tree/hetero/host.py @@ -0,0 +1,227 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from fate.ml.ensemble.learner.decision_tree.tree_core.decision_tree import ( + DecisionTree, + Node, + _update_sample_pos_on_local_nodes, + FeatureImportance, +) +from fate.ml.ensemble.learner.decision_tree.tree_core.hist import SBTHistogramBuilder, DistributedHistogram +from fate.ml.ensemble.learner.decision_tree.tree_core.splitter import FedSBTSplitter +from fate.arch import Context +from fate.arch.dataframe import DataFrame +from typing import List +import functools +import logging + + +logger = logging.getLogger(__name__) + + +class HeteroDecisionTreeHost(DecisionTree): + def __init__( + self, + max_depth=3, + valid_features=None, + use_missing=False, + zero_as_missing=False, + random_seed=None, + global_random_seed=None, + hist_sub=True, + ): + super().__init__( + max_depth, use_missing=use_missing, zero_as_missing=zero_as_missing, valid_features=valid_features + ) + self._tree_node_num = 0 + self.hist_builder = None + self.splitter = None + self._valid_features = valid_features + self._random_seed = random_seed + self._global_random_seed = global_random_seed + self._pk = None + self._evaluator = None + self._gh_pack = True + self._pack_info = None + self._hist_sub = hist_sub + + def _convert_split_id( + self, + ctx: Context, + cur_layer_nodes: List[Node], + node_map: dict, + hist_builder: SBTHistogramBuilder, + statistic_histogram: DistributedHistogram, + splitter: FedSBTSplitter, + data: DataFrame, + ): + sitename = ctx.local.party[0] + "_" + ctx.local.party[1] + to_recover = {} + for idx, n in enumerate(cur_layer_nodes): + if (not n.is_leaf) and n.sitename == sitename: + node_id = n.nid + split_id = n.split_id + to_recover[node_id] = split_id + + if len(to_recover) != 0: + if self._random_seed is None: + for node_id, split_id in to_recover.items(): + node = cur_layer_nodes[node_map[node_id]] + fid, bid = splitter.get_bucket(split_id) + node.fid = self._fid_to_feature_name(int(fid), data) + node.bid = int(bid) + else: + recover_rs = hist_builder.recover_feature_bins(statistic_histogram, to_recover, node_map) + for node_id, split_tuple in recover_rs.items(): + node = cur_layer_nodes[node_map[node_id]] + fid, bid = split_tuple + node.fid = self._fid_to_feature_name(int(fid), data) + node.bid = int(bid) + + def _update_host_feature_importance(self, ctx: Context, nodes: List[Node], train_df: DataFrame): + sitename = ctx.local.party[0] + "_" + ctx.local.party[1] + for n in nodes: + if sitename == n.sitename: + feat_name = n.fid + if feat_name not in self._feature_importance: + self._feature_importance[feat_name] = FeatureImportance() + else: + self._feature_importance[feat_name] = self._feature_importance[feat_name] + FeatureImportance() + + def _update_sample_pos( + self, ctx, cur_layer_nodes: List[Node], sample_pos: DataFrame, data: DataFrame, node_map: dict + ): + sitename = ctx.local.party[0] + "_" + ctx.local.party[1] + data_with_pos = DataFrame.hstack([data, sample_pos]) + map_func = functools.partial( + _update_sample_pos_on_local_nodes, cur_layer_node=cur_layer_nodes, node_map=node_map, sitename=sitename + ) + update_sample_pos = data_with_pos.apply_row(map_func, columns=["h_on_local", "h_node_idx"]) + + ctx.guest.put("updated_data", update_sample_pos) + new_sample_pos = ctx.guest.get("new_sample_pos") + + return new_sample_pos + + def _get_gh(self, ctx: Context): + grad_and_hess: DataFrame = ctx.guest.get("en_gh") + if len(grad_and_hess.columns) == 1: + gh_pack = True + elif len(grad_and_hess.columns) == 2: + gh_pack = False + else: + raise ValueError("error columns, got {}".format(len(grad_and_hess.columns))) + return grad_and_hess, gh_pack + + def _sync_nodes(self, ctx: Context): + nodes = ctx.guest.get("sync_nodes") + cur_layer_nodes, next_layer_nodes = nodes + return cur_layer_nodes, next_layer_nodes + + def booster_fit(self, ctx: Context, bin_train_data: DataFrame, binning_dict: dict): + train_df = bin_train_data + feat_max_bin, max_bin = self._get_column_max_bin(binning_dict) + sample_pos = self._init_sample_pos(train_df) + + # Get Encrypted Grad And Hess + ret = self._get_gh(ctx) + en_grad_and_hess: DataFrame = ret[0] + self._gh_pack = ret[1] + self._pk, self._evaluator = ctx.guest.get("en_kit") + self._pack_info = ctx.guest.get("pack_info") + split_info_pack = self._pack_info.get("split_info_pack", False) + root_node = self._initialize_root_node(ctx, train_df) + + # init histogram builder + self.hist_builder = SBTHistogramBuilder( + bin_train_data, + binning_dict, + random_seed=self._random_seed, + global_random_seed=self._global_random_seed, + hist_sub=self._hist_sub, + ) + # splitter + self.splitter = FedSBTSplitter(bin_train_data, binning_dict) + + node_map = {} + cur_layer_node = [root_node] + en_grad_and_hess["cnt"] = 1 + for cur_depth, sub_ctx in ctx.on_iterations.ctxs_range(self.max_depth): + if len(cur_layer_node) == 0: + logger.info("no nodes to split, stop training") + break + + node_map = {n.nid: idx for idx, n in enumerate(cur_layer_node)} + # compute histogram with encrypted grad and hess + hist_inst, statistic_histogram = self.hist_builder.compute_hist( + sub_ctx, + cur_layer_node, + train_df, + en_grad_and_hess, + sample_pos, + node_map, + pk=self._pk, + evaluator=self._evaluator, + gh_pack=self._gh_pack, + ) + + if split_info_pack: + logger.debug("packing split info") + statistic_histogram.i_squeeze( + {"gh": (self._pack_info["total_pack_num"], self._pack_info["split_point_shift_bit"])} + ) + + self.splitter.split(sub_ctx, statistic_histogram, cur_layer_node, node_map) + cur_layer_node, next_layer_nodes = self._sync_nodes(sub_ctx) + self._convert_split_id( + sub_ctx, cur_layer_node, node_map, self.hist_builder, statistic_histogram, self.splitter, train_df + ) + self._update_host_feature_importance(sub_ctx, cur_layer_node, train_df) + logger.info( + "cur layer node num: {}, next layer node num: {}".format(len(cur_layer_node), len(next_layer_nodes)) + ) + sample_pos = self._update_sample_pos(sub_ctx, cur_layer_node, sample_pos, train_df, node_map) + train_df, sample_pos, en_grad_and_hess = self._drop_samples_on_leaves( + sample_pos, train_df, en_grad_and_hess + ) + self._nodes += cur_layer_node + cur_layer_node = next_layer_nodes + logger.info( + "layer {} done: next layer will split {} nodes, active samples num {}".format( + cur_depth, len(cur_layer_node), len(sample_pos) + ) + ) + + # sync complete tree + if len(cur_layer_node) != 0: + for node in cur_layer_node: + node.is_leaf = True + node.sitename = ctx.guest.party[0] + "_" + ctx.guest.party[1] + self._nodes.append(node) + + # convert bid to split value + self._nodes = self._convert_bin_idx_to_split_val(ctx, self._nodes, binning_dict, bin_train_data.schema) + + def get_hyper_param(self): + param = { + "max_depth": self.max_depth, + "valid_features": self._valid_features, + "use_missing": self.use_missing, + "zero_as_missing": self.zero_as_missing, + } + return param + + @staticmethod + def from_model(model_dict): + return HeteroDecisionTreeHost._from_model(model_dict, HeteroDecisionTreeHost) diff --git a/python/fate/ml/ensemble/learner/decision_tree/test/test_decision_tree.py b/python/fate/ml/ensemble/learner/decision_tree/test/test_decision_tree.py new file mode 100644 index 0000000000..32f5bacde3 --- /dev/null +++ b/python/fate/ml/ensemble/learner/decision_tree/test/test_decision_tree.py @@ -0,0 +1,101 @@ +import pandas as pd +from fate.ml.ensemble.learner.decision_tree.hetero.guest import HeteroDecisionTreeGuest +from fate.ml.ensemble.learner.decision_tree.hetero.host import HeteroDecisionTreeHost +from fate.arch.dataframe import PandasReader, DataFrame +import numpy as np +from fate.ml.ensemble.learner.decision_tree.tree_core.loss import BCELoss +from fate.arch import Context +import sys + + +arbiter = ("arbiter", "10000") +guest = ("guest", "10000") +host = ("host", "9999") +name = "fed55" + + +def create_ctx(local): + from fate.arch import Context + from fate.arch.computing.standalone import CSession + from fate.arch.federation.standalone import StandaloneFederation + import logging + + logger = logging.getLogger() + logger.setLevel(logging.DEBUG) + + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.DEBUG) + + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + console_handler.setFormatter(formatter) + + logger.addHandler(console_handler) + computing = CSession() + return Context( + computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]) + ) + + +if __name__ == "__main__": + party = sys.argv[1] + max_depth = 5 + if party == "guest": + ctx = create_ctx(guest) + + df = pd.read_csv("./../../../../../../../examples/data/breast_hetero_guest.csv") + df["sample_id"] = [i for i in range(len(df))] + + reader = PandasReader(sample_id_name="sample_id", match_id_name="id", label_name="y", dtype="float32") + + data_guest = reader.to_frame(ctx, df) + + from fate.ml.ensemble.utils.binning import binning + + bin_info = binning(data_guest, max_bin=32) + bin_data = data_guest.bucketize(boundaries=bin_info) + + loss_bce = BCELoss() + label = data_guest.label + init_score = loss_bce.initialize(label) + predict = loss_bce.predict(init_score) + empty_gh = data_guest.create_frame() + loss_bce.compute_grad(empty_gh, label, predict) + loss_bce.compute_hess(empty_gh, label, predict) + + kit = ctx.cipher.phe.setup(options={"kind": "paillier", "key_length": 1024}) + sk, pk, coder, evaluator, encryptor = kit.sk, kit.pk, kit.coder, kit.evaluator, kit.get_tensor_encryptor() + # from fate.ml.ensemble.learner.decision_tree.tree_core.hist import SBTHistogram + # from fate.ml.ensemble.learner.decision_tree.tree_core.decision_tree import Node + # from fate.ml.ensemble.learner.decision_tree.tree_core.splitter import FedSBTSplitter + # sample_pos = bin_data.create_frame() + # sample_pos['sample_pos'] = bin_data.apply_row(lambda x: 0) + + # hist = SBTHistogram(bin_data, bin_info) + # g_sum, h_sum, cnt_sum = float(empty_gh['g'].sum()), float(empty_gh['h'].sum()), len(empty_gh) + # root_node = [Node(nid=0, grad=g_sum, hess=h_sum, sample_num=cnt_sum)] + # node_map = {0: 0} + # stat_obj = hist.compute_hist(root_node, bin_data, empty_gh, sample_pos, node_map) + # computed_hist = stat_obj.decrypt({}, {}) + # splitter = FedSBTSplitter(data_guest, bin_info) + # rs = splitter._find_guest_best_splits(computed_hist, '123', root_node, node_map) + tree = HeteroDecisionTreeGuest(max_depth, objective="binary:bce", gh_pack=True) + tree.set_encrypt_kit(kit) + ret = tree.booster_fit(ctx, bin_data, empty_gh, bin_info) + + elif party == "host": + ctx = create_ctx(host) + + df_host = pd.read_csv("./../../../../../../../examples/data/breast_hetero_host.csv") + df_host["sample_id"] = [i for i in range(len(df_host))] + + reader_host = PandasReader(sample_id_name="sample_id", match_id_name="id", dtype="float32") + + data_host = reader_host.to_frame(ctx, df_host) + + from fate.ml.ensemble.utils.binning import binning + + bin_info = binning(data_host, max_bin=32) + bin_data = data_host.bucketize(boundaries=bin_info) + + tree = HeteroDecisionTreeHost(max_depth, random_seed=114) + ret = tree.booster_fit(ctx, bin_data, bin_info) diff --git a/python/fate/ml/lr/__init__.py b/python/fate/ml/ensemble/learner/decision_tree/tree_core/__init__.py similarity index 100% rename from python/fate/ml/lr/__init__.py rename to python/fate/ml/ensemble/learner/decision_tree/tree_core/__init__.py diff --git a/python/fate/ml/ensemble/learner/decision_tree/tree_core/decision_tree.py b/python/fate/ml/ensemble/learner/decision_tree/tree_core/decision_tree.py new file mode 100644 index 0000000000..57263525dd --- /dev/null +++ b/python/fate/ml/ensemble/learner/decision_tree/tree_core/decision_tree.py @@ -0,0 +1,514 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# ============================================================================= +# DecisionTree Base Class +# ============================================================================= +import numpy as np +import pandas as pd +from fate.arch import Context +from fate.arch.dataframe import DataFrame +from fate.ml.ensemble.learner.decision_tree.tree_core.splitter import SplitInfo +from typing import List +import logging + + +FLOAT_ZERO = 1e-8 +LEAF_IDX = -1 + + +logger = logging.getLogger(__name__) + + +class FeatureImportance(object): + def __init__(self, gain=0): + self.gain = gain + self.split = 1 + + def add_gain(self, val): + self.gain += val + + def add_split(self, val): + self.split += val + + def __repr__(self): + return "gain: {}, split {}".format(self.gain, self.split) + + def __add__(self, other): + new_importance = FeatureImportance(gain=self.gain + other.gain) + new_importance.split = self.split + other.split + return new_importance + + def to_dict(self): + return {"gain": self.gain, "split": self.split} + + @staticmethod + def from_dict(dict_): + importance = FeatureImportance() + importance.gain = dict_["gain"] + importance.split = dict_["split"] + return importance + + +class Node(object): + + """ + Parameters: + ----------- + nid : int, optional + ID of the node. + sitename : str, optional + Name of the site that the node belongs to. + fid : int, optional + ID of the feature that the node splits on. + bid : float or int, optional + Feature value that the node splits on. + weight : float, optional + Weight of the node. + is_leaf : bool, optional + Boolean indicating whether the node is a leaf node. + grad : float, optional + Gradient value of the node. + hess : float, optional + Hessian value of the node. + l : int, optional + ID of the left child node. + r : int, optional + ID of the right child node. + missing_dir : int, optional + Direction for missing values (1 for left, -1 for right). + sample_num : int, optional + Number of samples in the node. + is_left_node : bool, optional + Boolean indicating whether the node is a left child node. + sibling_nodeid : int, optional + ID of the sibling node. + """ + + def __init__( + self, + nid=None, + sitename=None, + fid=None, + bid=None, + weight=0, + is_leaf=False, + grad=None, + hess=None, + l=-1, + r=-1, + missing_dir=1, + sample_num=0, + is_left_node=False, + sibling_nodeid=None, + parent_nodeid=None, + split_id=None, + ): + self.nid = nid + self.sitename = sitename + self.fid = fid + self.bid = bid + self.weight = weight + self.is_leaf = is_leaf + self.grad = grad + self.hess = hess + self.l = l + self.r = r + self.missing_dir = missing_dir + self.sample_num = sample_num + self.is_left_node = is_left_node + self.sibling_nodeid = sibling_nodeid + self.parent_nodeid = parent_nodeid + self.split_id = split_id + + def to_dict(self): + return { + "nid": self.nid, + "sitename": self.sitename, + "fid": self.fid, + "bid": self.bid, + "weight": self.weight, + "is_leaf": self.is_leaf, + "grad": self.grad, + "hess": self.hess, + "l": self.l, + "r": self.r, + "missing_dir": self.missing_dir, + "sample_num": self.sample_num, + "is_left_node": self.is_left_node, + "sibling_nodeid": self.sibling_nodeid, + "parent_nodeid": self.parent_nodeid, + "split_id": self.split_id, + } + + def __repr__(self): + """ + Returns a string representation of the node. + """ + return "(node_id {}: fid {} bid {} left {}, right {}, pid {}, is_leaf {}, sample_count {}, g {}, h {}, weight {}, sitename {})".format( + self.nid, + self.fid, + self.bid, + self.l, + self.r, + self.parent_nodeid, + self.is_leaf, + self.sample_num, + self.grad, + self.hess, + self.weight, + self.sitename, + ) + + +def _make_decision(feat_val, bid, missing_dir=None, use_missing=None, zero_as_missing=None, zero_val=0): + # no missing val + left, right = True, False + direction = left if feat_val <= bid + FLOAT_ZERO else right + return direction + + +def _update_sample_pos(s: pd.Series, cur_layer_node: List[Node], node_map: dict, sitename=None): + node_id = s.iloc[-1] + node = cur_layer_node[node_map[node_id]] + if node.is_leaf: + return -(node.nid + 1) # use negative index to represent leaves, + 1 to avoid root node 0 + feat_val = s[node.fid] + bid = node.bid + dir_ = _make_decision(feat_val, bid) + if dir_: # go left + ret_node = node.l + else: + ret_node = node.r + return ret_node + + +def _get_sample_on_local_nodes(s: pd.Series, cur_layer_node: List[Node], node_map: dict, sitename): + node_id = s.iloc[-1] + node = cur_layer_node[node_map[node_id]] + on_local_node = node.sitename == sitename + return on_local_node + + +def _update_sample_pos_on_local_nodes(s: pd.Series, cur_layer_node: List[Node], node_map: dict, sitename): + on_local_node = _get_sample_on_local_nodes(s, cur_layer_node, node_map, sitename) + if not on_local_node: + return False, -1 + else: + return True, _update_sample_pos(s, cur_layer_node, node_map, sitename) + + +def _merge_sample_pos(s: pd.Series): + if s["g_on_local"]: + return s["g_on_local"], s["g_node_idx"] + else: + return s["h_on_local"], s["h_node_idx"] + + +def _convert_sample_pos_to_score(s: pd.Series, tree_nodes: List[Node]): + node_idx = s.iloc[0] + if node_idx < 0: + node_idx = -(node_idx + 1) + target_node = tree_nodes[node_idx] + if not target_node.is_leaf: + raise ValueError("this sample is not on a leaf node") + return target_node.weight + + +class DecisionTree(object): + def __init__(self, max_depth=3, use_missing=False, zero_as_missing=False, valid_features=None): + """ + Initialize a DecisionTree instance. + + Parameters: + ----------- + max_depth : int + The maximum depth of the tree. + use_missing : bool, optional + Whether or not to use missing values (default is False). + zero_as_missing : bool, optional + Whether to treat zero as a missing value (default is False). + valid_features: list of boolean, optional + Valid features for training, default is None, which means all features are valid. + """ + self.max_depth = max_depth + self.use_missing = use_missing + self.zero_as_missing = zero_as_missing + + # runtime variables + self._nodes = [] + self._cur_layer_node = [] + self._cur_leaf_idx = -1 + self._feature_importance = {} + self._predict_weights = None + self._g_tensor, self._h_tensor = None, None + self._sample_pos = None + self._leaf_node_map = {} + self._valid_feature = valid_features + self._sample_on_leaves = None + self._sample_weights = None + + def _init_sample_pos(self, train_data: DataFrame): + sample_pos = train_data.create_frame() + sample_pos["node_idx"] = 0 # position of current sample + return sample_pos + + def _init_leaves_sample_table(self, sample_pos: DataFrame): + return sample_pos.empty_frame() + + def _get_leaf_node_map(self): + if len(self._nodes) >= len(self._leaf_node_map): + for n in self._nodes: + self._leaf_node_map[n.nid] = n.is_leaf + + def _convert_sample_pos_to_weight(self, sample_pos: DataFrame, tree_nodes: List[Node]): + import functools + + map_func = functools.partial(_convert_sample_pos_to_score, tree_nodes=tree_nodes) + sample_weight = sample_pos.apply_row(map_func, columns=["score"]) + return sample_weight + + def _convert_bin_idx_to_split_val(self, ctx: Context, tree_nodes: List[Node], binning_dict: dict, schema): + columns = schema.columns + sitename = ctx.local.name + for node in tree_nodes: + if node.sitename == sitename: + if not node.is_leaf: + feat_name = node.fid + split_val = binning_dict[feat_name][node.bid] + node.bid = split_val + else: + continue + + return tree_nodes + + def _initialize_root_node(self, ctx: Context, train_df: DataFrame, gh: DataFrame = None): + sitename = ctx.local.name + if gh is None: + sum_g, sum_h = 0, 0 + else: + sum_gh = gh.sum() + sum_g = float(sum_gh["g"]) + sum_h = float(sum_gh["h"]) + root_node = Node(nid=0, grad=sum_g, hess=sum_h, sitename=sitename, sample_num=len(train_df)) + + return root_node + + def _update_feature_importance(self, ctx: Context, split_info: List[SplitInfo], data: DataFrame): + sitename = ctx.local.name + for info in split_info: + if info is not None and info.sitename == sitename: + feat_name = self._fid_to_feature_name(info.best_fid, data) + if feat_name not in self._feature_importance: + self._feature_importance[feat_name] = FeatureImportance(gain=info.gain) + else: + self._feature_importance[feat_name] = self._feature_importance[feat_name] + FeatureImportance( + gain=info.gain + ) + + def _fid_to_feature_name(self, fid: int, dataframe: DataFrame): + if fid is None: + return None + return dataframe.schema.columns[fid] + + def _update_tree(self, ctx: Context, cur_layer_nodes: List[Node], split_info: List[SplitInfo], data: DataFrame): + assert len(cur_layer_nodes) == len( + split_info + ), "node num not match split info num, got {} node vs {} split info".format( + len(cur_layer_nodes), len(split_info) + ) + + next_layer_node = [] + + for idx in range(len(split_info)): + node: Node = cur_layer_nodes[idx] + + if split_info[idx] is None: + node.is_leaf = True + node.sitename = ctx.guest.name # leaf always belongs to guest + self._nodes.append(node) + logger.info("set node {} to leaf".format(node)) + continue + + sum_grad = node.grad + sum_hess = node.hess + sum_cnt = node.sample_num + + feat_name = self._fid_to_feature_name(split_info[idx].best_fid, data) + node.fid = feat_name + node.bid = split_info[idx].best_bid + node.missing_dir = split_info[idx].missing_dir + node.sitename = split_info[idx].sitename + node.split_id = split_info[idx].split_id # if not a local node, has split id + + p_id = node.nid + l_id, r_id = self._tree_node_num + 1, self._tree_node_num + 2 + self._tree_node_num += 2 + node.l, node.r = l_id, r_id + + l_g, l_h = split_info[idx].sum_grad, split_info[idx].sum_hess + l_cnt = split_info[idx].sample_count + + # logger.info("splitting node {}, split info is {}".format(node, split_info[idx])) + + # create new left node and new right node + left_node = Node( + nid=l_id, + grad=float(l_g), + hess=float(l_h), + weight=float(self.splitter.node_weight(l_g, l_h)), + parent_nodeid=p_id, + sibling_nodeid=r_id, + is_left_node=True, + sample_num=l_cnt, + ) + + # not gonna happen + assert sum_cnt > l_cnt, "sum cnt {} not greater than l cnt {}".format(sum_cnt, l_cnt) + + r_g = float(sum_grad - l_g) + r_h = float(sum_hess - l_h) + r_cnt = sum_cnt - l_cnt + + right_node = Node( + nid=r_id, + grad=r_g, + hess=r_h, + weight=float(self.splitter.node_weight(sum_grad - l_g, sum_hess - l_h)), + parent_nodeid=p_id, + sibling_nodeid=l_id, + sample_num=r_cnt, + is_left_node=False, + ) + + next_layer_node.append(left_node) + next_layer_node.append(right_node) + self._nodes.append(node) + + return next_layer_node + + def _drop_samples_on_leaves(self, new_sample_pos: DataFrame, data: DataFrame, grad_and_hess: DataFrame): + assert len(new_sample_pos) == len( + data + ), "sample pos num not match data num, got {} sample pos vs {} data".format(len(new_sample_pos), len(data)) + x = new_sample_pos >= 0 + pack_data = DataFrame.hstack([data, new_sample_pos, grad_and_hess]).iloc(x) + new_data = pack_data.create_frame(columns=data.schema.columns) + update_pos = pack_data.create_frame(columns=new_sample_pos.schema.columns) + grad_and_hess = pack_data.create_frame(columns=grad_and_hess.schema.columns) + """ + new_data = data.iloc(x) + update_pos = new_sample_pos.iloc(x) + grad_and_hess = grad_and_hess.iloc(x) + """ + logger.info( + "drop leaf samples, new sample count is {}, {} samples dropped".format( + len(new_sample_pos), len(data) - len(new_data) + ) + ) + return new_data, update_pos, grad_and_hess + + def _get_samples_on_leaves(self, sample_pos: DataFrame): + x = sample_pos < 0 + samples_on_leaves = sample_pos.iloc(x) + return samples_on_leaves + + def _get_column_max_bin(self, result_dict): + bin_len = {} + for column, values in result_dict.items(): + bin_num = len(values) + bin_len[column] = bin_num + max_max_value = max(bin_len.values()) + return bin_len, max_max_value + + def fit(self, ctx: Context, train_data: DataFrame, grad_and_hess: DataFrame, encryptor): + raise NotImplementedError("This method should be implemented by subclass") + + def get_feature_importance(self): + return self._feature_importance + + def get_sample_predict_weights(self): + return self._sample_weights + + def get_nodes(self): + return self._nodes + + def print_tree(self): + from anytree import Node as AnyNode, RenderTree + + nodes = self._nodes + anytree_nodes = {} + for node in nodes: + if not node.is_leaf: + anytree_nodes[node.nid] = AnyNode( + name=f"{node.nid}: fid {node.fid}, bid {node.bid}, sample num {node.sample_num}, on {node.sitename}" + ) + else: + anytree_nodes[node.nid] = AnyNode( + name=f"{node.nid}: weight {node.weight}, sample num {node.sample_num}, leaf" + ) + for node in nodes: + if node.l != -1: + anytree_nodes[node.l].parent = anytree_nodes[node.nid] + if node.r != -1: + anytree_nodes[node.r].parent = anytree_nodes[node.nid] + + for pre, _, node in RenderTree(anytree_nodes[0]): + print("%s%s" % (pre, node.name)) + + @staticmethod + def _recover_nodes(model_dict): + nodes = [] + for node_dict in model_dict["nodes"]: + node = Node(**node_dict) + nodes.append(node) + return nodes + + @staticmethod + def _recover_feature_importance(model_dict): + feature_importance = {} + for k, v in model_dict["feature_importance"].items(): + feature_importance[k] = FeatureImportance.from_dict(v) + return feature_importance + + @staticmethod + def _from_model(model_dict, tree_class): + nodes = DecisionTree._recover_nodes(model_dict) + feature_importance = DecisionTree._recover_feature_importance(model_dict) + param = model_dict["hyper_param"] + tree = tree_class(**param) + tree._nodes = nodes + tree._feature_importance = feature_importance + return tree + + def get_hyper_param(self): + param = {"max_depth": self.max_depth, "use_missing": self.use_missing, "zero_as_missing": self.zero_as_missing} + return param + + @staticmethod + def from_model(model_dict): + return DecisionTree._from_model(model_dict, DecisionTree) + + def get_model(self): + model_dict = {} + nodes = [n.to_dict() for n in self._nodes] + feat_importance = {k: v.to_dict() for k, v in self._feature_importance.items()} + param = self.get_hyper_param() + model_dict["nodes"] = nodes + model_dict["feature_importance"] = feat_importance + model_dict["hyper_param"] = param + + return model_dict diff --git a/python/fate/ml/ensemble/learner/decision_tree/tree_core/hist.py b/python/fate/ml/ensemble/learner/decision_tree/tree_core/hist.py new file mode 100644 index 0000000000..c4df2a6170 --- /dev/null +++ b/python/fate/ml/ensemble/learner/decision_tree/tree_core/hist.py @@ -0,0 +1,231 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import typing + +import torch +from typing import Dict +from fate.arch.histogram import HistogramBuilder, DistributedHistogram +from fate.ml.ensemble.learner.decision_tree.tree_core.decision_tree import Node +from typing import List +import numpy as np +from fate.arch.dataframe import DataFrame +from fate.arch import Context +import logging + + +logger = logging.getLogger(__name__) + + +HIST_TYPE = ["distributed", "sklearn"] + + +class SklearnHistBuilder(object): + def __init__(self, bin_data, bin_num, g, h) -> None: + from sklearn.ensemble._hist_gradient_boosting.grower import HistogramBuilder + + try: + hist_builder = HistogramBuilder(bin_data, bin_num, g, h, False) + except TypeError as e: + from sklearn.utils._openmp_helpers import _openmp_effective_n_threads + + n_threads = _openmp_effective_n_threads(None) + hist_builder = HistogramBuilder(bin_data, bin_num, g, h, False, n_threads) + + self.hist_builder = hist_builder + + def compute_hist( + self, nodes: List[Node], bin_train_data=None, gh=None, sample_pos: DataFrame = None, node_map={}, debug=False + ): + grouped = sample_pos.as_pd_df().groupby("node_idx")["sample_id"].apply(np.array).apply(np.uint32) + data_indices = [None for i in range(len(nodes))] + inverse_node_map = {v: k for k, v in node_map.items()} + for idx, node in enumerate(nodes): + data_indices[idx] = grouped[inverse_node_map[idx]] + + hists = [] + idx = 0 + for node in nodes: + hist = self.hist_builder.compute_histograms_brute(data_indices[idx]) + hist_arr = np.array(hist) + g = hist_arr["sum_gradients"].cumsum(axis=1) + h = hist_arr["sum_hessians"].cumsum(axis=1) + count = hist_arr["count"].cumsum(axis=1) + hists.append([g, h, count]) + idx += 1 + + if debug: + return hists, data_indices + else: + return hists + + +class SBTHistogramBuilder(object): + def __init__( + self, bin_train_data: DataFrame, bin_info: dict, random_seed=None, global_random_seed=None, hist_sub=True + ) -> None: + columns = bin_train_data.schema.columns + self.random_seed = random_seed + self.global_random_seed = global_random_seed + self.feat_bin_num = [len(bin_info[feat]) for feat in columns] + self._cache_parent_hist: typing.Optional[DistributedHistogram] = None + self._last_layer_node_map = None + self._hist_sub = hist_sub + + def _get_plain_text_schema(self, dtypes): + return { + "g": {"type": "plaintext", "stride": 1, "dtype": dtypes["g"]}, + "h": {"type": "plaintext", "stride": 1, "dtype": dtypes["h"]}, + "cnt": {"type": "plaintext", "stride": 1, "dtype": dtypes["cnt"]}, + } + + def _get_enc_hist_schema(self, pk, evaluator, dtypes): + return { + "g": {"type": "ciphertext", "stride": 1, "pk": pk, "evaluator": evaluator, "dtype": dtypes["g"]}, + "h": {"type": "ciphertext", "stride": 1, "pk": pk, "evaluator": evaluator, "dtype": dtypes["h"]}, + "cnt": {"type": "plaintext", "stride": 1, "dtype": dtypes["cnt"]}, + } + + def _get_pack_en_hist_schema(self, pk, evaluator, dtypes): + return { + "gh": {"type": "ciphertext", "stride": 1, "pk": pk, "evaluator": evaluator, "dtype": dtypes["gh"]}, + "cnt": {"type": "plaintext", "stride": 1, "dtype": dtypes["cnt"]}, + } + + def _prepare_hist_sub(self, nodes: List[Node], cur_layer_node_map: dict, parent_node_map: dict): + weak_nodes_ids = [] + mapping = [] + n_map = {n.nid: n for n in nodes} + new_node_map = {} + hist_pos = 0 + for n in nodes: + if n.nid == 0: + # root node + weak_nodes_ids.append(0) + # root node, just return + return set(weak_nodes_ids), None, mapping + + if n.is_left_node: + sib = n_map[n.sibling_nodeid] + + if sib.sample_num < n.sample_num: + weak_node = sib + else: + weak_node = n + + mapping_list = [] + parent_nid = weak_node.parent_nodeid + weak_nodes_ids.append(weak_node.nid) + mapping_list = ( + parent_node_map[parent_nid], + hist_pos, + cur_layer_node_map[weak_node.nid], + cur_layer_node_map[weak_node.sibling_nodeid], + ) + mapping.append(mapping_list) + new_node_map[weak_node.nid] = hist_pos + hist_pos += 1 + + else: + continue + return set(weak_nodes_ids), new_node_map, mapping + + def _get_samples_on_weak_nodes(self, sample_pos: DataFrame, weak_nodes: set): + # root node + if 0 in weak_nodes: + return sample_pos + is_on_weak = sample_pos.apply_row(lambda s: s["node_idx"] in weak_nodes) + weak_sample_pos = sample_pos.iloc(is_on_weak) + return weak_sample_pos + + def _is_first_layer(self, nodes): + if len(nodes) == 1 and nodes[0].nid == 0: + return True + else: + return False + + def compute_hist( + self, + ctx: Context, + nodes: List[Node], + bin_train_data: DataFrame, + gh: DataFrame, + sample_pos: DataFrame = None, + node_map={}, + pk=None, + evaluator=None, + gh_pack=False, + ): + node_num = len(nodes) + is_first_layer = self._is_first_layer(nodes) + need_hist_sub_process = (not is_first_layer) and self._hist_sub + + weak_nodes, new_node_map, mapping = None, None, None + if need_hist_sub_process: + weak_nodes, new_node_map, mapping = self._prepare_hist_sub(nodes, node_map, self._last_layer_node_map) + node_num = len(weak_nodes) + logger.debug("weak nodes {}, new_node_map {}, mapping {}".format(weak_nodes, new_node_map, mapping)) + + if ctx.is_on_guest: + schema = self._get_plain_text_schema(gh.dtypes) + elif ctx.is_on_host: + if pk is None or evaluator is None: + schema = self._get_plain_text_schema(gh.dtypes) + else: + if gh_pack: + schema = self._get_pack_en_hist_schema(pk, evaluator, gh.dtypes) + else: + schema = self._get_enc_hist_schema(pk, evaluator, gh.dtypes) + else: + raise ValueError("not support called on role: {}".format(ctx.local)) + + if need_hist_sub_process: + node_mapping = {node_map[k]: v for k, v in new_node_map.items()} + else: + node_mapping = None + + hist = HistogramBuilder( + num_node=node_num, + feature_bin_sizes=self.feat_bin_num, + value_schemas=schema, + global_seed=self.global_random_seed, + seed=self.random_seed, + node_mapping=node_mapping, + ) + + map_sample_pos = sample_pos.apply_row(lambda x: node_map[x["node_idx"]]) + stat_obj = bin_train_data.distributed_hist_stat(hist, map_sample_pos, gh) + + if need_hist_sub_process: + stat_obj = self._cache_parent_hist.compute_child(stat_obj, mapping) + + if self._hist_sub: + self._cache_parent_hist = stat_obj + self._last_layer_node_map = node_map + + stat_obj = stat_obj.shuffle_splits() + + return hist, stat_obj + + def recover_feature_bins( + self, statistic_histogram: DistributedHistogram, nid_split_id: Dict[int, int], node_map: dict + ) -> Dict[int, int]: + if self.random_seed is None: + return nid_split_id # randome seed has no shuffle, no need to recover + else: + reverse_node_map = {v: k for k, v in node_map.items()} + nid_split_id_ = {node_map[k]: v for k, v in nid_split_id.items()} + recover = statistic_histogram.recover_feature_bins(self.feat_bin_num, nid_split_id_) + recover_rs = {reverse_node_map[k]: v for k, v in recover.items()} + return recover_rs diff --git a/python/fate/ml/ensemble/learner/decision_tree/tree_core/loss.py b/python/fate/ml/ensemble/learner/decision_tree/tree_core/loss.py new file mode 100644 index 0000000000..fc5c59aaf5 --- /dev/null +++ b/python/fate/ml/ensemble/learner/decision_tree/tree_core/loss.py @@ -0,0 +1,138 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pandas as pd +import torch as t +from fate.arch.dataframe import DataFrame +from scipy.special import expit as sigmoid +from fate.ml.utils.predict_tools import BINARY, MULTI, REGRESSION + + +BINARY_BCE = "binary:bce" +MULTI_CE = "multi:ce" +REGRESSION_L2 = "regression:l2" + + +def apply_weight(loss: DataFrame, weight: DataFrame): + return loss["loss"] * weight["weight"] + + +class BCELoss(object): + @staticmethod + def initialize(label: DataFrame): + init_score = label.create_frame() + init_score["score"] = 0.0 + return init_score + + @staticmethod + def predict(score: DataFrame): + pred_rs = score.create_frame() + pred_rs["score"] = score.apply_row(lambda s: sigmoid(s)) + return pred_rs + + @staticmethod + def compute_loss(label: DataFrame, pred: DataFrame): + sample_num = len(label) + label_pred = DataFrame.hstack([label, pred]) + label_pred["loss"] = label_pred.apply_row( + lambda s: -(s[0] * np.log(s[1]) + (1 - s[0]) * np.log(1 - s[1])), with_label=True + ) + loss_rs = label_pred["loss"].fillna(1) + reduce_loss = loss_rs["loss"].sum() / sample_num + return reduce_loss + + @staticmethod + def compute_grad(gh: DataFrame, label: DataFrame, predict_score: DataFrame): + gh["g"] = predict_score - label + + @staticmethod + def compute_hess(gh: DataFrame, label: DataFrame, predict_score: DataFrame): + gh["h"] = predict_score * (1 - predict_score) + + +class CELoss(object): + @staticmethod + def initialize(label, class_num=3): + init_score = label.create_frame() + init_score["score"] = [0.0 for i in range(class_num)] + return init_score + + @staticmethod + def predict(score: DataFrame): + def softmax(s): + s = np.array(s["score"]).astype(np.float64) + ret = (np.exp(s) / np.exp(s).sum()).tolist() + return [ret] + + pred_rs = score.create_frame() + pred_rs["score"] = score.apply_row(lambda s: softmax(s)) + return pred_rs + + @staticmethod + def compute_loss(label: DataFrame, pred: DataFrame, weight: DataFrame): + loss_col = label.create_frame() + label_pred = label.hstack(pred) + sample_num = len(label) + loss_col["loss"] = label_pred.apply_row(lambda s: np.log(s[1:][int(s[0])])) + loss_col["loss"].fillna(1) + if weight: + loss_col["loss"] = apply_weight(loss_col, weight) + reduce_loss = loss_col["loss"].sum() / sample_num + return reduce_loss + + @staticmethod + def compute_grad(gh: DataFrame, label: DataFrame, score: DataFrame): + gh["g"] = score.apply_row(lambda s: [[i - 1 for i in s["score"]]]) + + @staticmethod + def compute_hess(gh: DataFrame, y, score): + gh["h"] = score.apply_row(lambda s: [[2 * i * (1 - i) for i in s["score"]]]) + + +class L2Loss(object): + @staticmethod + def initialize(label): + init_score = label.create_frame() + mean_score = float(label.mean()) + init_score["score"] = mean_score + return init_score, mean_score + + @staticmethod + def predict(score): + return score + + @staticmethod + def compute_loss(label: DataFrame, pred: DataFrame): + loss_col = label.create_frame() + sample_num = len(label) + loss_col["loss"] = (label - pred["score"]) ** 2 + reduce_loss = loss_col["loss"].sum() / sample_num + return reduce_loss + + @staticmethod + def compute_grad(gh: DataFrame, label, score): + gh["g"] = 2 * (score["score"] - label) + + @staticmethod + def compute_hess(gh: DataFrame, label, score): + gh["h"] = 2 + + +OBJECTIVE = {BINARY_BCE: BCELoss, MULTI_CE: CELoss, REGRESSION_L2: L2Loss} + + +def get_task_info(objective): + task_type = objective.split(":")[0] + return task_type diff --git a/python/fate/ml/ensemble/learner/decision_tree/tree_core/splitter.py b/python/fate/ml/ensemble/learner/decision_tree/tree_core/splitter.py new file mode 100644 index 0000000000..317b73c666 --- /dev/null +++ b/python/fate/ml/ensemble/learner/decision_tree/tree_core/splitter.py @@ -0,0 +1,617 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import numpy as np +import logging +from fate.arch.dataframe import DataFrame +from fate.arch import Context +from fate.arch.histogram import DistributedHistogram + + +logger = logging.getLogger(__name__) + + +class SplitInfo(object): + def __init__( + self, + best_fid=None, + best_bid=None, + sum_grad=0, + sum_hess=0, + gain=None, + missing_dir=1, + split_id=None, + sample_count=-1, + sitename=None, + ): + self.best_fid = best_fid + self.best_bid = best_bid + self.sum_grad = sum_grad + self.sum_hess = sum_hess + self.gain = gain + self.missing_dir = missing_dir + self.split_id = split_id + self.sample_count = sample_count + self.sitename = sitename + + def __str__(self): + return ( + "(fid {} bid {}, sum_grad {}, sum_hess {}, gain {}, missing dir {}, split_id {}, " + "sample_count {}, sitename {})\n".format( + self.best_fid, + self.best_bid, + self.sum_grad, + self.sum_hess, + self.gain, + self.missing_dir, + self.split_id, + self.sample_count, + self.sitename, + ) + ) + + def __repr__(self): + return self.__str__() + + +class Splitter(object): + def __init__(self) -> None: + pass + + +class SklearnSplitter(Splitter): + def __init__( + self, + feature_binning_dict, + min_impurity_split=1e-2, + min_sample_split=2, + min_leaf_node=1, + min_child_weight=1, + l1=0, + l2=0.1, + valid_features=None, + ) -> None: + super().__init__() + self.min_impurity_split = min_impurity_split + self.min_sample_split = min_sample_split + self.min_leaf_node = min_leaf_node + self.min_child_weight = min_child_weight + self.feature_binning_dict = feature_binning_dict + self.hist_mask = self.generate_mask(feature_binning_dict) + self.l1, self.l2 = l1, l2 + + def generate_mask(self, feature_dict): + split_counts = [len(split_point) for split_point in feature_dict.values()] + max_bin = max(split_counts) + mask = np.zeros((len(feature_dict), max_bin)).astype(np.bool8) + for i, bucket_count in enumerate(split_counts): + mask[i, :bucket_count] = True # valid split point + + return ~mask + + def node_gain(self, g, h): + if isinstance(h, np.ndarray): + h[h == 0] = np.nan + score = g * g / (h + self.l2) + return score + + def node_weight(self, sum_grad, sum_hess): + weight = -(sum_grad / (sum_hess + self.l2)) + return weight + + def _compute_min_leaf_mask(self, l_cnt, r_cnt): + min_leaf_node_mask_l = l_cnt < self.min_leaf_node + min_leaf_node_mask_r = r_cnt < self.min_leaf_node + union_mask_0 = np.logical_or(min_leaf_node_mask_l, min_leaf_node_mask_r) + return union_mask_0 + + def _compute_gains(self, g, h, cnt, g_sum, h_sum, cnt_sum, hist_mask=None): + l_g, l_h, l_cnt = g, h, cnt + + if cnt_sum < self.min_sample_split: + return None + + r_g, r_h = g_sum - l_g, h_sum - l_h + r_cnt = cnt_sum - l_cnt + + # filter split + # leaf count + union_mask_0 = self._compute_min_leaf_mask(l_cnt, r_cnt) + # min child weight + min_child_weight_mask_l = l_h < self.min_child_weight + min_child_weight_mask_r = r_h < self.min_child_weight + union_mask_1 = np.logical_or(min_child_weight_mask_l, min_child_weight_mask_r) + if hist_mask is not None: + mask = np.logical_or(union_mask_0, hist_mask) + else: + mask = union_mask_0 + mask = np.logical_or(mask, union_mask_1) + + rs = self.node_gain(l_g, l_h) + self.node_gain(r_g, r_h) - self.node_gain(g_sum, h_sum) + + rs[np.isnan(rs)] = -np.inf + rs[rs < self.min_impurity_split] = -np.inf + rs[mask] = -np.inf + + return rs + + def _find_guest_best_splits(self, node_hist, sitename, ret_sum=False): + l_g, l_h, l_cnt = node_hist + cnt_sum = l_cnt[::, -1][0] + g_sum = l_g[::, -1][0] + h_sum = l_h[::, -1][0] + + rs = self._compute_gains(l_g, l_h, l_cnt, g_sum, h_sum, cnt_sum, hist_mask=self.hist_mask) + + # reduce + feat_best_split = rs.argmax(axis=1) + feat_best_gain = rs.max(axis=1) + + logger.debug("best gain {}".format(feat_best_gain)) + # best split + best_split_idx = feat_best_gain.argmax() + best_gain = feat_best_gain.max() + + if best_gain == -np.inf: + # can not split + logger.info("this node cannot be further split") + if ret_sum: + return None, g_sum, h_sum, cnt_sum + else: + return None + + feat_id = best_split_idx + bin_id = feat_best_split[best_split_idx] + + split_info = SplitInfo( + best_fid=feat_id, + best_bid=bin_id, + gain=best_gain, + sum_grad=l_g[feat_id][bin_id], + sum_hess=l_h[feat_id][bin_id], + sample_count=l_cnt[feat_id][bin_id], + sitename=sitename, + ) + + if ret_sum: + return split_info, g_sum, h_sum, cnt_sum + else: + return split_info + + def _split(self, ctx: Context, histogram: list, cur_layer_node): + splits = [] + logger.info("got {} hist".format(len(histogram))) + for node_hist in histogram: + split_info = self._find_guest_best_splits(node_hist, self.hist_mask, sitename=ctx.guest.name) + splits.append(split_info) + logger.info("split info is {}".format(split_info)) + assert len(splits) == len(cur_layer_node), "split info length {} != node length {}".format( + len(splits), len(cur_layer_node) + ) + return splits + + def split(self, ctx: Context, histogram: list, cur_layer_node): + return self._split(ctx, histogram, cur_layer_node) + + +class FedSklearnSplitter(SklearnSplitter): + def __init__( + self, + feature_binning_dict, + min_impurity_split=1e-2, + min_sample_split=2, + min_leaf_node=1, + min_child_weight=1, + l1=0, + l2=0, + valid_features=None, + random_seed=42, + ) -> None: + super().__init__( + feature_binning_dict, + min_impurity_split, + min_sample_split, + min_leaf_node, + min_child_weight, + l1, + l2, + valid_features, + ) + self.random_seed = random_seed + np.random.seed(self.random_seed) + + def _get_host_splits(self, ctx): + host_splits = ctx.hosts.get("host_splits") + return host_splits + + def _find_host_best_splits(self, split, g_sum, h_sum, cnt_sum, sitename): + g, h, cnt = split + rs = self._compute_gains(g, h, cnt, g_sum, h_sum, cnt_sum) + best_splits_id = rs.argmax() + best_gain = rs.max() + split_info = SplitInfo( + gain=best_gain, + split_id=best_splits_id, + sitename=sitename, + sum_grad=g[best_splits_id], + sum_hess=h[best_splits_id], + sample_count=cnt[best_splits_id], + ) + + return split_info + + def _merge_splits(self, guest_splits, host_splits_list): + splits = [] + for node_idx in range(len(guest_splits)): + best_gain = -np.inf + best_splitinfo = None + guest_splitinfo: SplitInfo = guest_splits[node_idx] + if guest_splitinfo is not None and guest_splitinfo.gain > best_gain: + best_gain = guest_splitinfo.gain + best_splitinfo = guest_splitinfo + + for host_idx in range(len(host_splits_list)): + host_splits = host_splits_list[host_idx] + host_splitinfo: SplitInfo = host_splits[node_idx] + if host_splitinfo is not None and host_splitinfo.gain > best_gain: + best_gain = host_splitinfo.gain + best_splitinfo = host_splitinfo + splits.append(best_splitinfo) + + return splits + + def _guest_split(self, ctx: Context, histogram, cur_layer_node): + sitename = ctx.guest.name + guest_best_splits = [] + gh_sum = [] + logger.info("got {} hist".format(len(histogram))) + for node_hist in histogram: + split_info, g_sum, h_sum, cnt_sum = self._find_guest_best_splits( + node_hist, ret_sum=True, sitename=sitename + ) + guest_best_splits.append(split_info) + gh_sum.append((g_sum, h_sum, cnt_sum)) + + assert len(guest_best_splits) == len(cur_layer_node), "split info length {} != node length {}".format( + len(guest_best_splits), len(cur_layer_node) + ) + + host_splits_list = self._get_host_splits(ctx) + all_host_splits = [] + for host_idx in range(len(host_splits_list)): + host_sitename = ctx.hosts[host_idx].name + host_splits = host_splits_list[host_idx] + assert len(host_splits) == len(cur_layer_node) + best_split = [] + for node_idx, node_splits in enumerate(host_splits): + g_sum, h_sum, cnt_sum = gh_sum[node_idx] + node_best = self._find_host_best_splits(node_splits, g_sum, h_sum, cnt_sum, host_sitename) + best_split.append(node_best) + all_host_splits.append(best_split) + + logger.info("guest split info is {}".format(guest_best_splits)) + logger.info("host split info is {}".format(all_host_splits)) + final_best_split = self._merge_splits(guest_best_splits, all_host_splits) + logger.info("final split info is {}".format(final_best_split)) + return host_splits[0] + + def _host_prepare(self, histogram): + to_send_hist = [] + pos_map = [] + # prepare host split points + for node_hist in histogram: + g, h, cnt = node_hist + shape = g.shape + pos_map_ = {} + g[self.hist_mask] = np.nan + h[self.hist_mask] = np.nan + # cnt is int, cannot use np.nan as mask + cnt[self.hist_mask] = 0 + g, h, cnt = g.flatten(), h.flatten(), cnt.flatten() + random_shuffle_idx = np.random.permutation(len(g)) + # random_shuffle_idx = np.array([i for i in range(len(g))]) + g = g[random_shuffle_idx] + h = h[random_shuffle_idx] + cnt = cnt[random_shuffle_idx] + to_send_hist.append([g, h, cnt]) + for split_idx, real_idx in enumerate(random_shuffle_idx): + pos_map_[split_idx] = (real_idx // shape[1], real_idx % shape[1]) + pos_map.append(pos_map_) + return to_send_hist, pos_map + + def _host_split(self, ctx, histogram, cur_layer_node): + to_send_hist, pos_map = self._host_prepare(histogram) + ctx.guest.put("host_splits", to_send_hist) + return pos_map + + def split(self, ctx: Context, histogram, cur_layer_node): + if ctx.is_on_guest: + return self._guest_split(ctx, histogram, cur_layer_node) + elif ctx.is_on_host: + return self._host_split(ctx, histogram, cur_layer_node) + else: + raise ValueError("illegal role {}".format(ctx.role)) + + +class FedSBTSplitter(object): + def __init__( + self, + bin_train_data: DataFrame, + bin_info: dict, + min_impurity_split=1e-2, + min_sample_split=2, + min_leaf_node=1, + min_child_weight=1, + l1=0, + l2=0.1, + ) -> None: + super().__init__() + self.min_impurity_split = min_impurity_split + self.min_sample_split = min_sample_split + self.min_leaf_node = min_leaf_node + self.min_child_weight = min_child_weight + self.bin_info = bin_info + self.l1, self.l2 = l1, l2 + columns = bin_train_data.schema.columns + self.feat_bin_num = [len(bin_info[feat]) for feat in columns] + + def get_bucket(self, idx): + feature_buckets = self.feat_bin_num + cumulative_buckets = [0] + for bucket in feature_buckets: + cumulative_buckets.append(cumulative_buckets[-1] + bucket) + + for i in range(1, len(cumulative_buckets)): + if idx < cumulative_buckets[i]: + fid = i - 1 + bid = idx - cumulative_buckets[i - 1] + return fid, bid + + raise ValueError("idx is out of range") + + def node_gain(self, g, h): + if isinstance(h, np.ndarray): + h[h == 0] = np.nan + score = g * g / (h + self.l2) + return score + + def node_weight(self, sum_grad, sum_hess): + weight = -(sum_grad / (sum_hess + self.l2)) + return weight + + def _extract_hist(self, histogram, pack_info=None): + tensor_hist: dict = histogram.extract_data() + g_all, h_all, cnt_all = None, None, None + for k, v in tensor_hist.items(): + cnt = v["cnt"].reshape((1, -1)) + + # if gh pack + if "gh" in v: + g = v["gh"][::, 0].reshape((1, -1)) + h = v["gh"][::, 1].reshape((1, -1)) + if pack_info is None: + raise ValueError("must provide pack info for gh packing computing") + g = g - pack_info["g_offset"] * cnt + else: + g = v["g"].reshape((1, -1)) + h = v["h"].reshape((1, -1)) + + if g_all is None: + g_all = g + else: + g_all = torch.vstack([g_all, g]) + if h_all is None: + h_all = h + else: + h_all = torch.vstack([h_all, h]) + if cnt_all is None: + cnt_all = cnt + else: + cnt_all = torch.vstack([cnt_all, cnt]) + + return g_all, h_all, cnt_all + + def _make_sum_tensor(self, nodes): + g_sum, h_sum, cnt_sum = [], [], [] + for node in nodes: + g_sum.append(node.grad) + h_sum.append(node.hess) + cnt_sum.append(node.sample_num) + + return ( + torch.Tensor(g_sum).reshape((len(nodes), 1)), + torch.Tensor(h_sum).reshape((len(nodes), 1)), + torch.Tensor(cnt_sum).reshape((len(nodes), 1)), + ) + + def _compute_min_leaf_mask(self, l_cnt, r_cnt): + min_leaf_node_mask_l = l_cnt < self.min_leaf_node + min_leaf_node_mask_r = r_cnt < self.min_leaf_node + union_mask_0 = torch.logical_or(min_leaf_node_mask_l, min_leaf_node_mask_r) + return union_mask_0 + + def _compute_gains(self, g, h, cnt, g_sum, h_sum, cnt_sum, hist_mask=None): + l_g, l_h, l_cnt = g, h, cnt + + r_g, r_h = g_sum - l_g, h_sum - l_h + r_cnt = cnt_sum - l_cnt + + # filter split + # leaf count + union_mask_0 = self._compute_min_leaf_mask(l_cnt, r_cnt) + # min child weight + min_child_weight_mask_l = l_h < self.min_child_weight + min_child_weight_mask_r = r_h < self.min_child_weight + union_mask_1 = torch.logical_or(min_child_weight_mask_l, min_child_weight_mask_r) + if hist_mask is not None: + mask = torch.logical_or(union_mask_0, hist_mask) + else: + mask = union_mask_0 + mask = torch.logical_or(mask, union_mask_1) + rs = self.node_gain(l_g, l_h) + self.node_gain(r_g, r_h) - self.node_gain(g_sum, h_sum) + rs[torch.isnan(rs)] = float("-inf") + rs[rs < self.min_impurity_split] = float("-inf") + rs[mask] = float("-inf") + + return rs + + def _find_best_splits( + self, node_hist, sitename, cur_layer_nodes, reverse_node_map, recover_bucket=True, pack_info=None + ): + """ + recover_bucket: if node_hist is guest hist, can get the fid and bid of the split info + but for node_hist from host sites, histograms are shuffled, so can not get the fid and bid, + only hosts know them. + """ + l_g, l_h, l_cnt = self._extract_hist(node_hist, pack_info) + g_sum, h_sum, cnt_sum = self._make_sum_tensor(cur_layer_nodes) + rs = self._compute_gains(l_g, l_h, l_cnt, g_sum, h_sum, cnt_sum) + + # reduce + best = rs.max(dim=-1) + best_gain = best[0] + best_idx = best[1] + logger.debug("best_idx: {}".format(best_idx)) + logger.debug("best_gain: {}".format(best_gain)) + + split_infos = [] + node_idx = 0 + for idx, gain in zip(best_idx, best_gain): + idx_ = int(idx.detach().cpu().item()) + if gain == float("-inf") or cnt_sum[node_idx] < self.min_sample_split: + split_infos.append(None) + logger.info("Node {} can not be further split".format(reverse_node_map[node_idx])) + else: + split_info = SplitInfo( + gain=float(gain), + sum_grad=float(l_g[node_idx][idx_]), + sum_hess=float(l_h[node_idx][idx_]), + sample_count=int(l_cnt[node_idx][idx_]), + sitename=sitename, + ) + if recover_bucket: + fid, bid = self.get_bucket(idx_) + split_info.best_fid = fid + split_info.best_bid = bid + else: + split_info.split_id = idx_ + split_infos.append(split_info) + node_idx += 1 + + return split_infos + + def _merge_splits(self, guest_splits, host_splits_list): + splits = [] + for node_idx in range(len(guest_splits)): + best_gain = -np.inf + best_splitinfo = None + guest_splitinfo: SplitInfo = guest_splits[node_idx] + if guest_splitinfo is not None and guest_splitinfo.gain > best_gain: + best_gain = guest_splitinfo.gain + best_splitinfo = guest_splitinfo + + for host_idx in range(len(host_splits_list)): + host_splits = host_splits_list[host_idx] + host_splitinfo: SplitInfo = host_splits[node_idx] + if host_splitinfo is not None and host_splitinfo.gain > best_gain: + best_gain = host_splitinfo.gain + best_splitinfo = host_splitinfo + splits.append(best_splitinfo) + + return splits + + def _recover_pack_split(self, hist: DistributedHistogram, schema, decode_schema=None): + host_hist = hist.decrypt(schema[0], schema[1], decode_schema) + return host_hist + + def _guest_split(self, ctx: Context, stat_rs, cur_layer_node, node_map, sk, coder, gh_pack, pack_info): + if sk is None or coder is None: + raise ValueError("sk or coder is None, not able to decode host split points") + + histogram = stat_rs.decrypt({}, {}, None) + sitename = ctx.local.name + reverse_node_map = {v: k for k, v in node_map.items()} + + # find local best splits + guest_best_splits = self._find_best_splits( + histogram, sitename, cur_layer_node, reverse_node_map, recover_bucket=True + ) + # find best splits from host parties + host_histograms = ctx.hosts.get("hist") + + host_splits = [] + if gh_pack: + decrypt_schema = ({"gh": sk}, {"gh": (coder, torch.int64)}) + # (coder, pack_num, offset_bit, precision, total_num) + if pack_info is not None: + decode_schema = { + "gh": ( + coder, + pack_info["pack_num"], + pack_info["shift_bit"], + pack_info["precision"], + pack_info["total_pack_num"], + ) + } + else: + raise ValueError("pack info is not provided") + else: + decrypt_schema = ({"g": sk, "h": sk}, {"g": (coder, torch.float32), "h": (coder, torch.float32)}) + decode_schema = None + + for idx, hist in enumerate(host_histograms): + host_sitename = ctx.hosts[idx].name + host_hist = self._recover_pack_split(hist, decrypt_schema, decode_schema) + # logger.debug("splitting host") + host_split = self._find_best_splits( + host_hist, host_sitename, cur_layer_node, reverse_node_map, recover_bucket=False, pack_info=pack_info + ) + host_splits.append(host_split) + + # logger.debug("host splits are {}".format(host_splits)) + best_splits = self._merge_splits(guest_best_splits, host_splits) + # logger.debug("guest splits are {}".format(guest_best_splits)) + # logger.debug("best splits are {}".format(best_splits)) + return best_splits + + def _host_split(self, ctx: Context, en_histogram, cur_layer_node): + ctx.guest.put("hist", en_histogram) + + def split( + self, + ctx: Context, + histogram_statistic_result, + cur_layer_node, + node_map, + sk=None, + coder=None, + gh_pack=None, + pack_info=None, + ): + if ctx.is_on_guest: + if sk is None or coder is None: + raise ValueError("sk or coder is None, not able to decode host split points") + assert gh_pack is not None and isinstance( + gh_pack, bool + ), "gh_pack should be bool, indicating if the gh is packed" + if not gh_pack: + logger.info("not using gh pack to split") + return self._guest_split( + ctx, histogram_statistic_result, cur_layer_node, node_map, sk, coder, gh_pack, pack_info + ) + elif ctx.is_on_host: + return self._host_split(ctx, histogram_statistic_result, cur_layer_node) + else: + raise ValueError("illegal role {}".format(ctx.role)) diff --git a/python/fate/ml/ensemble/learner/decision_tree/tree_core/subsample.py b/python/fate/ml/ensemble/learner/decision_tree/tree_core/subsample.py new file mode 100644 index 0000000000..dcbff7dd99 --- /dev/null +++ b/python/fate/ml/ensemble/learner/decision_tree/tree_core/subsample.py @@ -0,0 +1,57 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from fate.arch.dataframe import DataFrame +import pandas as pd +import numpy as np + + +def goss_sampling(train_data: DataFrame, gh: DataFrame, top_rate=0.1, other_rate=0.1): + sample_num = len(train_data) + gh_df: pd.DataFrame = gh.as_pd_df() + id_list = np.array(gh_df["sample_id"]) + g_arr = np.array(gh_df["g"]).astype(np.float64) + h_arr = np.array(gh_df["h"]).astype(np.float64) + + g_sum_arr = np.abs(g_arr).sum(axis=1) # if it is multi-classification case, we need to sum g + abs_g_list_arr = g_sum_arr + sorted_idx = np.argsort(-abs_g_list_arr, kind="stable") # stable sample result + + a_part_num = int(sample_num * top_rate) + b_part_num = int(sample_num * other_rate) + + if a_part_num == 0 or b_part_num == 0: + raise ValueError("subsampled result is 0: top sample {}, other sample {}".format(a_part_num, b_part_num)) + + # index of a part + a_sample_idx = sorted_idx[:a_part_num] + + # index of b part + rest_sample_idx = sorted_idx[a_part_num:] + b_sample_idx = np.random.choice(rest_sample_idx, size=b_part_num, replace=False) + + # small gradient sample weights + amplify_weights = (1 - top_rate) / other_rate + g_arr[b_sample_idx] *= amplify_weights + h_arr[b_sample_idx] *= amplify_weights + + # get selected sample + a_idx_set, b_idx_set = set(list(a_sample_idx)), set(list(b_sample_idx)) + idx_set = a_idx_set.union(b_idx_set) + selected_idx = np.array(list(idx_set)) + selected_g, selected_h = g_arr[selected_idx], h_arr[selected_idx] + selected_id = id_list[selected_idx] + + new_gh = None + subsample_data = None diff --git a/python/fate/ml/ensemble/learner/decision_tree/tree_core/test/test_loss.py b/python/fate/ml/ensemble/learner/decision_tree/tree_core/test/test_loss.py new file mode 100644 index 0000000000..ee34552976 --- /dev/null +++ b/python/fate/ml/ensemble/learner/decision_tree/tree_core/test/test_loss.py @@ -0,0 +1,89 @@ +import pandas as pd +from fate.arch.dataframe import PandasReader +import logging +from fate.ml.ensemble.learner.decision_tree.tree_core.loss import BCELoss, L2Loss, CELoss + + +# Get the root logger +logger = logging.getLogger() +logger.setLevel(logging.INFO) +ch = logging.StreamHandler() +ch.setLevel(logging.DEBUG) +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +ch.setFormatter(formatter) + +logger.addHandler(ch) +arbiter = ("arbiter", "10000") +guest = ("guest", "10000") +host = ("host", "9999") +name = "fed" + + +def create_ctx(local): + from fate.arch import Context + from fate.arch.computing.standalone import CSession + from fate.arch.federation.standalone import StandaloneFederation + import logging + + logger = logging.getLogger() + logger.setLevel(logging.DEBUG) + + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.DEBUG) + + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + console_handler.setFormatter(formatter) + + logger.addHandler(console_handler) + computing = CSession() + return Context( + computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]) + ) + + +ctx = create_ctx(guest) + +df = pd.read_csv("../../../../../../../../examples/data/breast_hetero_guest.csv") +df["sample_id"] = [i for i in range(len(df))] + +df_reg = pd.read_csv("../../../../../../../../examples/data/student_hetero_guest.csv") +df_reg["sample_id"] = [i for i in range(len(df_reg))] + +df_multi = pd.read_csv("../../../../../../../../examples/data/student_hetero_guest.csv") +df_multi["sample_id"] = [i for i in range(len(df_multi))] + + +reader = PandasReader(sample_id_name="sample_id", match_id_name="id", label_name="y", dtype="object") + +data = reader.to_frame(ctx, df) +data_reg = reader.to_frame(ctx, df_reg) +data_multi = reader.to_frame(ctx, df_multi) + +# test loss here +loss_bce = BCELoss() +label = data.label +init_score = loss_bce.initialize(label) +predict = loss_bce.predict(init_score) +loss = loss_bce.compute_loss(label, predict) +empty_gh = data.create_frame() +loss_bce.compute_grad(empty_gh, label, predict) +loss_bce.compute_hess(empty_gh, label, predict) + +# loss_l2 = L2Loss() +# label = data_reg.label +# init_score = loss_l2.initialize(label) +# predict = loss_l2.predict(init_score) +# loss = loss_l2.compute_loss(label, predict) +# empty_gh = data_reg.create_frame() +# loss_l2.compute_grad(empty_gh, label, predict) +# loss_l2.compute_hess(empty_gh, label, predict) + + +# loss = CELoss() +# label = data_multi.label +# init_score = loss.initialize(label, class_num=4) +# predict = loss.predict(init_score) +# loss = loss.compute_loss(label, predict) +# empty_gh = data_reg.create_frame() +# loss.compute_grad(empty_gh, label, predict) +# loss.compute_hess(empty_gh, label, predict) diff --git a/python/fate/arch/context/io/data/dataframe.py b/python/fate/ml/ensemble/utils/__init__.py similarity index 100% rename from python/fate/arch/context/io/data/dataframe.py rename to python/fate/ml/ensemble/utils/__init__.py diff --git a/python/fate/arch/tensor/ops/_binary_ops.py b/python/fate/ml/ensemble/utils/binning.py similarity index 52% rename from python/fate/arch/tensor/ops/_binary_ops.py rename to python/fate/ml/ensemble/utils/binning.py index f546b3b142..c3d689aa60 100644 --- a/python/fate/arch/tensor/ops/_binary_ops.py +++ b/python/fate/ml/ensemble/utils/binning.py @@ -12,46 +12,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ._ops import auto_binary_op +from fate.arch.dataframe import DataFrame +import numpy as np +import pandas as pd +import torch as t -@auto_binary_op -def add(x, y, *args, **kwargs): - """""" - ... +def _process_dataframe(df): + result_dict = {} + for column in df.columns: + unique_values = df[column].unique() + sorted_values = sorted(unique_values) + result_dict[column] = sorted_values -@auto_binary_op -def sub(x, y, *args, **kwargs): - """""" - ... + return result_dict -@auto_binary_op -def mul(x, y, *args, **kwargs): - """""" - ... +def binning(data: DataFrame, max_bin=32): + quantile = [i / max_bin for i in range(0, max_bin)] + quantile_values = data.quantile(quantile) + result_dict = _process_dataframe(quantile_values) - -@auto_binary_op -def div(x, y, *args, **kwargs): - """""" - ... - - -@auto_binary_op -def pow(x, y, *args, **kwargs): - """""" - ... - - -@auto_binary_op -def remainder(x, y, *args, **kwargs): - """""" - ... - - -@auto_binary_op -def fmod(x, y, *args, **kwargs): - "element wise remainder of division" - ... + return result_dict diff --git a/python/fate/ml/evaluation/__init__.py b/python/fate/ml/evaluation/__init__.py index 57c42b0685..290575b2d2 100644 --- a/python/fate/ml/evaluation/__init__.py +++ b/python/fate/ml/evaluation/__init__.py @@ -12,4 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .evaluation import BinaryEvaluator + diff --git a/python/fate/ml/evaluation/classification.py b/python/fate/ml/evaluation/classification.py new file mode 100644 index 0000000000..0361f36090 --- /dev/null +++ b/python/fate/ml/evaluation/classification.py @@ -0,0 +1,792 @@ +import sys +import copy +import pandas as pd +from typing import Dict +import numpy as np +import torch +from fate.ml.evaluation.metric_base import Metric +from sklearn.metrics import roc_auc_score +from sklearn.metrics import accuracy_score +from sklearn.metrics import recall_score, precision_score, f1_score +from fate.ml.evaluation.metric_base import EvalResult + + + +""" +Single Value Metrics +""" + +class AUC(Metric): + + metric_name = 'auc' + + def __init__(self): + super().__init__() + + def __call__(self, predict, label, **kwargs) -> Dict: + predict = self.to_np_format(predict) + label = self.to_np_format(label) + auc_score = roc_auc_score(label, predict) + return EvalResult(self.metric_name, auc_score) + + +class BinaryMetricWithThreshold(Metric): + + def __init__(self, threshold=0.5): + super().__init__() + self.threshold = threshold + + +class MultiAccuracy(Metric): + + metric_name = 'multi_accuracy' + + def __call__(self, predict, label, **kwargs) -> Dict: + predict = self.to_np_format(predict, flatten=False) + label = self.to_np_format(label).astype(np.int32) + if predict.shape != label.shape: + predict = predict.argmax(axis=-1).astype(np.int32) + acc = accuracy_score(label, predict) + return EvalResult(self.metric_name, acc) + + +class MultiRecall(Metric): + + metric_name = 'multi_recall' + + def __call__(self, predict, label, **kwargs) -> Dict: + predict = self.to_np_format(predict, flatten=False) + label = self.to_np_format(label) + if predict.shape != label.shape: + predict = predict.argmax(axis=-1) + recall = recall_score(label, predict, average='macro') + return EvalResult(self.metric_name, recall) + + +class MultiPrecision(Metric): + + metric_name = 'multi_precision' + + def __call__(self, predict, label, **kwargs) -> Dict: + predict = self.to_np_format(predict, flatten=False) + label = self.to_np_format(label) + if predict.shape != label.shape: + predict = predict.argmax(axis=-1) + precision = precision_score(label, predict, average='macro') + return EvalResult(self.metric_name, precision) + + +class BinaryAccuracy(MultiAccuracy, BinaryMetricWithThreshold): + + metric_name = 'binary_accuracy' + + def __call__(self, predict, label, **kwargs) -> Dict: + predict = self.to_np_format(predict) + predict = (predict > self.threshold).astype(int) + label = self.to_np_format(label) + acc = accuracy_score(label, predict) + return EvalResult(self.metric_name, acc) + + +class BinaryRecall(MultiRecall, BinaryMetricWithThreshold): + + metric_name = 'binary_recall' + + def __call__(self, predict, label, **kwargs) -> Dict: + predict = self.to_np_format(predict) + predict = (predict > self.threshold).astype(int) + label = self.to_np_format(label) + recall = recall_score(label, predict) + return EvalResult(self.metric_name, recall) + + +class BinaryPrecision(MultiPrecision, BinaryMetricWithThreshold): + + metric_name = 'binary_precision' + + def __call__(self, predict, label, **kwargs) -> Dict: + predict = self.to_np_format(predict) + predict = (predict > self.threshold).astype(int) + label = self.to_np_format(label) + precision = precision_score(label, predict) + return EvalResult(self.metric_name, precision) + + +class MultiF1Score(Metric): + + metric_name = 'multi_f1_score' + + def __init__(self, average='micro'): + super().__init__() + self.average = average + + def __call__(self, predict, label, **kwargs) -> Dict: + predict = self.to_np_format(predict, flatten=False) + label = self.to_np_format(label) + if predict.shape != label.shape: + predict = predict.argmax(axis=-1) + f1 = f1_score(label, predict, average=self.average) + return EvalResult(self.metric_name, f1) + + +class BinaryF1Score(MultiF1Score, BinaryMetricWithThreshold): + + metric_name = 'binary_f1_score' + + def __init__(self, threshold=0.5, average='binary'): + super().__init__(average) + self.threshold = threshold + + def __call__(self, predict, label, **kwargs) -> Dict: + predict = self.to_np_format(predict) + predict = (predict > self.threshold).astype(int) + label = self.to_np_format(label) + f1 = f1_score(label, predict, average=self.average) + return EvalResult(self.metric_name, f1) + + +""" +Functions for Metrics with Cruve/Table Results +""" + +ROUND_NUM = 6 + + +def neg_pos_count(labels: np.ndarray, pos_label: int): + pos_num = ((labels == pos_label) + 0).sum() + neg_num = len(labels) - pos_num + return pos_num, neg_num + + +def sort_score_and_label(labels: np.ndarray, pred_scores: np.ndarray): + labels = np.array(labels) + pred_scores = np.array(pred_scores) + + sort_idx = np.flip(pred_scores.argsort()) + sorted_labels = labels[sort_idx] + sorted_scores = pred_scores[sort_idx] + + return sorted_labels, sorted_scores + + +class _ConfusionMatrix(object): + + @staticmethod + def compute(sorted_labels: list, sorted_pred_scores: list, score_thresholds: list, ret: list, pos_label=1): + + for ret_type in ret: + assert ret_type in ['tp', 'tn', 'fp', 'fn'] + + sorted_labels = np.array(sorted_labels) + sorted_scores = np.array(sorted_pred_scores) + sorted_labels[sorted_labels != pos_label] = 0 + sorted_labels[sorted_labels == pos_label] = 1 + score_thresholds = np.array([score_thresholds]).transpose() + pred_labels = (sorted_scores > score_thresholds) + 0 + + ret_dict = {} + if 'tp' in ret or 'tn' in ret: + match_arr = (pred_labels + sorted_labels) + if 'tp' in ret: + tp_num = (match_arr == 2).sum(axis=-1) + ret_dict['tp'] = tp_num + if 'tn' in ret: + tn_num = (match_arr == 0).sum(axis=-1) + ret_dict['tn'] = tn_num + + if 'fp' in ret or 'fn' in ret: + match_arr = (sorted_labels - pred_labels) + if 'fp' in ret: + fp_num = (match_arr == -1).sum(axis=-1) + ret_dict['fp'] = fp_num + if 'fn' in ret: + fn_num = (match_arr == 1).sum(axis=-1) + ret_dict['fn'] = fn_num + + return ret_dict + + +class ThresholdCutter(object): + + @staticmethod + def cut_by_step(sorted_scores, steps=0.01): + assert isinstance(steps, float) and (0 < steps < 1) + thresholds = list(set(sorted_scores)) + thresholds, cuts = ThresholdCutter.__filt_threshold(thresholds, 0.01) + score_threshold = thresholds + + return score_threshold, cuts + + @staticmethod + def fixed_interval_threshold(steps=0.01): + intervals = np.array([i for i in range(0, 100)]) + intervals = intervals * steps + return intervals + + @staticmethod + def cut_by_index(sorted_scores): + cuts = np.array([c / 100 for c in range(100)]) + data_size = len(sorted_scores) + indexs = [int(data_size * cut) for cut in cuts] + score_threshold = [sorted_scores[idx] for idx in indexs] + return score_threshold, cuts + + @staticmethod + def __filt_threshold(thresholds, step): + cuts = list(map(float, np.arange(0, 1, step))) + size = len(list(thresholds)) + thresholds.sort(reverse=True) + index_list = [int(size * cut) for cut in cuts] + new_thresholds = [thresholds[idx] for idx in index_list] + + return new_thresholds, cuts + + @staticmethod + def cut_by_quantile(scores, quantile_list=None, interpolation='nearest', remove_duplicate=True): + + if quantile_list is None: # default is 20 intervals + quantile_list = [round(i * 0.05, 3) for i in range(20)] + [1.0] + quantile_val = np.quantile(scores, quantile_list, interpolation=interpolation) + if remove_duplicate: + quantile_val = sorted(list(set(quantile_val))) + else: + quantile_val = sorted(list(quantile_val)) + + if len(quantile_val) == 1: + quantile_val = [np.min(scores), np.max(scores)] + + return quantile_val + +class BiClassMetric(object): + + def __init__(self, cut_method='step', remove_duplicate=False, pos_label=1): + assert cut_method in ['step', 'quantile'] + self.cut_method = cut_method + self.remove_duplicate = remove_duplicate # available when cut_method is quantile + self.pos_label = pos_label + + def prepare_confusion_mat(self, labels, scores, add_to_end=True, ): + import logging + logger = logging.getLogger(__name__) + logger.info('labels are {}, scores are {}'.format(labels, scores)) + sorted_labels, sorted_scores = sort_score_and_label(labels, scores) + + score_threshold, cuts = None, None + + if self.cut_method == 'step': + score_threshold, cuts = ThresholdCutter.cut_by_step(sorted_scores, steps=0.01) + if add_to_end: + score_threshold.append(min(score_threshold) - 0.001) + cuts.append(1) + + elif self.cut_method == 'quantile': + score_threshold = ThresholdCutter.cut_by_quantile(sorted_scores, remove_duplicate=self.remove_duplicate) + score_threshold = list(np.flip(score_threshold)) + + confusion_mat = _ConfusionMatrix.compute(sorted_labels, sorted_scores, score_threshold, + ret=['tp', 'fp', 'fn', 'tn'], pos_label=self.pos_label) + + return confusion_mat, score_threshold, cuts + + def compute(self, labels, scores, ): + confusion_mat, score_threshold, cuts = self.prepare_confusion_mat(labels, scores, ) + metric_scores = self.compute_metric_from_confusion_mat(confusion_mat) + return list(metric_scores), score_threshold, cuts + + def compute_metric_from_confusion_mat(self, *args): + raise NotImplementedError() + + +""" +Metrics with Cruve/Table Results +""" + +class KS(Metric): + + metric_name = 'ks' + + def __init__(self): + super().__init__() + + def __call__(self, predict, label, **kwargs) -> Dict: + predict = self.to_np_format(predict) + label = self.to_np_format(label) + + sorted_labels, sorted_scores = sort_score_and_label(label, predict) + threshold, cuts = ThresholdCutter.cut_by_index(sorted_scores) + confusion_mat = _ConfusionMatrix.compute(sorted_labels, sorted_scores, threshold, ret=['tp', 'fp'], + pos_label=1) + pos_num, neg_num = neg_pos_count(sorted_labels, pos_label=1) + + assert pos_num > 0 and neg_num > 0, "error when computing KS metric, pos sample number and neg sample number" \ + "must be larger than 0" + + tpr_arr = confusion_mat['tp'] / pos_num + fpr_arr = confusion_mat['fp'] / neg_num + + tpr = np.append(tpr_arr, np.array([1.0])) + fpr = np.append(fpr_arr, np.array([1.0])) + cuts = np.append(cuts, np.array([1.0])) + threshold.append(0.0) + ks_curve = tpr[:-1] - fpr[:-1] + ks_val = np.max(ks_curve) + + return EvalResult(self.metric_name, ks_val), \ + EvalResult(self.metric_name + '_table', pd.DataFrame({'tpr': tpr, 'fpr': fpr, 'threshold': threshold, 'cuts': cuts})) + + +class ConfusionMatrix(Metric): + + metric_name = 'confusion_matrix' + + def __init__(self): + super().__init__() + + def __call__(self, predict, label, **kwargs): + + predict = self.to_np_format(predict) + label = self.to_np_format(label) + + sorted_labels, sorted_scores = sort_score_and_label(label, predict) + threshold, cuts = ThresholdCutter.cut_by_index(sorted_scores) + confusion_mat = _ConfusionMatrix.compute(sorted_labels, sorted_scores, threshold, ret=['tp', 'tn', 'fp', 'fn'], + pos_label=1) + confusion_mat['cuts'] = cuts + confusion_mat['threshold'] = threshold + return EvalResult(self.metric_name, pd.DataFrame(confusion_mat)) + + +class Lift(Metric, BiClassMetric): + + metric_name = 'lift' + + def __init__(self, *args, **kwargs): + Metric.__init__(self) + BiClassMetric.__init__(self, cut_method='step', remove_duplicate=False, pos_label=1) + + @staticmethod + def _lift_helper(val): + + tp, fp, fn, tn, labels_num = val[0], val[1], val[2], val[3], val[4] + + lift_x_type, lift_y_type = [], [] + + for label_type in ['1', '0']: + + if label_type == '0': + tp, tn = tn, tp + fp, fn = fn, fp + + if labels_num == 0: + lift_x = 1 + denominator = 1 + else: + lift_x = (tp + fp) / labels_num + denominator = (tp + fn) / labels_num + + if tp + fp == 0: + numerator = 1 + else: + numerator = tp / (tp + fp) + + if denominator == 0: + lift_y = sys.float_info.max + else: + lift_y = numerator / denominator + + lift_x_type.insert(0, lift_x) + lift_y_type.insert(0, lift_y) + + return lift_x_type, lift_y_type + + def compute_metric_from_confusion_mat(self, confusion_mat, labels_len, ): + + labels_nums = np.zeros(len(confusion_mat['tp'])) + labels_len + + rs = map(self._lift_helper, zip(confusion_mat['tp'], confusion_mat['fp'], + confusion_mat['fn'], confusion_mat['tn'], labels_nums)) + + rs = list(rs) + + lifts_x, lifts_y = [i[0] for i in rs], [i[1] for i in rs] + + return lifts_y, lifts_x + + def __call__(self, predict, label, **kwargs): + + predict = self.to_np_format(predict) + label = self.to_np_format(label) + confusion_mat, score_threshold, cuts = self.prepare_confusion_mat(label, predict, add_to_end=False, ) + + lifts_y, lifts_x = self.compute_metric_from_confusion_mat(confusion_mat, len(label), ) + + return EvalResult(self.metric_name, + pd.DataFrame({ + 'liftx': lifts_x, + 'lifty': lifts_y, + 'threshold': list(score_threshold) + }) + ) + + +class Gain(Metric, BiClassMetric): + + metric_name = 'gain' + + def __init__(self, *args, **kwargs): + Metric.__init__(self) + BiClassMetric.__init__(self, cut_method='step', remove_duplicate=False, pos_label=1) + + @staticmethod + def _gain_helper(val): + + tp, fp, fn, tn, num_label = val[0], val[1], val[2], val[3], val[4] + + gain_x_type, gain_y_type = [], [] + + for pos_label in ['1', '0']: + + if pos_label == '0': + tp, tn = tn, tp + fp, fn = fn, fp + + if num_label == 0: + gain_x = 1 + else: + gain_x = float((tp + fp) / num_label) + + num_positives = tp + fn + if num_positives == 0: + gain_y = 1 + else: + gain_y = float(tp / num_positives) + + gain_x_type.insert(0, gain_x) + gain_y_type.insert(0, gain_y) + + return gain_x_type, gain_y_type + + def compute_metric_from_confusion_mat(self, confusion_mat, labels_len): + + labels_nums = np.zeros(len(confusion_mat['tp'])) + labels_len + + rs = map(self._gain_helper, zip(confusion_mat['tp'], confusion_mat['fp'], + confusion_mat['fn'], confusion_mat['tn'], labels_nums)) + + rs = list(rs) + + gain_x, gain_y = [i[0] for i in rs], [i[1] for i in rs] + + return gain_y, gain_x + + def __call__(self, predict, label, **kwargs): + + predict = self.to_np_format(predict) + label = self.to_np_format(label) + confusion_mat, score_threshold, cuts = self.prepare_confusion_mat(label, predict, add_to_end=False, ) + + gain_y, gain_x = self.compute_metric_from_confusion_mat(confusion_mat, len(label)) + + return EvalResult(self.metric_name, + pd.DataFrame({ + 'gainx': gain_x, + 'gainy': gain_y, + 'threshold': list(score_threshold) + }) + ) + + + +class BiClassPrecisionTable(Metric, BiClassMetric): + """ + Compute binary classification precision using multiple thresholds + """ + metric_name = 'biclass_precision_table' + + def __init__(self, *args, **kwargs): + Metric.__init__(self) + BiClassMetric.__init__(self, cut_method='step', remove_duplicate=False, pos_label=1) + + def compute_metric_from_confusion_mat(self, confusion_mat, impute_val=1.0): + numerator = confusion_mat['tp'] + denominator = (confusion_mat['tp'] + confusion_mat['fp']) + zero_indexes = (denominator == 0) + denominator[zero_indexes] = 1 + precision_scores = numerator / denominator + precision_scores[zero_indexes] = impute_val # impute_val is for prettifying when drawing pr curves + + return precision_scores + + def __call__(self, predict, label, **kwargs) -> Dict: + predict = self.to_np_format(predict) + label = self.to_np_format(label) + p, threshold, cuts = self.compute(label, predict) + return EvalResult(self.metric_name, pd.DataFrame({ + 'p': p, + 'threshold': threshold, + 'cuts': cuts + })) + + + +class BiClassRecallTable(Metric, BiClassMetric): + """ + Compute binary classification recall using multiple thresholds + """ + metric_name = 'biclass_recall_table' + + def __init__(self, *args, **kwargs): + Metric.__init__(self) + BiClassMetric.__init__(self, cut_method='step', remove_duplicate=False, pos_label=1) + + def compute_metric_from_confusion_mat(self, confusion_mat, formatted=True): + recall_scores = confusion_mat['tp'] / (confusion_mat['tp'] + confusion_mat['fn']) + return recall_scores + + def __call__(self, predict, label, **kwargs) -> Dict: + predict = self.to_np_format(predict) + label = self.to_np_format(label) + r, threshold, cuts = self.compute(label, predict) + return EvalResult(self.metric_name, pd.DataFrame({ + 'r': r, + 'threshold': threshold, + 'cuts': cuts + })) + + +class BiClassAccuracyTable(Metric, BiClassMetric): + """ + Compute binary classification accuracy using multiple thresholds + """ + metric_name = 'biclass_accuracy_table' + + def __init__(self, *args, **kwargs): + Metric.__init__(self) + BiClassMetric.__init__(self, cut_method='step', remove_duplicate=False, pos_label=1) + + def compute(self, labels, scores, normalize=True): + confusion_mat, score_threshold, cuts = self.prepare_confusion_mat(labels, scores) + metric_scores = self.compute_metric_from_confusion_mat(confusion_mat, normalize=normalize) + return list(metric_scores), score_threshold[: len(metric_scores)], cuts[: len(metric_scores)] + + def compute_metric_from_confusion_mat(self, confusion_mat, normalize=True): + rs = (confusion_mat['tp'] + confusion_mat['tn']) / \ + (confusion_mat['tp'] + confusion_mat['tn'] + confusion_mat['fn'] + confusion_mat['fp']) if normalize \ + else (confusion_mat['tp'] + confusion_mat['tn']) + return rs[:-1] + + def __call__(self, predict, label, **kwargs) -> Dict: + predict = self.to_np_format(predict) + label = self.to_np_format(label) + accuracy, threshold, cuts = self.compute(label, predict) + return EvalResult(self.metric_name, pd.DataFrame({ + 'accuracy': accuracy, + 'threshold': threshold, + 'cuts': cuts + })) + + +class FScoreTable(Metric): + """ + Compute F score from bi-class confusion mat + """ + + metric_name = 'fscore_table' + + def __call__(self, predict, label, beta=1): + + predict = self.to_np_format(predict) + label = self.to_np_format(label) + + sorted_labels, sorted_scores = sort_score_and_label(label, predict) + _, cuts = ThresholdCutter.cut_by_step(sorted_scores, steps=0.01) + fixed_interval_threshold = ThresholdCutter.fixed_interval_threshold() + confusion_mat = _ConfusionMatrix.compute(sorted_labels, sorted_scores, + fixed_interval_threshold, + ret=['tp', 'fp', 'fn', 'tn']) + precision_computer = BiClassPrecisionTable() + recall_computer = BiClassRecallTable() + p_score = precision_computer.compute_metric_from_confusion_mat(confusion_mat) + r_score = recall_computer.compute_metric_from_confusion_mat(confusion_mat) + beta_2 = beta * beta + denominator = (beta_2 * p_score + r_score) + denominator[denominator == 0] = 1e-6 # in case denominator is 0 + numerator = (1 + beta_2) * (p_score * r_score) + f_score = numerator / denominator + + return EvalResult(self.metric_name, pd.DataFrame({ + 'f_score': f_score, + 'threshold': fixed_interval_threshold, + 'cuts': cuts + })) + + +class PSI(Metric): + + metric_name = 'psi' + + def __call__(self, predict: dict, label: dict, **kwargs) -> Dict: + + """ + train/validate scores: predicted scores on train/validate set + train/validate labels: true labels + debug: print debug message + if train&validate labels are not None, count positive sample percentage in every interval + """ + + str_intervals=False + round_num=3 + pos_label=1 + + if not isinstance(predict, dict) or (label is not None and not isinstance(label, dict)): + raise ValueError("Input 'predict' must be a dictionary, and 'label' must be either None or a dictionary.") + + train_scores = predict.get('train_scores') + validate_scores = predict.get('validate_scores') + + if train_scores is None or validate_scores is None: + raise ValueError( + "Input 'predict' should contain the following keys: 'train_scores', 'validate_scores'. " + "Please make sure both keys are present." + ) + + train_labels = label.get('train_labels') if label is not None else None + validate_labels = label.get('validate_labels') if label is not None else None + + train_scores = np.array(train_scores) + validate_scores = np.array(validate_scores) + quantile_points = ThresholdCutter().cut_by_quantile(train_scores) + + train_count = self.quantile_binning_and_count(train_scores, quantile_points) + validate_count = self.quantile_binning_and_count(validate_scores, quantile_points) + + train_pos_perc, validate_pos_perc = None, None + + if train_labels is not None and validate_labels is not None: + assert len(train_labels) == len(train_scores) and len(validate_labels) == len(validate_scores) + train_labels, validate_labels = np.array(train_labels), np.array(validate_labels) + train_pos_count = self.quantile_binning_and_count(train_scores[train_labels == pos_label], quantile_points) + validate_pos_count = self.quantile_binning_and_count(validate_scores[validate_labels == pos_label], + quantile_points) + + train_pos_perc = np.array(train_pos_count['count']) / np.array(train_count['count']) + validate_pos_perc = np.array(validate_pos_count['count']) / np.array(validate_count['count']) + + # handle special cases + train_pos_perc[train_pos_perc == np.inf] = -1 + validate_pos_perc[validate_pos_perc == np.inf] = -1 + train_pos_perc[np.isnan(train_pos_perc)] = 0 + validate_pos_perc[np.isnan(validate_pos_perc)] = 0 + + assert (train_count['interval'] == validate_count['interval']), 'train count interval is not equal to ' \ + 'validate count interval' + + expected_interval = np.array(train_count['count']) + actual_interval = np.array(validate_count['count']) + + expected_interval = expected_interval.astype(np.float) + actual_interval = actual_interval.astype(np.float) + + psi_scores, total_psi, expected_interval, actual_interval, expected_percentage, actual_percentage \ + = self.psi_score(expected_interval, actual_interval, len(train_scores), len(validate_scores)) + + intervals = train_count['interval'] if not str_intervals else PSI.intervals_to_str(train_count['interval'], + round_num=round_num) + + total_psi = EvalResult('total_psi', total_psi) + + if train_labels is None and validate_labels is None: + psi_table = EvalResult('psi_table', pd.DataFrame({ + 'psi_scores': psi_scores, + 'expected_interval': expected_interval, + 'actual_interval': actual_interval, + 'expected_percentage': expected_percentage, + 'actual_percentage': actual_percentage, + 'interval': intervals + })) + else: + psi_table = EvalResult('psi_table', pd.DataFrame({ + 'psi_scores': psi_scores, + 'expected_interval': expected_interval, + 'actual_interval': actual_interval, + 'expected_percentage': expected_percentage, + 'actual_percentage': actual_percentage, + 'train_pos_perc': train_pos_perc, + 'validate_pos_perc': validate_pos_perc, + 'interval': intervals + })) + + return psi_table, total_psi + + @staticmethod + def quantile_binning_and_count(scores, quantile_points): + """ + left edge and right edge of last interval are closed + """ + + assert len(quantile_points) >= 2 + + left_bounds = copy.deepcopy(quantile_points[:-1]) + right_bounds = copy.deepcopy(quantile_points[1:]) + + last_interval_left = left_bounds.pop() + last_interval_right = right_bounds.pop() + + bin_result_1, bin_result_2 = None, None + + if len(left_bounds) != 0 and len(right_bounds) != 0: + bin_result_1 = pd.cut(scores, pd.IntervalIndex.from_arrays(left_bounds, right_bounds, closed='left')) + + bin_result_2 = pd.cut(scores, pd.IntervalIndex.from_arrays([last_interval_left], [last_interval_right], + closed='both')) + + count1 = None if bin_result_1 is None else bin_result_1.value_counts().reset_index() + count2 = bin_result_2.value_counts().reset_index() + + # if predict scores are the same, count1 will be None, only one interval exists + final_interval = list(count1['index']) + list(count2['index']) if count1 is not None else list(count2['index']) + final_count = list(count1[0]) + list(count2[0]) if count1 is not None else list(count2[0]) + rs = {'interval': final_interval, 'count': final_count} + + return rs + + @staticmethod + def interval_psi_score(val): + expected, actual = val[0], val[1] + return (actual - expected) * np.log(actual / expected) + + @staticmethod + def intervals_to_str(intervals, round_num=3): + str_intervals = [] + for interval in intervals: + left_bound, right_bound = '[', ']' + if interval.closed == 'left': + right_bound = ')' + elif interval.closed == 'right': + left_bound = '(' + str_intervals.append("{}{}, {}{}".format(left_bound, round(interval.left, round_num), + round(interval.right, round_num), right_bound)) + + return str_intervals + + @staticmethod + def psi_score(expected_interval: np.ndarray, actual_interval: np.ndarray, expect_total_num, actual_total_num, + debug=False): + + expected_interval[expected_interval == 0] = 1e-6 # in case no overlap samples + + actual_interval[actual_interval == 0] = 1e-6 # in case no overlap samples + + expected_percentage = expected_interval / expect_total_num + actual_percentage = actual_interval / actual_total_num + + if debug: + print(expected_interval) + print(actual_interval) + print(expected_percentage) + print(actual_percentage) + + psi_scores = list(map(PSI.interval_psi_score, zip(expected_percentage, actual_percentage))) + psi_scores = np.array(psi_scores) + total_psi = psi_scores.sum() + return psi_scores, total_psi, expected_interval, actual_interval, expected_percentage, actual_percentage \ No newline at end of file diff --git a/python/fate/ml/evaluation/evaluation.py b/python/fate/ml/evaluation/evaluation.py deleted file mode 100644 index 9383543d84..0000000000 --- a/python/fate/ml/evaluation/evaluation.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from sklearn import metrics - - -class BinaryEvaluator(object): - def __init__(self): - self._auc = None - - def fit(self, ctx, y_true, y_pred): - fpr, tpr, thresholds = metrics.roc_curve(y_true, y_pred, pos_label=1) - self._auc = metrics.auc(fpr, tpr) - - self._report(ctx) - - def _report(self, ctx): - ctx.metrics.log_auc("auc", self._auc) diff --git a/python/fate/ml/evaluation/metric_base.py b/python/fate/ml/evaluation/metric_base.py new file mode 100644 index 0000000000..20766bcb8b --- /dev/null +++ b/python/fate/ml/evaluation/metric_base.py @@ -0,0 +1,134 @@ +from typing import Dict +from transformers import EvalPrediction +import pandas as pd +import torch +import numpy as np +from typing import Union +import json +import logging + +logger = logging.getLogger(__name__) + + +SINGLE_VALUE = 'single_value' +TABLE_VALUE = 'table_value' + + +class EvalResult(object): + + def __init__(self, metric_name: str, result: Union[int, float, pd.DataFrame]): + self.metric_name = metric_name + assert isinstance(self.metric_name, str), "metric_name must be a string." + if isinstance(result, (int, float)): + self.result = float(result) + self.result_type = SINGLE_VALUE + elif isinstance(result, pd.DataFrame): + if len(result.shape) == 2: + self.result = result + self.result_type = TABLE_VALUE + else: + raise ValueError("DataFrame must be a 2D table.") + else: + raise TypeError("Invalid type for result. Expected int, float or DataFrame.") + + def __repr__(self) -> str: + return self.result.__repr__() + + def to_dict(self): + return { + "metric": self.metric_name, + # "result_type": self.result_type, + "val": self.result.to_dict(orient='list') if self.result_type == TABLE_VALUE else self.result + } + + def to_json(self): + if self.result_type == TABLE_VALUE: + return self.result.to_json(orient='split') + else: + return json.dumps(self.to_dict()) + + def get_raw_data(self): + return self.result + + def __dict__(self): + return self.to_dict() + + +class Metric(object): + + metric_name = None + + def __init__(self, *args, **kwargs): + pass + + def __call__(self, predict, label, **kwargs) -> EvalResult: + pass + + def to_np_format(self, data, flatten=True): + + if isinstance(data, list): + ret = np.array(data) + elif isinstance(data, torch.Tensor): + ret = data.detach().cpu().numpy() + elif isinstance(data, pd.Series) or isinstance(data, pd.DataFrame): + ret = np.array(data.values.tolist()) + else: + ret = data + + if flatten: + ret = ret.flatten() + + return ret.astype(np.float64) + + +class MetricEnsemble(object): + + def __init__(self, to_dict=True) -> None: + self._metrics = [] + self._metric_suffix = set() + self._to_dict = to_dict + + def add_metric(self, metric: Metric): + self._metrics.append(metric) + return self + + def _parse_input(self, eval_rs): + if isinstance(eval_rs, EvalPrediction): + # parse hugging face format + predict = eval_rs.predictions + label = eval_rs.label_ids + input_ = eval_rs.inputs + + elif isinstance(eval_rs, tuple) and len(eval_rs) == 2: + # conventional format + predict, label = eval_rs + input_ = None + else: + raise ValueError('Unknown eval_rs format: {}. Expected input formats are either ' + 'an instance of EvalPrediction or a 2-tuple (predict, label).'.format(type(eval_rs))) + + return predict, label, input_ + + def __call__(self, eval_rs=None, predict=None, label=None, **kwargs) -> Dict: + + metric_result = [] + + if eval_rs is not None: + predict, label, input_ = self._parse_input(eval_rs) + + for metric in self._metrics: + rs = metric(predict, label) + if isinstance(rs, tuple): + new_rs = [r.to_dict() for r in rs] + rs = new_rs + elif isinstance(rs, EvalResult): + rs = rs.to_dict() + else: + raise ValueError('cannot parse metric result: {}'.format(rs)) + metric_result.append(rs) + return metric_result + + def fit(self, eval_rs=None, predict=None, label=None, **kwargs) -> Dict: + return self.__call__(eval_rs, predict, label, **kwargs) + + diff --git a/python/fate/ml/evaluation/regression.py b/python/fate/ml/evaluation/regression.py new file mode 100644 index 0000000000..77631f9ee0 --- /dev/null +++ b/python/fate/ml/evaluation/regression.py @@ -0,0 +1,49 @@ +from typing import Dict +import numpy as np +from fate.ml.evaluation.metric_base import Metric +from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score +from fate.ml.evaluation.metric_base import EvalResult + + +class RMSE(Metric): + + metric_name = 'rmse' + + def __call__(self, predict, label, **kwargs) -> Dict: + predict = self.to_np_format(predict) + label = self.to_np_format(label) + rmse = np.sqrt(mean_squared_error(label, predict)) + return EvalResult(self.metric_name, rmse) + + +class MSE(Metric): + + metric_name = 'mse' + + def __call__(self, predict, label, **kwargs) -> Dict: + predict = self.to_np_format(predict) + label = self.to_np_format(label) + mse = mean_squared_error(label, predict) + return EvalResult(self.metric_name, mse) + + +class MAE(Metric): + + metric_name = 'mae' + + def __call__(self, predict, label, **kwargs) -> Dict: + predict = self.to_np_format(predict) + label = self.to_np_format(label) + mae = mean_absolute_error(label, predict) + return EvalResult(self.metric_name, mae) + + +class R2Score(Metric): + + metric_name = 'r2_score' + + def __call__(self, predict, label, **kwargs) -> Dict: + predict = self.to_np_format(predict) + label = self.to_np_format(label) + r2 = r2_score(label, predict) + return EvalResult(self.metric_name, r2) diff --git a/python/fate/ml/evaluation/test/test_classi_metrci.py b/python/fate/ml/evaluation/test/test_classi_metrci.py new file mode 100644 index 0000000000..fa96a1e360 --- /dev/null +++ b/python/fate/ml/evaluation/test/test_classi_metrci.py @@ -0,0 +1,74 @@ +import unittest +import numpy as np +from sklearn.metrics import roc_auc_score, accuracy_score, recall_score, precision_score, f1_score +from sklearn.preprocessing import LabelBinarizer +from fate.ml.evaluation.classification import * + + +class TestMetric(unittest.TestCase): + + def test_AUC(self): + auc_metric = AUC() + predict = np.random.random_sample(1000) + label = np.random.randint(0, 2, 1000) + result = auc_metric(predict, label) + print(result.to_dict()) + self.assertEqual(result.metric_name, 'auc') + self.assertAlmostEqual(result.result, roc_auc_score(label, predict), places=7) + + def test_MultiAccuracy(self): + multi_acc_metric = MultiAccuracy() + predict = np.random.random_sample((1000, 3)) + label = np.random.randint(0, 3, 1000) + result = multi_acc_metric(predict, label) + print(result.to_dict()) + self.assertEqual(result.metric_name, 'multi_accuracy') + self.assertAlmostEqual(result.result, accuracy_score(label, predict.argmax(axis=-1)), places=7) + + def test_MultiRecall(self): + multi_recall_metric = MultiRecall() + predict = np.random.random_sample((1000, 3)) + label = np.random.randint(0, 3, 1000) + result = multi_recall_metric(predict, label) + print(result.to_dict()) + self.assertEqual(result.metric_name, 'multi_recall') + self.assertAlmostEqual(result.result, recall_score(label, predict.argmax(axis=-1), average='micro'), places=7) + + def test_MultiPrecision(self): + multi_precision_metric = MultiPrecision() + predict = np.random.random_sample((1000, 3)) + label = np.random.randint(0, 3, 1000) + result = multi_precision_metric(predict, label) + print(result.to_dict()) + self.assertEqual(result.metric_name, 'multi_precision') + self.assertAlmostEqual(result.result, precision_score(label, predict.argmax(axis=-1), average='micro'), places=7) + + def test_MultiF1Score(self): + multi_f1_metric = MultiF1Score() + predict = np.random.random_sample((1000, 3)) + label = np.random.randint(0, 3, 1000) + result = multi_f1_metric(predict, label) + print(result.to_dict()) + self.assertEqual(result.metric_name, 'multi_f1_score') + self.assertAlmostEqual(result.result, f1_score(label, predict.argmax(axis=-1), average='micro'), places=7) + + def test_BinaryMetrics(self): + for metric_class, sklearn_metric in [ + (BinaryAccuracy, accuracy_score), + (BinaryRecall, recall_score), + (BinaryPrecision, precision_score), + (BinaryF1Score, f1_score) + ]: + with self.subTest(metric_class=metric_class): + metric = metric_class() + predict = np.random.random_sample(1000) + label = np.random.randint(0, 2, 1000) + binary_predict = (predict > metric.threshold).astype(int) + result = metric(predict, label) + print(result.to_dict()) + self.assertEqual(result.metric_name, metric.metric_name) + self.assertAlmostEqual(result.result, sklearn_metric(label, binary_predict), places=7) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/fate/ml/evaluation/test/test_classi_table_metric.py b/python/fate/ml/evaluation/test/test_classi_table_metric.py new file mode 100644 index 0000000000..1ecedc437b --- /dev/null +++ b/python/fate/ml/evaluation/test/test_classi_table_metric.py @@ -0,0 +1,59 @@ +import unittest +import numpy as np +from sklearn.metrics import roc_auc_score, accuracy_score, recall_score, precision_score, f1_score +from sklearn.preprocessing import LabelBinarizer +from fate.ml.evaluation.classification import * + + +def generate_predict_and_label(num): + predict = np.random.random_sample(num) + label = np.random.randint(0, 2, num) + return predict, label + + +class TestMetric(unittest.TestCase): + + def test_KS(self): + ks_metric = KS() + predict, label = generate_predict_and_label(1000) + result = ks_metric(predict, label) + print(result) + print(result[0].to_dict()) + print(result[1].to_dict()) + + def test_confusion_matrix(self): + confusion_matrix_metric = ConfusionMatrix() + predict, label = generate_predict_and_label(1000) + result = confusion_matrix_metric(predict, label) + print(result.to_dict()) + + def test_gain(self): + gain_metric = Gain() + predict, label = generate_predict_and_label(1000) + result = gain_metric(predict, label) + print(result.to_dict()) + + def test_lift(self): + lift_metric = Lift() + predict, label = generate_predict_and_label(1000) + result = lift_metric(predict, label) + print(result.to_dict()) + + def test_bi_acc(self): + bi_acc_metric = BiClassAccuracyTable() + predict, label = generate_predict_and_label(1000) + result = bi_acc_metric(predict, label) + print(result.to_dict()) + + def test_psi(self): + psi_metric = PSI() + predict, label = generate_predict_and_label(1000) + predict2, label2 = generate_predict_and_label(1000) + result = psi_metric({'train_scores': predict, 'validate_scores': predict2}, {'train_labels': label, 'validate_labels': label2}) + print('result is {}'.format(result)) + print(result[0].to_dict()) + print(result[1].to_dict()) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/fate/ml/evaluation/test/test_metric_ensembles.py b/python/fate/ml/evaluation/test/test_metric_ensembles.py new file mode 100644 index 0000000000..6d360e4bd1 --- /dev/null +++ b/python/fate/ml/evaluation/test/test_metric_ensembles.py @@ -0,0 +1,45 @@ +import numpy as np +import unittest +from fate.ml.evaluation import classification as classi +from fate.ml.evaluation.metric_base import MetricEnsemble +from fate.ml.evaluation.tool import get_binary_metrics, get_specified_metrics, get_single_val_binary_metrics + + +def generate_predict_and_label(num): + predict = np.random.random_sample(num) + label = np.random.randint(0, 2, num) + return predict, label + + +class TestMetric(unittest.TestCase): + + def test_binary_ensemble(self): + + binary_ensemble = get_binary_metrics() + + predict, label = generate_predict_and_label(1000) + result = binary_ensemble(predict=predict, label=label) + print(result) + print(type(result)) + + + def test_selected(self): + metrics = get_specified_metrics(['auc', 'ks', 'lift', 'gain', 'binary_accuracy']) + predict, label = generate_predict_and_label(1000) + result = metrics(predict=predict, label=label) + print(result) + print(type(result)) + + def test_binary_ensemble(self): + + binary_ensemble = get_single_val_binary_metrics(threshold=0.8) + + predict, label = generate_predict_and_label(1000) + result = binary_ensemble(predict=predict, label=label) + print(result) + print(type(result)) + + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/python/fate/ml/evaluation/test/test_reg_metric.py b/python/fate/ml/evaluation/test/test_reg_metric.py new file mode 100644 index 0000000000..10e5566cf8 --- /dev/null +++ b/python/fate/ml/evaluation/test/test_reg_metric.py @@ -0,0 +1,51 @@ +import unittest +import numpy as np +from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score +from fate.ml.evaluation.regression import * + +class TestMetric(unittest.TestCase): + + def test_RMSE(self): + rmse_metric = RMSE() + predict = np.random.random_sample(1000) * 10 + label = np.random.random_sample(1000) * 10 + result = rmse_metric(predict, label) + self.assertEqual(result.metric_name, 'rmse') + self.assertAlmostEqual(result.result, np.sqrt(mean_squared_error(label, predict)), places=7) + + def test_MSE(self): + mse_metric = MSE() + predict = np.random.random_sample(1000) * 10 + label = np.random.random_sample(1000) * 10 + result = mse_metric(predict, label) + self.assertEqual(result.metric_name, 'mse') + self.assertAlmostEqual(result.result, mean_squared_error(label, predict), places=7) + + def test_MAE(self): + mae_metric = MAE() + predict = np.random.random_sample(1000) * 10 + label = np.random.random_sample(1000) * 10 + result = mae_metric(predict, label) + self.assertEqual(result.metric_name, 'mae') + self.assertAlmostEqual(result.result, mean_absolute_error(label, predict), places=7) + + def test_R2Score(self): + r2_metric = R2Score() + predict = np.random.random_sample(1000) * 10 + label = np.random.random_sample(1000) * 10 + result = r2_metric(predict, label) + self.assertEqual(result.metric_name, 'r2_score') + self.assertAlmostEqual(result.result, r2_score(label, predict), places=7) + + def test_to_dict(self): + metrics = [RMSE(), MSE(), MAE(), R2Score()] + predict = np.random.random_sample(1000) * 10 + label = np.random.random_sample(1000) * 10 + for m in metrics: + print(m(predict, label).to_dict()) + + + + +if __name__ == '__main__': + unittest.main() diff --git a/python/fate/ml/evaluation/tool.py b/python/fate/ml/evaluation/tool.py new file mode 100644 index 0000000000..9e48e6e5fb --- /dev/null +++ b/python/fate/ml/evaluation/tool.py @@ -0,0 +1,73 @@ +import inspect +from fate.ml.evaluation import classification as classi +from fate.ml.evaluation import regression as reg +from fate.ml.evaluation.metric_base import Metric, MetricEnsemble + + +def get_metric_names(modules): + result = {} + + for module in modules: + for name, obj in inspect.getmembers(module): + if inspect.isclass(obj): + if hasattr(obj, 'metric_name') and issubclass(obj, Metric): + metric_name = getattr(obj, 'metric_name') + if metric_name is not None: + result[metric_name] = obj + + return result + + +def all_available_metrics(): + return get_metric_names([classi, reg]) + + +def get_single_val_binary_metrics(threshold=0.5): + + binary_ensembles = MetricEnsemble() + binary_ensembles.add_metric(classi.AUC()).add_metric(classi.BinaryAccuracy(threshold=threshold)).add_metric(classi.BinaryF1Score(threshold=threshold)) + binary_ensembles.add_metric(classi.BinaryPrecision(threshold=threshold)).add_metric(classi.BinaryRecall(threshold=threshold)) + return binary_ensembles + + +def get_binary_metrics(): + + binary_ensembles = MetricEnsemble() + binary_ensembles.add_metric(classi.AUC()).add_metric(classi.KS()).add_metric(classi.ConfusionMatrix()) + binary_ensembles.add_metric(classi.Gain()).add_metric(classi.Lift()) + binary_ensembles.add_metric(classi.BiClassPrecisionTable()).add_metric(classi.BiClassRecallTable()) + binary_ensembles.add_metric(classi.BiClassAccuracyTable()).add_metric(classi.FScoreTable()) + return binary_ensembles + + +def get_multi_metrics(): + + multi_ensembles = MetricEnsemble() + multi_ensembles.add_metric(classi.MultiAccuracy()).add_metric(classi.MultiPrecision()).add_metric(classi.MultiRecall()) + + return multi_ensembles + + +def get_regression_metrics(): + + regression_ensembles = MetricEnsemble() + regression_ensembles.add_metric(reg.RMSE()).add_metric(reg.MAE()).add_metric(reg.MSE()).add_metric(reg.R2Score()) + return regression_ensembles + + +def get_special_metrics(): + # metrics that need special input format like PSI + ensembles = MetricEnsemble() + ensembles.add_metric(classi.PSI()) + return ensembles + + +def get_specified_metrics(metric_names: list): + ensembles = MetricEnsemble() + available_metrics = get_metric_names([classi, reg]) + for metric_name in metric_names: + if metric_name in available_metrics: + ensembles.add_metric(get_metric_names([classi, reg])[metric_name]()) + else: + raise ValueError(f"metric {metric_name} is not supported yet, supported metrics are \n {list(available_metrics.keys())}") + return ensembles \ No newline at end of file diff --git a/python/fate/arch/context/metric/__init__.py b/python/fate/ml/feature_binning/__init__.py similarity index 79% rename from python/fate/arch/context/metric/__init__.py rename to python/fate/ml/feature_binning/__init__.py index 400736af0a..37b4fd0aa7 100644 --- a/python/fate/arch/context/metric/__init__.py +++ b/python/fate/ml/feature_binning/__init__.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The FATE Authors. All Rights Reserved. +# Copyright 2023 The FATE Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,5 +12,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ._type import InCompleteMetrics, Metric, Metrics -from ._wrap import MetricsWrap + +from .hetero_feature_binning import HeteroBinningModuleGuest, HeteroBinningModuleHost diff --git a/python/fate/ml/feature_binning/hetero_feature_binning.py b/python/fate/ml/feature_binning/hetero_feature_binning.py new file mode 100644 index 0000000000..0568790993 --- /dev/null +++ b/python/fate/ml/feature_binning/hetero_feature_binning.py @@ -0,0 +1,460 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import numpy as np +import pandas as pd + +from fate.arch import Context +from fate.arch.histogram import HistogramBuilder +from ..abc.module import HeteroModule, Module + +logger = logging.getLogger(__name__) + + +class HeteroBinningModuleGuest(HeteroModule): + def __init__( + self, + method="quantile", + n_bins=10, + split_pt_dict=None, + bin_col=None, + transform_method=None, + category_col=None, + local_only=False, + error_rate=1e-6, + adjustment_factor=0.5, + ): + self.method = method + self.bin_col = bin_col + self.category_col = category_col + self.n_bins = n_bins + self._federation_bin_obj = None + # param check + if self.method in ["quantile", "bucket", "manual"]: + self._bin_obj = StandardBinning( + method, n_bins, split_pt_dict, bin_col, transform_method, category_col, error_rate, adjustment_factor + ) + else: + raise ValueError(f"{self.method} binning method not supported, please check") + self.local_only = local_only + + def set_transform_method(self, transform_method): + self._bin_obj.transform_method = transform_method + + def fit(self, ctx: Context, train_data, validate_data=None) -> None: + logger.info("Enter HeteroBinning fit.") + + train_data_binarized_label = train_data.label.get_dummies() + label_count = train_data_binarized_label.shape[1] + if label_count > 2: + raise ValueError( + f"More than 2 classes found in label column. " + f"HeteroBinning currently only supports binary data. Please check." + ) + + self._bin_obj.fit(ctx, train_data) + + def compute_metrics(self, ctx: Context, binned_data): + # label_tensor = binned_data.label.as_tensor() + self._bin_obj.compute_metrics(binned_data) + if not self.local_only: + self.compute_federated_metrics(ctx, binned_data) + + def compute_federated_metrics(self, ctx: Context, binned_data): + logger.info(f"Start computing federated metrics.") + kit = ctx.cipher.phe.setup() + encryptor = kit.get_tensor_encryptor() + sk, pk, evaluator, coder = kit.sk, kit.pk, kit.evaluator, kit.coder + + label_tensor = binned_data.label.as_tensor() + ctx.hosts.put("enc_y", encryptor.encrypt_tensor(label_tensor)) + ctx.hosts.put("pk", pk) + ctx.hosts.put("evaluator", evaluator) + ctx.hosts.put("coder", coder) + host_col_bin = ctx.hosts.get("anonymous_col_bin") + host_event_non_event_count = ctx.hosts.get("event_non_event_count") + host_bin_sizes = ctx.hosts.get("feature_bin_sizes") + for i, (col_bin_list, bin_sizes, en_host_count_res) in enumerate(zip(host_col_bin, + host_bin_sizes, + host_event_non_event_count)): + host_event_non_event_count_hist = en_host_count_res.decrypt({"event_count": sk, + "non_event_count": sk}, + {"event_count": (coder, None), + "non_event_count": (coder, None)}) + host_event_non_event_count_hist = host_event_non_event_count_hist.reshape(bin_sizes) + summary_metrics, _ = self._bin_obj.compute_all_col_metrics(host_event_non_event_count_hist, + col_bin_list) + self._bin_obj.set_host_metrics(ctx.hosts[i], summary_metrics) + + def transform(self, ctx: Context, test_data): + transformed_data = self._bin_obj.transform(ctx, test_data) + return transformed_data + + def get_model(self): + model_info = self._bin_obj.to_model() + model = { + "data": model_info, + "meta": { + "method": self.method, + "metrics": ["iv"] if model_info.get("metrics_summary") else [], + "local_only": self.local_only, + "bin_col": self.bin_col, + "category_col": self.category_col, + "model_type": "binning", + "n_bins": self.n_bins, + }, + } + return model + + def restore(self, model): + self._bin_obj.restore(model) + + @classmethod + def from_model(cls, model) -> "HeteroBinningModuleGuest": + bin_obj = HeteroBinningModuleGuest( + method=model["meta"]["method"], + bin_col=model["meta"]["bin_col"], + category_col=model["meta"]["category_col"], + n_bins=model["meta"]["n_bins"], + ) + bin_obj.restore(model["data"]) + return bin_obj + + +class HeteroBinningModuleHost(HeteroModule): + def __init__( + self, + method="quantile", + n_bins=10, + split_pt_dict=None, + bin_col=None, + transform_method=None, + category_col=None, + local_only=False, + error_rate=1e-6, + adjustment_factor=0.5, + ): + self.method = method + self.n_bins = n_bins + self._federation_bin_obj = None + if self.method in ["quantile", "bucket", "manual"]: + self._bin_obj = StandardBinning( + method, n_bins, split_pt_dict, bin_col, transform_method, category_col, error_rate, adjustment_factor + ) + self.local_only = local_only + self.bin_col = bin_col + self.category_col = category_col + + def set_transform_method(self, new_transform_method): + self._bin_obj.transform_method = new_transform_method + + def fit(self, ctx: Context, train_data, validate_data=None) -> None: + logger.info("Enter HeteroBinning fit.") + self._bin_obj.fit(ctx, train_data) + + def compute_metrics(self, ctx: Context, binned_data): + if not self.local_only: + self.compute_federated_metrics(ctx, binned_data) + + def compute_federated_metrics(self, ctx: Context, binned_data): + logger.info(f"Start computing federated metrics.") + pk = ctx.guest.get("pk") + evaluator = ctx.guest.get("evaluator") + coder = ctx.guest.get("coder") + columns = binned_data.schema.columns.to_list() + # logger.info(f"self.bin_col: {self.bin_col}") + to_compute_col = self.bin_col + self.category_col + anonymous_col_bin = [binned_data.schema.anonymous_columns[columns.index(col)] for col in to_compute_col] + + ctx.guest.put("anonymous_col_bin", anonymous_col_bin) + encrypt_y = ctx.guest.get("enc_y") + # event count: + feature_bin_sizes = [self._bin_obj._bin_count_dict[col] for col in self.bin_col] + if self.category_col: + for col in self.category_col: + category_bin_size = binned_data[col].get_dummies().shape[1] + feature_bin_sizes.append(category_bin_size) + to_compute_data = binned_data[to_compute_col] + to_compute_data.rename( + columns=dict(zip(to_compute_data.schema.columns, to_compute_data.schema.anonymous_columns)) + ) + hist_targets = binned_data.create_frame() + hist_targets["event_count"] = encrypt_y + hist_targets["non_event_count"] = 1 + dtypes = hist_targets.dtypes + + hist_schema = {"event_count": {"type": "ciphertext", + "stride": 1, + "pk": pk, + "evaluator": evaluator, + "coder": coder, + "dtype": dtypes["event_count"], + }, + "non_event_count": {"type": "plaintext", + "stride": 1, + "dtype": dtypes["non_event_count"]} + } + hist = HistogramBuilder(num_node=1, + feature_bin_sizes=feature_bin_sizes, + value_schemas=hist_schema, + enable_cumsum=False) + event_non_event_count_hist = to_compute_data.distributed_hist_stat(histogram_builder=hist, + targets=hist_targets) + event_non_event_count_hist.i_sub_on_key("non_event_count", "event_count") + ctx.guest.put("event_non_event_count", (event_non_event_count_hist)) + ctx.guest.put("feature_bin_sizes", feature_bin_sizes) + + def transform(self, ctx: Context, test_data): + return self._bin_obj.transform(ctx, test_data) + + def get_model(self): + model_info = self._bin_obj.to_model() + model = { + "data": model_info, + "meta": { + "method": self.method, + "bin_col": self.bin_col, + "category_col": self.category_col, + "n_bins": self.n_bins, + "model_type": "binning", + }, + } + return model + + def restore(self, model): + self._bin_obj.restore(model) + + @classmethod + def from_model(cls, model) -> "HeteroBinningModuleHost": + bin_obj = HeteroBinningModuleHost( + method=model["meta"]["method"], + bin_col=model["meta"]["bin_col"], + category_col=model["meta"]["category_col"], + n_bins=model["meta"]["n_bins"], + ) + bin_obj.restore(model["data"]) + return bin_obj + + +class StandardBinning(Module): + def __init__( + self, method, n_bins, split_pt_dict, bin_col, transform_method, category_col, error_rate, adjustment_factor + ): + self.method = method + self.n_bins = n_bins + # cols to be binned + self.bin_col = bin_col + # cols to be treated as categorical feature + self.category_col = category_col + self.transform_method = transform_method + self.relative_error = error_rate + self.adjustment_factor = adjustment_factor + self._manual_split_pt_dict = split_pt_dict + # {col_name: [split_pts]}, ordered by split_pts + self._split_pt_dict = None + self._bin_idx_dict = None + # {col_name: [bin_num]}, ordered by split_pts + self._bin_count_dict = None + # {col_name: [woe]}, ordered by split_pts, for transform + self._woe_dict = None + # for prediction transform + self._train_woe_dict = None + # {col_name: {"iv_array": [], "woe": [], "event_count": []...}} + self._metrics_summary = None + self._host_metrics_summary = None + self._train_metrics_summary = None + self._train_host_metrics_summary = None + + def set_host_metrics(self, host, metrics_summary): + if self._host_metrics_summary is None: + self._host_metrics_summary = {} + self._host_metrics_summary[host.name] = metrics_summary + + def fit(self, ctx: Context, train_data, validate_data=None, skip_none=False): + # only bin given `col_bin` cols + if self.bin_col is None: + self.bin_col = train_data.schema.columns.to_list() + select_data = train_data[self.bin_col] + + if self.method == "quantile": + q = list(np.arange(0, 1, 1 / self.n_bins)) + [1.0] + split_pt_df = select_data.quantile(q=q, relative_error=self.relative_error).drop(0) + elif self.method == "bucket": + split_pt_df = select_data.qcut(q=self.n_bins) + elif self.method == "manual": + split_pt_df = pd.DataFrame.from_dict(self._manual_split_pt_dict) + else: + raise ValueError(f"Unknown binning method {self.method} encountered. Please check") + # self._split_pt_dict = split_pt_df.to_dict() + self._split_pt_dict = split_pt_df + + def __get_col_bin_count(col): + count = len(col.unique()) + return count + + bin_count = split_pt_df.apply(__get_col_bin_count, axis=0) + self._bin_count_dict = bin_count.to_dict() + + def bucketize_data(self, train_data): + binned_df = train_data.bucketize(boundaries=self._split_pt_dict) + return binned_df + + def compute_all_col_metrics(self, event_non_event_count_hist, columns): + event_non_event_count = event_non_event_count_hist.to_dict(columns)[0] + non_event_count_dict = event_non_event_count.get("non_event_count") + event_count_dict = event_non_event_count.get("event_count") + + event_count, non_event_count = {}, {} + event_rate, non_event_rate = {}, {} + bin_woe, bin_iv, is_monotonic, iv = {}, {}, {}, {} + total_event_count, total_non_event_count = None, None + for col_name in event_count_dict.keys(): + col_event_count = pd.Series( + {bin_num: int(bin_count.data) for bin_num, bin_count in event_count_dict[col_name].items()} + ) + col_non_event_count = pd.Series( + {bin_num: int(bin_count.data) for bin_num, bin_count in non_event_count_dict[col_name].items()} + ) + if total_event_count is None: + total_event_count = col_event_count.sum() or 1 + total_non_event_count = col_non_event_count.sum() or 1 + col_event_rate = (col_event_count == 0) * self.adjustment_factor + col_event_count / total_event_count + col_non_event_rate = ( + col_non_event_count == 0 + ) * self.adjustment_factor + col_non_event_count / total_non_event_count + col_rate_ratio = col_event_rate / col_non_event_rate + col_bin_woe = col_rate_ratio.apply(lambda v: np.log(v)) + col_bin_iv = (col_event_rate - col_non_event_rate) * col_bin_woe + + event_count[col_name] = col_event_count.to_dict() + non_event_count[col_name] = col_non_event_count.to_dict() + event_rate[col_name] = col_event_rate.to_dict() + non_event_rate[col_name] = col_non_event_rate.to_dict() + bin_woe[col_name] = col_bin_woe.to_dict() + bin_iv[col_name] = col_bin_iv.to_dict() + is_monotonic[col_name] = col_bin_woe.is_monotonic_increasing or col_bin_woe.is_monotonic_decreasing + iv[col_name] = col_bin_iv[1:].sum() + + metrics_summary = {} + + metrics_summary["event_count"] = event_count + metrics_summary["non_event_count"] = non_event_count + metrics_summary["event_rate"] = event_rate + metrics_summary["non_event_rate"] = non_event_rate + metrics_summary["woe"] = bin_woe + metrics_summary["iv_array"] = bin_iv + metrics_summary["is_monotonic"] = is_monotonic + metrics_summary["iv"] = iv + return metrics_summary, bin_woe + + def compute_metrics(self, binned_data): + to_compute_col = self.bin_col + self.category_col + to_compute_data = binned_data[to_compute_col] + + feature_bin_sizes = [self._bin_count_dict[col] for col in self.bin_col] + if self.category_col: + for col in self.category_col: + category_bin_size = binned_data[col].get_dummies().shape[1] + feature_bin_sizes.append(category_bin_size) + hist_targets = binned_data.create_frame() + hist_targets["event_count"] = binned_data.label + hist_targets["non_event_count"] = 1 + dtypes = hist_targets.dtypes + hist_schema = {"event_count": {"type": "plaintext", + "stride": 1, + "dtype": dtypes["event_count"]}, + "non_event_count": {"type": "plaintext", + "stride": 1, + "dtype": dtypes["non_event_count"]} + } + hist = HistogramBuilder(num_node=1, + feature_bin_sizes=feature_bin_sizes, + value_schemas=hist_schema, + enable_cumsum=False) + event_non_event_count_hist = to_compute_data.distributed_hist_stat(histogram_builder=hist, + targets=hist_targets) + event_non_event_count_hist.i_sub_on_key("non_event_count", "event_count") + event_non_event_count_hist = event_non_event_count_hist.decrypt({}, {}).reshape(feature_bin_sizes) + self._metrics_summary, self._woe_dict = self.compute_all_col_metrics(event_non_event_count_hist, + to_compute_col) + + def transform(self, ctx: Context, binned_data): + logger.debug(f"Given transform method: {self.transform_method}.") + if self.transform_method == "bin_idx" and self._bin_idx_dict: + return binned_data + elif self.transform_method == "woe": + if ctx.is_on_host: + raise ValueError(f"host does not support 'woe' transform method, please use 'bin_idx'.") + # predict: replace with woe from train phase + to_transform_data = binned_data[self.bin_col] + if self._train_woe_dict: + logger.debug(f"`train_woe_dict` provided, will transform to woe values from training phase.") + binned_data[self.bin_col] = to_transform_data.replace(self._train_woe_dict) + # return binned_data.replace(self._train_woe_dict, self.bin_col) + elif self._woe_dict: + binned_data[self.bin_col] = to_transform_data.replace(self._woe_dict) + # return binned_data.replace(self._woe_dict, self.bin_col) + else: + logger.warning( + f"to transform type {self.transform_method} encountered, but no bin tag dict provided. " + f"Please check" + ) + return binned_data + + def to_model(self): + return dict( + method=self.method, + bin_col=self.bin_col, + split_pt_dict=self._split_pt_dict.to_dict(), + bin_idx_dict=self._bin_idx_dict, + bin_count_dict=self._bin_count_dict, + metrics_summary=self._metrics_summary, + train_metrics_summary=self._train_metrics_summary, + host_metrics_summary=self._host_metrics_summary, + train_host_metrics_summary=self._train_host_metrics_summary, + woe_dict=self._woe_dict, + category_col=self.category_col, + adjustment_factor=self.adjustment_factor + # transform_method = self.transform_method, + ) + + def restore(self, model): + self.method = model["method"] + self.bin_col = model["bin_col"] + # self.transform_method = model["transform_method"] + self._split_pt_dict = pd.DataFrame.from_dict(model["split_pt_dict"]) + self._bin_idx_dict = model["bin_idx_dict"] + self._bin_count_dict = model["bin_count_dict"] + # load predict model + if model.get("train_metrics_summary"): + self._metrics_summary = model["metrics_summary"] + self._train_metrics_summary = model["train_metrics_summary"] + else: + self._train_metrics_summary = model["metrics_summary"] + if model.get("train_host_metrics_summary"): + self._host_metrics_summary = model["host_metrics_summary"] + self._train_host_metrics_summary = model["train_host_metrics_summary"] + else: + self._train_host_metrics_summary = model["host_metrics_summary"] + if model.get("train_woe_dict"): + self._woe_dict = model["woe_dict"] + self._train_woe_dict = model["train_woe_dict"] + else: + self._train_woe_dict = model["woe_dict"] + + self.category_col = model["category_col"] + self.adjustment_factor = model["adjustment_factor"] diff --git a/python/fate/ml/feature_scale/feature_scale.py b/python/fate/ml/feature_scale/feature_scale.py deleted file mode 100644 index fe37dd2f37..0000000000 --- a/python/fate/ml/feature_scale/feature_scale.py +++ /dev/null @@ -1,75 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -import logging - -import pandas as pd -from fate.interface import Context - -from ..abc.module import Module - -logger = logging.getLogger(__name__) - - -class FeatureScale(Module): - def __init__(self, method="standard"): - self.method = method - self._scaler = None - if self.method == "standard": - self._scaler = StandardScaler() - - def fit(self, ctx: Context, train_data, validate_data=None) -> None: - self._scaler.fit(ctx, train_data) - - def transform(self, ctx: Context, test_data): - return self._scaler.transform(ctx, test_data) - - def to_model(self): - scaler_info = self._scaler.to_model() - return dict(scaler_info=scaler_info, method=self.method) - - def restore(self, model): - self._scaler.from_model(model) - - @classmethod - def from_model(cls, model) -> "FeatureScale": - scaler = FeatureScale(model["method"]) - scaler.restore(model["scaler_info"]) - return scaler - - -class StandardScaler(Module): - def __init__(self): - self._mean = None - self._std = None - - def fit(self, ctx: Context, train_data, validate_data=None) -> None: - self._mean = train_data.mean() - self._std = train_data.std() - - def transform(self, ctx: Context, test_data): - return (test_data - self._mean) / self._std - - def to_model(self): - return dict( - mean=self._mean.to_json(), - mean_dtype=self._mean.dtype.name, - std=self._std.to_json(), - std_dtype=self._std.dtype.name, - ) - - def from_model(self, model): - self._mean = pd.Series(json.loads(model["mean"]), dtype=model["mean_dtype"]) - self._std = pd.Series(json.loads(model["std"]), dtype=model["std_dtype"]) diff --git a/python/fate/ml/feature_selection/__init__.py b/python/fate/ml/feature_selection/__init__.py new file mode 100644 index 0000000000..3eb1f74a6b --- /dev/null +++ b/python/fate/ml/feature_selection/__init__.py @@ -0,0 +1,16 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .hetero_feature_selection import HeteroSelectionModuleHost, HeteroSelectionModuleGuest diff --git a/python/fate/ml/feature_selection/hetero_feature_selection.py b/python/fate/ml/feature_selection/hetero_feature_selection.py new file mode 100644 index 0000000000..5e38691dcc --- /dev/null +++ b/python/fate/ml/feature_selection/hetero_feature_selection.py @@ -0,0 +1,560 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import json +import logging +import math +import random + +import numpy as np +import pandas as pd + +from fate.arch import Context +from ..abc.module import HeteroModule, Module + +logger = logging.getLogger(__name__) + +DEFAULT_METRIC = {"iv": ["iv"], "statistics": ["mean"]} + + +class HeteroSelectionModuleGuest(HeteroModule): + def __init__( + self, + method=None, + select_col=None, + input_models=None, + iv_param=None, + statistic_param=None, + manual_param=None, + keep_one=True, + ): + self.method = method + self.select_col = select_col + self.iv_param = iv_param + self.statistic_param = statistic_param + self.manual_param = manual_param + self.keep_one = keep_one + + # keep selection history + self._inner_method = [] + self._selection_obj = [] + + self.isometric_model_dict = None + if input_models: + isometric_model_dict = {} + for model in input_models: + model_type = model["meta"].get("model_type") + if model_type is None: + raise ValueError(f"Missing 'model_type' in input model") + isometric_model_dict[model_type] = model + self.isometric_model_dict = isometric_model_dict + + def fit(self, ctx: Context, train_data, validate_data=None) -> None: + # logger.info(f"isometric_model_dict: {self.isometric_model_dict}") + if self.select_col is None: + self.select_col = train_data.schema.columns.to_list() + + select_data = train_data[self.select_col] + header = select_data.schema.columns.to_list() + for i, filter_type in enumerate(self.method): + if filter_type == "manual": + selection_obj = ManualSelection( + method=filter_type, header=header, param=self.manual_param, keep_one=self.keep_one + ) + elif filter_type == "iv": + model = self.isometric_model_dict.get("binning", None) + if model is None: + raise ValueError(f"Cannot find binning model in input, please check") + selection_obj = StandardSelection( + method=filter_type, header=header, param=self.iv_param, model=model, keep_one=self.keep_one + ) + elif filter_type == "statistics": + model = self.isometric_model_dict.get("statistics", None) + if model is None: + raise ValueError(f"Cannot find statistics model in input, please check") + selection_obj = StandardSelection( + method=filter_type, header=header, param=self.statistic_param, model=model, keep_one=self.keep_one + ) + else: + raise ValueError(f"{filter_type} selection method not supported, please check") + self._selection_obj.append(selection_obj) + self._inner_method.append(filter_type) + + prev_selection_obj = None + for method, selection_obj in zip(self._inner_method, self._selection_obj): + if prev_selection_obj: + selection_obj.set_prev_selected_mask(copy.deepcopy(prev_selection_obj.selected_mask)) + if isinstance(selection_obj, StandardSelection) and isinstance(prev_selection_obj, StandardSelection): + selection_obj.set_host_prev_selected_mask(copy.deepcopy(prev_selection_obj._host_selected_mask)) + selection_obj.fit(ctx, select_data) + if method == "iv": + if self.iv_param.get("select_federated"): + HeteroSelectionModuleGuest.sync_select_federated(ctx, selection_obj) + prev_selection_obj = selection_obj + + @staticmethod + def sync_select_federated(ctx: Context, selection_obj): + logger.info(f"Sync federated selection.") + for i, host in enumerate(ctx.hosts): + federated_mask = selection_obj._host_selected_mask[host.name] + ctx.hosts[i].put(f"selected_mask_{selection_obj.method}", federated_mask) + + def transform(self, ctx: Context, test_data): + transformed_data = self._selection_obj[-1].transform(ctx, test_data) + return transformed_data + + def get_model(self): + # all selection obj need to be recorded for display of cascade order + selection_obj_list = [] + for selection_obj in self._selection_obj: + selection_obj_list.append(selection_obj.to_model()) + data = {"selection_obj_list": json.dumps(selection_obj_list), "inner_method": self._inner_method} + meta = {"method": self.method, "select_col": self.select_col, "keep_one": self.keep_one} + return {"data": data, "meta": meta} + + def restore(self, model): + selection_obj_list = [] + selection_obj_model_list = json.loads(model["selection_obj_list"]) + for i, selection_model in enumerate(selection_obj_model_list): + if selection_model["method"] in ["manual"]: + selection_obj = ManualSelection(method=self._inner_method[i]) + else: + selection_obj = StandardSelection(method=self._inner_method[i]) + selection_obj.restore(selection_model) + selection_obj_list.append(selection_obj) + self._selection_obj = selection_obj_list + + @classmethod + def from_model(cls, model) -> "HeteroSelectionModuleGuest": + selection_obj = HeteroSelectionModuleGuest(model["meta"]["method"], model["meta"]["select_col"]) + selection_obj._inner_method = model["data"]["inner_method"] + selection_obj.restore(model["data"]) + return selection_obj + + +class HeteroSelectionModuleHost(HeteroModule): + def __init__( + self, + method=None, + select_col=None, + input_models=None, + iv_param=None, + statistic_param=None, + manual_param=None, + keep_one=True, + ): + self.method = method + self.iv_param = iv_param + self.statistic_param = statistic_param + self.manual_param = manual_param + self.keep_one = keep_one + self.select_col = select_col + # keep selection history + self._inner_method = [] + self._selection_obj = [] + + self.isometric_model_dict = None + if input_models: + isometric_model_dict = {} + for model in input_models: + model_type = model["meta"].get("model_type") + if model_type is None: + raise ValueError(f"Missing 'model_type' in input model") + isometric_model_dict[model_type] = model + self.isometric_model_dict = isometric_model_dict + + def fit(self, ctx: Context, train_data, validate_data=None) -> None: + if self.select_col is None: + self.select_col = train_data.schema.columns.to_list() + select_data = train_data[self.select_col] + header = select_data.schema.columns.to_list() + for i, filter_type in enumerate(self.method): + if filter_type == "manual": + selection_obj = ManualSelection( + method=filter_type, header=header, param=self.manual_param, keep_one=self.keep_one + ) + elif filter_type == "iv": + model = self.isometric_model_dict.get("binning", None) + if model is None: + raise ValueError(f"Cannot find binning model in input, please check") + selection_obj = StandardSelection( + method=filter_type, header=header, param=self.iv_param, model=model, keep_one=self.keep_one + ) + elif filter_type == "statistics": + model = self.isometric_model_dict.get("statistics", None) + if model is None: + raise ValueError(f"Cannot find statistics model in input, please check") + selection_obj = StandardSelection( + method=filter_type, header=header, param=self.statistic_param, model=model, keep_one=self.keep_one + ) + + else: + raise ValueError(f"{type} selection method not supported, please check") + self._selection_obj.append(selection_obj) + self._inner_method.append(filter_type) + + prev_selection_obj = None + for method, selection_obj in zip(self._inner_method, self._selection_obj): + if prev_selection_obj: + selection_obj.set_prev_selected_mask(copy.deepcopy(prev_selection_obj.selected_mask)) + selection_obj.fit(ctx, train_data, validate_data) + if method == "iv": + if self.iv_param.get("select_federated"): + HeteroSelectionModuleHost.sync_select_federated(ctx, selection_obj, train_data) + prev_selection_obj = selection_obj + + @staticmethod + def sync_select_federated(ctx: Context, selection_obj, data): + cur_selected_mask = ctx.guest.get(f"selected_mask_{selection_obj.method}") + # logger.info(f"cur_selected_mask: {cur_selected_mask}") + columns, anonymous_columns = data.schema.columns, data.schema.anonymous_columns + # logger.info(f"anonymous columns: {data.schema.anonymous_columns}") + new_index = [columns[anonymous_columns.get_loc(col)] for col in cur_selected_mask.index] + cur_selected_mask.index = new_index + prev_selected_mask = selection_obj._prev_selected_mask[selection_obj._prev_selected_mask] + missing_col = set(prev_selected_mask.index).difference(set(new_index)) + if missing_col: + raise ValueError(f"results for columns: {missing_col} not found in received selection result.") + cur_selected_mask = [cur_selected_mask.get(col, False) for col in selection_obj._header] + selected_mask = selection_obj._prev_selected_mask & cur_selected_mask + selection_obj.set_selected_mask(selected_mask) + + def transform(self, ctx: Context, test_data): + transformed_data = self._selection_obj[-1].transform(ctx, test_data) + return transformed_data + + def get_model(self): + # all selection history need to be recorded for display + selection_obj_list = [] + for selection_obj in self._selection_obj: + selection_obj_list.append(selection_obj.to_model()) + + data = {"selection_obj_list": json.dumps(selection_obj_list), "inner_method": self._inner_method} + meta = {"method": self.method, "select_col": self.select_col, "keep_one": self.keep_one} + return {"data": data, "meta": meta} + + def restore(self, model): + selection_obj_list = [] + selection_obj_model_list = json.loads(model["selection_obj_list"]) + for i, selection_model in enumerate(selection_obj_model_list): + if selection_model["method"] in ["manual"]: + selection_obj = ManualSelection(method=self._inner_method[i]) + else: + selection_obj = StandardSelection(method=self._inner_method[i]) + selection_obj.restore(selection_model) + selection_obj_list.append(selection_obj) + self._selection_obj = selection_obj_list + + @classmethod + def from_model(cls, model) -> "HeteroSelectionModuleHost": + selection_obj = HeteroSelectionModuleHost(model["meta"]["method"], model["meta"]["select_col"]) + selection_obj._inner_method = model["data"]["inner_method"] + selection_obj.restore(model["data"]) + return selection_obj + + +class ManualSelection(Module): + def __init__(self, method, param=None, header=None, model=None, keep_one=True): + assert method == "manual", f"Manual Selection only accepts 'manual' as `method`, received {method} instead." + self.method = method + self.param = param + self.model = model + self.keep_one = keep_one + self._header = header + self._prev_selected_mask = None + self._selected_mask = None + if header is None: + self._prev_selected_mask = None + else: + self._prev_selected_mask = pd.Series(np.ones(len(header)), dtype=bool, index=header) + + @property + def selected_mask(self): + return self._selected_mask + + def set_selected_mask(self, mask): + self._selected_mask = mask + + def set_prev_selected_mask(self, mask): + self._prev_selected_mask = mask + + def fit(self, ctx: Context, train_data, validate_data=None): + header = train_data.schema.columns.to_list() + if self._header is None: + self._header = header + self._prev_selected_mask = pd.Series(np.ones(len(header)), dtype=bool, index=header) + + filter_out_col = self.param.get("filter_out_col", None) + keep_col = self.param.get("keep_col", None) + if filter_out_col is None: + filter_out_col = [] + if keep_col is None: + keep_col = [] + if len(filter_out_col) >= len(header): + raise ValueError("`filter_out_col` should not be all columns") + filter_out_col = set(filter_out_col) + # keep_col = set(keep_col) + missing_col = (filter_out_col.union(set(keep_col))).difference(set(self._prev_selected_mask.index)) + if missing_col: + raise ValueError( + f"columns {missing_col} given in `filter_out_col` & `keep_col` " f"not found in `select_col` or header" + ) + filter_out_mask = pd.Series( + [False if col in filter_out_col else True for col in self._header], index=self._header + ) + # keep_mask = [True if col in keep_col else False for col in self._header] + selected_mask = self._prev_selected_mask & filter_out_mask + selected_mask.loc[keep_col] = True + self._selected_mask = selected_mask + if self.keep_one: + StandardSelection._keep_one(self._selected_mask, self._header) + + def transform(self, ctx: Context, transform_data): + logger.debug(f"Start transform") + drop_cols = set(self._selected_mask[~self._selected_mask].index) + select_cols = [col for col in transform_data.schema.columns.to_list() if col not in drop_cols] + return transform_data[select_cols] + + def to_model(self): + return dict(method=self.method, keep_one=self.keep_one, selected_mask=self._selected_mask.to_dict()) + + def restore(self, model): + self.method = model["method"] + self.keep_one = model["keep_one"] + self._selected_mask = pd.Series(model["selected_mask"], dtype=bool) + + +class StandardSelection(Module): + def __init__(self, method, header=None, param=None, model=None, keep_one=True): + self.method = method + self.param = param + self.filter_conf = {} + + if param is not None: + for metric_name, filter_type, threshold, take_high in zip( + self.param.get("metrics", DEFAULT_METRIC.get(method)), + self.param.get("filter_type", ["threshold"]), + self.param.get("threshold", [1.0]), + self.param.get("take_high", [True]), + ): + metric_conf = self.filter_conf.get(metric_name, {}) + metric_conf["filter_type"] = metric_conf.get("filter_type", []) + [filter_type] + metric_conf["threshold"] = metric_conf.get("threshold", []) + [threshold] + metric_conf["take_high"] = metric_conf.get("take_high", []) + [take_high] + self.filter_conf[metric_name] = metric_conf + + self.model = self.convert_model(model) + self.keep_one = keep_one + self._header = header + self._selected_mask = None + self._all_selected_mask = None + if header is None: + self._prev_selected_mask = None + else: + self._prev_selected_mask = pd.Series(np.ones(len(header)), dtype=bool, index=header) + self._host_selected_mask = {} + self._all_host_selected_mask = {} + self._host_prev_selected_mask = {} + self._all_metrics = None + self._all_host_metrics = {} + + @staticmethod + def convert_model(input_model): + return input_model + + @property + def selected_mask(self): + return self._selected_mask + + def set_selected_mask(self, mask): + self._selected_mask = mask + + def set_host_prev_selected_mask(self, mask): + self._host_prev_selected_mask = mask + + def set_prev_selected_mask(self, mask): + self._prev_selected_mask = mask + + def fit(self, ctx: Context, train_data, validate_data=None): + if self._header is None: + header = train_data.schema.columns.to_list() + self._header = header + self._prev_selected_mask = pd.Series(np.ones(len(header)), dtype=bool, index=header) + + metric_names = self.param.get("metrics", []) + # local only + if self.method in ["statistics"]: + for metric_name in metric_names: + if metric_name not in self.model.get("meta", {}).get("metrics", {}): + raise ValueError( + f"metric {metric_name} not found in given statistic model with metrics: " + f"{self.model.get('metrics', {})}, please check" + ) + model_data = self.model.get("data", {}) + metrics_all = pd.DataFrame(model_data.get("metrics_summary", {})).loc[metric_names] + self._all_metrics = metrics_all + missing_col = set(self._prev_selected_mask[self._prev_selected_mask].index).difference( + set(metrics_all.columns) + ) + if missing_col: + raise ValueError( + f"metrics for columns {missing_col} from `select_col` or header not found in given model." + ) + + mask_all = self.apply_filter(metrics_all, self.filter_conf) + self._all_selected_mask = mask_all + cur_selected_mask = mask_all.all(axis=0) + cur_selected_mask = [cur_selected_mask.get(col, False) for col in self._header] + self._selected_mask = self._prev_selected_mask & cur_selected_mask + if self.keep_one: + self._keep_one(self._selected_mask, self._prev_selected_mask, self._header) + # federated selection possible + elif self.method == "iv": + # host does not perform local iv selection + if ctx.is_on_host: + return + model_data = self.model.get("data", {}) + iv_metrics = pd.Series(model_data["metrics_summary"]["iv"]) + # metrics_all = pd.DataFrame(iv_metrics).T.rename({0: "iv"}, axis=0) + metrics_all = StandardSelection.convert_series_metric_to_dataframe(iv_metrics, "iv") + self._all_metrics = metrics_all + mask_all = self.apply_filter(metrics_all, self.filter_conf) + self._all_selected_mask = mask_all + cur_selected_mask = mask_all.all(axis=0) + cur_selected_mask = [cur_selected_mask.get(col, False) for col in self._header] + self._selected_mask = self._prev_selected_mask & cur_selected_mask + if self.keep_one: + self._keep_one(self._selected_mask, self._prev_selected_mask, self._header) + if self.param.get("select_federated", True): + host_metrics_summary = self.model.get("data", {}).get("host_metrics_summary") + # logger.info(f"host metrics summary: {host_metrics_summary}") + if host_metrics_summary is None: + raise ValueError(f"Host metrics not found in provided model, please check.") + for host, host_metrics in host_metrics_summary.items(): + iv_metrics = pd.Series(host_metrics["iv"]) + # metrics_all = pd.DataFrame(iv_metrics).T.rename({0: "iv"}, axis=0) + metrics_all = StandardSelection.convert_series_metric_to_dataframe(iv_metrics, "iv") + self._all_host_metrics[host] = metrics_all + host_mask_all = self.apply_filter(metrics_all, self.filter_conf) + self._all_host_selected_mask[host] = host_mask_all + host_selected_mask = host_mask_all.all(axis=0) + if self.keep_one: + self._keep_one(host_selected_mask) + self._host_selected_mask[host] = host_selected_mask + + @staticmethod + def _keep_one(cur_mask, prev_mask=None, select_col=None): + if sum(cur_mask) > 0: + return cur_mask + else: + if prev_mask is not None: + idx = random.choice(prev_mask[prev_mask].index) + elif select_col is not None: + idx = random.choice(select_col) + else: + idx = random.choice(cur_mask.index) + cur_mask[idx] = True + + @staticmethod + def convert_series_metric_to_dataframe(metrics, metric_name): + # logger.info(f"metrics: {metrics}") + return pd.DataFrame(metrics).T.rename({0: metric_name}, axis=0) + + @staticmethod + def apply_filter(metrics_all, filter_conf): + return metrics_all.apply(lambda r: StandardSelection.filter_multiple_metrics(r, filter_conf[r.name]), axis=1) + + @staticmethod + def filter_multiple_metrics(metrics, metric_conf): + filter_type_list = metric_conf["filter_type"] + threshold_list = metric_conf["threshold"] + take_high_list = metric_conf["take_high"] + result = pd.Series(np.ones(len(metrics.index)), index=metrics.index, dtype=bool) + for idx in range(len(filter_type_list)): + result &= StandardSelection.filter_metrics( + metrics, filter_type_list[idx], threshold_list[idx], take_high_list[idx] + ) + return result + + @staticmethod + def filter_metrics(metrics, filter_type, threshold, take_high=True): + if filter_type == "top_k": + return StandardSelection.filter_by_top_k(metrics, threshold, take_high) + elif filter_type == "threshold": + return StandardSelection.filter_by_threshold(metrics, threshold, take_high) + elif filter_type == "top_percentile": + return StandardSelection.filter_by_percentile(metrics, threshold, take_high) + else: + raise ValueError(f"filter_type {filter_type} not supported, please check") + + @staticmethod + def filter_by_top_k(metrics, k, take_high=True): + # strict top k + if k == 0: + return pd.Series(np.ones(len(metrics)), dtype=bool) + # stable sort + ordered_metrics = metrics.sort_values(ascending=not take_high, kind="mergesort") + select_k = ordered_metrics.index[:k] + return metrics.index.isin(select_k) + + @staticmethod + def filter_by_threshold(metrics, threshold, take_high=True): + if take_high: + return metrics >= threshold + else: + return metrics <= threshold + + @staticmethod + def filter_by_percentile(metrics, percentile, take_high=True): + top_k = math.ceil(len(metrics) * percentile) + return StandardSelection.filter_by_top_k(metrics, top_k, take_high) + + def transform(self, ctx: Context, transform_data): + logger.debug(f"Start transform") + drop_cols = set(self._selected_mask[~self._selected_mask].index) + cols = transform_data.schema.columns.to_list() + select_cols = [col for col in cols if col not in drop_cols] + return transform_data[select_cols] + + def to_model(self): + return dict( + method=self.method, + keep_one=self.keep_one, + all_selected_mask=self._all_selected_mask.to_dict() if self._all_selected_mask is not None else None, + all_metrics=self._all_metrics.to_dict() if self._all_metrics is not None else None, + all_host_metrics={k: v.to_dict() for k, v in self._all_host_metrics.items()} + if self._all_host_metrics is not None + else None, + selected_mask=self._selected_mask.to_dict(), + host_selected_mask={k: v.to_dict() for k, v in self._host_selected_mask.items()} + if self._host_selected_mask is not None + else None, + all_host_selected_mask={k: v.to_dict() for k, v in self._all_host_selected_mask.items()} + if self._all_host_selected_mask is not None + else None, + ) + + def restore(self, model): + self.method = model["method"] + self.keep_one = model["keep_one"] + self._selected_mask = pd.Series(model["selected_mask"], dtype=bool) + self._all_selected_mask = pd.DataFrame(model["all_selected_mask"], dtype=bool) + self._all_metrics = pd.DataFrame(model["all_metrics"]) + self._host_selected_mask = {k: pd.Series(v, dtype=bool) for k, v in model["host_selected_mask"].items()} + self._all_host_selected_mask = { + k: pd.DataFrame(v, dtype=bool) for k, v in model["all_host_selected_mask"].items() + } + self._all_host_metrics = {k: pd.DataFrame(v) for k, v in model["all_host_metrics"].items()} diff --git a/python/fate/ml/glm/__init__.py b/python/fate/ml/glm/__init__.py new file mode 100644 index 0000000000..4495033129 --- /dev/null +++ b/python/fate/ml/glm/__init__.py @@ -0,0 +1,4 @@ +from .hetero.coordinated_linr import CoordinatedLinRModuleHost, CoordinatedLinRModuleGuest, CoordinatedLinRModuleArbiter +from .hetero.coordinated_lr import CoordinatedLRModuleHost, CoordinatedLRModuleGuest, CoordinatedLRModuleArbiter +from .homo.lr.client import HomoLRClient +from .homo.lr.server import HomoLRServer diff --git a/python/fate/ml/glm/hetero/__init__.py b/python/fate/ml/glm/hetero/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/fate/interface/_tensor.py b/python/fate/ml/glm/hetero/coordinated_linr/__init__.py similarity index 81% rename from python/fate/interface/_tensor.py rename to python/fate/ml/glm/hetero/coordinated_linr/__init__.py index 42a655ad78..53d20fc9af 100644 --- a/python/fate/interface/_tensor.py +++ b/python/fate/ml/glm/hetero/coordinated_linr/__init__.py @@ -12,9 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List - -class Tensor: - shape: List[int] - T: "Tensor" +from .arbiter import CoordinatedLinRModuleArbiter +from .guest import CoordinatedLinRModuleGuest +from .host import CoordinatedLinRModuleHost diff --git a/python/fate/ml/glm/hetero/coordinated_linr/arbiter.py b/python/fate/ml/glm/hetero/coordinated_linr/arbiter.py new file mode 100644 index 0000000000..2ffaf0d8b0 --- /dev/null +++ b/python/fate/ml/glm/hetero/coordinated_linr/arbiter.py @@ -0,0 +1,218 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +import torch + +from fate.arch import Context +from fate.arch.dataframe import DataLoader +from fate.ml.abc.module import HeteroModule +from fate.ml.utils._convergence import converge_func_factory +from fate.ml.utils._optimizer import LRScheduler, Optimizer, separate + +logger = logging.getLogger(__name__) + + +class CoordinatedLinRModuleArbiter(HeteroModule): + def __init__(self, epochs, early_stop, tol, batch_size, optimizer_param, learning_rate_param): + self.epochs = epochs + self.batch_size = batch_size + self.early_stop = early_stop + self.tol = tol + self.learning_rate_param = learning_rate_param + self.optimizer_param = optimizer_param + + self.estimator = None + + def set_batch_size(self, batch_size): + self.batch_size = batch_size + self.estimator.batch_size = batch_size + + def set_epochs(self, epochs): + self.epochs = epochs + self.estimator.epochs = epochs + + def fit(self, ctx: Context) -> None: + kit = ctx.cipher.phe.setup() + encryptor = kit.get_tensor_encryptor() + decryptor = kit.get_tensor_decryptor() + ctx.hosts("encryptor").put(encryptor) + ctx.guest("encryptor").put(encryptor) + if self.estimator is None: + optimizer = Optimizer( + self.optimizer_param["method"], + self.optimizer_param["penalty"], + self.optimizer_param["alpha"], + self.optimizer_param["optimizer_params"], + ) + lr_scheduler = LRScheduler( + self.learning_rate_param["method"], self.learning_rate_param["scheduler_params"] + ) + single_estimator = HeteroLinREstimatorArbiter( + epochs=self.epochs, + early_stop=self.early_stop, + tol=self.tol, + batch_size=self.batch_size, + optimizer=optimizer, + learning_rate_scheduler=lr_scheduler, + ) + self.estimator = single_estimator + self.estimator.fit_model(ctx, decryptor) + + def get_model(self): + return { + "data": {"estimator": self.estimator.get_model()}, + "meta": { + "epochs": self.epochs, + "early_stop": self.early_stop, + "tol": self.tol, + "batch_size": self.batch_size, + "learning_rate_param": self.learning_rate_param, + "optimizer_param": self.optimizer_param, + }, + } + + @classmethod + def from_model(cls, model): + linr = CoordinatedLinRModuleArbiter( + model["meta"]["epochs"], + model["meta"]["early_stop"], + model["meta"]["tol"], + model["meta"]["batch_size"], + model["meta"]["optimizer_param"], + model["meta"]["learning_rate_param"], + ) + estimator = HeteroLinREstimatorArbiter() + estimator.restore(model["data"]["estimator"]) + linr.estimator = estimator + return linr + + +class HeteroLinREstimatorArbiter(HeteroModule): + def __init__( + self, epochs=None, early_stop=None, tol=None, batch_size=None, optimizer=None, learning_rate_scheduler=None + ): + self.epochs = epochs + self.batch_size = batch_size + self.early_stop = early_stop + self.tol = tol + self.optimizer = optimizer + self.lr_scheduler = learning_rate_scheduler + + if early_stop is not None: + self.converge_func = converge_func_factory(early_stop, tol) + self.start_epoch = 0 + self.end_epoch = -1 + self.is_converged = False + + def fit_model(self, ctx, decryptor): + batch_loader = DataLoader( + dataset=None, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="arbiter", sync_arbiter=True + ) + logger.info(f"batch_num={batch_loader.batch_num}") + if self.optimizer.optimizer is None: + optimizer_ready = False + else: + optimizer_ready = True + # self.start_epoch = self.end_epoch + 1 + + for i, iter_ctx in ctx.on_iterations.ctxs_range(self.epochs): + iter_loss = None + iter_g = None + self.optimizer.set_iters(i) + logger.info(f"self.optimizer set epoch {i}") + for batch_ctx, _ in iter_ctx.on_batches.ctxs_zip(batch_loader): + g_guest_enc = batch_ctx.guest.get("g_enc") + g_guest = decryptor.decrypt_tensor(g_guest_enc) + size_list = [g_guest.size()[0]] + g_total = g_guest.squeeze() + # get torch tensor + + host_g = batch_ctx.hosts.get("g_enc") + for i, g_host_enc in enumerate(host_g): + g = decryptor.decrypt_tensor(g_host_enc) + size_list.append(g.size()[0]) + g_total = torch.hstack((g_total, g.squeeze())) + if not optimizer_ready: + self.optimizer.init_optimizer(model_parameter_length=sum(size_list), dtype=g_total.dtype) + self.lr_scheduler.init_scheduler(self.optimizer.optimizer) + optimizer_ready = True + self.optimizer.step(g_total.unsqueeze(1)) + delta_g = self.optimizer.get_delta_gradients() + delta_g_list = separate(delta_g, size_list) + + delta_g_list_squeezed = [] + batch_ctx.guest.put("g", delta_g_list[0]) + delta_g_list_squeezed.append(delta_g_list[0].squeeze()) + for i, g_host in enumerate(delta_g_list[1:]): + batch_ctx.hosts[i].put("g", g_host) + delta_g_list_squeezed.append(g_host.squeeze()) + if iter_g is None: + iter_g = torch.hstack(delta_g_list_squeezed) + else: + iter_g += torch.hstack(delta_g_list_squeezed) + + if len(host_g) == 1: + loss = decryptor.decrypt_tensor(batch_ctx.guest.get("loss")) + iter_loss = 0 if iter_loss is None else iter_loss + iter_loss += loss + else: + logger.info("Multiple hosts exist, do not compute loss.") + + if iter_loss is not None: + iter_ctx.metrics.log_loss("linr_loss", iter_loss.tolist()[0]) + if self.early_stop == "weight_diff": + self.is_converged = self.converge_func.is_converge(iter_g) + else: + if iter_loss is None: + raise ValueError( + "Multiple host situation, loss early stop function is not available." + "You should use 'weight_diff' instead" + ) + self.is_converged = self.converge_func.is_converge(iter_loss) + + iter_ctx.hosts.put("converge_flag", self.is_converged) + iter_ctx.guest.put("converge_flag", self.is_converged) + + if self.is_converged: + self.end_epoch = i + break + if i < self.epochs - 1: + self.lr_scheduler.step() + if not self.is_converged: + self.end_epoch = self.epochs + logger.debug(f"Finish training at {self.end_epoch}th epoch.") + + def get_model(self): + return { + "optimizer": self.optimizer.state_dict(), + "lr_scheduler": self.lr_scheduler.state_dict(), + "end_epoch": self.end_epoch, + "is_converged": self.is_converged, + "tol": self.tol, + "early_stop": self.early_stop, + } + + def restore(self, model): + self.optimizer = Optimizer() + self.lr_scheduler = LRScheduler() + self.optimizer.load_state_dict(model["optimizer"]) + self.lr_scheduler.load_state_dict(model["lr_scheduler"], self.optimizer.optimizer) + self.end_epoch = model["end_epoch"] + self.is_converged = model["is_converged"] + self.tol = model["tol"] + self.early_stop = model["early_stop"] + self.converge_func = converge_func_factory(self.early_stop, self.tol) + # self.start_epoch = model["end_epoch"] + 1 diff --git a/python/fate/ml/glm/hetero/coordinated_linr/guest.py b/python/fate/ml/glm/hetero/coordinated_linr/guest.py new file mode 100644 index 0000000000..abf229a930 --- /dev/null +++ b/python/fate/ml/glm/hetero/coordinated_linr/guest.py @@ -0,0 +1,274 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import torch + +from fate.arch import Context, dataframe +from fate.ml.abc.module import HeteroModule +from fate.ml.utils import predict_tools +from fate.ml.utils._model_param import ( + deserialize_param, + initialize_param, + serialize_param, +) +from fate.ml.utils._optimizer import LRScheduler, Optimizer + +logger = logging.getLogger(__name__) + + +class CoordinatedLinRModuleGuest(HeteroModule): + def __init__(self, epochs=None, batch_size=None, optimizer_param=None, learning_rate_param=None, init_param=None, + floating_point_precision=23): + self.epochs = epochs + self.batch_size = batch_size + self.optimizer_param = optimizer_param + self.learning_rate_param = learning_rate_param + self.init_param = init_param + self.floating_point_precision = floating_point_precision + + self.estimator = None + + def set_batch_size(self, batch_size): + self.batch_size = batch_size + self.estimator.batch_size = batch_size + + def set_epochs(self, epochs): + self.epochs = epochs + self.estimator.epochs = epochs + + def fit(self, ctx: Context, train_data, validate_data=None) -> None: + if self.estimator is None: + optimizer = Optimizer( + self.optimizer_param["method"], + self.optimizer_param["penalty"], + self.optimizer_param["alpha"], + self.optimizer_param["optimizer_params"], + ) + lr_scheduler = LRScheduler( + self.learning_rate_param["method"], self.learning_rate_param["scheduler_params"] + ) + estimator = CoordinatedLinREstimatorGuest( + epochs=self.epochs, + batch_size=self.batch_size, + optimizer=optimizer, + learning_rate_scheduler=lr_scheduler, + init_param=self.init_param, + floating_point_precision=self.floating_point_precision + ) + self.estimator = estimator + encryptor = ctx.arbiter("encryptor").get() + self.estimator.fit_model(ctx, encryptor, train_data, validate_data) + + def predict(self, ctx, test_data): + prob = self.estimator.predict(ctx, test_data) + return prob + + def get_model(self): + return { + "data": {"estimator": self.estimator.get_model()}, + "meta": { + "epochs": self.epochs, + "batch_size": self.batch_size, + "learning_rate_param": self.learning_rate_param, + "init_param": self.init_param, + "optimizer_param": self.optimizer_param, + "floating_point_precision": self.floating_point_precision, + }, + } + + @classmethod + def from_model(cls, model) -> "CoordinatedLinRModuleGuest": + linr = CoordinatedLinRModuleGuest( + optimizer_param=model["meta"]["optimizer_param"], + learning_rate_param=model["meta"]["learning_rate_param"], + batch_size=model["meta"]["batch_size"], + init_param=model["meta"]["init_param"], + floating_point_precision=model["meta"]["floating_point_precision"] + ) + estimator = CoordinatedLinREstimatorGuest( + epochs=model["meta"]["epochs"], + batch_size=model["meta"]["batch_size"], + init_param=model["meta"]["init_param"], + floating_point_precision=model["meta"]["floating_point_precision"] + ) + estimator.restore(model["data"]["estimator"]) + linr.estimator = estimator + + return linr + + +class CoordinatedLinREstimatorGuest(HeteroModule): + def __init__(self, epochs=None, batch_size=None, optimizer=None, learning_rate_scheduler=None, init_param=None, + floating_point_precision=23): + self.epochs = epochs + self.batch_size = batch_size + self.optimizer = optimizer + self.lr_scheduler = learning_rate_scheduler + self.init_param = init_param + self.floating_point_precision = floating_point_precision + self._fixpoint_precision = 2 ** floating_point_precision + + self.w = None + self.start_epoch = 0 + self.end_epoch = -1 + self.is_converged = False + + def asynchronous_compute_gradient(self, batch_ctx, encryptor, w, X, Y, weight): + h = X.shape[0] + Xw = torch.matmul(X, w.detach()) + half_d = Xw - Y + if weight: + half_d = half_d * weight + batch_ctx.hosts.put("half_d", encryptor.encrypt_tensor(half_d, obfuscate=True)) + half_g = torch.matmul(X.T, half_d) + + Xw_h = batch_ctx.hosts.get("Xw_h")[0] + if weight: + Xw_h = Xw_h * weight + if self.floating_point_precision: + host_half_g = torch.matmul(torch.encode_as_int_f(X.T, self.floating_point_precision), Xw_h) + host_half_g = 1 / self._fixpoint_precision * host_half_g + else: + host_half_g = torch.matmul(X.T, Xw_h) + + loss = 0.5 / h * torch.matmul(half_d.T, half_d) + if self.optimizer.l1_penalty or self.optimizer.l2_penalty: + loss_norm = self.optimizer.loss_norm(w) + loss += loss_norm + + for Xw2_h in batch_ctx.hosts.get("Xw2_h"): + loss += 0.5 / h * Xw2_h + h_loss_list = batch_ctx.hosts.get("h_loss") + for h_loss in h_loss_list: + if h_loss is not None: + loss += h_loss + + batch_ctx.arbiter.put(loss=loss) + + # gradient + g = 1 / h * (half_g + host_half_g) + return g + + def centralized_compute_gradient(self, batch_ctx, w, X, Y, weight): + h = X.shape[0] + Xw = torch.matmul(X, w.detach()) + d = Xw - Y + + Xw_h_all = batch_ctx.hosts.get("Xw_h") + for Xw_h in Xw_h_all: + d += Xw_h + + if weight: + d = d * weight + batch_ctx.hosts.put(d=d) + + # gradient + if self.floating_point_precision: + g = torch.matmul(torch.encode_as_int_f(X.T, self.floating_point_precision), d) + g = 1 / (self._fixpoint_precision * h) * g + else: + g = 1 / h * torch.matmul(X.T, d) + return g + + def fit_model(self, ctx, encryptor, train_data, validate_data=None): + coef_count = train_data.shape[1] + logger.debug(f"init param: {self.init_param}") + if self.init_param.get("fit_intercept"): + logger.debug(f"add intercept to train data") + train_data["intercept"] = 1.0 + w = self.w + if self.w is None: + w = initialize_param(coef_count, **self.init_param) + self.optimizer.init_optimizer(model_parameter_length=w.size()[0]) + self.lr_scheduler.init_scheduler(optimizer=self.optimizer.optimizer) + batch_loader = dataframe.DataLoader( + train_data, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="guest", sync_arbiter=True + ) + # if self.end_epoch >= 0: + # self.start_epoch = self.end_epoch + 1 + is_centralized = len(ctx.hosts) > 1 + + for i, iter_ctx in ctx.on_iterations.ctxs_range(self.epochs): + self.optimizer.set_iters(i) + logger.info(f"self.optimizer set epoch {i}") + for batch_ctx, batch_data in iter_ctx.on_batches.ctxs_zip(batch_loader): + X = batch_data.x + Y = batch_data.label + weight = batch_data.weight + if is_centralized: + g = self.centralized_compute_gradient(batch_ctx, w, X, Y, weight) + else: + g = self.asynchronous_compute_gradient(batch_ctx, encryptor, w, X, Y, weight) + g = self.optimizer.add_regular_to_grad(g, w, self.init_param.get("fit_intercept")) + batch_ctx.arbiter.put("g_enc", g) + g = batch_ctx.arbiter.get("g") + + w = self.optimizer.update_weights(w, g, self.init_param.get("fit_intercept"), self.lr_scheduler.lr) + # logger.info(f"w={w}") + self.is_converged = iter_ctx.arbiter("converge_flag").get() + if self.is_converged: + self.end_epoch = i + break + if i < self.epochs - 1: + self.lr_scheduler.step() + if not self.is_converged: + self.end_epoch = self.epochs + self.w = w + logger.debug(f"Finish training at {self.end_epoch}th epoch.") + + def predict(self, ctx, test_data): + pred_df = test_data.create_frame(with_label=True, with_weight=False) + if self.init_param.get("fit_intercept"): + test_data["intercept"] = 1.0 + X = test_data.values.as_tensor() + pred = torch.matmul(X, self.w) + for h_pred in ctx.hosts.get("h_pred"): + pred += h_pred + pred_df[predict_tools.PREDICT_SCORE] = pred + predict_result = predict_tools.compute_predict_details(pred_df, task_type=predict_tools.REGRESSION) + return predict_result + + def get_model(self): + """w = self.w.tolist() + intercept = None + if self.init_param.get("fit_intercept"): + w = w[:-1] + intercept = w[-1]""" + param = serialize_param(self.w, self.init_param.get("fit_intercept")) + return { + "param": param, + # "intercept": intercept, + "optimizer": self.optimizer.state_dict(), + "lr_scheduler": self.lr_scheduler.state_dict(), + "end_epoch": self.end_epoch, + "is_converged": self.is_converged, + "fit_intercept": self.init_param.get("fit_intercept"), + } + + def restore(self, model): + """w = model["w"] + if model["fit_intercept"]: + w.append(model["intercept"]) + self.w = torch.tensor(w) + """ + self.w = deserialize_param(model["param"], model["fit_intercept"]) + self.optimizer = Optimizer() + self.lr_scheduler = LRScheduler() + self.optimizer.load_state_dict(model["optimizer"]) + self.lr_scheduler.load_state_dict(model["lr_scheduler"], self.optimizer.optimizer) + self.end_epoch = model["end_epoch"] + self.is_converged = model["is_converged"] diff --git a/python/fate/ml/glm/hetero/coordinated_linr/host.py b/python/fate/ml/glm/hetero/coordinated_linr/host.py new file mode 100644 index 0000000000..98793f4816 --- /dev/null +++ b/python/fate/ml/glm/hetero/coordinated_linr/host.py @@ -0,0 +1,234 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +import torch + +from fate.arch import Context +from fate.arch.dataframe import DataLoader +from fate.ml.abc.module import HeteroModule +from fate.ml.utils._model_param import ( + deserialize_param, + initialize_param, + serialize_param, +) +from fate.ml.utils._optimizer import LRScheduler, Optimizer + +logger = logging.getLogger(__name__) + + +class CoordinatedLinRModuleHost(HeteroModule): + def __init__(self, epochs=None, batch_size=None, optimizer_param=None, learning_rate_param=None, init_param=None, + floating_point_precision=23): + self.epochs = epochs + self.optimizer_param = optimizer_param + self.learning_rate_param = learning_rate_param + self.batch_size = batch_size + self.init_param = init_param or {} + self.init_param["fit_intercept"] = False + self.floating_point_precision = 23 + + self.estimator = None + + def set_batch_size(self, batch_size): + self.batch_size = batch_size + self.estimator.batch_size = batch_size + + def set_epochs(self, epochs): + self.epochs = epochs + self.estimator.epochs = epochs + + def fit(self, ctx: Context, train_data, validate_data=None) -> None: + encryptor = ctx.arbiter("encryptor").get() + if self.estimator is None: + optimizer = Optimizer( + self.optimizer_param["method"], + self.optimizer_param["penalty"], + self.optimizer_param["alpha"], + self.optimizer_param["optimizer_params"], + ) + lr_scheduler = LRScheduler( + self.learning_rate_param["method"], self.learning_rate_param["scheduler_params"] + ) + estimator = CoordinatedLinREstimatorHost( + epochs=self.epochs, + batch_size=self.batch_size, + optimizer=optimizer, + learning_rate_scheduler=lr_scheduler, + init_param=self.init_param, + floating_point_precision=self.floating_point_precision + ) + self.estimator = estimator + + self.estimator.fit_model(ctx, encryptor, train_data, validate_data) + + def predict(self, ctx, test_data): + self.estimator.predict(ctx, test_data) + + def get_model(self): + return { + "data": {"estimator": self.estimator.get_model()}, + "meta": { + "epochs": self.epochs, + "batch_size": self.batch_size, + "learning_rate_param": self.learning_rate_param, + "init_param": self.init_param, + "optimizer_param": self.optimizer_param, + "floating_point_precision": self.floating_point_precision, + }, + } + + @classmethod + def from_model(cls, model) -> "CoordinatedLinRModuleHost": + linr = CoordinatedLinRModuleHost( + optimizer_param=model["meta"]["optimizer_param"], + learning_rate_param=model["meta"]["learning_rate_param"], + epochs=model["meta"]["epochs"], + batch_size=model["meta"]["batch_size"], + init_param=model["meta"]["init_param"], + floating_point_precision=model["meta"]["floating_point_precision"] + ) + estimator = CoordinatedLinREstimatorHost( + epochs=model["meta"]["epochs"], + batch_size=model["meta"]["batch_size"], + init_param=model["meta"]["init_param"], + floating_point_precision=model["meta"]["floating_point_precision"] + ) + estimator.restore(model["data"]["estimator"]) + linr.estimator = estimator + + return linr + + +class CoordinatedLinREstimatorHost(HeteroModule): + def __init__(self, epochs=None, batch_size=None, optimizer=None, learning_rate_scheduler=None, init_param=None, + floating_point_precision=23): + self.epochs = epochs + self.optimizer = optimizer + self.lr_scheduler = learning_rate_scheduler + self.batch_size = batch_size + self.init_param = init_param + self.floating_point_precision = floating_point_precision + self._fixpoint_precision = 2 ** floating_point_precision + + self.w = None + self.start_epoch = 0 + self.end_epoch = -1 + self.is_converged = False + + def asynchronous_compute_gradient(self, batch_ctx, encryptor, w, X): + h = X.shape[0] + Xw_h = torch.matmul(X, w.detach()) + batch_ctx.guest.put("Xw_h", encryptor.encrypt_tensor(Xw_h, obfuscate=True)) + half_g = torch.matmul(X.T, Xw_h) + guest_half_d = batch_ctx.guest.get("half_d") + if self.floating_point_precision: + guest_half_g = torch.matmul(torch.encode_as_int_f(X.T, self.floating_point_precision), guest_half_d) + guest_half_g = 1 / self._fixpoint_precision * guest_half_g + else: + guest_half_g = torch.matmul(X.T, guest_half_d) + + batch_ctx.guest.put("Xw2_h", encryptor.encrypt_tensor(torch.matmul(Xw_h.T, Xw_h))) + loss_norm = self.optimizer.loss_norm(w) + if loss_norm is not None: + batch_ctx.guest.put("h_loss", encryptor.encrypt_tensor(loss_norm)) + else: + batch_ctx.guest.put(h_loss=loss_norm) + + g = 1 / h * (half_g + guest_half_g) + return g + + def centralized_compute_gradient(self, batch_ctx, encryptor, w, X): + h = X.shape[0] + Xw_h = torch.matmul(X, w.detach()) + batch_ctx.guest.put("Xw_h", encryptor.encrypt_tensor(Xw_h, obfuscate=True)) + + d = batch_ctx.guest.get("d") + if self.floating_point_precision: + g = torch.matmul(torch.encode_as_int_f(X.T, self.floating_point_precision), d) + g = 1 / (self._fixpoint_precision * h) * g + else: + g = 1 / h * torch.matmul(X.T, d) + return g + + def fit_model(self, ctx: Context, encryptor, train_data, validate_data=None) -> None: + batch_loader = DataLoader(train_data, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="host") + + coef_count = train_data.shape[1] + w = self.w + if self.w is None: + w = initialize_param(coef_count, **self.init_param) + self.optimizer.init_optimizer(model_parameter_length=w.size()[0]) + self.lr_scheduler.init_scheduler(optimizer=self.optimizer.optimizer) + # if self.end_epoch >= 0: + # self.start_epoch = self.end_epoch + 1 + is_centralized = len(ctx.hosts) > 1 + for i, iter_ctx in ctx.on_iterations.ctxs_range(self.epochs): + self.optimizer.set_iters(i) + logger.info(f"self.optimizer set epoch {i}") + for batch_ctx, batch_data in iter_ctx.on_batches.ctxs_zip(batch_loader): + X = batch_data.x + if is_centralized: + g = self.centralized_compute_gradient(batch_ctx, encryptor, w, X) + else: + g = self.asynchronous_compute_gradient(batch_ctx, encryptor, w, X) + g = self.optimizer.add_regular_to_grad(g, w, False) + batch_ctx.arbiter.put("g_enc", g) + g = batch_ctx.arbiter.get("g") + + w = self.optimizer.update_weights(w, g, False, self.lr_scheduler.lr) + logger.info(f"w={w}") + self.is_converged = iter_ctx.arbiter("converge_flag").get() + if self.is_converged: + self.end_epoch = i + break + if i < self.epochs - 1: + self.lr_scheduler.step() + if not self.is_converged: + self.end_epoch = self.epochs + self.w = w + logger.debug(f"Finish training at {self.end_epoch}th epoch.") + + def predict(self, ctx, test_data): + X = test_data.values.as_tensor() + output = torch.matmul(X, self.w) + ctx.guest.put("h_pred", output) + + def get_model(self): + """return { + "w": self.w.tolist(), + "optimizer": self.optimizer.state_dict(), + "lr_scheduler": self.lr_scheduler.state_dict(), + "end_epoch": self.end_epoch, + "is_converged": self.is_converged + }""" + param = serialize_param(self.w, False) + return { + "param": param, + "optimizer": self.optimizer.state_dict(), + "lr_scheduler": self.lr_scheduler.state_dict(), + "end_epoch": self.end_epoch, + "is_converged": self.is_converged, + } + + def restore(self, model): + # self.w = torch.tensor(model["w"]) + self.w = deserialize_param(model["param"], False) + self.optimizer = Optimizer() + self.lr_scheduler = LRScheduler() + self.optimizer.load_state_dict(model["optimizer"]) + self.lr_scheduler.load_state_dict(model["lr_scheduler"], self.optimizer.optimizer) + self.end_epoch = model["end_epoch"] + self.is_converged = model["is_converged"] diff --git a/python/fate/components/entrypoint/clean_cli.py b/python/fate/ml/glm/hetero/coordinated_lr/__init__.py similarity index 82% rename from python/fate/components/entrypoint/clean_cli.py rename to python/fate/ml/glm/hetero/coordinated_lr/__init__.py index 981006ad28..bf577d797a 100644 --- a/python/fate/components/entrypoint/clean_cli.py +++ b/python/fate/ml/glm/hetero/coordinated_lr/__init__.py @@ -12,13 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import click - -@click.command() -def clean(): - """ - clean task resources - """ - # TODO: implement - print("cleaned") +from .arbiter import CoordinatedLRModuleArbiter +from .guest import CoordinatedLRModuleGuest +from .host import CoordinatedLRModuleHost diff --git a/python/fate/ml/glm/hetero/coordinated_lr/arbiter.py b/python/fate/ml/glm/hetero/coordinated_lr/arbiter.py new file mode 100644 index 0000000000..8ca8c5341e --- /dev/null +++ b/python/fate/ml/glm/hetero/coordinated_lr/arbiter.py @@ -0,0 +1,281 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +import torch + +from fate.arch import Context +from fate.arch.dataframe import DataLoader +from fate.ml.abc.module import HeteroModule +from fate.ml.utils._convergence import converge_func_factory +from fate.ml.utils._optimizer import LRScheduler, Optimizer, separate + +logger = logging.getLogger(__name__) + + +class CoordinatedLRModuleArbiter(HeteroModule): + def __init__(self, epochs, early_stop, tol, batch_size, optimizer_param, learning_rate_param): + self.epochs = epochs + self.batch_size = batch_size + self.early_stop = early_stop + self.tol = tol + self.learning_rate_param = learning_rate_param + self.optimizer_param = optimizer_param + self.lr_param = learning_rate_param + + self.estimator = None + self.ovr = False + + def set_batch_size(self, batch_size): + self.batch_size = batch_size + if self.ovr: + for estimator in self.estimator.values(): + estimator.batch_size = batch_size + else: + self.estimator.batch_size = batch_size + + def set_epochs(self, epochs): + self.epochs = epochs + if self.ovr: + for estimator in self.estimator.values(): + estimator.epochs = epochs + else: + self.estimator.epochs = epochs + + def fit(self, ctx: Context) -> None: + kit = ctx.cipher.phe.setup() + encryptor = kit.get_tensor_encryptor() + decryptor = kit.get_tensor_decryptor() + ctx.hosts("encryptor").put(encryptor) + ctx.guest("encryptor").put(encryptor) + label_count = ctx.guest("label_count").get() + if label_count > 2 or self.ovr: + self.ovr = True + warm_start = True + if self.estimator is None: + self.estimator = {} + warm_start = False + for i, class_ctx in ctx.sub_ctx("class").ctxs_range(label_count): + if not warm_start: + optimizer = Optimizer( + self.optimizer_param["method"], + self.optimizer_param["penalty"], + self.optimizer_param["alpha"], + self.optimizer_param["optimizer_params"], + ) + lr_scheduler = LRScheduler( + self.learning_rate_param["method"], self.learning_rate_param["scheduler_params"] + ) + single_estimator = CoordinatedLREstimatorArbiter( + epochs=self.epochs, + early_stop=self.early_stop, + tol=self.tol, + batch_size=self.batch_size, + optimizer=optimizer, + learning_rate_scheduler=lr_scheduler, + ) + else: + logger.info("estimator is not none, will train with warm start") + single_estimator = self.estimator[i] + single_estimator.epochs = self.epochs + single_estimator.batch_size = self.batch_size + single_estimator.fit_single_model(class_ctx, decryptor) + self.estimator[i] = single_estimator + else: + if self.estimator is None: + optimizer = Optimizer( + self.optimizer_param["method"], + self.optimizer_param["penalty"], + self.optimizer_param["alpha"], + self.optimizer_param["optimizer_params"], + ) + lr_scheduler = LRScheduler( + self.learning_rate_param["method"], self.learning_rate_param["scheduler_params"] + ) + single_estimator = CoordinatedLREstimatorArbiter( + epochs=self.epochs, + early_stop=self.early_stop, + tol=self.tol, + batch_size=self.batch_size, + optimizer=optimizer, + learning_rate_scheduler=lr_scheduler, + ) + else: + logger.info("estimator is not none, will train with warm start") + single_estimator = self.estimator + single_estimator.epochs = self.epochs + single_estimator.batch_size = self.batch_size + single_estimator.fit_single_model(ctx, decryptor) + self.estimator = single_estimator + + def get_model(self): + all_estimator = {} + if self.ovr: + for label, estimator in self.estimator.items(): + all_estimator[label] = estimator.get_model() + else: + all_estimator = self.estimator.get_model() + return { + "data": {"estimator": all_estimator}, + "meta": { + "epochs": self.epochs, + "ovr": self.ovr, + "early_stop": self.early_stop, + "tol": self.tol, + "batch_size": self.batch_size, + "learning_rate_param": self.learning_rate_param, + "optimizer_param": self.optimizer_param, + }, + } + + @classmethod + def from_model(cls, model) -> "CoordinatedLRModuleArbiter": + lr = CoordinatedLRModuleArbiter( + epochs=model["meta"]["epochs"], + early_stop=model["meta"]["early_stop"], + tol=model["meta"]["tol"], + batch_size=model["meta"]["batch_size"], + optimizer_param=model["meta"]["optimizer_param"], + learning_rate_param=model["meta"]["learning_rate_param"], + ) + all_estimator = model["data"]["estimator"] + lr.estimator = {} + if lr.ovr: + for label, d in all_estimator.items(): + estimator = CoordinatedLREstimatorArbiter( + epochs=model["meta"]["epochs"], batch_size=model["meta"]["batch_size"] + ) + estimator.restore(d) + lr.estimator[int(label)] = estimator + else: + estimator = CoordinatedLREstimatorArbiter() + estimator.restore(all_estimator) + lr.estimator = estimator + return lr + + +class CoordinatedLREstimatorArbiter(HeteroModule): + def __init__( + self, epochs=None, early_stop=None, tol=None, batch_size=None, optimizer=None, learning_rate_scheduler=None + ): + self.epochs = epochs + self.batch_size = batch_size + self.early_stop = early_stop + self.tol = tol + self.optimizer = optimizer + self.lr_scheduler = learning_rate_scheduler + + if early_stop is not None: + self.converge_func = converge_func_factory(early_stop, tol) + self.start_epoch = 0 + self.end_epoch = -1 + self.is_converged = False + + def fit_single_model(self, ctx: Context, decryptor): + batch_loader = DataLoader( + dataset=None, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="arbiter", sync_arbiter=True + ) + logger.info(f"batch_num={batch_loader.batch_num}") + if self.optimizer.optimizer is None: + optimizer_ready = False + else: + optimizer_ready = True + # self.start_epoch = self.end_epoch + 1 + for i, iter_ctx in ctx.on_iterations.ctxs_range(self.epochs): + iter_loss = None + iter_g = None + self.optimizer.set_iters(i) + logger.info(f"self.optimizer set epoch {i}") + for batch_ctx, _ in iter_ctx.on_batches.ctxs_zip(batch_loader): + g_guest_enc = batch_ctx.guest.get("g_enc") + g_guest = decryptor.decrypt_tensor(g_guest_enc) + size_list = [g_guest.size()[0]] + g_total = g_guest.squeeze() # get torch tensor + + host_g = batch_ctx.hosts.get("g_enc") + for i, g_host_enc in enumerate(host_g): + g = decryptor.decrypt_tensor(g_host_enc) + size_list.append(g.size()[0]) + g_total = torch.hstack((g_total, g.squeeze())) + if not optimizer_ready: + self.optimizer.init_optimizer(model_parameter_length=sum(size_list), dtype=g_total.dtype) + self.lr_scheduler.init_scheduler(optimizer=self.optimizer.optimizer) + optimizer_ready = True + self.optimizer.step(g_total.unsqueeze(1)) + delta_g = self.optimizer.get_delta_gradients() + delta_g_list = separate(delta_g, size_list) + + delta_g_list_squeezed = [] + batch_ctx.guest.put("g", delta_g_list[0]) + delta_g_list_squeezed.append(delta_g_list[0].squeeze()) + for i, g_host in enumerate(delta_g_list[1:]): + batch_ctx.hosts[i].put("g", g_host) + delta_g_list_squeezed.append(g_host.squeeze()) + if iter_g is None: + iter_g = torch.hstack(delta_g_list_squeezed) + else: + iter_g += torch.hstack(delta_g_list_squeezed) + + if len(host_g) == 1: + loss = decryptor.decrypt_tensor(batch_ctx.guest.get("loss")) + iter_loss = 0 if iter_loss is None else iter_loss + iter_loss = iter_loss + loss + else: + logger.info("Multiple hosts exist, do not compute loss.") + + if iter_loss is not None: + iter_ctx.metrics.log_loss("lr_loss", iter_loss.tolist()[0]) + if self.early_stop == "weight_diff": + self.is_converged = self.converge_func.is_converge(iter_g) + else: + if iter_loss is None: + raise ValueError( + "Multiple host situation, loss early stop function is not available." + "You should use 'weight_diff' instead" + ) + self.is_converged = self.converge_func.is_converge(iter_loss) + + iter_ctx.hosts.put("converge_flag", self.is_converged) + iter_ctx.guest.put("converge_flag", self.is_converged) + if self.is_converged: + self.end_epoch = i + break + if i < self.epochs - 1: + self.lr_scheduler.step() + if not self.is_converged: + self.end_epoch = self.epochs + logger.debug(f"Finish training at {self.end_epoch}th epoch.") + + def get_model(self): + return { + "optimizer": self.optimizer.state_dict(), + "lr_scheduler": self.lr_scheduler.state_dict(), + "end_epoch": self.end_epoch, + "is_converged": self.is_converged, + "tol": self.tol, + "early_stop": self.early_stop, + } + + def restore(self, model): + self.optimizer = Optimizer() + self.lr_scheduler = LRScheduler() + self.optimizer.load_state_dict(model["optimizer"]), + self.lr_scheduler.load_state_dict(model["lr_scheduler"], self.optimizer.optimizer) + self.end_epoch = model["end_epoch"] + self.is_converged = model["is_converged"] + self.tol = model["tol"] + self.early_stop = model["early_stop"] + self.converge_func = converge_func_factory(self.early_stop, self.tol) + # self.start_epoch = model["end_epoch"] + 1 diff --git a/python/fate/ml/glm/hetero/coordinated_lr/guest.py b/python/fate/ml/glm/hetero/coordinated_lr/guest.py new file mode 100644 index 0000000000..325b590b03 --- /dev/null +++ b/python/fate/ml/glm/hetero/coordinated_lr/guest.py @@ -0,0 +1,413 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import numpy as np +import torch + +from fate.arch import Context, dataframe +from fate.ml.abc.module import HeteroModule +from fate.ml.utils import predict_tools +from fate.ml.utils._model_param import ( + check_overflow, + deserialize_param, + initialize_param, + serialize_param, +) +from fate.ml.utils._optimizer import LRScheduler, Optimizer + +logger = logging.getLogger(__name__) + + +class CoordinatedLRModuleGuest(HeteroModule): + def __init__( + self, + epochs=None, + batch_size=None, + optimizer_param=None, + learning_rate_param=None, + init_param=None, + threshold=0.5, + floating_point_precision=23 + ): + self.epochs = epochs + self.batch_size = batch_size + self.learning_rate_param = learning_rate_param + self.optimizer_param = optimizer_param + self.init_param = init_param + self.threshold = threshold + self.floating_point_precision = floating_point_precision + + self.estimator = None + self.ovr = False + self.labels = None + + def set_batch_size(self, batch_size): + self.batch_size = batch_size + if self.ovr: + for estimator in self.estimator.values(): + estimator.batch_size = batch_size + else: + self.estimator.batch_size = batch_size + + def set_epochs(self, epochs): + self.epochs = epochs + if self.ovr: + for estimator in self.estimator.values(): + estimator.epochs = epochs + else: + self.estimator.epochs = epochs + + def fit(self, ctx: Context, train_data, validate_data=None) -> None: + # original_label = train_data.label + train_data_binarized_label = train_data.label.get_dummies() + label_count = train_data_binarized_label.shape[1] + ctx.arbiter.put("label_count", label_count) + ctx.hosts.put("label_count", label_count) + encryptor = ctx.arbiter("encryptor").get() + labels = [int(label_name.split("_")[1]) for label_name in train_data_binarized_label.columns] + if self.labels is None: + self.labels = sorted(labels) + if label_count > 2 or self.ovr: + logger.info(f"OVR data provided, will train OVR models.") + self.ovr = True + warm_start = True + if self.estimator is None: + self.estimator = {} + warm_start = False + for i, class_ctx in ctx.sub_ctx("class").ctxs_range(label_count): + logger.info(f"start train for {i}th class") + # optimizer = copy.deepcopy(self.optimizer) + if not warm_start: + optimizer = Optimizer( + self.optimizer_param["method"], + self.optimizer_param["penalty"], + self.optimizer_param["alpha"], + self.optimizer_param["optimizer_params"], + ) + lr_scheduler = LRScheduler( + self.learning_rate_param["method"], self.learning_rate_param["scheduler_params"] + ) + single_estimator = CoordinatedLREstimatorGuest( + epochs=self.epochs, + batch_size=self.batch_size, + optimizer=optimizer, + learning_rate_scheduler=lr_scheduler, + init_param=self.init_param, + floating_point_precision=self.floating_point_precision + ) + else: + # warm start + logger.info("estimator is not none, will train with warm start") + # single_estimator = self.estimator[self.labels.index(labels[i])] + single_estimator = self.estimator[i] + single_estimator.epochs = self.epochs + single_estimator.batch_size = self.batch_size + class_train_data = train_data.copy() + class_validate_data = validate_data + if validate_data: + class_validate_data = validate_data.copy() + class_train_data.label = train_data_binarized_label[train_data_binarized_label.columns[i]] + single_estimator.fit_single_model(class_ctx, encryptor, class_train_data, class_validate_data) + self.estimator[i] = single_estimator + + else: + if self.estimator is None: + optimizer = Optimizer( + self.optimizer_param["method"], + self.optimizer_param["penalty"], + self.optimizer_param["alpha"], + self.optimizer_param["optimizer_params"], + ) + lr_scheduler = LRScheduler( + self.learning_rate_param["method"], self.learning_rate_param["scheduler_params"] + ) + single_estimator = CoordinatedLREstimatorGuest( + epochs=self.epochs, + batch_size=self.batch_size, + optimizer=optimizer, + learning_rate_scheduler=lr_scheduler, + init_param=self.init_param, + floating_point_precision=self.floating_point_precision + ) + else: + logger.info("estimator is not none, will train with warm start") + single_estimator = self.estimator + single_estimator.epochs = self.epochs + single_estimator.batch_size = self.batch_size + train_data_fit = train_data.copy() + validate_data_fit = validate_data + if validate_data: + validate_data_fit = validate_data.copy() + single_estimator.fit_single_model(ctx, encryptor, train_data_fit, validate_data_fit) + self.estimator = single_estimator + + def predict(self, ctx, test_data): + pred_df = test_data.create_frame(with_label=True, with_weight=False) + if self.ovr: + pred_score = test_data.create_frame(with_label=False, with_weight=False) + for i, class_ctx in ctx.sub_ctx("class").ctxs_range(len(self.labels)): + estimator = self.estimator[i] + pred = estimator.predict(class_ctx, test_data) + pred_score[str(self.labels[i])] = pred + pred_df[predict_tools.PREDICT_SCORE] = pred_score.apply_row(lambda v: [list(v)]) + predict_result = predict_tools.compute_predict_details( + pred_df, task_type=predict_tools.MULTI, classes=self.labels + ) + else: + predict_score = self.estimator.predict(ctx, test_data) + pred_df[predict_tools.PREDICT_SCORE] = predict_score + predict_result = predict_tools.compute_predict_details( + pred_df, task_type=predict_tools.BINARY, classes=self.labels, threshold=self.threshold + ) + + return predict_result + + def get_model(self): + all_estimator = {} + if self.ovr: + for label, estimator in self.estimator.items(): + all_estimator[label] = estimator.get_model() + else: + all_estimator = self.estimator.get_model() + return { + "data": {"estimator": all_estimator}, + "meta": { + "epochs": self.epochs, + "batch_size": self.batch_size, + "learning_rate_param": self.learning_rate_param, + "init_param": self.init_param, + "optimizer_param": self.optimizer_param, + "labels": self.labels, + "ovr": self.ovr, + "threshold": self.threshold, + "floating_point_precision": self.floating_point_precision, + }, + } + + @classmethod + def from_model(cls, model) -> "CoordinatedLRModuleGuest": + lr = CoordinatedLRModuleGuest( + epochs=model["meta"]["epochs"], + batch_size=model["meta"]["batch_size"], + learning_rate_param=model["meta"]["learning_rate_param"], + optimizer_param=model["meta"]["optimizer_param"], + threshold=model["meta"]["threshold"], + init_param=model["meta"]["init_param"], + floating_point_precision=model["meta"]["floating_point_precision"], + ) + lr.ovr = model["meta"]["ovr"] + lr.labels = model["meta"]["labels"] + + all_estimator = model["data"]["estimator"] + lr.estimator = {} + if lr.ovr: + for label, d in all_estimator.items(): + estimator = CoordinatedLREstimatorGuest( + epochs=model["meta"]["epochs"], + batch_size=model["meta"]["batch_size"], + init_param=model["meta"]["init_param"], + floating_point_precision=model["meta"]["floating_point_precision"], + ) + estimator.restore(d) + lr.estimator[int(label)] = estimator + else: + estimator = CoordinatedLREstimatorGuest( + epochs=model["meta"]["epochs"], + batch_size=model["meta"]["batch_size"], + init_param=model["meta"]["init_param"], + floating_point_precision=model["meta"]["floating_point_precision"], + ) + estimator.restore(all_estimator) + lr.estimator = estimator + + return lr + + +class CoordinatedLREstimatorGuest(HeteroModule): + def __init__(self, epochs=None, batch_size=None, optimizer=None, learning_rate_scheduler=None, init_param=None, + floating_point_precision=23): + self.epochs = epochs + self.batch_size = batch_size + self.optimizer = optimizer + self.lr_scheduler = learning_rate_scheduler + self.init_param = init_param + self.floating_point_precision = floating_point_precision + self._fixpoint_precision = 2 ** floating_point_precision + + self.w = None + self.start_epoch = 0 + self.end_epoch = -1 + self.is_converged = False + + def asynchronous_compute_gradient(self, batch_ctx, encryptor, w, X, Y, weight): + h = X.shape[0] + # logger.info(f"h: {h}") + Xw = torch.matmul(X, w.detach()) + half_d = 0.25 * Xw - 0.5 * Y + if weight: + half_d = half_d * weight + batch_ctx.hosts.put("half_d", encryptor.encrypt_tensor(half_d, obfuscate=True)) + half_g = torch.matmul(X.T, half_d) + + Xw_h = batch_ctx.hosts.get("Xw_h")[0] + if weight: + Xw_h = Xw_h * weight + + if self.floating_point_precision: + host_half_g = torch.matmul(torch.encode_as_int_f(X.T, self.floating_point_precision), Xw_h) + host_half_g = 1 / self._fixpoint_precision * host_half_g + else: + host_half_g = torch.matmul(X.T, Xw_h) + + loss = np.log(2) - 1 + 0.125 / h * torch.matmul(Xw.T, Xw) - 2 / h * torch.matmul(half_d.T, Y) + + if self.optimizer.l1_penalty or self.optimizer.l2_penalty: + loss_norm = self.optimizer.loss_norm(w) + loss += loss_norm + + loss += torch.matmul((1 / h * Xw).T, Xw_h) - torch.matmul((2 / h * Y).T, Xw_h) + + for Xw2_h in batch_ctx.hosts.get("Xw2_h"): + loss += 0.125 / h * Xw2_h + h_loss_list = batch_ctx.hosts.get("h_loss") + for h_loss in h_loss_list: + if h_loss is not None: + loss += h_loss + + batch_ctx.arbiter.put(loss=loss) + # gradient + g = 1 / h * (half_g + host_half_g) + return g + + def centralized_compute_gradient(self, batch_ctx, w, X, Y, weight): + h = X.shape[0] + # logger.info(f"h: {h}") + Xw = torch.matmul(X, w.detach()) + d = 0.25 * Xw - 0.5 * Y + + Xw_h_all = batch_ctx.hosts.get("Xw_h") + + for Xw_h in Xw_h_all: + d += Xw_h + + if weight: + # logger.info(f"weight: {weight.tolist()}") + d = d * weight + batch_ctx.hosts.put("d", d) + + # gradient + if self.floating_point_precision: + g = torch.matmul(torch.encode_as_int_f(X.T, self.floating_point_precision), d) + g = 1 / (h * self._fixpoint_precision) * g + else: + g = 1 / h * torch.matmul(X.T, d) + return g + + def fit_single_model(self, ctx: Context, encryptor, train_data, validate_data=None): + """ + l(w) = 1/h * Σ(log(2) - 0.5 * y * xw + 0.125 * (wx)^2) + ∇l(w) = 1/h * Σ(0.25 * xw - 0.5 * y)x = 1/h * Σdx + where d = 0.25(xw - 2y) + loss = log2 - (1/N)*0.5*∑ywx + (1/N)*0.125*[∑(Wg*Xg)^2 + ∑(Wh*Xh)^2 + 2 * ∑(Wg*Xg * Wh*Xh)] + """ + coef_count = train_data.shape[1] + if self.init_param.get("fit_intercept"): + train_data["intercept"] = 1.0 + + w = self.w + if w is None: + w = initialize_param(coef_count, **self.init_param) + + self.optimizer.init_optimizer(model_parameter_length=w.size()[0]) + self.lr_scheduler.init_scheduler(optimizer=self.optimizer.optimizer) + + train_data.label = train_data.label.apply_row( + lambda x: [1.0] if abs(x[0] - 1) < 1e-8 else [-1.0], with_label=True + ) + + batch_loader = dataframe.DataLoader( + train_data, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="guest", sync_arbiter=True + ) + # if self.end_epoch >= 0: + # self.start_epoch = self.end_epoch + 1 + + is_centralized = len(ctx.hosts) > 1 + + for i, iter_ctx in ctx.on_iterations.ctxs_range(self.epochs): + self.optimizer.set_iters(i) + logger.info(f"self.optimizer set epoch {i}") + for batch_ctx, batch_data in iter_ctx.on_batches.ctxs_zip(batch_loader): + X = batch_data.x + Y = batch_data.label + weight = batch_data.weight + if is_centralized: + g = self.centralized_compute_gradient(batch_ctx, w, X, Y, weight) + else: + g = self.asynchronous_compute_gradient(batch_ctx, encryptor, w, X, Y, weight) + + g = self.optimizer.add_regular_to_grad(g, w, self.init_param.get("fit_intercept")) + batch_ctx.arbiter.put("g_enc", g) + g = batch_ctx.arbiter.get("g") + + w = self.optimizer.update_weights(w, g, self.init_param.get("fit_intercept"), self.lr_scheduler.lr) + # logger.info(f"w={w}") + check_overflow(w) + + self.is_converged = iter_ctx.arbiter("converge_flag").get() + if self.is_converged: + self.end_epoch = i + break + if i < self.epochs - 1: + logger.info(f"lr step at {i}th epoch") + self.lr_scheduler.step() + if not self.is_converged: + self.end_epoch = self.epochs + self.w = w + logger.debug(f"Finish training at {self.end_epoch}th epoch.") + + def predict(self, ctx, test_data): + if self.init_param.get("fit_intercept"): + test_data["intercept"] = 1.0 + X = test_data.values.as_tensor() + # logger.info(f"in predict, w: {self.w}") + pred = torch.matmul(X, self.w.detach()) + for h_pred in ctx.hosts.get("h_pred"): + pred += h_pred + pred = torch.sigmoid(pred) + return pred + + def get_model(self): + param = serialize_param(self.w, self.init_param.get("fit_intercept")) + return { + # "w": w, + # "intercept": intercept, + "param": param, + "optimizer": self.optimizer.state_dict(), + "lr_scheduler": self.lr_scheduler.state_dict(), + "end_epoch": self.end_epoch, + "is_converged": self.is_converged, + "fit_intercept": self.init_param.get("fit_intercept"), + } + + def restore(self, model): + self.w = deserialize_param(model["param"], model["fit_intercept"]) + self.optimizer = Optimizer() + self.lr_scheduler = LRScheduler() + self.optimizer.load_state_dict(model["optimizer"]) + self.lr_scheduler.load_state_dict(model["lr_scheduler"], self.optimizer.optimizer) + self.end_epoch = model["end_epoch"] + self.is_converged = model["is_converged"] diff --git a/python/fate/ml/glm/hetero/coordinated_lr/host.py b/python/fate/ml/glm/hetero/coordinated_lr/host.py new file mode 100644 index 0000000000..8d720779b8 --- /dev/null +++ b/python/fate/ml/glm/hetero/coordinated_lr/host.py @@ -0,0 +1,319 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +import torch + +from fate.arch import Context +from fate.arch.dataframe import DataLoader +from fate.ml.abc.module import HeteroModule +from fate.ml.utils._model_param import ( + check_overflow, + deserialize_param, + initialize_param, + serialize_param, +) +from fate.ml.utils._optimizer import LRScheduler, Optimizer + +logger = logging.getLogger(__name__) + + +class CoordinatedLRModuleHost(HeteroModule): + def __init__(self, epochs=None, batch_size=None, optimizer_param=None, learning_rate_param=None, init_param=None, + floating_point_precision=23): + self.epochs = epochs + self.learning_rate_param = learning_rate_param + self.optimizer_param = optimizer_param + self.batch_size = batch_size + self.init_param = init_param + self.floating_point_precision = floating_point_precision + + # host never has fit intercept + self.init_param["fit_intercept"] = False + + self.estimator = None + self.ovr = False + self.label_count = False + + def set_batch_size(self, batch_size): + self.batch_size = batch_size + if self.ovr: + for estimator in self.estimator.values(): + estimator.batch_size = batch_size + else: + self.estimator.batch_size = batch_size + + def set_epochs(self, epochs): + self.epochs = epochs + if self.ovr: + for estimator in self.estimator.values(): + estimator.epochs = epochs + else: + self.estimator.epochs = epochs + + def fit(self, ctx: Context, train_data, validate_data=None) -> None: + encryptor = ctx.arbiter("encryptor").get() + label_count = ctx.guest("label_count").get() + if label_count > 2 or self.ovr: + self.ovr = True + self.label_count = label_count + warm_start = True + if self.estimator is None: + self.estimator = {} + warm_start = False + for i, class_ctx in ctx.sub_ctx("class").ctxs_range(label_count): + # optimizer = copy.deepcopy(self.optimizer) + # lr_scheduler = copy.deepcopy(self.lr_scheduler) + if not warm_start: + optimizer = Optimizer( + self.optimizer_param["method"], + self.optimizer_param["penalty"], + self.optimizer_param["alpha"], + self.optimizer_param["optimizer_params"], + ) + lr_scheduler = LRScheduler( + self.learning_rate_param["method"], self.learning_rate_param["scheduler_params"] + ) + single_estimator = CoordinatedLREstimatorHost( + epochs=self.epochs, + batch_size=self.batch_size, + optimizer=optimizer, + learning_rate_scheduler=lr_scheduler, + init_param=self.init_param, + floating_point_precision=self.floating_point_precision, + ) + else: + logger.info("estimator is not none, will train with warm start") + single_estimator = self.estimator[i] + single_estimator.epochs = self.epochs + single_estimator.batch_size = self.batch_size + single_estimator.fit_single_model(class_ctx, encryptor, train_data, validate_data) + self.estimator[i] = single_estimator + else: + if self.estimator is None: + optimizer = Optimizer( + self.optimizer_param["method"], + self.optimizer_param["penalty"], + self.optimizer_param["alpha"], + self.optimizer_param["optimizer_params"], + ) + lr_scheduler = LRScheduler( + self.learning_rate_param["method"], self.learning_rate_param["scheduler_params"] + ) + single_estimator = CoordinatedLREstimatorHost( + epochs=self.epochs, + batch_size=self.batch_size, + optimizer=optimizer, + learning_rate_scheduler=lr_scheduler, + init_param=self.init_param, + floating_point_precision=self.floating_point_precision, + ) + else: + logger.info("estimator is not none, will train with warm start") + single_estimator = self.estimator + single_estimator.epochs = self.epochs + single_estimator.batch_size = self.batch_size + single_estimator.fit_single_model(ctx, encryptor, train_data, validate_data) + self.estimator = single_estimator + + def predict(self, ctx, test_data): + if self.ovr: + for i, class_ctx in ctx.sub_ctx("class").ctxs_range(self.label_count): + estimator = self.estimator[i] + estimator.predict(class_ctx, test_data) + else: + self.estimator.predict(ctx, test_data) + + def get_model(self): + all_estimator = {} + if self.ovr: + for label_idx, estimator in self.estimator.items(): + all_estimator[label_idx] = estimator.get_model() + else: + all_estimator = self.estimator.get_model() + return { + "data": {"estimator": all_estimator}, + "meta": { + "label_count": self.label_count, + "ovr": self.ovr, + "epochs": self.epochs, + "batch_size": self.batch_size, + "learning_rate_param": self.learning_rate_param, + "optimizer_param": self.optimizer_param, + "init_param": self.init_param, + "floating_point_precision": self.floating_point_precision, + }, + } + + @classmethod + def from_model(cls, model) -> "CoordinatedLRModuleHost": + lr = CoordinatedLRModuleHost( + epochs=model["meta"]["epochs"], + batch_size=model["meta"]["batch_size"], + learning_rate_param=model["meta"]["learning_rate_param"], + optimizer_param=model["meta"]["optimizer_param"], + init_param=model["meta"]["init_param"], + floating_point_precision=model["meta"]["floating_point_precision"], + ) + lr.label_count = model["meta"]["label_count"] + lr.ovr = model["meta"]["ovr"] + + all_estimator = model["data"]["estimator"] + lr.estimator = {} + + if lr.ovr: + lr.estimator = {} + for label, d in all_estimator.items(): + estimator = CoordinatedLREstimatorHost( + epochs=model["meta"]["epochs"], + batch_size=model["meta"]["batch_size"], + init_param=model["meta"]["init_param"], + floating_point_precision=model["meta"]["floating_point_precision"], + ) + estimator.restore(d) + lr.estimator[int(label)] = estimator + else: + estimator = CoordinatedLREstimatorHost( + epochs=model["meta"]["epochs"], + batch_size=model["meta"]["batch_size"], + init_param=model["meta"]["init_param"], + floating_point_precision=model["meta"]["floating_point_precision"], + ) + estimator.restore(all_estimator) + lr.estimator = estimator + logger.info(f"finish from model") + + return lr + + +class CoordinatedLREstimatorHost(HeteroModule): + def __init__(self, epochs=None, batch_size=None, optimizer=None, learning_rate_scheduler=None, init_param=None, + floating_point_precision=23): + self.epochs = epochs + self.optimizer = optimizer + self.lr_scheduler = learning_rate_scheduler + self.batch_size = batch_size + self.init_param = init_param + self.floating_point_precision = floating_point_precision + self._fixpoint_precision = 2 ** floating_point_precision + + self.w = None + self.start_epoch = 0 + self.end_epoch = -1 + self.is_converged = False + + def asynchronous_compute_gradient(self, batch_ctx, encryptor, w, X): + h = X.shape[0] + Xw_h = 0.25 * torch.matmul(X, w.detach()) + batch_ctx.guest.put("Xw_h", encryptor.encrypt_tensor(Xw_h, obfuscate=True)) + + half_g = torch.matmul(X.T, Xw_h) + + guest_half_d = batch_ctx.guest.get("half_d") + logger.info(f"guest half d received") + if self.floating_point_precision: + guest_half_g = torch.matmul(torch.encode_as_int_f(X.T, self.floating_point_precision), guest_half_d) + guest_half_g = 1 / self._fixpoint_precision * guest_half_g + else: + guest_half_g = torch.matmul(X.T, guest_half_d) + logger.info(f"guest half g obtained") + + batch_ctx.guest.put("Xw2_h", encryptor.encrypt_tensor(torch.matmul(Xw_h.T, Xw_h))) + loss_norm = self.optimizer.loss_norm(w) + + if loss_norm is not None: + batch_ctx.guest.put("h_loss", encryptor.encrypt_tensor(loss_norm)) + else: + batch_ctx.guest.put("h_loss", loss_norm) + + g = 1 / h * (half_g + guest_half_g) + return g + + def centralized_compute_gradient(self, batch_ctx, encryptor, w, X): + h = X.shape[0] + Xw_h = 0.25 * torch.matmul(X, w.detach()) + batch_ctx.guest.put("Xw_h", encryptor.encrypt_tensor(Xw_h, obfuscate=True)) + + d = batch_ctx.guest.get("d") + if self.floating_point_precision: + g = torch.matmul(torch.encode_as_int_f(X.T, self.floating_point_precision), d) + g = 1 / (h * self._fixpoint_precision) * g + else: + g = 1 / h * torch.matmul(X.T, d) + return g + + def fit_single_model(self, ctx: Context, encryptor, train_data, validate_data=None) -> None: + coef_count = train_data.shape[1] + w = self.w + if self.w is None: + w = initialize_param(coef_count, **self.init_param) + self.optimizer.init_optimizer(model_parameter_length=w.size()[0]) + self.lr_scheduler.init_scheduler(optimizer=self.optimizer.optimizer) + batch_loader = DataLoader(train_data, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="host") + # if self.end_epoch >= 0: + # self.start_epoch = self.end_epoch + 1 + is_centralized = len(ctx.hosts) > 1 + for i, iter_ctx in ctx.on_iterations.ctxs_range(self.epochs): + self.optimizer.set_iters(i) + logger.info(f"self.optimizer set epoch{i}") + for batch_ctx, batch_data in iter_ctx.on_batches.ctxs_zip(batch_loader): + X = batch_data.x + if is_centralized: + g = self.centralized_compute_gradient(batch_ctx, encryptor, w, X) + else: + g = self.asynchronous_compute_gradient(batch_ctx, encryptor, w, X) + + g = self.optimizer.add_regular_to_grad(g, w, False) + batch_ctx.arbiter.put("g_enc", g) + g = batch_ctx.arbiter.get("g") + + w = self.optimizer.update_weights(w, g, False, self.lr_scheduler.lr) + check_overflow(w) + + self.is_converged = iter_ctx.arbiter("converge_flag").get() + if self.is_converged: + self.end_epoch = i + break + if i < self.epochs - 1: + self.lr_scheduler.step() + if not self.is_converged: + self.end_epoch = self.epochs + self.w = w + logger.debug(f"Finish training at {self.end_epoch}th epoch.") + + def predict(self, ctx, test_data): + X = test_data.values.as_tensor() + output = torch.matmul(X, self.w) + ctx.guest.put("h_pred", output) + + def get_model(self): + param = serialize_param(self.w, False) + return { + "param": param, + "optimizer": self.optimizer.state_dict(), + "lr_scheduler": self.lr_scheduler.state_dict(), + "end_epoch": self.end_epoch, + "is_converged": self.is_converged, + } + + def restore(self, model): + # self.w = torch.tensor(model["w"]) + self.w = deserialize_param(model["param"], False) + self.optimizer = Optimizer() + self.lr_scheduler = LRScheduler() + self.optimizer.load_state_dict(model["optimizer"]) + self.lr_scheduler.load_state_dict(model["lr_scheduler"], self.optimizer.optimizer) + self.end_epoch = model["end_epoch"] + self.is_converged = model["is_converged"] diff --git a/python/fate/arch/tensor/storage/local/_types.py b/python/fate/ml/glm/homo/__init__.py similarity index 100% rename from python/fate/arch/tensor/storage/local/_types.py rename to python/fate/ml/glm/homo/__init__.py diff --git a/python/fate/arch/tensor/storage/local/device/cpu/plain_custom.py b/python/fate/ml/glm/homo/lr/__init__.py similarity index 100% rename from python/fate/arch/tensor/storage/local/device/cpu/plain_custom.py rename to python/fate/ml/glm/homo/lr/__init__.py diff --git a/python/fate/ml/glm/homo/lr/client.py b/python/fate/ml/glm/homo/lr/client.py new file mode 100644 index 0000000000..2511615759 --- /dev/null +++ b/python/fate/ml/glm/homo/lr/client.py @@ -0,0 +1,492 @@ +import torch.nn as nn +from fate.arch import Context +from fate.arch.dataframe import DataFrame +from fate.ml.abc.module import HomoModule +from fate.ml.utils.model_io import ModelIO +from fate.arch import Context +import logging +import torch as t +from fate.ml.nn.algo.homo.fedavg import FedAVGCLient, TrainingArguments, FedAVGArguments +from transformers import default_data_collator +import functools +import tempfile +from fate.ml.utils.predict_tools import array_to_predict_df +from fate.ml.utils.predict_tools import MULTI, BINARY +from fate.ml.nn.dataset.table import TableDataset +from fate.ml.utils._optimizer import optimizer_factory, lr_scheduler_factory + + +logger = logging.getLogger(__name__) + + +def homo_lr_loss(pred, labels, dim=1): + """ + The function assumes that pred has shape (n, num_classes) where each class has its own linear model. + labels have shape (n,) and the values are integers denoting the class. + """ + + # initialize the loss + loss = 0.0 + loss_fn = t.nn.BCELoss() + if dim <= 2: + return loss_fn(pred[:, 0], labels) + + for c in range(dim): + # get binary labels for this class + binary_labels = (labels == c).float().flatten() + bin_pred = pred[:, c].flatten() + # compute binary cross-entropy loss + loss = loss_fn(bin_pred, binary_labels) + # normalize loss by the number of classes + loss /= dim + + return loss + + +class HomoLRModel(t.nn.Module): + + def __init__(self, feature_num, label_num=2, l1=0, bias=True) -> None: + super().__init__() + assert feature_num >= 2 and isinstance( + feature_num, int), "feature_num must be int greater than 2" + assert label_num >= 1 and isinstance( + label_num, int), "label_num must be int greater than 1" + self.models = t.nn.ModuleList() + + if 2 >= label_num > 0: + self.models.append( + t.nn.Linear(feature_num, 1, bias=bias) + ) + else: + # OVR Setting + for i in range(label_num): + self.models.append( + t.nn.Linear(feature_num, 1, bias=bias) + ) + self.sigmoid = t.nn.Sigmoid() + self.softmax = t.nn.Softmax(dim=1) + self.l1 = l1 + + def forward(self, x, labels=None): + + if len(self.models) == 1: + linear_out = self.models[0](x) + else: + linear_out = t.cat([model(x) for model in self.models], dim=1) + + ret_dict = {} + linear_out = self.sigmoid(linear_out).reshape((-1, len(self.models))) + + if not self.training: + if len(self.models) > 1: + linear_out = self.softmax(linear_out) + + ret_dict['pred'] = linear_out + + if labels is not None: + loss = homo_lr_loss(linear_out, labels, dim=len(self.models)) + if self.l1 != 0: + l1_regularization = t.tensor(0.) + for param in self.models.parameters(): + l1_regularization += t.norm(param, 1) + loss += self.l1 * l1_regularization + ret_dict['loss'] = loss + + return ret_dict + + def to_dict(self): + model_dict = { + "feature_num": self.models[0].in_features, + "label_num": len(self.models), + # convert tensor to list + "state_dict": {k: v.tolist() for k, v in self.state_dict().items()} + } + return model_dict + + @classmethod + def from_dict(cls, model_dict): + model = cls(model_dict["feature_num"], model_dict["label_num"]) + model_state_dict = { + k: t.tensor(v) for k, + v in model_dict["state_dict"].items()} # convert list back to tensor + model.load_state_dict(model_state_dict) + return model + + +def init_model(model, method='random', fill_val=1.0): + if method == 'zeros': + init_fn = nn.init.zeros_ + elif method == 'ones': + init_fn = nn.init.ones_ + elif method == 'consts': + def init_fn(x): return nn.init.constant_(x, fill_val) + elif method == 'random': + init_fn = nn.init.normal_ + else: + raise ValueError( + "Invalid method. Options are: 'zeros', 'ones', 'consts', 'random'") + + for name, param in model.named_parameters(): + if 'bias' in name: + # usually it's good practice to initialize biases to zero + nn.init.zeros_(param) + else: + init_fn(param) + + +# read model from model bytes +def recover_torch_bytes(model_bytes): + + with tempfile.TemporaryFile() as f: + f.write(model_bytes) + f.seek(0) + model_dict = t.load(f) + + return model_dict + + +def get_torch_bytes(model_dict): + + with tempfile.TemporaryFile() as f: + t.save(model_dict, f) + f.seek(0) + model_saved_bytes = f.read() + + return model_saved_bytes + + +def update_params(new_params, default, name='optimizer'): + import copy + params = copy.deepcopy(default) + if not isinstance(new_params, dict): + raise ValueError( + "{} param dict must be a dict but got {}".format( + name, new_params)) + + def _update(default, new): + for key in new.keys(): + if key in default: + default[key] = new[key] + + _update(params, new_params) + + return params + + +DEFAULT_OPT_PARAM = { + 'method': 'sgd', + 'penalty': 'l2', + 'alpha': 0.0, + 'optimizer_params': { + 'lr': 0.01, + 'weight_decay': 0}} +DEFAULT_INIT_PARAM = { + "method": "random", + "fill_val": 1.0, + "fit_intercept": True} +DEFAULT_LR_SCHEDULER_PARAM = { + 'method': 'constant', + 'scheduler_params': { + 'factor': 1.0}} + + +class HomoLRClient(HomoModule): + + def __init__( + self, + epochs: int = 5, + batch_size: int = None, + optimizer_param={ + 'method': 'sgd', + 'optimizer_params': { + 'lr': 0.01, + 'weight_decay': 0}}, + learning_rate_scheduler={ + 'method': 'constant', + 'scheduler_params': { + 'factor': 1.0}}, + init_param={ + "method": "random", + "fill_val": 1.0, + "fit_intercept": True}, + threshold: float = 0.5, + ovr=False, + label_num=None, + ) -> None: + + super().__init__() + self.df_schema = None + self.train_set = None + self.validate_set = None + self.predict_set = None + + # set vars + self.max_iter = epochs + self.batch_size = batch_size + self.optimizer_param = update_params( + optimizer_param, DEFAULT_OPT_PARAM, name='optimizer') + self.learning_rate_param = update_params( + learning_rate_scheduler, + DEFAULT_LR_SCHEDULER_PARAM, + name='learning_rate_scheduler') + self.init_param = update_params( + init_param, DEFAULT_INIT_PARAM, name='init_param') + self.threshold = threshold + self.run_ovr = False + self.train_feature_num = None + self.validate_feature_num = None + self.ovr = ovr + self.label_num = label_num + + if self.ovr: + if self.label_num is None or self.label_num < 2: + raise ValueError( + "label_num must be greater than 2 when ovr is True, but got {}".format( + self.label_num)) + + # models & optimizer & schduler + self.model = None + self.optimizer = None + self.scheduler = None + self.optimizer_state_dict = None + self.trainer = None + + # loaded meta + self.loaded_meta = None + + # reg + self.l1 = 0 + self.l2 = 0 + + # for testing + self.local_mode = False + + # checkping param + assert self.max_iter > 0 and isinstance( + self.max_iter, int), "max_iter must be int greater than 0" + if self.batch_size is not None: + assert self.batch_size > 0 and isinstance( + self.batch_size, int), "batch_size must be int greater than 0 or None" + assert self.threshold > 0 and self.threshold < 1, "threshold must be float between 0 and 1" + + def _make_dataset(self, data) -> TableDataset: + ds = TableDataset(return_dict=True, to_tensor=True) + ds.load(data) + return ds + + def _make_output_df( + self, + ctx, + predict_rs, + data: TableDataset, + threshold: float): + classes = [i for i in range(len(self.model.models))] + if len(classes) == 1: # binary: + classes = [0, 1] + task_type = BINARY if len(classes) == 2 else MULTI + + out_df = array_to_predict_df( + ctx, + task_type, + predict_rs.predictions, + match_ids=data.get_match_ids(), + sample_ids=data.get_sample_ids(), + match_id_name=data.get_match_id_name(), + sample_id_name=data.get_sample_id_name(), + label=predict_rs.label_ids, + threshold=threshold, + classes=classes + ) + + return out_df + + def _check_labels(self, label_set, has_validate=False): + + dataset_descrb = 'train dataset' if not has_validate else 'train and validate dataset' + if not self.ovr and len(label_set) > 2: + raise ValueError( + "please set ovr=True to enable multi-label classification, multiple labels found in {}: {}".format( + dataset_descrb, label_set)) + if not self.ovr and len(label_set) == 2: + # 0, 1 is required + if 0 not in label_set or 1 not in label_set: + # ask for label 0, 1 when running binary classification + raise ValueError( + "when doing binary classification, lables must be 0, 1, but found in {}'s label set is {}".format( + label_set, dataset_descrb)) + if self.ovr: + if max(label_set) > self.label_num - 1: + # make sure labels start from 0 and not the label indices not + # exceed the label num parameter + raise ValueError( + "when doing multi-label classification, labels must start from 0 and not exceed the label num parameter, \ + but {}'s label set is {}, while label num is {}".format( + label_set, dataset_descrb, self.label_num)) + + def fit(self, ctx: Context, train_data: DataFrame, + validate_data: DataFrame = None) -> None: + + # check data, must be fate Dataframe + assert isinstance( + train_data, DataFrame), "train_data must be a fate DataFrame" + if validate_data is not None: + assert isinstance( + validate_data, DataFrame), "validate_data must be a fate DataFrame" + + self.train_set = self._make_dataset(train_data) + if not self.train_set.has_label(): + raise RuntimeError("train data must have label column") + self.train_feature_num = self.train_set.features.shape[1] + unique_label_set = set(self.train_set.get_classes()) + + if validate_data is not None: + self.validate_set = self._make_dataset(validate_data) + if not self.validate_set.has_label(): + raise RuntimeError("validate data must have label column") + self.validate_feature_num = self.validate_set.features.shape[1] + assert self.train_feature_num == self.validate_feature_num, "train and validate feature num not match: {} vs {}".format( + self.train_feature_num, self.validate_feature_num) + unique_label_set = unique_label_set.union( + set(self.validate_set.get_classes())) + + self._check_labels(unique_label_set, validate_data is not None) + + if self.batch_size is None: + self.batch_size = len(self.train_set) + + # prepare loss function + loss_fn = functools.partial(homo_lr_loss, dim=len(unique_label_set)) + optimizer_params = self.optimizer_param['optimizer_params'] + opt_method = self.optimizer_param['method'] + if self.optimizer_param['penalty'] == 'l2': + self.l2 = self.optimizer_param['alpha'] + optimizer_params['weight_decay'] = self.l2 + elif self.optimizer_param['penalty'] == 'l1': + self.l1 = self.optimizer_param['alpha'] + + # initialize model + if self.model is None: + fit_intercept = self.init_param["fit_intercept"] + self.model = HomoLRModel( + self.train_feature_num, + label_num=len(unique_label_set), + l1=self.l1, + bias=fit_intercept) + # init model here + init_model( + self.model, + method=self.init_param["method"], + fill_val=self.init_param["fill_val"]) + logger.info('model initialized') + logger.info( + 'model parameters are {}'.format( + list( + self.model.parameters()))) + else: + logger.info('model is loaded, warm start training') + logger.info('model structure is {}'.format(self.model)) + + self.optimizer = optimizer_factory( + self.model.parameters(), opt_method, optimizer_params) + self.lr_scheduler = lr_scheduler_factory( + self.optimizer, + self.learning_rate_param['method'], + self.learning_rate_param['scheduler_params']) + + if self.optimizer_state_dict is not None: + optimizer_state_dict = { + "state": { + k: t.tensor(v) for k, + v in self.optimizer_state_dict['state'].items()}, + "param_groups": self.optimizer_state_dict['param_groups'], + } + self.optimizer.load_state_dict(optimizer_state_dict) + logger.info('load warmstart optimizer state dict') + + # training + fed_arg = FedAVGArguments() + train_arg = TrainingArguments( + num_train_epochs=self.max_iter, + per_device_train_batch_size=self.batch_size, + per_device_eval_batch_size=self.batch_size) + self.trainer = FedAVGCLient( + ctx, + model=self.model, + loss_fn=loss_fn, + optimizer=self.optimizer, + train_set=self.train_set, + val_set=self.validate_set, + training_args=train_arg, + fed_args=fed_arg, + data_collator=default_data_collator, + scheduler=self.lr_scheduler) + if self.local_mode: # for debugging + self.trainer.set_local_mode() + self.trainer.train() + + logger.info('homo lr fit done') + + def predict(self, ctx: Context, predict_data: DataFrame) -> DataFrame: + + if self.model is None: + raise ValueError("model is not initialized") + self.predict_set = self._make_dataset(predict_data) + if self.trainer is None: + batch_size = len( + self.predict_set) if self.batch_size is None else self.batch_size + train_arg = TrainingArguments( + num_train_epochs=self.max_iter, + per_device_eval_batch_size=batch_size) + trainer = FedAVGCLient( + ctx, + train_set=self.predict_set, + model=self.model, + training_args=train_arg, + fed_args=FedAVGArguments(), + data_collator=default_data_collator) + trainer.set_local_mode() + else: + trainer = self.trainer + predict_rs = trainer.predict(self.predict_set) + predict_out_df = self._make_output_df( + ctx, predict_rs, self.predict_set, self.threshold) + return predict_out_df + + def get_model(self) -> ModelIO: + param = {} + if self.model is not None: + param['model'] = self.model.to_dict() + if self.optimizer is not None: + param['optimizer'] = str( + get_torch_bytes( + self.optimizer.state_dict())) + + meta = { + 'batch_size': self.batch_size, + 'max_iter': self.max_iter, + 'threshold': self.threshold, + 'optimizer_param': self.optimizer_param, + 'learning_rate_param': self.learning_rate_param, + 'init_param': self.init_param, + 'ovr': self.ovr, + 'label_num': self.label_num} + + return {'param': param, 'meta': meta} + + def from_model(self, model: dict): + + if 'param' not in model: + raise ('key "data" is not found in the input model dict') + + model_param = model['param'] + if 'model' not in model_param: + raise ValueError( + "param dict must have key 'model' that contains the model parameter and structure info") + self.model = HomoLRModel.from_dict(model_param['model']) + if self.ovr: + assert len(self.model.models) == self.label_num, '' + self.model.l1 = self.l1 + if hasattr(model_param, 'optimizer'): + self.optimizer_state_dict = recover_torch_bytes( + bytes(model_param['optimizer'], 'utf-8')) + self.loaded_meta = model['meta'] diff --git a/python/fate/ml/glm/homo/lr/server.py b/python/fate/ml/glm/homo/lr/server.py new file mode 100644 index 0000000000..e143fd21fc --- /dev/null +++ b/python/fate/ml/glm/homo/lr/server.py @@ -0,0 +1,29 @@ +from fate.ml.abc.module import HomoModule +from fate.arch import Context +from fate.arch.dataframe import DataFrame +from fate.arch import Context +import logging +from fate.ml.nn.algo.homo.fedavg import FedAVGServer + + +logger = logging.getLogger(__name__) + + +class HomoLRServer(HomoModule): + + def __init__(self) -> None: + pass + + def fit(self, ctx: Context, data: DataFrame = None) -> None: + + server = FedAVGServer(ctx=ctx) + logger.info('server class init done, start fed training') + server.train() + logger.info('homo lr fit done') + + def predict( + self, + ctx: Context, + predict_data: DataFrame = None) -> DataFrame: + + logger.info('kkip prediction stage') diff --git a/python/fate/ml/glm/homo/lr/test/test_fed_lr.py b/python/fate/ml/glm/homo/lr/test/test_fed_lr.py new file mode 100644 index 0000000000..c926daf357 --- /dev/null +++ b/python/fate/ml/glm/homo/lr/test/test_fed_lr.py @@ -0,0 +1,86 @@ +from fate.ml.nn.algo.homo.fedavg import FedAVGCLient, FedArguments, TrainingArguments, FedAVGServer +import torch as t +import pandas as pd +import sys +from fate.arch.dataframe import PandasReader +from fate.ml.glm.homo.lr.client import HomoLRClient +from fate.ml.glm.homo.lr.server import HomoLRServer + + +arbiter = ("arbiter", 10000) +guest = ("guest", 10000) +host = ("host", 9999) +name = "fed" + + +def create_ctx(local): + from fate.arch import Context + from fate.arch.computing.standalone import CSession + from fate.arch.federation.standalone import StandaloneFederation + import logging + logger = logging.getLogger() + logger.setLevel(logging.DEBUG) + + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.DEBUG) + + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + console_handler.setFormatter(formatter) + + logger.addHandler(console_handler) + computing = CSession() + return Context(computing=computing, + federation=StandaloneFederation(computing, name, local, [guest, host, arbiter])) + + + +if __name__ == "__main__": + + if sys.argv[1] == "guest": + + ctx = create_ctx(guest) + df = pd.read_csv( + '../../../../../../../examples/data/breast_homo_guest.csv') + df['sample_id'] = [i for i in range(len(df))] + + reader = PandasReader( + sample_id_name='sample_id', + match_id_name="id", + label_name="y", + dtype="object") + + data = reader.to_frame(ctx, df) + client = HomoLRClient( + 50, 800, optimizer_param={ + 'method': 'adam', 'penalty': 'l1', 'aplha': 0.1, 'optimizer_para': { + 'lr': 0.1}}, init_param={ + 'method': 'random', 'fill_val': 1.0}) + + client.fit(ctx, data) + + elif sys.argv[1] == "host": + + ctx = create_ctx(host) + df = pd.read_csv( + '../../../../../../../examples/data/breast_homo_host.csv') + df['sample_id'] = [i for i in range(len(df))] + + reader = PandasReader( + sample_id_name='sample_id', + match_id_name="id", + label_name="y", + dtype="object") + + data = reader.to_frame(ctx, df) + client = HomoLRClient( + 50, 800, optimizer_param={ + 'method': 'adam', 'penalty': 'l1', 'aplha': 0.1, 'optimizer_para': { + 'lr': 0.1}}, init_param={ + 'method': 'random', 'fill_val': 1.0}) + + client.fit(ctx, data) + else: + + ctx = create_ctx(arbiter) + server = HomoLRServer() + server.fit(ctx) \ No newline at end of file diff --git a/python/fate/ml/glm/homo/lr/test/test_local.py b/python/fate/ml/glm/homo/lr/test/test_local.py new file mode 100644 index 0000000000..70a60d576e --- /dev/null +++ b/python/fate/ml/glm/homo/lr/test/test_local.py @@ -0,0 +1,59 @@ +from fate.arch import Context +from fate.arch.computing.standalone import CSession +from fate.arch.context import Context +from fate.arch.federation.standalone import StandaloneFederation +import pandas as pd +from fate.arch.dataframe import PandasReader +from fate.ml.nn.dataset.table import TableDataset +from fate.ml.glm.homo.lr.client import HomoLRClient +import logging + +# Get the root logger +logger = logging.getLogger() +logger.setLevel(logging.INFO) +ch = logging.StreamHandler() +ch.setLevel(logging.DEBUG) +formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s') +ch.setFormatter(formatter) +logger.addHandler(ch) + + +computing = CSession() +ctx = Context("guest", computing=computing, federation=StandaloneFederation( + computing, "fed", ("guest", 10000), [("guest", 10000), ("host", 9999)]) ) + +df = pd.read_csv( + '../../../../../../../examples/data/breast_homo_guest.csv') +df['sample_id'] = [i for i in range(len(df))] + +reader = PandasReader( + sample_id_name='sample_id', + match_id_name="id", + label_name="y", + dtype="object") +reader_2 = PandasReader( + sample_id_name='sample_id', + match_id_name="id", + dtype="object") +data = reader.to_frame(ctx, df) + +# df = data.as_pd_df() +data_2 = reader_2.to_frame(ctx, df.drop(columns=['y'])) +ds = TableDataset(return_dict=True, to_tensor=True) +ds.load(data) + + +client = HomoLRClient( + 50, 800, optimizer_param={ + 'method': 'adam', 'penalty': 'l1', 'aplha': 0.1, 'optimizer_para': { + 'lr': 0.1}}, init_param={ + 'method': 'random', 'fill_val': 1.0}, learning_rate_scheduler={ + 'method': 'linear', 'scheduler_params': {'start_factor'}}) +client.l2 = 0.01 +client.l1 = 0.01 +client.local_mode = True +client.fit(ctx, data, validate_data=data) +export_model = client.get_model() +pred = client.predict(ctx, data) +# pred_2 = client.predict(ctx, data_2) diff --git a/python/fate/ml/intersection/__init__.py b/python/fate/ml/intersection/__init__.py deleted file mode 100644 index e6bfe57182..0000000000 --- a/python/fate/ml/intersection/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from .raw_intersection import RawIntersectionGuest, RawIntersectionHost diff --git a/python/fate/ml/intersection/raw_intersection.py b/python/fate/ml/intersection/raw_intersection.py deleted file mode 100644 index a037470871..0000000000 --- a/python/fate/ml/intersection/raw_intersection.py +++ /dev/null @@ -1,55 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging - -from fate.interface import Context - -from ..abc.module import HeteroModule - -logger = logging.getLogger(__name__) - - -class RawIntersectionGuest(HeteroModule): - def __init__(self): - ... - - def fit(self, ctx: Context, train_data, validate_data=None): - # ctx.hosts.put("raw_index", train_data.index.tolist()) - ctx.hosts.put("raw_index", train_data.index.values) - intersect_indexes = ctx.hosts.get("intersect_index") - intersect_data = train_data - for intersect_index in intersect_indexes: - intersect_data = intersect_data.loc(intersect_index) - - intersect_count = intersect_data.count() - ctx.hosts.put("intersect_count", intersect_count) - - logger.info(f"intersect count={intersect_count}") - return intersect_data - - -class RawIntersectionHost(HeteroModule): - def __init__(self): - ... - - def fit(self, ctx: Context, train_data, validate_data=None): - guest_index = ctx.guest.get("raw_index") - intersect_data = train_data.loc(guest_index) - # ctx.guest.put("intersect_index", intersect_data.index.tolist()) - ctx.guest.put("intersect_index", intersect_data.index.values) - - intersect_count = ctx.guest.get("intersect_count") - logger.info(f"intersect count={intersect_count}") - return intersect_data diff --git a/python/fate/ml/lr/arbiter.py b/python/fate/ml/lr/arbiter.py deleted file mode 100644 index 196b75b5f2..0000000000 --- a/python/fate/ml/lr/arbiter.py +++ /dev/null @@ -1,55 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging - -from fate.arch.dataframe import DataLoader -from fate.interface import Context - -from ..abc.module import HeteroModule - -logger = logging.getLogger(__name__) - - -class LrModuleArbiter(HeteroModule): - def __init__( - self, - batch_size, - max_iter=100, - ): - self.max_iter = max_iter - self.batch_size = batch_size - - def fit(self, ctx: Context) -> None: - encryptor, decryptor = ctx.cipher.phe.keygen(options=dict(key_length=2048)) - # ctx.guest("encryptor").put(encryptor) # ctx.guest.put("encryptor", encryptor) - ctx.hosts("encryptor").put(encryptor) - # num_batch = ctx.guest.get("num_batch") - batch_loader = DataLoader( - dataset=None, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="arbiter", sync_arbiter=True - ) - logger.info(f"batch_num={batch_loader.batch_num}") - step = 0 - for _, iter_ctx in ctx.range(self.max_iter): - for batch_ctx, _ in iter_ctx.iter(batch_loader): - g_guest_enc = batch_ctx.guest.get("g_enc") - g = decryptor.decrypt(g_guest_enc) - batch_ctx.guest.put("g", g) - for i, g_host_enc in enumerate(batch_ctx.hosts.get("g_enc")): - g = decryptor.decrypt(g_host_enc) - batch_ctx.hosts[i].put("g", g) - loss = decryptor.decrypt(batch_ctx.guest.get("loss")) - iter_ctx.metrics.log_loss("lr_loss", loss.tolist(), step=step) - logger.info(f"loss={loss}") - step += 1 diff --git a/python/fate/ml/lr/guest.py b/python/fate/ml/lr/guest.py deleted file mode 100644 index 6a455618be..0000000000 --- a/python/fate/ml/lr/guest.py +++ /dev/null @@ -1,116 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging - -from fate.arch import dataframe, tensor -from fate.arch.tensor import dtype -from fate.interface import Context - -from ..abc.module import HeteroModule - -logger = logging.getLogger(__name__) - - -class LrModuleGuest(HeteroModule): - def __init__( - self, - max_iter, - batch_size, - learning_rate=0.01, - alpha=1.0, - ): - self.max_iter = max_iter - self.batch_size = batch_size - self.learning_rate = learning_rate - self.alpha = alpha - - self.w = None - - def fit(self, ctx: Context, train_data, validate_data=None) -> None: - """ - l(w) = 1/h * Σ(log(2) - 0.5 * y * xw + 0.125 * (wx)^2) - ∇l(w) = 1/h * Σ(0.25 * xw - 0.5 * y)x = 1/h * Σdx - where d = 0.25(xw - 2y) - loss = log2 - (1/N)*0.5*∑ywx + (1/N)*0.125*[∑(Wg*Xg)^2 + ∑(Wh*Xh)^2 + 2 * ∑(Wg*Xg * Wh*Xh)] - """ - # mock data - batch_loader = dataframe.DataLoader( - train_data, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="guest", sync_arbiter=True - ) - # # get encryptor - # ctx.arbiter("encryptor").get() - - w = tensor.randn((train_data.num_features, 1), dtype=dtype.float32) - for i, iter_ctx in ctx.range(self.max_iter): - logger.info(f"start iter {i}") - j = 0 - for batch_ctx, (X, Y) in iter_ctx.iter(batch_loader): - h = X.shape[0] - - # d - Xw = tensor.matmul(X, w) - d = 0.25 * Xw - 0.5 * Y - loss = 0.125 / h * tensor.matmul(Xw.T, Xw) - 0.5 / h * tensor.matmul(Xw.T, Y) - for Xw_h in batch_ctx.hosts.get("Xw_h"): - d += Xw_h - loss -= 0.5 / h * tensor.matmul(Y.T, Xw_h) - loss += 0.25 / h * tensor.matmul(Xw.T, Xw_h) - for Xw2_h in batch_ctx.hosts.get("Xw2_h"): - loss += 0.125 / h * Xw2_h - batch_ctx.hosts.put(d=d) - batch_ctx.arbiter.put(loss=loss) - - # gradient - batch_ctx.arbiter.put("g_enc", X.T @ d) - g: tensor.Tensor = batch_ctx.arbiter.get("g") - # apply l2 penalty - g = g / h + self.alpha * w - w -= self.learning_rate * g - logger.info(f"w={w}") - j += 1 - self.w = w - - def predict(self, ctx, test_data): - batch_loader = dataframe.DataLoader( - test_data, - ctx=ctx, - batch_size=-1, - mode="hetero", - role="guest", - sync_arbiter=False, - ) - for X, y in batch_loader: - output = tensor.matmul(X, self.w) - - return output - - def get_model(self): - return { - "w": self.w.to_local()._storage.data.tolist(), - "metadata": { - "max_iter": self.max_iter, - "batch_size": self.batch_size, - "learning_rate": self.learning_rate, - "alpha": self.alpha, - }, - } - - @classmethod - def from_model(cls, model) -> "LrModuleGuest": - lr = LrModuleGuest(**model["metadata"]) - import torch - - lr.w = tensor.tensor(torch.tensor(model["w"])) - return lr diff --git a/python/fate/ml/lr/host.py b/python/fate/ml/lr/host.py deleted file mode 100644 index 1d66a608f9..0000000000 --- a/python/fate/ml/lr/host.py +++ /dev/null @@ -1,97 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging - -import torch -from fate.arch import tensor -from fate.arch.dataframe import DataLoader -from fate.interface import Context - -from ..abc.module import HeteroModule - -logger = logging.getLogger(__name__) - - -class LrModuleHost(HeteroModule): - def __init__( - self, - max_iter, - batch_size=None, - learning_rate=0.01, - alpha=1.0, - ): - self.max_iter = max_iter - self.learning_rate = learning_rate - self.alpha = alpha - self.batch_size = batch_size - - self.w = None - - def fit(self, ctx: Context, train_data, validate_data=None) -> None: - batch_loader = DataLoader(train_data, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="host") - # get encryptor - encryptor = ctx.arbiter("encryptor").get() - - w = tensor.tensor(torch.randn((train_data.num_features, 1), dtype=torch.float32)) - for i, iter_ctx in ctx.range(self.max_iter): - logger.info(f"start iter {i}") - j = 0 - for batch_ctx, X in iter_ctx.iter(batch_loader): - h = X.shape[0] - logger.info(f"start batch {j}") - Xw_h = 0.25 * tensor.matmul(X, w) - encryptor.encrypt(Xw_h).to(batch_ctx.guest, "Xw_h") - encryptor.encrypt(tensor.matmul(Xw_h.T, Xw_h)).to(batch_ctx.guest, "Xw2_h") - d = batch_ctx.guest.get("d") - tensor.matmul(X.T, d).to(batch_ctx.arbiter, "g_enc") - g = batch_ctx.arbiter.get("g") - g = g / h + self.alpha * w - w -= self.learning_rate * g - logger.info(f"w={w}") - j += 1 - - self.w = w - - def get_model(self): - return { - "w": self.w.to_local()._storage.data.tolist(), - "metadata": { - "max_iter": self.max_iter, - "batch_size": self.batch_size, - "learning_rate": self.learning_rate, - "alpha": self.alpha, - }, - } - - def predict(self, ctx, test_data): - batch_loader = DataLoader( - test_data, - ctx=ctx, - batch_size=-1, - mode="hetero", - role="host", - sync_arbiter=False, - ) - for X in batch_loader: - output = tensor.matmul(X, self.w) - print(output) - - @classmethod - def from_model(cls, model) -> "LrModuleHost": - lr = LrModuleHost(**model["metadata"]) - import torch - - lr.w = tensor.tensor(torch.tensor(model["w"])) - return lr diff --git a/python/fate/ml/model_selection/__init__.py b/python/fate/ml/model_selection/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/fate/ml/model_selection/data_split.py b/python/fate/ml/model_selection/data_split.py new file mode 100644 index 0000000000..fdea5edb6c --- /dev/null +++ b/python/fate/ml/model_selection/data_split.py @@ -0,0 +1,221 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +from fate.arch import Context +from fate.arch.dataframe import DataFrame +from ..abc.module import Module + +logger = logging.getLogger(__name__) + + +class DataSplitModuleGuest(Module): + def __init__( + self, + train_size=0.8, + validate_size=0.2, + test_size=0.0, + stratified=False, + random_state=None, + hetero_sync=True + ): + self.train_size = train_size + self.validate_size = validate_size + self.test_size = test_size + self.stratified = stratified + self.random_state = random_state + self.hetero_sync = hetero_sync + + def fit(self, ctx: Context, train_data, validate_data=None): + data_count = train_data.shape[0] + train_size, validate_size, test_size = get_split_data_size(self.train_size, + self.validate_size, + self.test_size, + data_count) + if self.stratified: + train_data_set = sample_per_label(train_data, sample_count=train_size, random_state=self.random_state) + else: + train_data_set = sample_data(df=train_data, n=train_size, random_state=self.random_state) + if train_data_set is not None: + train_sid = train_data_set.get_indexer(target="sample_id") + validate_test_data_set = train_data.drop(train_data_set) + else: + train_sid = None + validate_test_data_set = train_data + + if self.stratified: + validate_data_set = sample_per_label(validate_test_data_set, sample_count=validate_size, + random_state=self.random_state) + else: + validate_data_set = sample_data(df=validate_test_data_set, n=validate_size, random_state=self.random_state) + if validate_data_set is not None: + validate_sid = validate_data_set.get_indexer(target="sample_id") + test_data_set = validate_test_data_set.drop(validate_data_set) + if test_data_set.shape[0] == 0: + test_sid = None + test_data_set = None + else: + test_sid = test_data_set.get_indexer(target="sample_id") + else: + validate_sid = None + if validate_test_data_set.shape[0] == 0: + test_data_set = None + test_sid = None + else: + test_data_set = validate_test_data_set + test_sid = validate_test_data_set.get_indexer(target="sample_id") + + if self.hetero_sync: + ctx.hosts.put("train_data_sid", train_sid) + ctx.hosts.put("validate_data_sid", validate_sid) + ctx.hosts.put("test_data_sid", test_sid) + + return train_data_set, validate_data_set, test_data_set + + +class DataSplitModuleHost(Module): + def __init__( + self, + train_size=0.8, + validate_size=0.2, + test_size=0.0, + stratified=False, + random_state=None, + hetero_sync=True + ): + self.train_size = train_size + self.validate_size = validate_size + self.test_size = test_size + self.stratified = stratified + self.random_state = random_state + self.hetero_sync = hetero_sync + + def fit(self, ctx: Context, train_data, validate_data=None): + if self.hetero_sync: + train_data_sid = ctx.guest.get("train_data_sid") + validate_data_sid = ctx.guest.get("validate_data_sid") + test_data_sid = ctx.guest.get("test_data_sid") + train_data_set, validate_data_set, test_data_set = None, None, None + if train_data_sid: + train_data_set = train_data.loc(train_data_sid, preserve_order=True) + if validate_data_sid: + validate_data_set = train_data.loc(validate_data_sid, preserve_order=True) + if test_data_sid: + test_data_set = train_data.loc(test_data_sid, preserve_order=True) + else: + data_count = train_data.shape[0] + train_size, validate_size, test_size = get_split_data_size(self.train_size, + self.validate_size, + self.test_size, + data_count) + + if self.stratified: + train_data_set = sample_per_label(train_data, sample_count=train_size, random_state=self.random_state) + else: + train_data_set = sample_data(df=train_data, n=train_size, random_state=self.random_state) + if train_data_set is not None: + # train_sid = train_data_set.get_indexer(target="sample_id") + validate_test_data_set = train_data.drop(train_data_set) + else: + validate_test_data_set = train_data + + if self.stratified: + validate_data_set = sample_per_label(validate_test_data_set, sample_count=validate_size, + random_state=self.random_state) + else: + validate_data_set = sample_data(df=validate_test_data_set, n=validate_size, + random_state=self.random_state) + if validate_data_set is not None: + # validate_sid = validate_data_set.get_indexer(target="sample_id") + test_data_set = validate_test_data_set.drop(validate_data_set) + if test_data_set.shape[0] == 0: + test_data_set = None + else: + if validate_test_data_set.shape[0] == 0: + test_data_set = None + else: + test_data_set = validate_test_data_set + + return train_data_set, validate_data_set, test_data_set + + +def sample_data(df, n, random_state): + if n == 0: + return + else: + return df.sample(n=n, random_state=random_state) + + +def sample_per_label(train_data, sample_count=None, random_state=None): + train_data_binarized_label = train_data.label.get_dummies() + labels = [label_name.split("_")[1] for label_name in train_data_binarized_label.columns] + sampled_data_df = [] + sampled_n = 0 + data_n = train_data.shape[0] + for i, label in enumerate(labels): + label_data = train_data.iloc(train_data.label == int(label)) + if i == len(labels) - 1: + # last label: + to_sample_n = sample_count - sampled_n + else: + to_sample_n = round(label_data.shape[0] / data_n * sample_count) + label_sampled_data = sample_data(df=label_data, n=to_sample_n, random_state=random_state) + if label_sampled_data is not None: + sampled_data_df.append(label_sampled_data) + sampled_n += label_sampled_data.shape[0] + sampled_data = None + if sampled_data_df: + sampled_data = DataFrame.vstack(sampled_data_df) + return sampled_data + + +def get_split_data_size(train_size, validate_size, test_size, data_count): + """ + Validate & transform param inputs into all int + """ + # check & transform data set sizes + if isinstance(test_size, float) or isinstance(train_size, float) or isinstance(validate_size, float): + total_size = 1.0 + else: + total_size = data_count + if train_size is None: + if validate_size is None: + train_size = total_size - test_size + validate_size = total_size - (test_size + train_size) + else: + if test_size is None: + test_size = 0 + train_size = total_size - (validate_size + test_size) + elif test_size is None: + if validate_size is None: + test_size = total_size - train_size + validate_size = total_size - (test_size + train_size) + else: + test_size = total_size - (validate_size + train_size) + elif validate_size is None: + if train_size is None: + train_size = total_size - test_size + validate_size = total_size - (test_size + train_size) + + if abs((abs(train_size) + abs(test_size) + abs(validate_size)) - total_size) > 1e-6: + raise ValueError(f"train_size, test_size, validate_size should sum up to 1.0 or data count") + + if isinstance(train_size, float): + train_size = round(train_size * data_count) + validate_size = round(validate_size * data_count) + test_size = total_size - train_size - validate_size + return train_size, validate_size, test_size diff --git a/python/fate/ml/model_selection/sample.py b/python/fate/ml/model_selection/sample.py new file mode 100644 index 0000000000..0fd7e6bf34 --- /dev/null +++ b/python/fate/ml/model_selection/sample.py @@ -0,0 +1,114 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +from fate.arch import Context +from fate.arch.dataframe import utils +from ..abc.module import Module + +logger = logging.getLogger(__name__) + + +class SampleModuleGuest(Module): + def __init__( + self, + replace=False, + frac=1.0, + n=None, + random_state=None, + hetero_sync=True + ): + self.replace = replace + self.frac = frac + self.n = n + self.random_state = random_state + self.hetero_sync = hetero_sync + + self._sample_obj = None + + def fit(self, ctx: Context, train_data, validate_data=None) -> None: + logger.info(f"enter sample fit") + if self.hetero_sync: + logger.info(f"hetero sync") + logger.info(f"role: {ctx.local.role}") + + sampled_data = utils.federated_sample(ctx, + train_data, + n=self.n, + frac=self.frac, + replace=self.replace, + role=ctx.local.role, + random_state=self.random_state) + else: + logger.info(f"local sample") + # local sample + sampled_data = utils.local_sample(ctx, + train_data, + n=self.n, + frac=self.frac, + replace=self.replace, + random_state=self.random_state) + + return sampled_data + + +class SampleModuleHost(Module): + def __init__( + self, + replace=False, + frac=1.0, + n=None, + random_state=None, + hetero_sync=True + ): + self.replace = replace + self.frac = frac + self.n = n + self.random_state = random_state + self.hetero_sync = hetero_sync + + def fit(self, ctx: Context, train_data, validate_data=None) -> None: + logger.info(f"enter sample fit") + if self.hetero_sync: + logger.info(f"hetero sync") + logger.info(f"role: {ctx.local.role}") + + sampled_data = utils.federated_sample(ctx, + train_data, + role=ctx.local.role) + else: + # local sample + logger.info(f"local sample") + sampled_data = utils.local_sample(ctx, + train_data, + n=self.n, + frac=self.frac, + replace=self.replace, + random_state=self.random_state) + """elif self.mode == "weight": + if self.n is not None: + sampled_data = train_data.sample(n=self.n, + replace=self.replace, + weight=train_data.weight, + random_state=self.random_state) + else: + sampled_data = train_data.sample(frac=self.frac, + relace=self.replace, + weight=train_data.weight, + random_state=self.random_state)""" + + return sampled_data diff --git a/python/fate/ml/nn/__init__.py b/python/fate/ml/nn/__init__.py new file mode 100644 index 0000000000..ae946a49c4 --- /dev/null +++ b/python/fate/ml/nn/__init__.py @@ -0,0 +1,14 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/fate/ml/nn/algo/__init__.py b/python/fate/ml/nn/algo/__init__.py new file mode 100644 index 0000000000..ae946a49c4 --- /dev/null +++ b/python/fate/ml/nn/algo/__init__.py @@ -0,0 +1,14 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/fate/ml/nn/algo/homo/__init__.py b/python/fate/ml/nn/algo/homo/__init__.py new file mode 100644 index 0000000000..ae946a49c4 --- /dev/null +++ b/python/fate/ml/nn/algo/homo/__init__.py @@ -0,0 +1,14 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/fate/ml/nn/algo/homo/fedavg.py b/python/fate/ml/nn/algo/homo/fedavg.py new file mode 100644 index 0000000000..0248301d33 --- /dev/null +++ b/python/fate/ml/nn/algo/homo/fedavg.py @@ -0,0 +1,125 @@ +from transformers.training_args import TrainingArguments +from fate.ml.nn.trainer.trainer_base import FedTrainerClient, FedTrainerServer, TrainingArguments +from fate.ml.nn.trainer.trainer_base import FedArguments, TrainingArguments +from dataclasses import field +from dataclasses import dataclass, field +from dataclasses import dataclass +from typing import List, Optional, Tuple, Callable, Union +from fate.arch import Context +from torch.optim import Optimizer +from torch.utils.data import Dataset +from torch.optim.lr_scheduler import _LRScheduler +from transformers.trainer_callback import TrainerCallback +from torch.nn import Module +from torch import nn +from torch.utils.data import DataLoader +from fate.ml.aggregator import PlainTextAggregatorClient, SecureAggregatorClient +from fate.ml.aggregator import PlainTextAggregatorServer, SecureAggregatorServer +from transformers import TrainerState, TrainerControl, PreTrainedTokenizer +from fate.ml.aggregator import AggregatorType, aggregator_map +import logging + + +logger = logging.getLogger(__name__) + + +@dataclass +class FedAVGArguments(FedArguments): + pass + + +class FedAVGCLient(FedTrainerClient): + def __init__( + self, + ctx: Context, + model: Module, + training_args: TrainingArguments, + fed_args: FedArguments, + train_set: Dataset, + val_set: Dataset = None, + loss_fn: Module = None, + optimizer: Optimizer = None, + scheduler: _LRScheduler = None, + callbacks: List[TrainerCallback] = [], + data_collator: Callable = None, + tokenizer: Optional[PreTrainedTokenizer] = None, + use_hf_default_behavior: bool = False, + compute_metrics: Callable = None, + local_mode: bool = False, + ): + super().__init__( + ctx, + model, + training_args, + fed_args, + train_set, + val_set, + loss_fn, + optimizer, + data_collator, + scheduler, + tokenizer, + callbacks, + use_hf_default_behavior, + compute_metrics=compute_metrics, + local_mode=local_mode, + ) + + def init_aggregator(self, ctx: Context, fed_args: FedArguments): + aggregate_type = "weighted_mean" + aggregator_name = "fedavg" + aggregator = fed_args.aggregator + assert aggregator in { + item.value for item in AggregatorType + }, f"aggregator should be one of {{item.value for item in AggregatorType}}, but got {aggregator}" + client_class = aggregator_map[aggregator][0] + logger.info(f"Using {aggregator} aggregator") + sample_num = len(self.train_dataset) + ctx.arbiter.put("agg_type", aggregator) + aggregator = client_class( + ctx, aggregate_type=aggregate_type, aggregator_name=aggregator_name, sample_num=sample_num + ) + + return aggregator + + def on_federation( + self, + ctx: Context, + aggregator: Union[PlainTextAggregatorClient, SecureAggregatorClient], + fed_args: FedArguments, + args: TrainingArguments, + model: Optional[nn.Module] = None, + optimizer: Optional[Optimizer] = None, + scheduler: Optional[_LRScheduler] = None, + dataloader: Optional[Tuple[DataLoader]] = None, + control: Optional[TrainerControl] = None, + state: Optional[TrainerState] = None, + **kwargs, + ): + aggregator.model_aggregation(ctx, model) + + +class FedAVGServer(FedTrainerServer): + def __init__(self, ctx: Context, local_mode: bool = False) -> None: + super().__init__(ctx, local_mode) + + def init_aggregator(self, ctx): + aggregator = [ctx.guest.get("agg_type")] + aggregator.extend(ctx.hosts.get("agg_type")) + aggregator = set(aggregator) + if len(aggregator) > 1: + raise ValueError("Aggregator type should be the same between clients, but got {}".format(aggregator)) + aggregator = aggregator.pop() + aggregator_name = "fedavg" + aggregator_server = aggregator_map[aggregator][1] + logger.info(f"Using {aggregator} aggregator") + aggregator = aggregator_server(ctx, aggregator_name=aggregator_name) + return aggregator + + def on_federation(self, ctx: Context, aggregator: Union[SecureAggregatorServer, PlainTextAggregatorServer]): + aggregator.model_aggregation(ctx) + + +class FedAVG(object): + client = FedAVGCLient + server = FedAVGServer diff --git a/python/fate/ml/nn/dataset/__init__.py b/python/fate/ml/nn/dataset/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/fate/ml/nn/dataset/base.py b/python/fate/ml/nn/dataset/base.py new file mode 100644 index 0000000000..c1b5ae8c28 --- /dev/null +++ b/python/fate/ml/nn/dataset/base.py @@ -0,0 +1,39 @@ +from torch.utils.data import Dataset as Dataset_ +import abc +import pandas as pd + + +class Dataset(Dataset_): + def __init__(self, **kwargs): + super(Dataset, self).__init__() + + # Function to implemented + @abc.abstractmethod + def load(self, data_or_path): + raise NotImplementedError( + "You must implement load function so that Client can pass file-path to this " "class" + ) + + def __getitem__(self, item): + raise NotImplementedError() + + def __len__(self): + raise NotImplementedError() + + def has_label(self) -> bool: + pass + + def get_classes(self) -> list: + pass + + def get_match_ids(self) -> pd.DataFrame: + pass + + def get_sample_ids(self) -> pd.DataFrame: + pass + + def get_sample_id_name(self) -> str: + pass + + def get_match_id_name(self) -> str: + pass diff --git a/python/fate/ml/nn/dataset/table.py b/python/fate/ml/nn/dataset/table.py new file mode 100644 index 0000000000..b01a7972c4 --- /dev/null +++ b/python/fate/ml/nn/dataset/table.py @@ -0,0 +1,225 @@ +import numpy as np +import pandas as pd +from fate.arch.dataframe import DataFrame +from fate.ml.nn.dataset.base import Dataset +import logging +import torch as t + + +logger = logging.getLogger(__name__) + + +class TableDataset(Dataset): + + """ + A Table Dataset, load data from a give csv path, or transform FATE DTable + + Parameters + ---------- + label_col str, name of label column in csv, if None, will automatically take 'y' or 'label' or 'target' as label + match_id_col str, name of match id column in csv, if None, will automatically take 'id' or 'sid' as match id + sample_id_col str, name of sample id column in csv, if None, will automatically generate sample id + feature_dtype str, dtype of features, available: 'long', 'int', 'float', 'double' + label_dtype str, dtype of label, available: 'long', 'int', 'float', 'double' + label_shape tuple or list, shape of label, if None, will automatically infer from data + flatten_label bool, whether to flatten label, if True, will flatten label to 1-d array + to_tensor bool, whether to transform data to pytorch tensor, if True, will transform data to tensor + return_dict bool, whether to return a dict in the format of {'x': xxx, 'label': xxx} if True, will return a dict, else will return a tuple + """ + + def __init__( + self, + label_col=None, + match_id_col=None, + sample_id_col=None, + feature_dtype="float", + label_dtype="float", + label_shape=None, + flatten_label=False, + to_tensor=True, + return_dict=False, + ): + super(TableDataset, self).__init__() + self.features: np.ndarray = None + self.label: np.ndarray = None + self.sample_weights: np.ndarray = None + self.origin_table: pd.DataFrame = pd.DataFrame() + self.label_col = label_col + self.match_id_col = match_id_col + self.sample_id_col = sample_id_col + self.f_dtype = self.check_dtype(feature_dtype) + self.l_dtype = self.check_dtype(label_dtype) + self.to_tensor = to_tensor + self.return_dict = return_dict + if label_shape is not None: + assert isinstance(label_shape, tuple) or isinstance(label_shape, list), "label shape is {}".format( + label_shape + ) + self.label_shape = label_shape + self.flatten_label = flatten_label + + # sample ids, match ids + self.sample_ids = None + self.match_ids = None + + if self.label_col is not None: + assert isinstance(self.label_col, str) or isinstance( + self.label_col, int + ), "label columns parameter must be a str or an int" + + @staticmethod + def check_dtype(dtype): + if dtype is not None: + avail = ["long", "int", "float", "double"] + assert dtype in avail, "available dtype is {}, but got {}".format(avail, dtype) + if dtype == "long": + return np.int64 + if dtype == "int": + return np.int32 + if dtype == "float": + return np.float32 + if dtype == "double": + return np.float64 + return dtype + + def __getitem__(self, item): + if self.label is not None: + feat = self.features[item] + label = self.label[item] + if self.to_tensor: + feat = t.tensor(feat) + label = t.tensor(label) + if self.return_dict: + return {"x": feat, "label": label} + else: + return feat, label + else: + feat = self.features[item] + if self.to_tensor: + feat = t.tensor(feat) + if self.return_dict: + return {"x": feat} + else: + return feat + + def __len__(self): + return len(self.features) + + def load(self, data_or_path): + if isinstance(data_or_path, str): + self.origin_table = pd.read_csv(data_or_path) + # if is FATE DTable, collect data and transform to array format + label_col_candidates = ["y", "label", "target"] + # automatically set id columns + if self.match_id_col is not None: + if self.match_id_col not in self.origin_table: + raise ValueError("match id column {} not found".format(self.match_id_col)) + else: + self.match_ids = self.origin_table[[self.match_id_col]] + self.origin_table = self.origin_table.drop(columns=[self.match_id_col]) + else: + match_id_col_cadidaites = ["id", "sid"] + for id_col in match_id_col_cadidaites: + if id_col in self.origin_table: + self.match_ids = self.origin_table[[id_col]] + self.origin_table = self.origin_table.drop(columns=[id_col]) + break + if self.match_ids is None: + logger.info("match id column not found, no match id will be set") + + # generate sample ids + if self.sample_id_col is not None: + if self.sample_id_col not in self.origin_table: + raise ValueError("sample id column {} not found".format(self.sample_id_col)) + self.sample_ids = self.origin_table[[self.sample_id_col]] + self.origin_table = self.origin_table.drop(columns=[self.sample_id_col]) + else: + self.sample_ids = pd.DataFrame() + self.sample_ids["sample_id"] = range(len(self.origin_table)) + logger.info( + "sample id column not found, generate sample id from 0 to {}".format(len(self.origin_table)) + ) + + # infer column name + label = self.label_col + if label is None: + for i in label_col_candidates: + if i in self.origin_table: + label = i + logger.info('use "{}" as label column'.format(label)) + break + if label is None: + logger.info('found no "y"/"label"/"target" in input table, no label will be set') + else: + if label not in self.origin_table: + raise ValueError("label column {} not found in input table".format(label)) + + self.label = self.origin_table[[label]].values + self.origin_table = self.origin_table.drop(columns=[label]) + self.features = self.origin_table.values + + elif isinstance(data_or_path, DataFrame): + schema = data_or_path.schema + sample_id = schema.sample_id_name + match_id = schema.match_id_name + label = schema.label_name + if label is None: + logger.info("label column is None, not provided in the uploaded data") + pd_df = data_or_path.as_pd_df() + if label is None: + labels = None + features = pd_df.drop([sample_id, match_id], axis=1) + else: + labels = pd_df[[label]] + features = pd_df.drop([sample_id, match_id, label], axis=1) + self.label = labels.values + sample_ids = pd_df[[sample_id]] + match_ids = pd_df[[match_id]] + self.sample_ids = sample_ids + self.match_ids = match_ids + self.features = features.values + + if self.label is not None: + if self.l_dtype: + self.label = self.label.astype(self.l_dtype) + + if self.label_shape: + self.label = self.label.reshape(self.label_shape) + else: + self.label = self.label.reshape((len(self.features), -1)) + + if self.flatten_label: + self.label = self.label.flatten() + + else: + self.label = None + + if self.f_dtype: + self.features = self.features.astype(self.f_dtype) + + def get_classes(self): + if self.label is not None: + return np.unique(self.label).tolist() + else: + raise ValueError("no label found, please check if self.label is set") + + def get_sample_ids(self) -> np.ndarray: + return self.sample_ids.values + + def get_match_ids(self) -> np.ndarray: + return self.match_ids.values + + def get_sample_id_name(self) -> str: + if self.sample_ids is not None and isinstance(self.sample_ids, pd.DataFrame): + return self.sample_ids.columns[0] + else: + raise ValueError("Cannot get sample id name") + + def get_match_id_name(self) -> str: + if self.match_ids is not None and isinstance(self.match_ids, pd.DataFrame): + return self.match_ids.columns[0] + else: + raise ValueError("Cannot get match id name") + + def has_label(self) -> bool: + return self.label is not None diff --git a/python/fate/ml/nn/model_zoo/__init__.py b/python/fate/ml/nn/model_zoo/__init__.py new file mode 100644 index 0000000000..ae946a49c4 --- /dev/null +++ b/python/fate/ml/nn/model_zoo/__init__.py @@ -0,0 +1,14 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/fate/ml/nn/model_zoo/multi_model.py b/python/fate/ml/nn/model_zoo/multi_model.py new file mode 100644 index 0000000000..c6e68813e5 --- /dev/null +++ b/python/fate/ml/nn/model_zoo/multi_model.py @@ -0,0 +1,14 @@ +from torch import nn + + +class Multi(nn.Module): + def __init__(self, feat=18, class_num=4) -> None: + super().__init__() + self.class_num = class_num + self.model = nn.Sequential(nn.Linear(feat, 10), nn.ReLU(), nn.Linear(10, class_num)) + + def forward(self, x): + if self.training: + return self.model(x) + else: + return nn.Softmax(dim=-1)(self.model(x)) diff --git a/python/fate/ml/nn/trainer/__init__.py b/python/fate/ml/nn/trainer/__init__.py new file mode 100644 index 0000000000..ae946a49c4 --- /dev/null +++ b/python/fate/ml/nn/trainer/__init__.py @@ -0,0 +1,14 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/fate/ml/nn/trainer/test/test_trainer.py b/python/fate/ml/nn/trainer/test/test_trainer.py new file mode 100644 index 0000000000..f068e385d5 --- /dev/null +++ b/python/fate/ml/nn/trainer/test/test_trainer.py @@ -0,0 +1,78 @@ +from fate.ml.nn.algo.homo.fedavg import FedAVGCLient, FedArguments, TrainingArguments, FedAVGServer +import torch as t +import pandas as pd +from fate.ml.nn.dataset.table import TableDataset +import sys + + +arbiter = ("arbiter", 10000) +guest = ("guest", 10000) +host = ("host", 9999) +name = "fed" + + +def create_ctx(local): + from fate.arch import Context + from fate.arch.computing.standalone import CSession + from fate.arch.federation.standalone import StandaloneFederation + import logging + + logger = logging.getLogger() + logger.setLevel(logging.DEBUG) + + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.DEBUG) + + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + console_handler.setFormatter(formatter) + + logger.addHandler(console_handler) + computing = CSession() + return Context( + computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]) + ) + + +if __name__ == "__main__": + epoch = 10 + model = t.nn.Sequential(t.nn.Linear(30, 1), t.nn.Sigmoid()) + + ds = TableDataset(return_dict=False, to_tensor=True) + ds.load("../../../../../../examples/data/breast_homo_guest.csv") + + if sys.argv[1] == "guest": + ctx = create_ctx(guest) + fed_args = FedArguments(aggregate_strategy="epochs", aggregate_freq=1, aggregator="secure_aggregate") + args = TrainingArguments( + num_train_epochs=5, per_device_train_batch_size=16, logging_strategy="steps", logging_steps=5 + ) + trainer = FedAVGCLient( + ctx=ctx, + model=model, + fed_args=fed_args, + training_args=args, + loss_fn=t.nn.BCELoss(), + optimizer=t.optim.SGD(model.parameters(), lr=0.01), + train_set=ds, + ) + trainer.train() + + elif sys.argv[1] == "host": + ctx = create_ctx(host) + fed_args = FedArguments(aggregate_strategy="epochs", aggregate_freq=1, aggregator="secure_aggregate") + args = TrainingArguments(num_train_epochs=5, per_device_train_batch_size=16) + trainer = FedAVGCLient( + ctx=ctx, + model=model, + fed_args=fed_args, + training_args=args, + loss_fn=t.nn.BCELoss(), + optimizer=t.optim.SGD(model.parameters(), lr=0.01), + train_set=ds, + ) + trainer.train() + + else: + ctx = create_ctx(arbiter) + trainer = FedAVGServer(ctx) + trainer.train() diff --git a/python/fate/ml/nn/trainer/trainer_base.py b/python/fate/ml/nn/trainer/trainer_base.py new file mode 100644 index 0000000000..466451c372 --- /dev/null +++ b/python/fate/ml/nn/trainer/trainer_base.py @@ -0,0 +1,991 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import re +import torch +import math +import sys +from torch import nn +import numpy as np +import logging +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Union, Callable +from enum import Enum +from transformers.training_args import TrainingArguments +from fate.arch import Context +from torch.optim import Optimizer +from torch.utils.data import DataLoader, Dataset +from transformers import TrainingArguments as _hf_TrainingArguments, PreTrainedTokenizer +from transformers import Trainer, TrainerState, TrainerControl, EvalPrediction +from transformers.trainer_utils import has_length +from torch.optim.lr_scheduler import _LRScheduler, LambdaLR +from torch.utils.data import _utils +from fate.ml.aggregator.base import Aggregator +import logging +from transformers import logging as transformers_logging +from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState +from typing import Optional +from dataclasses import dataclass, field, fields +from transformers.trainer_callback import PrinterCallback +from fate.ml.aggregator import AggregatorType + + +# Reset the logger to redirect logs output +transformers_logging.disable_default_handler() +transformers_logging.enable_propagation() +logger = logging.getLogger(__name__) + + +def get_ith_checkpoint(directory, i): + # List all files in the directory + files = os.listdir(directory) + + # Filter for checkpoint directories + checkpoint_dirs = [f for f in files if f.startswith("checkpoint-")] + + # Extract the numbers from the checkpoint directory names + checkpoint_numbers = [int(re.search(r"\d+", dir).group()) for dir in checkpoint_dirs] + + # Pair the checkpoint directories with their numbers and sort by the + # numbers + sorted_checkpoints = sorted(zip(checkpoint_dirs, checkpoint_numbers), key=lambda x: x[1]) + + if i < 0: + raise ValueError(f"checkpoint idx i must be greater than or equal to 0, got {i}") + if i > len(sorted_checkpoints) - 1: + raise ValueError(f"checkpoint number is {len(sorted_checkpoints)}, but got {i}") + # Return the name of the ith checkpoint directory + return sorted_checkpoints[i][0] + + +""" +Fed Arguments +""" + + +class AggregateStrategy(Enum): + EPOCH = "epochs" + STEP = "steps" + + +@dataclass +class FedArguments(object): + """ + The argument for Fed algorithm + """ + + aggregate_strategy: AggregateStrategy = field(default=AggregateStrategy.EPOCH.value) + aggregate_freq: int = field(default=1) + aggregator: str = field(default=AggregatorType.SECURE_AGGREGATE.value) + + def to_dict(self): + """ + Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates + the token values by removing their value. + """ + # filter out fields that are defined as field(init=False) + d = dict((field.name, getattr(self, field.name)) for field in fields(self) if field.init) + + for k, v in d.items(): + if isinstance(v, Enum): + d[k] = v.value + if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum): + d[k] = [x.value for x in v] + if k.endswith("_token"): + d[k] = f"<{k.upper()}>" + return d + + +@dataclass +class TrainingArguments(_hf_TrainingArguments): + # in fate-2.0, we will control the output dir when using pipeline + output_dir: str = field(default="./") + disable_tqdm: bool = field(default=True) + save_strategy: str = field(default="no") + logging_strategy: str = field(default="epoch") + evaluation_strategy: str = field(default="no") + logging_dir: str = field(default=None) + checkpoint_idx: int = field(default=None) + # by default we use constant learning rate, the same as FATE-1.X + lr_scheduler_type: str = field(default="constant") + + def __post_init__(self): + # Always use default values for hub-related attributes + self.push_to_hub = False + self.hub_model_id = None + self.hub_strategy = "every_save" + self.hub_token = None + self.hub_private_repo = False + self.push_to_hub_model_id = None + self.push_to_hub_organization = None + self.push_to_hub_token = None + + super().__post_init__() + + def to_dict(self): + # Call the superclass's to_dict method + all_args = super().to_dict() + + # Get a dict with default values for all fields + default_args = _hf_TrainingArguments(output_dir="./").to_dict() + + # Filter out args that are equal to their default values + set_args = {name: value for name, value in all_args.items() if value != default_args.get(name)} + + return set_args + + +""" +Fed Callback Related Classes +""" + + +class ShortcutCallBackInterFace(object): + def __init__(self) -> None: + pass + + def on_init_end( + self, + ctx: Context, + aggregator: Aggregator, + fed_args: FedArguments, + args: TrainingArguments, + model: Optional[nn.Module] = None, + optimizer: Optional[Optimizer] = None, + scheduler: Optional[_LRScheduler] = None, + dataloader: Optional[Tuple[DataLoader]] = None, + control: Optional[TrainerControl] = None, + state: Optional[TrainerState] = None, + **kwargs, + ): + pass + + def on_train_begin( + self, + ctx: Context, + aggregator: Aggregator, + fed_args: FedArguments, + args: TrainingArguments, + model: Optional[nn.Module] = None, + optimizer: Optional[Optimizer] = None, + scheduler: Optional[_LRScheduler] = None, + dataloader: Optional[Tuple[DataLoader]] = None, + control: Optional[TrainerControl] = None, + state: Optional[TrainerState] = None, + **kwargs, + ): + pass + + def on_train_end( + self, + ctx: Context, + aggregator: Aggregator, + fed_args: FedArguments, + args: TrainingArguments, + model: Optional[nn.Module] = None, + optimizer: Optional[Optimizer] = None, + scheduler: Optional[_LRScheduler] = None, + dataloader: Optional[Tuple[DataLoader]] = None, + control: Optional[TrainerControl] = None, + state: Optional[TrainerState] = None, + **kwargs, + ): + pass + + def on_epoch_begin( + self, + ctx: Context, + aggregator: Aggregator, + fed_args: FedArguments, + args: TrainingArguments, + model: Optional[nn.Module] = None, + optimizer: Optional[Optimizer] = None, + scheduler: Optional[_LRScheduler] = None, + dataloader: Optional[Tuple[DataLoader]] = None, + control: Optional[TrainerControl] = None, + state: Optional[TrainerState] = None, + **kwargs, + ): + pass + + def on_epoch_end( + self, + ctx: Context, + aggregator: Aggregator, + fed_args: FedArguments, + args: TrainingArguments, + model: Optional[nn.Module] = None, + optimizer: Optional[Optimizer] = None, + scheduler: Optional[_LRScheduler] = None, + dataloader: Optional[Tuple[DataLoader]] = None, + control: Optional[TrainerControl] = None, + state: Optional[TrainerState] = None, + **kwargs, + ): + pass + + def on_step_begin( + self, + ctx: Context, + aggregator: Aggregator, + fed_args: FedArguments, + args: TrainingArguments, + model: Optional[nn.Module] = None, + optimizer: Optional[Optimizer] = None, + scheduler: Optional[_LRScheduler] = None, + dataloader: Optional[Tuple[DataLoader]] = None, + control: Optional[TrainerControl] = None, + state: Optional[TrainerState] = None, + **kwargs, + ): + pass + + def on_step_end( + self, + ctx: Context, + aggregator: Aggregator, + fed_args: FedArguments, + args: TrainingArguments, + model: Optional[nn.Module] = None, + optimizer: Optional[Optimizer] = None, + scheduler: Optional[_LRScheduler] = None, + dataloader: Optional[Tuple[DataLoader]] = None, + control: Optional[TrainerControl] = None, + state: Optional[TrainerState] = None, + **kwargs, + ): + pass + + +class FedCallbackInterface(object): + def on_federation( + self, + ctx: Context, + aggregator: Aggregator, + fed_args: FedArguments, + args: TrainingArguments, + model: Optional[nn.Module] = None, + optimizer: Optional[Optimizer] = None, + scheduler: Optional[_LRScheduler] = None, + dataloader: Optional[Tuple[DataLoader]] = None, + control: Optional[TrainerControl] = None, + state: Optional[TrainerState] = None, + **kwargs, + ): + pass + + def init_aggregator(self, fed_arg: FedArguments): + raise NotImplementedError("init_aggregator() must be implemented in subclass, init aggregator here") + + +# I dont like huggingface logging +class LogSuppressFilter(logging.Filter): + def filter(self, record): + suppress_list = set( + ["\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n"] + ) + if record.getMessage() in suppress_list: + return False + return True + + +def compute_max_aggregation( + fed_args: FedArguments, max_epoch: int, max_steps: int, epochs_trained: int, steps_trained: int +) -> int: + assert ( + max_epoch > epochs_trained and max_epoch > 0 + ), "max_epoch must be greater than epochs_trained: {} and greater than 0".format(epochs_trained) + assert ( + max_steps > steps_trained and max_steps > 0 + ), "max_steps must be greater than steps_trained: {} and greater than 0".format(steps_trained) + + if isinstance(fed_args.aggregate_freq, float) and fed_args.aggregate_freq < 1 and fed_args.aggregate_freq > 0: + if fed_args.aggregate_strategy == AggregateStrategy.EPOCH.value: + aggregate_freq = int(max_epoch / int(1 / fed_args.aggregate_freq)) + elif fed_args.aggregate_strategy == AggregateStrategy.STEP.value: + aggregate_freq = int(max_steps / int(1 / fed_args.aggregate_freq)) + + elif isinstance(fed_args.aggregate_freq, int) and fed_args.aggregate_freq > 0: + aggregate_freq = fed_args.aggregate_freq + else: + raise ValueError("aggregate_freq must be a positive integer or a float between 0 and 1") + + if fed_args.aggregate_strategy == AggregateStrategy.EPOCH.value: + max_aggregation = int((max_epoch - epochs_trained) / aggregate_freq) + elif fed_args.aggregate_strategy == AggregateStrategy.STEP.value: + max_aggregation = int((max_steps - steps_trained) / aggregate_freq) + else: + raise ValueError('aggregate_strategy must be either "epochs" or "steps"') + + return max_aggregation, aggregate_freq + + +class AggregationChecker: + def __init__( + self, + fed_args, + max_aggregation, + aggregate_freq, + max_epoch: int, + max_steps: int, + epochs_trained: int, + steps_trained: int, + ): + self.fed_args = fed_args + self.max_epoch = max_epoch + self.max_steps = max_steps + self.epochs_trained = epochs_trained + self.steps_trained = steps_trained + self.aggregation_count = 0 + self.aggregate_freq = aggregate_freq + self.max_aggregation = max_aggregation + + def report(self): + logger.info(f"Aggregation count: {self.aggregation_count} / {self.max_aggregation}") + + def should_aggregate(self, state: TrainerState) -> bool: + cur_epoch = int(state.epoch) + cur_step = int(state.global_step) + + if self.aggregation_count >= self.max_aggregation: + return False + + if cur_epoch > self.max_epoch: + return False + + strategy = self.fed_args.aggregate_strategy + + if strategy == AggregateStrategy.EPOCH.value: + if cur_epoch > self.epochs_trained and (cur_epoch - self.epochs_trained) % self.aggregate_freq == 0: + return True + elif strategy == AggregateStrategy.STEP.value: + if cur_step > self.steps_trained and (cur_step - self.steps_trained) % self.aggregate_freq == 0: + return True + + return False + + def inc_aggregation_count(self): + self.aggregation_count += 1 + self.report() + + +class FedParameterAlignCallback(TrainerCallback): + def __init__( + self, + trainer_class, + ctx: Context, + training_args: TrainingArguments, + fed_args: FedArguments, + is_server: bool = False, + ) -> None: + super().__init__() + self.trainer_class = trainer_class + self.ctx = ctx + self.is_server = is_server + self.training_args = training_args + self.fed_args = fed_args + self._suffix = "fed_para" + self._send_count = 0 + self._parameters = None + self._aggregation_checker = None + + def get_aggregation_checker(self): + return self._aggregation_checker + + def _client_send_parameters(self, state: TrainerState, args, train_dataloader): + # client need to compute: epochs, max_steps, num_step_per_epoch, trained_epoch, trained_steps + # and sync with server + + # compute num_train_epochs, max_steps + len_dataloader = None + + if has_length(train_dataloader): + len_dataloader = len(train_dataloader) + num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps + num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) + if args.max_steps > 0: + max_steps = args.max_steps + num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( + args.max_steps % num_update_steps_per_epoch > 0 + ) + else: + max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) + num_train_epochs = math.ceil(args.num_train_epochs) + + elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size + max_steps = args.max_steps + # Setting a very large number of epochs so we go as many times as + # necessary over the iterator. + num_train_epochs = sys.maxsize + num_update_steps_per_epoch = max_steps + + # warm start related variables + epochs_trained = state.global_step // num_update_steps_per_epoch + if not args.ignore_data_skip: + steps_trained_in_current_epoch = state.global_step % (num_update_steps_per_epoch) + steps_trained_in_current_epoch *= args.gradient_accumulation_steps + else: + steps_trained_in_current_epoch = 0 + + max_aggregation, aggregate_freq = compute_max_aggregation( + self.fed_args, num_train_epochs, max_steps, epochs_trained, state.global_step + ) + logger.info("computed max_aggregation is {}".format(max_aggregation)) + + # send parameters + parameters = { + "num_train_epochs": num_train_epochs, + "max_steps": max_steps, + "num_update_steps_per_epoch": num_update_steps_per_epoch, + "epochs_trained": epochs_trained, + "steps_trained_in_current_epoch": steps_trained_in_current_epoch, + "max_aggregation": max_aggregation, + "aggregate_freq": aggregate_freq, + "aggregation_strategy": self.fed_args.aggregate_strategy, + } + + logger.info("parameters is {}".format(parameters)) + + self.ctx.arbiter.put(self._suffix + "_" + str(self._send_count), parameters) + self._send_count += 1 + self._parameters = parameters + self.trainer_class.aggregation_checker = AggregationChecker( + self.fed_args, + max_aggregation, + aggregate_freq, + num_train_epochs, + max_steps, + epochs_trained, + state.global_step, + ) + + def get_parameters(self): + return self._parameters + + def _startegy_type(self, strategy): + # by step or by epoch + by_step = set([AggregateStrategy.STEP.value]) + by_epoch = set([AggregateStrategy.EPOCH.value]) + if strategy in by_step: + return "by_step" + elif strategy in by_epoch: + return "by_epoch" + else: + raise ValueError("strategy {} not supported".format(strategy)) + + def _check_fed_strategy(self, parameters): + # check the fed strategy, assert all clients has the same startegy + all_cilent_strategy = [p["aggregation_strategy"] for p in parameters] + logger.info("all client strategies are {}".format(all_cilent_strategy)) + strategy_flag = self._startegy_type(all_cilent_strategy[0]) + for p in all_cilent_strategy[1:]: + if self._startegy_type(p) != strategy_flag: + raise ValueError( + "fed strategy not match, all clients has to have the same strategy: by epoch(epoch) or by step(step, progress_percentage),\n \ + please check: {}".format( + all_cilent_strategy + ) + ) + + return strategy_flag + + def _check_federation_round(self, parameters): + agg_round = [p["max_aggregation"] for p in parameters] + if len(set(agg_round)) != 1: + raise ValueError( + "federation round not match, all clients has to have the same aggregation round,\n \ + please check: {}".format( + agg_round + ) + ) + return agg_round[0] + + def _server_check_parameters(self): + # check if all clients parameters of aggregation match + para_1 = self.ctx.hosts.get(self._suffix + "_" + str(self._send_count)) + para_2 = self.ctx.guest.get(self._suffix + "_" + str(self._send_count)) + self._send_count += 1 + para_1.append(para_2) + para = para_1 + # strategy = self._check_fed_strategy(para) + agg_round = self._check_federation_round(para) + self._parameters = {"max_aggregation": agg_round} + + def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + if self.trainer_class.local_mode: + logger.info("FedParameterAlignCallback: local model, skipping federated parameter checking") + return + else: + if self.is_server: + self._server_check_parameters() + else: + train_dataloader = kwargs["train_dataloader"] + self._client_send_parameters(state, args, train_dataloader) + + +class FatePrinterCallback(TrainerCallback): + def on_log(self, args, state, control, logs=None, **kwargs): + if state.is_local_process_zero: + _ = logs.pop("total_flos", None) + logger.info(str(logs)) + + +class CallbackWrapper(TrainerCallback): + def __init__(self, ctx: Context, wrapped_trainer: "StdFedTrainerMixin"): + self.ctx = ctx + self.wrapped_trainer = wrapped_trainer + self.fed_arg = self.wrapped_trainer._fed_args + + def _call_wrapped(self, ctx, aggregator, fed_arg, event_name: str, **kwargs): + event = getattr(self.wrapped_trainer, event_name) + kwargs["scheduler"] = kwargs.pop("lr_scheduler", None) + + train_dataloader = kwargs.pop("train_dataloader", None) + eval_dataloader = kwargs.pop("eval_dataloader", None) + dataloaders = tuple(filter(None, (train_dataloader, eval_dataloader))) + kwargs["dataloader"] = dataloaders + return event(ctx, aggregator, fed_arg, **kwargs) + + +class WrappedFedCallback(CallbackWrapper): + def __init__(self, ctx: Context, wrapped_trainer: "StdFedTrainerMixin"): + super().__init__(ctx, wrapped_trainer) + + def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + # initialize aggregator + # doesnot call wrapper here, make sure aggregator is not called before + # it is initialized + if self.wrapped_trainer.local_mode: + logger.info("local mode, skip federation aggregator initialization, aggregator will be None") + else: + self.wrapped_trainer.aggregator = self.wrapped_trainer.init_aggregator(self.ctx, self.fed_arg) + + def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + if self.wrapped_trainer.local_mode: + return + if self.fed_arg.aggregate_strategy == AggregateStrategy.EPOCH.value: + if self.wrapped_trainer.aggregation_checker.should_aggregate(state): + logger.info("aggregation on epoch end") + agg_round = self.wrapped_trainer.aggregation_checker.aggregation_count + sub_ctx = self.ctx.sub_ctx("aggregation").indexed_ctx(agg_round) + ret = self._call_wrapped( + sub_ctx, + self.wrapped_trainer.aggregator, + self.fed_arg, + "on_federation", + args=args, + state=state, + control=control, + **kwargs, + ) + self.wrapped_trainer.aggregation_checker.inc_aggregation_count() + return ret + + def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + if self.wrapped_trainer.local_mode: + return + if self.fed_arg.aggregate_strategy == AggregateStrategy.STEP.value: + if self.wrapped_trainer.aggregation_checker.should_aggregate(state): + logger.info("state is {}".format(state)) + logger.info("aggregation on step end") + agg_round = self.wrapped_trainer.aggregation_checker.aggregation_count + sub_ctx = self.ctx.sub_ctx("aggregation").indexed_ctx(agg_round) + ret = self._call_wrapped( + sub_ctx, + self.wrapped_trainer.aggregator, + self.fed_arg, + "on_federation", + args=args, + state=state, + control=control, + **kwargs, + ) + self.wrapped_trainer.aggregation_checker.inc_aggregation_count() + return ret + + +class WrappedShortcutCallback(CallbackWrapper): + def __init__(self, ctx: Context, wrapped_trainer: "StdFedTrainerMixin"): + super().__init__(ctx, wrapped_trainer) + + def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + return self._call_wrapped( + self.ctx, + self.wrapped_trainer.aggregator, + self.fed_arg, + "on_init_end", + args=args, + state=state, + control=control, + **kwargs, + ) + + def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + return self._call_wrapped( + self.ctx, + self.wrapped_trainer.aggregator, + self.fed_arg, + "on_train_begin", + args=args, + state=state, + control=control, + **kwargs, + ) + + def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + return self._call_wrapped( + self.ctx, + self.wrapped_trainer.aggregator, + self.fed_arg, + "on_train_end", + args=args, + state=state, + control=control, + **kwargs, + ) + + def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + return self._call_wrapped( + self.ctx, + self.wrapped_trainer.aggregator, + self.fed_arg, + "on_epoch_begin", + args=args, + state=state, + control=control, + **kwargs, + ) + + def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + return self._call_wrapped( + self.ctx, + self.wrapped_trainer.aggregator, + self.fed_arg, + "on_epoch_end", + args=args, + state=state, + control=control, + **kwargs, + ) + + def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + return self._call_wrapped( + self.ctx, + self.wrapped_trainer.aggregator, + self.fed_arg, + "on_step_begin", + args=args, + state=state, + control=control, + **kwargs, + ) + + def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + return self._call_wrapped( + self.ctx, + self.wrapped_trainer.aggregator, + self.fed_arg, + "on_step_end", + args=args, + state=state, + control=control, + **kwargs, + ) + + +logger.addFilter(LogSuppressFilter()) + + +""" +Mixin Class For Federation Trainer +""" + + +class StdFedTrainerMixin(FedCallbackInterface, ShortcutCallBackInterFace): + def __init__( + self, + ctx: Context, + model: nn.Module, + training_args: TrainingArguments, + fed_args: FedArguments, + train_set: Dataset, + val_set: Dataset = None, + loss_fn: nn.Module = None, + optimizer: torch.optim.Optimizer = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + tokenizer: Optional[PreTrainedTokenizer] = None, + callbacks: Optional[List[TrainerCallback]] = [], + use_hf_default_behavior: bool = False, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + local_mode: bool = False, + ): + assert isinstance(callbacks, list), "callback must be a list containing Callback objects, but got {}".format( + callbacks + ) + + self.ctx: Context = ctx + self.local_mode = local_mode + self._callbacks = callbacks + self._args = training_args + self._fed_args = fed_args + self._user_compute_metric_func = compute_metrics + self.train_dataset = train_set + self.eval_dataset = val_set + self.loss_func = loss_fn + self._use_hf_default_behavior = use_hf_default_behavior + self._aggregator = None + + # for callback class to check if aggregation is needed + self.aggregation_checker: AggregationChecker = None + + def _compute_metrics_warp_func(self, *args, **kwargs): + if self._user_compute_metric_func is None: + return {} + else: + eval_result = self._user_compute_metric_func(*args, **kwargs) + # Do some FATEBoard Callback here + return eval_result + + def _handle_callback(self, callback_handler, new_callbacks): + # remove default logger.infoer callback, need to use our logging + # strategy + new_callback_list = [] + for i in callback_handler.callbacks: + # if not isinstance(i, logger.infoerCallback): + new_callback_list.append(i) + new_callback_list += new_callbacks + callback_handler.callbacks = new_callback_list + + def _add_fate_callback(self, callback_handler): + # the callback handler is Trainer.callback_handler + # call order: + # fed callback aggregator init(once), parameter check(once), + # on federation of fedcallback + # callbacks of shortcutcallback + new_callback_list = [] + for i in callback_handler.callbacks: + if isinstance(i, PrinterCallback): + continue + else: + new_callback_list.append(i) + new_callback_list.append(FatePrinterCallback()) + callback_handler.callbacks = new_callback_list + callback_handler.callbacks.append(WrappedFedCallback(self.ctx, self)) + callback_handler.callbacks.append( + FedParameterAlignCallback( + self, self.ctx, fed_args=self._fed_args, training_args=self._args, is_server=False + ) + ) + + callback_handler.callbacks.append(WrappedShortcutCallback(self.ctx, self)) + + def _remove_fed_callback(self, callback_class): + self.callback_handler.callbacks = [ + c for c in self.callback_handler.callbacks if not isinstance(c, callback_class) + ] + + def set_local_mode(self): + self.local_mode = True + logger.info("trainer set to local mode") + + def set_fed_mode(self): + self.local_mode = False + logger.info("trainer set to federated mode") + + @property + def aggregator(self): + return self._aggregator + + @aggregator.setter + def aggregator(self, value): + self._aggregator = value + + +""" +Base Classes of Client/Sever Trainer +""" + + +class FedTrainerClient(Trainer, StdFedTrainerMixin): + + """ + FedTrainerClient is designed to handle diverse federated training tasks. + + By extending the transformers.Trainer class, this class allows customization of the federated training, + evaluation, and prediction processes to meet the needs of specific federateion training tasks. Users can + override relevant methods to implement custom functionality. + """ + + def __init__( + self, + ctx: Context, + model: nn.Module, + training_args: TrainingArguments, + fed_args: FedArguments, + train_set: Dataset, + val_set: Dataset = None, + loss_fn: nn.Module = None, + optimizer: torch.optim.Optimizer = None, + data_collator: Callable = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + tokenizer: Optional[PreTrainedTokenizer] = None, + callbacks: Optional[List[TrainerCallback]] = [], + use_hf_default_behavior: bool = False, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + local_mode: bool = False, + ): + # in case you forget to set evaluation_strategy + if val_set is not None and training_args.evaluation_strategy == "no": + training_args.evaluation_strategy = "epoch" + + StdFedTrainerMixin.__init__( + self, + ctx=ctx, + model=model, + loss_fn=loss_fn, + optimizer=optimizer, + training_args=training_args, + fed_args=fed_args, + train_set=train_set, + val_set=val_set, + scheduler=scheduler, + callbacks=callbacks, + use_hf_default_behavior=use_hf_default_behavior, + compute_metrics=compute_metrics, + local_mode=local_mode, + ) + + if data_collator is None: + data_collator = _utils.collate.default_collate + + # concat checkpoint path if checkpoint idx is set + if self._args.checkpoint_idx is not None: + checkpoint_path = self._args.resume_from_checkpoint + if checkpoint_path is not None and os.path.exists(checkpoint_path): + checkpoint_folder = get_ith_checkpoint(checkpoint_path, self._args.checkpoint_idx) + self._args.resume_from_checkpoint = os.path.join(checkpoint_path, checkpoint_folder) + + Trainer.__init__( + self, + model=model, + args=self._args, + train_dataset=train_set, + eval_dataset=val_set, + data_collator=data_collator, + optimizers=[optimizer, scheduler], + tokenizer=tokenizer, + compute_metrics=self._compute_metrics_warp_func, + ) + + self._add_fate_callback(self.callback_handler) + + def init_aggregator(self, ctx: Context, fed_arg: FedArguments): + return None + + def compute_loss(self, model, inputs, **kwargs): + if self._use_hf_default_behavior: + return super().compute_loss(model, inputs, **kwargs) + else: + # (features, labels), this format is used in FATE-1.x + if isinstance(inputs, tuple) or isinstance(inputs, list) and len(inputs) == 2: + feats, labels = inputs + output = model(feats) + loss = self.loss_func(output, labels) + return loss + else: + return super().compute_loss(model, inputs, **kwargs) + + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ): + if self._use_hf_default_behavior: + return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) + else: + # (features, labels), this format is used in FATE-1.x + # now the model is in eval status + if isinstance(inputs, tuple) or isinstance(inputs, list) and len(inputs) == 2: + with torch.no_grad(): + feats, labels = inputs + logits = model(feats) + return (None, logits, labels) + else: + return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) + + +class FedTrainerServer(object): + def __init__(self, ctx: Context, local_mode: bool = False) -> None: + self.ctx = ctx + self.local_mode = local_mode + self._max_steps = None + self._parameter_check_callback = FedParameterAlignCallback(self, self.ctx, None, None, is_server=True) + self._max_aggregation = None + + def set_fed_context(self, ctx: Context): + assert isinstance(ctx, Context), "ctx must be a Context object, but got {}".format(ctx) + self.ctx = ctx + + def set_local_mode(self): + self.local_mode = True + logger.info("trainer set to local mode") + + def set_fed_mode(self): + self.local_mode = False + logger.info("trainer set to federated mode") + + def init_aggregator(self, ctx: Context): + return None + + def on_train_end(self, ctx: Context, aggregator: Aggregator): + pass + + def on_train_begin(self, ctx: Context, aggregator: Aggregator): + pass + + def on_init_end(self, ctx: Context, aggregator: Aggregator): + pass + + def on_federation(self, ctx: Context, aggregator: Aggregator): + pass + + def train(self): + if self.local_mode: + logger.info("Local model is set, skip initializing fed setting & aggregator") + return + + self.aggregator: Aggregator = self.init_aggregator(self.ctx) + logger.info("Initialized aggregator Done: {}".format(self.aggregator)) + self._parameter_check_callback.on_train_begin(None, None, None) # only get parameters from clients and align + parameters = self._parameter_check_callback.get_parameters() + self._max_aggregation = parameters["max_aggregation"] + logger.info("checked parameters are {}".format(parameters)) + + self.on_init_end(self.ctx, aggregator=self.aggregator) + self.on_train_begin(self.ctx, aggregator=self.aggregator) + + ctx = self.ctx + for i in range(self._max_aggregation): + sub_ctx = ctx.sub_ctx("aggregation").indexed_ctx(i) + self.on_federation(sub_ctx, aggregator=self.aggregator) + + self.on_train_end(self.ctx, aggregator=self.aggregator) + + def predict(self): + pass diff --git a/python/fate/ml/feature_scale/__init__.py b/python/fate/ml/preprocessing/__init__.py similarity index 96% rename from python/fate/ml/feature_scale/__init__.py rename to python/fate/ml/preprocessing/__init__.py index 415b1f10aa..77cdb94f60 100644 --- a/python/fate/ml/feature_scale/__init__.py +++ b/python/fate/ml/preprocessing/__init__.py @@ -12,4 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from .feature_scale import FeatureScale +from .union import Union diff --git a/python/fate/ml/preprocessing/feature_scale.py b/python/fate/ml/preprocessing/feature_scale.py new file mode 100644 index 0000000000..391b510617 --- /dev/null +++ b/python/fate/ml/preprocessing/feature_scale.py @@ -0,0 +1,165 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +import pandas as pd + +from fate.arch import Context +from ..abc.module import Module + +logger = logging.getLogger(__name__) + + +class FeatureScale(Module): + def __init__(self, method="standard", scale_col=None, feature_range=None, strict_range=True): + self.method = method + self._scaler = None + if self.method == "standard": + self._scaler = StandardScaler(scale_col) + elif self.method == "min_max": + self._scaler = MinMaxScaler(scale_col, feature_range, strict_range) + else: + raise ValueError(f"Unknown scale method {self.method} given. Please check.") + + def fit(self, ctx: Context, train_data, validate_data=None) -> None: + self._scaler.fit(ctx, train_data) + + def transform(self, ctx: Context, test_data): + return self._scaler.transform(ctx, test_data) + + def get_model(self): + scaler_info = self._scaler.to_model() + model_data = dict(scaler_info=scaler_info) + return {"data": model_data, "meta": {"method": self.method, + "model_type": "feature_scale"}} + + def restore(self, model): + self._scaler.from_model(model) + + @classmethod + def from_model(cls, model) -> "FeatureScale": + scaler = FeatureScale(model["meta"]["method"]) + scaler.restore(model["data"]["scaler_info"]) + return scaler + + +class StandardScaler(Module): + def __init__(self, select_col): + self._mean = None + self._std = None + self.select_col = select_col + + def fit(self, ctx: Context, train_data, validate_data=None) -> None: + if self.select_col is None: + self.select_col = train_data.schema.columns.to_list() + train_data_select = train_data[self.select_col] + self._mean = train_data_select.mean() + self._std = train_data_select.std() + + def transform(self, ctx: Context, test_data): + test_data_select = test_data[self.select_col] + test_data[self.select_col] = (test_data_select - self._mean) / self._std + return test_data + + def to_model(self): + return dict( + mean=self._mean.to_dict(), + mean_dtype=self._mean.dtype.name, + std=self._std.to_dict(), + std_dtype=self._std.dtype.name, + select_col=self.select_col, + ) + + def from_model(self, model): + self._mean = pd.Series(model["mean"], dtype=model["mean_dtype"]) + self._std = pd.Series(model["std"], dtype=model["std_dtype"]) + self.select_col = model["select_col"] + + +class MinMaxScaler(Module): + """ + Transform data by scaling features to given feature range. + Note that if `strict_range` is set, transformed values will always be within given range, + regardless whether transform data exceeds training data value range (as in 1.x ver) + + The transformation is given by:: + X_scaled = (X - X.min()) / (X.max() - X.min()) * feature_range + feature_range_min + = (X - X.min()) * (feature_range / (X.max() - X.min()) + feature_range_min + """ + + def __init__(self, select_col, feature_range, strict_range): + self.feature_range = feature_range + self.select_col = select_col + self.strict_range = strict_range + self._scale = None + self._scale_min = None + self._range_min = None + self._range_max = None + + def fit(self, ctx: Context, train_data, validate_data=None) -> None: + if self.select_col is None: + self.select_col = train_data.schema.columns.to_list() + train_data_select = train_data[self.select_col] + data_max = train_data_select.max() + data_min = train_data_select.min() + + # select_col has same keys as feature_range + self._range_min = pd.Series({col: self.feature_range[col][0] for col in self.select_col}) + self._range_max = pd.Series({col: self.feature_range[col][1] for col in self.select_col}) + + data_range = data_max - data_min + # for safe division + data_range[data_range < 1e-6] = 1.0 + self._scale = (self._range_max - self._range_min) / data_range + self._scale_min = data_min * self._scale + + def transform(self, ctx: Context, test_data): + """ + Transformation is given by: + X_scaled = (X * scale - scale_min) + feature_range_min + where scale = feature_range / (X_train.max() - X_train.min()) and scale_min = X_train.min() * scale + + """ + test_data_select = test_data[self.select_col] + + data_scaled = test_data_select * self._scale - (self._scale_min + self._range_min) + if self.strict_range: + # restrict feature output within given feature value range + data_scaled = data_scaled[data_scaled >= self._range_min].fillna(self._range_min) + data_scaled = data_scaled[data_scaled <= self._range_max].fillna(self._range_max) + test_data[self.select_col] = data_scaled + return test_data + + def to_model(self): + return dict( + scale=self._scale.to_dict(), + scale_dtype=self._scale.dtype.name, + scale_min=self._scale_min.to_dict(), + scale_min_dtype=self._scale_min.dtype.name, + range_min=self._range_min.to_dict(), + range_min_dtype=self._range_min.dtype.name, + range_max=self._range_max.to_dict(), + range_max_dtype=self._range_max.dtype.name, + strict_range=self.strict_range, + select_col=self.select_col, + ) + + def from_model(self, model): + self._scale = pd.Series(model["scale"], dtype=model["scale_dtype"]) + self._scale_min = pd.Series(model["scale_min"], dtype=model["scale_min_dtype"]) + self._range_min = pd.Series(model["range_min"], dtype=model["range_min_dtype"]) + self._range_max = pd.Series(model["range_max"], dtype=model["range_max_dtype"]) + self.strict_range = model["strict_range"] + self.select_col = model["select_col"] diff --git a/python/fate/ml/preprocessing/union.py b/python/fate/ml/preprocessing/union.py new file mode 100644 index 0000000000..8ea1075b6f --- /dev/null +++ b/python/fate/ml/preprocessing/union.py @@ -0,0 +1,48 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from fate.arch import Context +from fate.arch.dataframe import DataFrame +from ..abc.module import Module + +logger = logging.getLogger(__name__) + + +class Union(Module): + def __init__(self, axis=0): + self.axis = axis + + def fit(self, ctx: Context, train_data_list): + sample_id_name_list = [data.schema.sample_id_name for data in train_data_list] + if sum([name != sample_id_name_list[0] for name in sample_id_name_list]): + raise ValueError(f"Data sets should all have the same sample_id_name for union.") + + match_id_name_list = [data.schema.match_id_name for data in train_data_list] + if sum([name != match_id_name_list[0] for name in match_id_name_list]): + raise ValueError(f"Data sets should all have the same match_id_name for union.") + + if self.axis == 0: + label_name_list = [data.schema.label_name for data in train_data_list] + if sum([name != label_name_list[0] for name in label_name_list]): + raise ValueError(f"Data sets should all have the same label_name for union.") + + column_name_list = [set(data.schema.columns) for data in train_data_list] + if sum([col_names != column_name_list[0] for col_names in column_name_list]): + raise ValueError(f"Data sets should all have the same columns for union on 0 axis.") + result_data = DataFrame.vstack(train_data_list) + return result_data + else: + raise ValueError(f"axis must be 0, but got {self.axis}") diff --git a/python/fate/ml/statistics/__init__.py b/python/fate/ml/statistics/__init__.py new file mode 100644 index 0000000000..4c38e88545 --- /dev/null +++ b/python/fate/ml/statistics/__init__.py @@ -0,0 +1,16 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .statistics import FeatureStatistics diff --git a/python/fate/ml/statistics/statistics.py b/python/fate/ml/statistics/statistics.py new file mode 100644 index 0000000000..232d8d6ba2 --- /dev/null +++ b/python/fate/ml/statistics/statistics.py @@ -0,0 +1,137 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re +from typing import List + +import pandas as pd + +from fate.arch import Context +from ..abc.module import Module + +logger = logging.getLogger(__name__) + + +class FeatureStatistics(Module): + def __init__(self, metrics: List[str] = None, ddof=1, bias=True, relative_error=1e-3): + self.metrics = metrics + self.summary = StatisticsSummary(ddof, bias, relative_error) + + def fit(self, ctx: Context, input_data, validate_data=None) -> None: + self.summary.compute_metrics(input_data, self.metrics) + + def get_model(self): + model = self.summary.to_model() + output_model = {"data": model, + "meta": {"metrics": self.metrics, + "model_type": "statistics"}} + return output_model + + def restore(self, model): + self.summary.restore(model) + + def from_model(cls, model) -> "FeatureStatistics": + stat = FeatureStatistics(model["meta"]["metrics"]) + stat.restore(model["data"]) + return stat + + +class StatisticsSummary(Module): + def __init__(self, ddof=1, bias=True, relative_error=1e-3): + """if metrics is not None: + if len(metrics) == 1 and metrics[0] == "describe": + self.inner_metric_names = ['count', 'mean', 'std', 'min', '25%', '50%', '75%', 'max'] + else: + self.inner_metric_names = metrics""" + self.ddof = ddof + self.bias = bias + self.relative_error = relative_error + self.inner_metric_names = [] + self.metrics_summary = None + self._count = None + self._nan_count = None + self._mean = None + self._describe = None + self._quantile = None + self._q_pts = None + + def get_from_describe(self, data, metric): + if self._describe is None: + self._describe = data.describe(ddof=self.ddof, unbiased=~self.bias) + return self._describe[metric] + + def get_from_quantile_summary(self, data, metric): + query_q = int(metric[:-1]) / 100 + if self._quantile is None: + self._quantile = data.quantile(q=self._q_pts, relative_error=self.relative_error) + return self._quantile.loc[query_q] + + def compute_metrics(self, data, metrics): + res = pd.DataFrame(columns=data.schema.columns) + q_metrics = [metric for metric in metrics if re.match(r"^(100|\d{1,2})%$", metric)] + self._q_pts = [int(metric[:-1]) / 100 for metric in q_metrics] + for metric in metrics: + metric_val = None + """if metric == "describe": + res = data.describe(ddof=self.ddof, unbiased=~self.bias) + self.metrics_summary = res + self.inner_metric_names = list(res.index) + return""" + if metric in ["sum", "min", "max", "mean", "std", "var"]: + metric_val = self.get_from_describe(data, metric) + if metric in q_metrics: + metric_val = self.get_from_quantile_summary(data, metric) + elif metric == "count": + if self._count is None: + self._count = data.count() + metric_val = self._count + elif metric == "median": + metric_val = data.quantile(q=0.5, relative_error=self.relative_error) + metric_val = metric_val.loc[0.5] + elif metric == "coefficient_of_variation": + metric_val = self.get_from_describe(data, "variation") + elif metric == "missing_count": + if self._nan_count is None: + self._nan_count = self.get_from_describe(data, "na_count") + metric_val = self._nan_count + elif metric == "missing_ratio": + if self._nan_count is None: + self._nan_count = self.get_from_describe(data, "na_count") + if self._count is None: + self._count = data.count() + metric_val = self._nan_count / self._count + elif metric == "skewness": + metric_val = self.get_from_describe(data, "skew") + elif metric == "kurtosis": + metric_val = self.get_from_describe(data, "kurt") + + res.loc[metric] = metric_val + + has_nan = res.isnull().any() + if has_nan.any(): + nan_cols = res.columns[has_nan].to_list() + logger.warning( + f"NaN value(s) found in statistics over columns: {nan_cols}; " f"this may lead to unexpected behavior." + ) + self.metrics_summary = res + self.inner_metric_names = list(res.index) + + def to_model(self): + return {"inner_metric_names": self.inner_metric_names, "metrics_summary": self.metrics_summary.to_dict()} + + def restore(self, model): + self.inner_metric_names = model["inner_metric_names"] + self.metrics_summary = pd.DataFrame.from_dict(model["metrics_summary"]) diff --git a/python/fate/ml/utils/_convergence.py b/python/fate/ml/utils/_convergence.py new file mode 100644 index 0000000000..05498ecf52 --- /dev/null +++ b/python/fate/ml/utils/_convergence.py @@ -0,0 +1,108 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging + +import torch + +logger = logging.getLogger(__name__) + + +class _ConvergeFunction: + def __init__(self, eps): + self.eps = eps + + def is_converge(self, loss): pass + + +class _DiffConverge(_ConvergeFunction): + """ + Judge convergence by the difference between two iterations. + If the difference is smaller than eps, converge flag will be provided. + """ + + def __init__(self, eps): + super().__init__(eps=eps) + self.pre_loss = None + + def is_converge(self, loss): + logger.debug( + "In diff converge function, pre_loss: {}, current_loss: {}".format( + self.pre_loss, loss)) + + converge_flag = False + if self.pre_loss is None: + pass + elif abs(self.pre_loss - loss) < self.eps: + converge_flag = True + self.pre_loss = loss + return converge_flag + + +class _AbsConverge(_ConvergeFunction): + """ + Judge converge by absolute loss value. When loss value smaller than eps, converge flag + will be provided. + """ + + def is_converge(self, loss): + if loss <= self.eps: + converge_flag = True + else: + converge_flag = False + return converge_flag + + +class _WeightDiffConverge(_ConvergeFunction): + """ + Use 2-norm of gradient to judge whether converge or not. + """ + + def __init__(self, eps): + super().__init__(eps=eps) + self.pre_weight = None + + def is_converge(self, delta_weight, weight=None): + weight_diff = torch.linalg.norm(delta_weight, 2) + if weight is None: + # avoid tensor[bool] + if weight_diff < self.eps: + return True + return False + if self.pre_weight is None: + self.pre_weight = weight + return False + if weight_diff < self.eps * max([torch.linalg.norm(weight, 2), 1]): + return True + return False + + +def converge_func_factory(early_stop, tol): + # try: + # converge_func = param.converge_func + # eps = param.eps + # except AttributeError: + # raise AttributeError("Converge Function parameters has not been totally set") + + if early_stop == 'diff': + return _DiffConverge(tol) + elif early_stop == 'weight_diff': + return _WeightDiffConverge(tol) + elif early_stop == 'abs': + return _AbsConverge(tol) + else: + raise NotImplementedError( + "Converge Function method cannot be recognized: {}".format(early_stop)) diff --git a/python/fate/ml/utils/_model_param.py b/python/fate/ml/utils/_model_param.py new file mode 100644 index 0000000000..4aa7f42084 --- /dev/null +++ b/python/fate/ml/utils/_model_param.py @@ -0,0 +1,69 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +def initialize_param(coef_len, **kwargs): + param_len = coef_len + method = kwargs["method"] + fit_intercept = kwargs["fit_intercept"] + random_state = kwargs.get("random_state", None) + if fit_intercept: + param_len = param_len + 1 + if method == 'zeros': + return torch.zeros((param_len, 1), requires_grad=True) + elif method == 'ones': + return torch.ones((param_len, 1), requires_grad=True) + elif method == 'consts': + return torch.full( + (param_len, 1), float( + kwargs["fill_val"]), requires_grad=True) + elif method == 'random': + if random_state is not None: + generator = torch.Generator().manual_seed(random_state) + return torch.randn((param_len, 1), generator=generator, requires_grad=True) + return torch.randn((param_len, 1), requires_grad=True) + elif method == 'random_uniform': + if random_state is not None: + generator = torch.Generator().manual_seed(random_state) + return torch.rand((param_len, 1), generator=generator, requires_grad=True) + return torch.rand((param_len, 1), requires_grad=True) + else: + raise NotImplementedError(f"Unknown initialization method: {method}") + + +def serialize_param(param, fit_intercept=False): + dtype = str(param.dtype).split(".", -1)[-1] + w = param.tolist() + intercept = None + if fit_intercept: + intercept = w[-1] + w = w[:-1] + return {"coef_": w, "intercept_": intercept, "dtype": dtype} + + +def deserialize_param(param, fit_intercept=False): + w = param["coef_"] + if fit_intercept: + w.append(param["intercept_"]) + dtype = param["dtype"] + w = torch.tensor(w, dtype=getattr(torch, dtype)) + return w + + +def check_overflow(param, threshold=1e8): + if (torch.abs(param) > threshold).any(): + raise ValueError(f"Value(s) greater than {threshold} found in model param, please check.") diff --git a/python/fate/ml/utils/_optimizer.py b/python/fate/ml/utils/_optimizer.py new file mode 100644 index 0000000000..5229b5f14e --- /dev/null +++ b/python/fate/ml/utils/_optimizer.py @@ -0,0 +1,349 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import numpy as np +import torch + +logger = logging.getLogger(__name__) + + +class LRScheduler: + def __init__(self, method=None, lr_params=None, iters=0): + self.method = method + self.lr_params = lr_params + self.iters = iters + self.lr_scheduler = None + + def init_scheduler(self, optimizer): + self.lr_scheduler = lr_scheduler_factory(optimizer, self.method, self.lr_params) + + def step(self): + self.lr_scheduler.step() + self.iters += 1 + + @property + def lr(self): + return self.lr_scheduler.get_last_lr()[0] + + def state_dict(self): + return {"lr_scheduler": self.lr_scheduler.state_dict(), "method": self.method, "lr_params": self.lr_params} + + def load_state_dict(self, dict, optimizer): + self.method = dict["method"] + self.lr_params = dict["lr_params"] + self.init_scheduler(optimizer) + self.lr_scheduler.load_state_dict(dict["lr_scheduler"]) + + def get_last_lr(self): + return self.get_last_lr() + + +class Optimizer(object): + def __init__( + self, + method=None, + penalty=None, + alpha=None, + optim_param: dict = None, + iters: int = 0): + self.method = method + self.optim_param = optim_param + self.iters = iters + self.l2_penalty = True if penalty == "l2" else False + self.l1_penalty = True if penalty == "l1" else False + self.alpha = alpha + + self.model_parameter = None + self.prev_model_parameter = None + self.optimizer = None + + def init_optimizer( + self, + model_parameter_length=None, + model_parameter=None, + dtype=torch.float32): + # allow group of parameter in future + if model_parameter_length is not None: + model_parameter = torch.nn.parameter.Parameter(torch.zeros( + (model_parameter_length, 1), requires_grad=True, dtype=dtype)) + self.model_parameter = model_parameter + self.optimizer = optimizer_factory( + [model_parameter], self.method, self.optim_param) + # for regularization + # self.alpha = self.optimizer.state_dict()['param_groups'][0]['alpha'] + + def step(self, gradient): + # logger.info(f"before copy, model parameter: {self.model_parameter}") + self.prev_model_parameter = self.model_parameter.data.clone() + self.model_parameter.grad = gradient + self.optimizer.step() + # logger.info(f"after step, model parameter: {self.model_parameter}") + + def get_delta_gradients(self): + # logger.info(f"gradient: {self.model_parameter.grad}, prev model parameter: {self.prev_model_parameter}," + # f"delta grad: {self.prev_model_parameter.data - self.model_parameter.data}") + if self.prev_model_parameter is not None: + return self.prev_model_parameter.data - self.model_parameter.data + else: + raise ValueError(f"No optimization history found, please check.") + + def shrinkage_val(self, lr): + this_step_size = lr / np.sqrt(self.iters) + return self.alpha * this_step_size + + def state_dict(self): + optimizer_state_dict = self.optimizer.state_dict() + state_all = optimizer_state_dict["state"].get(0, {}) + for k, v in state_all.items(): + if isinstance(v, torch.Tensor): + state_all[k] = v.tolist() + dtype = str(self.model_parameter.dtype).split(".", -1)[-1] + return { + "l2_penalty": self.l2_penalty, + "l1_penalty": self.l1_penalty, + "alpha": self.alpha, + "optimizer": optimizer_state_dict, + "method": self.method, + "optim_param": self.optim_param, + "model_parameter": self.model_parameter.tolist(), + "model_parameter_dtype": dtype, + } + + def load_state_dict(self, state_dict): + self.l2_penalty = state_dict["l2_penalty"] + self.l1_penalty = state_dict["l1_penalty"] + self.alpha = state_dict["alpha"] + self.method = state_dict["method"] + self.optim_param = state_dict["optim_param"] + dtype = state_dict["model_parameter_dtype"] + self.init_optimizer( + model_parameter=torch.nn.parameter.Parameter( + torch.tensor(state_dict["model_parameter"], dtype=getattr(torch, dtype)) + ) + ) + state = state_dict["optimizer"] + state_all = state["state"].get(0, {}) + for k, v in state_all.items(): + if isinstance(v, list): + state_all[k] = torch.tensor(v) + self.optimizer.load_state_dict(state_dict["optimizer"]) + + def set_iters(self, new_iters): + self.iters = new_iters + + def _l1_updator(self, model_weights, gradient, fit_intercept, lr): + if fit_intercept: + gradient_without_intercept = gradient[:-1] + coef_ = model_weights[:-1] + else: + gradient_without_intercept = gradient + coef_ = model_weights + + new_weights = torch.sign(coef_ - gradient_without_intercept) * torch.maximum( + torch.tensor([0]), torch.abs(coef_ - gradient_without_intercept) - self.shrinkage_val(lr) + ) + + if fit_intercept: + new_intercept = model_weights[-1] - gradient[-1] + new_weights = torch.concat((new_weights, new_intercept.reshape((1, 1)))) + + return new_weights + + def add_regular_to_grad(self, grad, model_weights, fit_intercept=False): + if self.l2_penalty: + if fit_intercept: + weights_sum = torch.concat((model_weights[:-1], torch.tensor([[0]]))) + # logger.info(f"grad: {grad}, weights sum: {weights_sum}") + new_grad = grad + self.alpha * weights_sum + else: + new_grad = grad + self.alpha * model_weights + else: + new_grad = grad + + return new_grad + + def regularization_update( + self, + model_weights, + grad, + fit_intercept, + lr, + prev_round_weights=None): + if self.l1_penalty: + model_weights = self._l1_updator( + model_weights, grad, fit_intercept, lr) + else: + model_weights = model_weights - grad + """elif self.l2_penalty: + model_weights = self._l2_updator(model_weights, grad) + """ + """if prev_round_weights is not None: # additional proximal term for homo + coef_ = model_weights.unboxed + + if model_weights.fit_intercept: + coef_without_intercept = coef_[: -1] + else: + coef_without_intercept = coef_ + + coef_without_intercept -= self.mu * (model_weights.coef_ - prev_round_weights.coef_) + + if model_weights.fit_intercept: + new_coef_ = np.append(coef_without_intercept, coef_[-1]) + else: + new_coef_ = coef_without_intercept + + model_weights = LinearModelWeights(new_coef_, + model_weights.fit_intercept, + model_weights.raise_overflow_error)""" + return model_weights + + def __l1_loss_norm(self, model_weights): + loss_norm = torch.sum(self.alpha * model_weights) + return loss_norm.reshape((1, 1)) + + def __l2_loss_norm(self, model_weights): + loss_norm = 0.5 * self.alpha * \ + torch.matmul(model_weights.T, model_weights) + return loss_norm + + """def __add_proximal(self, model_weights, prev_round_weights): + prev_round_coef_ = prev_round_weights.coef_ + coef_ = model_weights.coef_ + diff = coef_ - prev_round_coef_ + loss_norm = self.mu * 0.5 * np.dot(diff, diff) + return loss_norm + """ + + def loss_norm(self, model_weights, prev_round_weights=None): + """ + proximal_term = None + if prev_round_weights is not None: + proximal_term = self.__add_proximal(model_weights, prev_round_weights) + """ + + if self.l1_penalty: + loss_norm_value = self.__l1_loss_norm(model_weights) + elif self.l2_penalty: + loss_norm_value = self.__l2_loss_norm(model_weights) + else: + loss_norm_value = None + + """# additional proximal term + if loss_norm_value is None: + loss_norm_value = proximal_term + elif proximal_term is not None: + loss_norm_value += proximal_term""" + return loss_norm_value + + """def hess_vector_norm(self, delta_s: LinearModelWeights): + if self.penalty == consts.L1_PENALTY: + return LinearModelWeights(np.zeros_like(delta_s.unboxed), + fit_intercept=delta_s.fit_intercept, + raise_overflow_error=delta_s.raise_overflow_error) + elif self.penalty == consts.L2_PENALTY: + return LinearModelWeights(self.alpha * np.array(delta_s.unboxed), + fit_intercept=delta_s.fit_intercept, + raise_overflow_error=delta_s.raise_overflow_error) + else: + return LinearModelWeights(np.zeros_like(delta_s.unboxed), + fit_intercept=delta_s.fit_intercept, + raise_overflow_error=delta_s.raise_overflow_error) + """ + + def update_weights( + self, + model_weights, + grad, + fit_intercept, + lr, + prev_round_weights=None, + has_applied=True): + """if not has_applied: + grad = self.add_regular_to_grad(grad, model_weights) + delta_grad = self.apply_gradients(grad) + else:""" + # logger.info( + # f"before update, model weights: {model_weights}, delta_grad: {grad}") + delta_grad = grad + model_weights = self.regularization_update( + model_weights, delta_grad, fit_intercept, lr, prev_round_weights) + # (f"after update, model weights: {model_weights}") + + return model_weights + + +def separate(value, size_list): + """ + Separate value in order to several set according size_list + Parameters + ---------- + value: 2d-tensor, input data + size_list: list, each set size + Returns + ---------- + list + separated 2d-tensors of sizes given in size_list + """ + separate_res = [] + cur = 0 + for size in size_list: + separate_res.append(value[cur : cur + size, :]) + cur += size + return separate_res + + +def optimizer_factory(model_parameter, optimizer_type, optim_params): + optimizer_params = optim_params + + if optimizer_type == "adadelta": + return torch.optim.Adadelta(model_parameter, **optimizer_params) + elif optimizer_type == "adagrad": + return torch.optim.Adagrad(model_parameter, **optimizer_params) + elif optimizer_type == "adam": + return torch.optim.Adam(model_parameter, **optimizer_params) + elif optimizer_type == "adamw": + return torch.optim.AdamW(model_parameter, **optimizer_params) + elif optimizer_type == "adamax": + return torch.optim.Adamax(model_parameter, **optimizer_params) + elif optimizer_type == "asgd": + return torch.optim.ASGD(model_parameter, **optimizer_params) + elif optimizer_type == "nadam": + return torch.optim.NAdam(model_parameter, **optimizer_params) + elif optimizer_type == "radam": + return torch.optim.RAdam(model_parameter, **optimizer_params) + elif optimizer_type == "rmsprop": + return torch.optim.RMSprop(model_parameter, **optimizer_params) + elif optimizer_type == "rprop": + return torch.optim.Rprop(model_parameter, **optimizer_params) + elif optimizer_type == "sgd": + return torch.optim.SGD(model_parameter, **optimizer_params) + else: + raise NotImplementedError( + "Optimize method cannot be recognized: {}".format(optimizer_type)) + + +def lr_scheduler_factory(optimizer, method, scheduler_param): + scheduler_method = method + if scheduler_method == "constant": + return torch.optim.lr_scheduler.ConstantLR(optimizer, **scheduler_param) + elif scheduler_method == "step": + return torch.optim.lr_scheduler.StepLR(optimizer, **scheduler_param) + elif scheduler_method == "linear": + return torch.optim.lr_scheduler.LinearLR(optimizer, **scheduler_param) + else: + raise NotImplementedError( + f"Learning rate method cannot be recognized: {scheduler_method}") diff --git a/python/fate/ml/utils/callbacks.py b/python/fate/ml/utils/callbacks.py new file mode 100644 index 0000000000..b659a73f83 --- /dev/null +++ b/python/fate/ml/utils/callbacks.py @@ -0,0 +1,75 @@ +from fate.ml.evaluation.tool import get_metric_names, get_specified_metrics +from fate.ml.abc.module import Module + + +class CallbackParam(object): + + def __init__(self, + callback_types: list, + metrics: list, + evaluation_freq: int = None, + early_stopping_rounds: int = None, + checkpoint_freq: int = None, + use_first_metric: bool = False) -> None: + + if not isinstance(callback_types, list) or len(callback_types) == 0: + raise ValueError("callback_types must be a list with at least one type.") + + if not isinstance(metrics, list) or len(metrics) == 0: + raise ValueError("metrics must be a list with at least one metric.") + + for param, param_name in [(evaluation_freq, "evaluation_freq"), + (early_stopping_rounds, "early_stopping_rounds"), + (checkpoint_freq, "checkpoint_freq")]: + if param is not None and (not isinstance(param, int) or param <= 0): + raise ValueError(f"{param_name} must be a positive integer or None.") + + if not isinstance(use_first_metric, bool): + raise ValueError("use_first_metric must be a boolean.") + + self.callback_types = callback_types + self.metrics = metrics + self.evaluation_freq = evaluation_freq + self.early_stopping_rounds = early_stopping_rounds + self.checkpoint_freq = checkpoint_freq + self.use_first_metric = use_first_metric + + def __str__(self) -> str: + return (f'Callback types: {self.callback_types}, ' + f'Metrics: {self.metrics}, ' + f'Evaluation frequency: {self.evaluation_freq}, ' + f'Early stopping rounds: {self.early_stopping_rounds}, ' + f'Use first metric for early stopping: {self.use_first_metric}, ' + f'Checkpoint frequency: {self.checkpoint_freq}') + + + + +class Callbacks(object): + + def __init__(self, model: Module, callback_params) -> None: + pass + + def on_train_begin(self, ctx): + pass + + def on_train_end(self, ctx): + pass + + def on_epoch_begin(self, ctx, epoch): + pass + + def on_epoch_end(self, ctx, epoch): + pass + + def on_batch_begin(self, ctx, batch_index): + pass + + def on_batch_end(self, ctx, batch_index): + pass + + def need_stop(self, ctx): + pass + + def get_best_model(self): + pass \ No newline at end of file diff --git a/python/fate/ml/utils/label_alignment.py b/python/fate/ml/utils/label_alignment.py new file mode 100644 index 0000000000..d9414b8291 --- /dev/null +++ b/python/fate/ml/utils/label_alignment.py @@ -0,0 +1,7 @@ +from fate.arch import Context + + +class LabelAilignment(object): + + def __init__(self) -> None: + pass \ No newline at end of file diff --git a/python/fate/ml/utils/model_io.py b/python/fate/ml/utils/model_io.py new file mode 100644 index 0000000000..6d26291cb8 --- /dev/null +++ b/python/fate/ml/utils/model_io.py @@ -0,0 +1,28 @@ +from typing import Optional + + +class ModelIO: + _META = "meta" + _DATA = "data" + + def __init__(self, data: dict, meta: Optional[dict] = None): + self.data = data + self.meta = meta + + def dict(self): + return { + self._DATA: self.data, + self._META: self.meta if self.meta is not None else {}, + } + + @classmethod + def from_dict(cls, d: dict): + data = d[cls._DATA] + if cls._META in d: + meta = d[cls._META] + else: + meta = None + return cls(data, meta) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(data={self.data}, meta={self.meta})" diff --git a/python/fate/ml/utils/model_serdes.py b/python/fate/ml/utils/model_serdes.py index 46a5b8413e..a7f75b14a3 100644 --- a/python/fate/ml/utils/model_serdes.py +++ b/python/fate/ml/utils/model_serdes.py @@ -23,7 +23,8 @@ def serialize_models(models): for model_name, buffer_object in models.items(): serialized_string = buffer_object.SerializeToString() pb_name = type(buffer_object).__name__ - json_format_dict = json_format.MessageToDict(buffer_object, including_default_value_fields=True) + json_format_dict = json_format.MessageToDict( + buffer_object, including_default_value_fields=True) serialized_models[model_name] = ( pb_name, diff --git a/python/fate/ml/utils/predict_tools.py b/python/fate/ml/utils/predict_tools.py new file mode 100644 index 0000000000..f73cf48ae6 --- /dev/null +++ b/python/fate/ml/utils/predict_tools.py @@ -0,0 +1,144 @@ +import json +from typing import Literal + +import numpy as np +import pandas as pd + +from fate.arch.dataframe import DataFrame +from fate.arch.dataframe import PandasReader + +# DATA SET COLUMNS +TRAIN_SET = 'train_set' +VALIDATE_SET = 'validate_set' +TEST_SET = 'test_set' + +# PREDICT RESULT COLUMNS +PREDICT_RESULT = "predict_result" +PREDICT_SCORE = "predict_score" +PREDICT_DETAIL = "predict_detail" +LABEL = "label" + +# TASK TYPE +BINARY = 'binary' +MULTI = 'multi' +REGRESSION = 'regression' +OTHER = 'other' + + +def predict_detail_dict_to_str(result_dict): + return "\"" + json.dumps(result_dict).replace("\"", "\'") + "\"" + + +def add_ids(df: pd.DataFrame, match_id: pd.DataFrame, sample_id: pd.DataFrame): + df = pd.concat([df, match_id, sample_id], axis=1) + return df + + +def to_dist_df(ctx, sample_id_name, match_id_name, result_df: pd.DataFrame): + + if LABEL in result_df: + reader = PandasReader( + sample_id_name=sample_id_name, + match_id_name=match_id_name, + label_name=LABEL, + dtype="object") + else: + reader = PandasReader( + sample_id_name=sample_id_name, + match_id_name=match_id_name, + dtype="object") + data = reader.to_frame(ctx, result_df) + return data + + +def compute_predict_details(dataframe: DataFrame, task_type: Literal['binary', 'multi', 'regression'], classes: list = None, threshold=0.5): + + assert task_type in [BINARY, MULTI, REGRESSION, + OTHER], 'task_type must be one of {} as a std task, but got {}'.format( + [BINARY, MULTI, REGRESSION, OTHER], task_type) + + assert threshold >= 0 and threshold <= 1, 'threshold must be float in [0, 1], but got {}'.format(threshold) + + if not isinstance(dataframe, DataFrame): + raise ValueError('dataframe must be a fate DataFrame, but got {}'.format(type(dataframe))) + if dataframe.schema.label_name is not None and dataframe.schema.label_name != LABEL: + dataframe.rename(label_name=LABEL) + assert PREDICT_SCORE in dataframe.schema.columns, 'column {} is not found in input dataframe'.format(PREDICT_SCORE) + + if task_type == BINARY and task_type == MULTI: + if classes is None or (not isinstance(classes, list) and len(classes) < 2): + raise ValueError('task_type is binary or multi, but classes is None, or classes length is less than 2') + + if task_type == BINARY: + if len(classes) == 2: + neg_class, pos_class = classes[0], classes[1] + dataframe[[PREDICT_RESULT, PREDICT_DETAIL]] = dataframe.apply_row( \ + lambda v: [int(v[PREDICT_SCORE] > threshold), + predict_detail_dict_to_str({neg_class: 1 - float(v[PREDICT_SCORE]), + pos_class: float(v[PREDICT_SCORE])})], + enable_type_align_checking=False) + else: + raise ValueError( + 'task_type is binary, but classes length is not 2: {}'.format(classes)) + + elif task_type == REGRESSION: + dataframe[[PREDICT_RESULT, PREDICT_DETAIL]] = dataframe.apply_row( \ + lambda v: [v[PREDICT_SCORE], predict_detail_dict_to_str({PREDICT_SCORE: float(v[PREDICT_SCORE])})], + enable_type_align_checking=False) + + elif task_type == MULTI: + + def handle_multi(v: pd.Series): + predict_result = np.argmax(v[PREDICT_SCORE]) + assert len(v[PREDICT_SCORE]) == len(classes), 'predict score length is not equal to classes length,\ + predict score is {}, but classes are {}, please check the data you provided'.format(v[PREDICT_SCORE], classes) + predict_details = {classes[j]: float(v[PREDICT_SCORE][j]) for j in range(len(classes))} + return [predict_result, predict_detail_dict_to_str(predict_details)] + + dataframe[[PREDICT_RESULT, PREDICT_DETAIL]] = dataframe.apply_row(handle_multi, enable_type_align_checking=False) + predict_score = dataframe[PREDICT_SCORE].apply_row(lambda v: max(v[PREDICT_SCORE])) + dataframe[PREDICT_SCORE] = predict_score + + return dataframe + + +def array_to_predict_df( + ctx, + task_type: Literal['binary', 'multi', 'regression'], + pred: np.ndarray, + match_ids: np.ndarray, + sample_ids: np.ndarray, + match_id_name: str, + sample_id_name: str, + label: np.array = None, + threshold=0.5, + classes: list = None): + + df = pd.DataFrame() + if len(pred.shape) == 1: + df[PREDICT_SCORE] = np.array(pred) + elif len(pred.shape) == 2: + if pred.shape[1] == 1: + df[PREDICT_SCORE] = np.array(pred).flatten() + else: + df[PREDICT_SCORE] = np.array(pred).tolist() + else: + raise ValueError( + 'This is not a FATE std task, pred scores shape are {}'.format( + pred.shape)) + + if label is not None: + if len(label.shape) == 1: + label = label.flatten() + elif len(label.shape) == 2 and label.shape[1] == 1: + label = label.flatten() + else: + label = label.tolist() + df[LABEL] = label + + df[sample_id_name] = sample_ids.flatten() + df[match_id_name] = match_ids.flatten() + fate_df = to_dist_df(ctx, sample_id_name, match_id_name, df) + predict_result = compute_predict_details(fate_df, task_type, classes, threshold) + + return predict_result diff --git a/python/fate/ml/utils/test/test_predict_format.py b/python/fate/ml/utils/test/test_predict_format.py new file mode 100644 index 0000000000..8281388ad2 --- /dev/null +++ b/python/fate/ml/utils/test/test_predict_format.py @@ -0,0 +1,49 @@ +from fate.arch import Context +from fate.arch.computing.standalone import CSession +from fate.arch.context import Context +from fate.arch.federation.standalone import StandaloneFederation +import pandas as pd +from fate.ml.utils.predict_tools import compute_predict_details, PREDICT_SCORE, LABEL, BINARY, REGRESSION, MULTI +from fate.arch.dataframe import PandasReader +import numpy as np + + +computing = CSession() +ctx = Context("guest", computing=computing, federation=StandaloneFederation( + computing, "fed", ("guest", 10000), [("host", 9999)]), ) + +df = pd.DataFrame() +df['id'] = [i for i in range(50)] +df['sample_id'] = [i for i in range(len(df))] +df[PREDICT_SCORE] = [np.random.random(1)[0] for i in range(len(df))] +df[LABEL] = [np.random.randint(0, 2) for i in range(len(df))] + +no_label_df = df.drop([LABEL], axis=1) + + +df_reg = pd.DataFrame() +df_reg['id'] = [i for i in range(50)] +df_reg['sample_id'] = [i for i in range(len(df_reg))] +df_reg[PREDICT_SCORE] = [np.random.random(1)[0] * 10 for i in range(len(df_reg))] +df_reg[LABEL] = [np.random.random(1)[0] * 10 for i in range(len(df_reg))] + +df_multi = pd.DataFrame() +df_multi['id'] = [i for i in range(50)] +df_multi['sample_id'] = [i for i in range(len(df_multi))] +df_multi[PREDICT_SCORE] = [[float(np.random.random(1)[0]) for i in range(3)] for i in range(len(df_multi))] +df_multi[LABEL] = [np.random.randint(0, 3) for i in range(len(df_multi))] + +reader = PandasReader( + sample_id_name='sample_id', + match_id_name="id", + dtype="object") +data = reader.to_frame(ctx, df) +data_2 = reader.to_frame(ctx, no_label_df) +data_3 = reader.to_frame(ctx, df_reg) +data_4 = reader.to_frame(ctx, df_multi) + + +rs = compute_predict_details(data, BINARY, classes=[0, 1], threshold=0.8) +rs_2 = compute_predict_details(data_2, BINARY, classes=[0, 1], threshold=0.3) +rs_3 = compute_predict_details(data_3, REGRESSION) +rs_4 = compute_predict_details(data_4, MULTI, classes=[0, 1, 2]) \ No newline at end of file diff --git a/python/fate/test/benchmarks/test_paillier.py b/python/fate/test/benchmarks/test_paillier.py new file mode 100644 index 0000000000..8d3b52049a --- /dev/null +++ b/python/fate/test/benchmarks/test_paillier.py @@ -0,0 +1,11 @@ +import torch +from fate.arch.protocol.phe.paillier import evaluator, keygen + +sk, pk, coder = keygen(2048) +data = torch.rand(1000) +a = pk.encrypt_encoded(coder.encode_tensor(data), True) +b = pk.encrypt_encoded(coder.encode_tensor(data), True) + + +def test_iadd(benchmark): + benchmark(lambda: evaluator.i_add(pk, a, b)) diff --git a/python/fate/test/histogram/indexer.py b/python/fate/test/histogram/indexer.py new file mode 100644 index 0000000000..41df4a8753 --- /dev/null +++ b/python/fate/test/histogram/indexer.py @@ -0,0 +1,102 @@ +import pytest + +# from fate.arch.histogram.indexer import HistogramIndexer, Shuffler +from fate_utils.histogram import HistogramIndexer, Shuffler + + +class TestHistogramIndexer: + @pytest.fixture + def histogram_indexer(self): + node_size = 2 + feature_bin_sizes = [3, 2] + return HistogramIndexer(node_size, feature_bin_sizes) + + def test_get_position(self, histogram_indexer): + assert histogram_indexer.get_position(0, 0, 0) == 0 + assert histogram_indexer.get_position(0, 0, 1) == 1 + assert histogram_indexer.get_position(0, 1, 0) == 3 + assert histogram_indexer.get_position(1, 0, 0) == 5 + assert histogram_indexer.get_position(1, 1, 1) == 9 + + def test_get_reverse_position(self, histogram_indexer): + assert histogram_indexer.get_reverse_position(0) == (0, 0, 0) + assert histogram_indexer.get_reverse_position(1) == (0, 0, 1) + assert histogram_indexer.get_reverse_position(3) == (0, 1, 0) + assert histogram_indexer.get_reverse_position(5) == (1, 0, 0) + assert histogram_indexer.get_reverse_position(9) == (1, 1, 1) + + def test_get_node_intervals(self, histogram_indexer): + assert histogram_indexer.get_node_intervals() == [(0, 5), (5, 10)] + + def test_get_feature_position_ranges(self, histogram_indexer): + assert histogram_indexer.get_feature_position_ranges() == [(0, 3), (3, 5), (5, 8), (8, 10)] + + def test_total_data_size(self, histogram_indexer): + assert histogram_indexer.total_data_size() == 10 + + def test_splits_into_k(self, histogram_indexer): + # Test if the method splits the data correctly into k parts + k = 2 + splits = list(histogram_indexer.splits_into_k(k)) + + assert len(splits) == k + + # Check if the intervals are disjoint and cover the entire range + all_intervals = [interval for _, _, intervals in splits for interval in intervals] + all_intervals.sort(key=lambda x: x[0]) + + assert all_intervals[0][0] == 0 + assert all_intervals[-1][1] == histogram_indexer.total_data_size() + for i in range(len(all_intervals) - 1): + assert all_intervals[i][1] == all_intervals[i + 1][0] + + def test_one_node_data_size(self, histogram_indexer): + # Test if the one node data size is correctly calculated + assert histogram_indexer.one_node_data_size() == sum(histogram_indexer.feature_bin_sizes) + + def test_global_flatten_bin_sizes(self, histogram_indexer): + # Test if the global flatten bin sizes is correctly calculated + assert ( + histogram_indexer.global_flatten_bin_sizes() + == histogram_indexer.feature_bin_sizes * histogram_indexer.node_size + ) + + def test_flatten_in_node(self, histogram_indexer): + # Test if the flatten in node method returns a new HistogramIndexer with correct parameters + new_indexer = histogram_indexer.flatten_in_node() + + assert new_indexer.node_size == histogram_indexer.node_size + assert new_indexer.feature_bin_sizes == [histogram_indexer.one_node_data_size()] + + def test_squeeze_bins(self, histogram_indexer): + # Test if the squeeze bins method returns a new HistogramIndexer with correct parameters + new_indexer = histogram_indexer.squeeze_bins() + + assert new_indexer.node_size == histogram_indexer.node_size + assert new_indexer.feature_bin_sizes == [1] * histogram_indexer.feature_size + + def test_reshape(self, histogram_indexer): + # Test if the reshape method returns a new HistogramIndexer with correct parameters + new_feature_bin_sizes = [2, 2, 1] + new_indexer = histogram_indexer.reshape(new_feature_bin_sizes) + + assert new_indexer.node_size == histogram_indexer.node_size + assert new_indexer.feature_bin_sizes == new_feature_bin_sizes + + # def test_get_shuffler(self, histogram_indexer): + # # Test if the get shuffler method returns a Shuffler with correct parameters + # seed = 123 + # shuffler = histogram_indexer.get_shuffler(seed) + # + # assert isinstance(shuffler, Shuffler) + # assert shuffler.num_node == histogram_indexer.node_size + # assert shuffler.node_size == histogram_indexer.one_node_data_size() + + def test_unflatten_indexes(self, histogram_indexer): + # Test if the unflatten indexes method returns a correct nested dictionary of indexes + indexes = histogram_indexer.unflatten_indexes() + + for nid in range(histogram_indexer.node_size): + for fid in range(histogram_indexer.feature_size): + for bid in range(histogram_indexer.feature_bin_sizes[fid]): + assert indexes[nid][fid][bid] == histogram_indexer.get_position(nid, fid, bid) diff --git a/python/fate/test/test_dtensor.py b/python/fate/test/test_dtensor.py new file mode 100644 index 0000000000..410aed83cb --- /dev/null +++ b/python/fate/test/test_dtensor.py @@ -0,0 +1,209 @@ +import pytest +import torch +from fate.arch import Context +from fate.arch.computing.standalone import CSession +from fate.arch.federation.standalone import StandaloneFederation +from fate.arch.tensor import DTensor +from pytest import fixture + + +@fixture +def ctx(): + computing = CSession() + return Context( + "guest", + computing=computing, + federation=StandaloneFederation(computing, "fed", ("guest", 10000), [("host", 9999)]), + ) + + +@fixture +def t1_i32_sharding(): + return [ + torch.tensor([[1, 2, 3], [4, 5, 6]]), + torch.tensor([[1, 2, 3], [4, 5, 6]]), + torch.tensor([[1, 2, 3], [4, 5, 6]]), + ] + + +@fixture +def t1_f32_sharding(): + return [ + torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32), + torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32), + torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32), + ] + + +@fixture +def t2_f32_sharding(): + return [ + torch.tensor([[3, 2, 1], [6, 5, 4]], dtype=torch.float32), + torch.tensor([[3, 2, 1], [6, 5, 4]], dtype=torch.float32), + torch.tensor([[3, 2, 1], [6, 5, 4]], dtype=torch.float32), + ] + + +@fixture +def t1_i32(ctx, t1_i32_sharding): + return DTensor.from_sharding_list( + ctx, + t1_i32_sharding, + num_partitions=3, + ) + + +@fixture +def t1_f32(ctx, t1_f32_sharding): + return DTensor.from_sharding_list( + ctx, + t1_f32_sharding, + num_partitions=3, + ) + + +@fixture +def t2_f32(ctx, t2_f32_sharding): + return DTensor.from_sharding_list( + ctx, + t2_f32_sharding, + num_partitions=3, + ) + + +@pytest.mark.parametrize( + "op", + [torch.exp, torch.log, torch.square], +) +def test_unary(ctx, t1_f32, t1_f32_sharding, op): + assert op(t1_f32) == DTensor.from_sharding_list(ctx, [op(s) for s in t1_f32_sharding], num_partitions=3) + + +def test_cipher(ctx, t1_f32): + kit = ctx.cipher.phe.setup({}) + encryptor, decryptor = kit.get_tensor_encryptor(), kit.get_tensor_decryptor() + encrypted = encryptor.encrypt_tensor(t1_f32) + print(torch.to_local_f(decryptor.decrypt_tensor(encrypted))) + + +@pytest.mark.parametrize( + "op", + [torch.add, torch.sub, torch.mul, torch.div, torch.rsub], +) +def test_binary(ctx, t1_f32, t2_f32, t1_f32_sharding, t2_f32_sharding, op): + assert op(t1_f32, t2_f32) == DTensor.from_sharding_list( + ctx, [op(s1, s2) for s1, s2 in zip(t1_f32_sharding, t2_f32_sharding)], num_partitions=3 + ) + + +@pytest.mark.parametrize( + "op", + [torch.sum, torch.mean], +) +def test_sum_mean(ctx, t1_f32, t2_f32, t1_f32_sharding, t2_f32_sharding, op): + assert op(t1_f32) == op(torch.cat(t1_f32_sharding)) + assert torch.allclose(op(t1_f32, dim=0), op(torch.cat(t1_f32_sharding), dim=0)) + assert op(t1_f32, dim=1) == DTensor.from_sharding_list( + ctx, [op(s, dim=1) for s in t1_f32_sharding], num_partitions=3 + ) + assert op(t1_f32, dim=1, keepdim=True) == DTensor.from_sharding_list( + ctx, [op(s, dim=1, keepdim=True) for s in t1_f32_sharding], num_partitions=3 + ) + + +@pytest.mark.parametrize( + "op", + [torch.var, torch.std], +) +def test_var_std(ctx, t1_f32, t2_f32, t1_f32_sharding, t2_f32_sharding, op): + assert torch.isclose(op(t1_f32), op(torch.cat(t1_f32_sharding))) + assert torch.allclose(op(t1_f32, dim=0), op(torch.cat(t1_f32_sharding), dim=0)) + assert torch.allclose(op(t1_f32, dim=0, unbiased=False), op(torch.cat(t1_f32_sharding), dim=0, unbiased=False)) + assert op(t1_f32, dim=1) == DTensor.from_sharding_list( + ctx, [op(s, dim=1) for s in t1_f32_sharding], num_partitions=3 + ) + assert op(t1_f32, dim=1, keepdim=True) == DTensor.from_sharding_list( + ctx, [op(s, dim=1, keepdim=True) for s in t1_f32_sharding], num_partitions=3 + ) + + +@pytest.mark.parametrize( + "op", + [torch.max, torch.min], +) +def test_max_min(ctx, t1_f32, t2_f32, t1_f32_sharding, t2_f32_sharding, op): + assert torch.isclose(op(t1_f32), op(torch.cat(t1_f32_sharding))) + + def _eq(r1, r2): + assert r1.indices.shape == r2.indices.shape + assert r1.values.shape == r2.values.shape + assert torch.allclose(r1.indices, r2.indices) + assert torch.allclose(r1.values, r2.values) + + _eq(op(t1_f32, dim=0), op(torch.cat(t1_f32_sharding), dim=0)) + _eq(op(t1_f32, dim=0, keepdim=True), op(torch.cat(t1_f32_sharding), dim=0, keepdim=True)) + + assert op(t1_f32, dim=1).values == DTensor.from_sharding_list( + ctx, [op(s, dim=1).values for s in t1_f32_sharding], num_partitions=3 + ) + + assert op(t1_f32, dim=1, keepdim=True).values == DTensor.from_sharding_list( + ctx, [op(s, dim=1, keepdim=True).values for s in t1_f32_sharding], num_partitions=3 + ) + + +@pytest.mark.parametrize( + "op", + [torch.add, torch.sub, torch.mul, torch.div, torch.rsub], +) +def test_binary_bc_dtensor(ctx, op): + t1 = [torch.rand((2, 4, 5)) for _ in range(3)] + dt1 = DTensor.from_sharding_list(ctx, t1, num_partitions=3) + + t2 = [torch.rand((2, 1, 5)) for _ in range(3)] + dt2 = DTensor.from_sharding_list(ctx, t2, num_partitions=3) + + assert op(dt1, dt2) == DTensor.from_sharding_list(ctx, [op(s1, s2) for s1, s2 in zip(t1, t2)], num_partitions=3) + + t1 = [torch.rand((2, 4, 5)) for _ in range(3)] + dt1 = DTensor.from_sharding_list(ctx, t1, num_partitions=3, axis=1) + + t2 = [torch.rand((4, 5)) for _ in range(3)] + dt2 = DTensor.from_sharding_list(ctx, t2, num_partitions=3, axis=0) + + assert op(dt1, dt2) == DTensor.from_sharding_list(ctx, [op(s1, s2) for s1, s2 in zip(t1, t2)], num_partitions=3) + + +@pytest.mark.parametrize( + "op", + [torch.add, torch.sub, torch.mul, torch.div, torch.rsub], +) +def test_binary_bc_tensor(ctx, op): + t1 = [torch.rand((2, 3, 4, 5)) for _ in range(3)] + dt1 = DTensor.from_sharding_list(ctx, t1, num_partitions=3) + + t2 = torch.rand((4, 5)) + assert op(dt1, t2) == DTensor.from_sharding_list(ctx, [op(s, t2) for s in t1], num_partitions=3) + + t2 = torch.rand((1, 1, 4, 5)) + assert op(dt1, t2) == DTensor.from_sharding_list(ctx, [op(s, t2) for s in t1], num_partitions=3) + + t1 = [torch.rand((2, 3, 4, 5)) for _ in range(3)] + dt1 = DTensor.from_sharding_list(ctx, t1, num_partitions=3, axis=1) + + t2 = torch.rand((4, 5)) + assert op(dt1, t2) == DTensor.from_sharding_list(ctx, [op(s, t2) for s in t1], num_partitions=3) + + +def test_slice(ctx): + t1 = [torch.rand((2, 3, 4, 5)) for _ in range(3)] + dt1 = DTensor.from_sharding_list(ctx, t1, num_partitions=3) + assert torch.allclose(torch.slice_f(dt1, 3), t1[1][1]) + + dt1 = DTensor.from_sharding_list(ctx, t1, num_partitions=3, axis=1) + assert torch.slice_f(dt1, 1) == DTensor.from_sharding_list(ctx, [s[1] for s in t1], num_partitions=3) + + dt1 = DTensor.from_sharding_list(ctx, t1, num_partitions=3) + # assert torch.allclose(torch.slice_f(dt1, [3,1,2]), torch.cat(t1)[[3,1,2]]) + print(torch.slice_f(dt1, [3, 1, 2]).shape) + print(torch.cat(t1)[[3, 1, 2]].shape) diff --git a/python/fate/test/test_matmul.py b/python/fate/test/test_matmul.py new file mode 100644 index 0000000000..319086c8fb --- /dev/null +++ b/python/fate/test/test_matmul.py @@ -0,0 +1,163 @@ +import torch +from fate.arch.computing.standalone import CSession +from fate.arch.context import Context +from fate.arch.federation.standalone import StandaloneFederation +from fate.arch.tensor import DTensor +from pytest import fixture + + +@fixture +def ctx(): + computing = CSession() + return Context( + "guest", + computing=computing, + federation=StandaloneFederation(computing, "fed", ("guest", 10000), [("host", 9999)]), + ) + + +@fixture +def t3(): + return torch.tensor([[1.0], [1.0], [1.0]]) + + +@fixture +def t2(ctx): + return DTensor.from_sharding_list( + ctx, + [ + torch.tensor([[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]), + torch.tensor([[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]), + torch.tensor([[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]), + ], + axis=1, + ) + + +@fixture +def t4(): + return torch.tensor( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] + ) + + +def test_local(): + # (2 x 3) @ (3 x 2) -> (2 x 2) + a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + b = torch.tensor([[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]) + assert torch.allclose(torch.matmul(a, b), torch.rmatmul_f(b, a)) + + +def test_shape_1_1(ctx): + a = [torch.rand(9), torch.rand(10), torch.rand(11)] + b = [torch.rand(9), torch.rand(10), torch.rand(11)] + da = make_dist(ctx, a) + db = make_dist(ctx, b) + la = torch.concat(a) + lb = torch.concat(b) + # distributed + assert torch.allclose(torch.matmul(da, db), torch.matmul(la, lb)) + assert torch.allclose(torch.rmatmul_f(da, db), torch.rmatmul_f(la, lb)) + # local + assert torch.allclose(torch.matmul(da, lb), torch.matmul(la, lb)) + assert torch.allclose(torch.rmatmul_f(da, lb), torch.rmatmul_f(la, lb)) + + +def test_shape_1_x(ctx): + a = [torch.rand(9), torch.rand(10), torch.rand(11)] + b = [torch.rand(9), torch.rand(10), torch.rand(11)] + c = torch.rand(30, 4) + d = torch.rand(4, 30) + e = [torch.rand(9, 5), torch.rand(10, 5), torch.rand(11, 5)] + f = [torch.rand(5, 9), torch.rand(5, 10), torch.rand(5, 11)] + da = make_dist(ctx, a) + db = make_dist(ctx, b) + la = torch.concat(a) + lb = torch.concat(b) + de = make_dist(ctx, e) + df = make_dist(ctx, f, axis=1) + le = torch.concat(e) + lf = torch.concat(f, dim=1) + # distributed + assert torch.allclose(torch.matmul(da, de), torch.matmul(la, le)) + assert torch.allclose(torch.rmatmul_f(da, df), torch.rmatmul_f(la, lf)) + + # local + assert torch.allclose(torch.matmul(da, c), torch.matmul(la, c)) + assert torch.allclose(torch.rmatmul_f(da, d), torch.rmatmul_f(la, d)) + + +def test_shape_x_1(ctx): + a_30 = [torch.rand(9), torch.rand(10), torch.rand(11)] + b_30 = [torch.rand(9), torch.rand(10), torch.rand(11)] + c = torch.rand(30, 4) + d = torch.rand(4, 30) + e_30_5 = [torch.rand(9, 5), torch.rand(10, 5), torch.rand(11, 5)] + f_5_30 = [torch.rand(5, 9), torch.rand(5, 10), torch.rand(5, 11)] + da_30 = make_dist(ctx, a_30) + db_30 = make_dist(ctx, b_30) + la_30 = torch.concat(a_30) + lb_30 = torch.concat(b_30) + de_30_5 = make_dist(ctx, e_30_5) + df_5_30 = make_dist(ctx, f_5_30, axis=1) + le_30_5 = torch.concat(e_30_5) + lf_5_30 = torch.concat(f_5_30, dim=1) + # distributed + assert torch.allclose(torch.matmul(df_5_30, da_30), torch.matmul(lf_5_30, la_30)) + assert torch.allclose(torch.rmatmul_f(de_30_5, db_30), torch.rmatmul_f(le_30_5, lb_30)) + + # local + assert torch.allclose(torch.matmul(df_5_30, la_30), torch.matmul(lf_5_30, la_30)) + assert torch.allclose(torch.rmatmul_f(de_30_5, lb_30), torch.rmatmul_f(le_30_5, lb_30)) + + +def test_shape_x_x_dist_dist_bc_matmul(ctx): + e_30_5_13 = [torch.rand(9, 5, 13), torch.rand(10, 5, 13), torch.rand(11, 5, 13)] + e_30_13_15 = [torch.rand(9, 13, 15), torch.rand(10, 13, 15), torch.rand(11, 13, 15)] + + assert torch.matmul(make_dist(ctx, e_30_5_13), make_dist(ctx, e_30_13_15)) == make_dist( + ctx, [torch.matmul(s1, s2) for s1, s2 in zip(e_30_5_13, e_30_13_15)] + ) + + assert torch.rmatmul_f(make_dist(ctx, e_30_13_15), make_dist(ctx, e_30_5_13)) == make_dist( + ctx, [torch.rmatmul_f(s1, s2) for s1, s2 in zip(e_30_13_15, e_30_5_13)] + ) + + +def test_shape_x_x_dist_dist_matmul(ctx): + e_5_13_30 = [torch.rand(5, 13, 9), torch.rand(5, 13, 10), torch.rand(5, 13, 11)] + e_19_30_17 = [torch.rand(5, 9, 17), torch.rand(5, 10, 17), torch.rand(5, 11, 17)] + + assert torch.allclose( + torch.matmul(make_dist(ctx, e_5_13_30, axis=2), make_dist(ctx, e_19_30_17, axis=1)), + torch.matmul(torch.concat(e_5_13_30, 2), torch.concat(e_19_30_17, 1)), + ) + + assert torch.allclose( + torch.rmatmul_f(make_dist(ctx, e_19_30_17, axis=1), make_dist(ctx, e_5_13_30, axis=2)), + torch.rmatmul_f(torch.concat(e_19_30_17, 1), torch.concat(e_5_13_30, 2)), + ) + + +def test_shape_x_x_dist_local_matmul(ctx): + e_5_30_13 = [torch.rand(5, 9, 13), torch.rand(5, 10, 13), torch.rand(5, 11, 13)] + e_5_13_30 = [torch.rand(5, 13, 9), torch.rand(5, 13, 10), torch.rand(5, 13, 11)] + el_5_13_30 = torch.concat(e_5_13_30, dim=2) + el_5_30_13 = torch.concat(e_5_30_13, dim=1) + + assert torch.matmul(make_dist(ctx, e_5_30_13, axis=1), el_5_13_30) == make_dist( + ctx, [torch.matmul(s1, el_5_13_30) for s1 in e_5_30_13], axis=1 + ) + assert torch.allclose( + torch.matmul(make_dist(ctx, e_5_13_30, axis=2), el_5_30_13), torch.matmul(el_5_13_30, el_5_30_13) + ) + assert torch.rmatmul_f(make_dist(ctx, e_5_13_30, axis=2), el_5_30_13) == make_dist( + ctx, [torch.rmatmul_f(s1, el_5_30_13) for s1 in e_5_13_30], axis=2 + ) + assert torch.allclose( + torch.rmatmul_f(make_dist(ctx, e_5_30_13, axis=1), el_5_13_30), torch.rmatmul_f(el_5_30_13, el_5_13_30) + ) + + +def make_dist(ctx, tensors, axis=0): + return DTensor.from_sharding_list(ctx, tensors, axis=axis) diff --git a/python/fate/test/test_vertor_paillier.py b/python/fate/test/test_vertor_paillier.py new file mode 100644 index 0000000000..3f8b3e1b22 --- /dev/null +++ b/python/fate/test/test_vertor_paillier.py @@ -0,0 +1,63 @@ +import torch +from fate.arch import Context + +ctx = Context() +kit = ctx.cipher.phe.setup({"kind": "paillier", "key_length": 1024}) +pk = kit.get_tensor_encryptor() +sk = kit.get_tensor_decryptor() + + +def test_add(): + encrypted = pk.encrypt_tensor(torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, -8.0]])) + double_encrypted = torch.add(encrypted, encrypted) + double_encrypted = torch.add(double_encrypted, 1) + double_encrypted = torch.add(double_encrypted, torch.rand(2, 4)) + double_encrypted = torch.add(double_encrypted, torch.tensor(0.3)) + decrypted = sk.decrypt_tensor(double_encrypted) + print(decrypted) + + +def test_sub(): + encrypted = pk.encrypt_tensor(torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, -8.0]])) + double_encrypted = torch.sub(encrypted, encrypted) + double_encrypted = torch.sub(double_encrypted, 1) + double_encrypted = torch.sub(double_encrypted, torch.rand(2, 4)) + double_encrypted = torch.sub(double_encrypted, torch.tensor(0.3)) + decrypted = sk.decrypt_tensor(double_encrypted) + print(decrypted) + + +def test_rsub(): + encrypted = pk.encrypt_tensor(torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, -8.0]])) + double_encrypted = torch.rsub(encrypted, encrypted) + double_encrypted = torch.rsub(double_encrypted, 1) + double_encrypted = torch.rsub(double_encrypted, torch.rand(2, 4)) + double_encrypted = torch.rsub(double_encrypted, torch.tensor(0.3)) + decrypted = sk.decrypt_tensor(double_encrypted) + print(decrypted) + + +def test_mul(): + encrypted = pk.encrypt_tensor(torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, -8.0]])) + double_encrypted = torch.mul(encrypted, 2) + double_encrypted = torch.mul(double_encrypted, torch.rand(2, 4)) + decrypted = sk.decrypt_tensor(double_encrypted) + print(decrypted) + + +def test_matmul(): + x = torch.rand(5, 2) + y = torch.rand(2, 4) + enc_x = pk.encrypt_tensor(x) + enc_z = torch.matmul(enc_x, y) + z = sk.decrypt_tensor(enc_z) + assert torch.allclose(z, torch.matmul(x, y)) + + +def test_rmatmul(): + x = torch.rand(2, 5) + y = torch.rand(4, 2) + enc_x = pk.encrypt_tensor(x) + enc_z = torch.rmatmul_f(enc_x, y) + z = sk.decrypt_tensor(enc_z) + assert torch.allclose(z, torch.matmul(y, x)) diff --git a/python/requirements-fate.txt b/python/requirements-fate.txt index def37a4912..5ad38c0545 100644 --- a/python/requirements-fate.txt +++ b/python/requirements-fate.txt @@ -1,19 +1,19 @@ --extra-index-url https://download.pytorch.org/whl/cpu -click>=7.1.2,<8.0.0 -scikit-learn==1.0.1 -pandas==1.1.5 -protobuf==3.19.6 -pydantic -typing-extensions -ruamel-yaml==0.16.10 -requests<2.26.0 -cloudpickle==2.1.0 lmdb==1.3.0 -numpy==1.23.1 -torch==1.13.1 -urllib3==1.26.5 -grpcio==1.46.3 -ml_metadata +torch==1.13.1+cpu +fate_utils +pydantic==1.10.12 +cloudpickle==2.1.0 +click +ruamel-yaml==0.16 +numpy +pandas==2.0.3 +transformers==4.30.2 +accelerate==0.20.2 beautifultable -rust_paillier +requests<2.26.0 +grpcio==1.46.3 +protobuf==3.19.6 +scikit-learn==1.2.1; sys_platform == 'darwin' and platform_machine == 'arm64' +scikit-learn==1.0.1; sys_platform != 'darwin' or platform_machine != 'arm64' diff --git a/python/requirements-pulsar.txt b/python/requirements-pulsar.txt index 9e3cf64b06..5b8052350c 100644 --- a/python/requirements-pulsar.txt +++ b/python/requirements-pulsar.txt @@ -1 +1,3 @@ -pulsar-client==2.10.2 +pulsar-client==2.10.1; sys_platform == 'darwin' +pulsar-client==2.10.2; sys_platform != 'darwin' +urllib3==1.26.5 diff --git a/python/setup.py b/python/setup.py index 260af48dd2..553a948b22 100644 --- a/python/setup.py +++ b/python/setup.py @@ -1,38 +1,44 @@ import os -import fate from setuptools import find_packages, setup -packages = find_packages(".") +import fate + +# Base requirements install_requires = [ - "scikit-learn", - "pandas", - "protobuf", - "pydantic", + "lmdb==1.3.0", + "torch==1.13.1", + "fate_utils", + "pydantic==1.10.12", + "cloudpickle==2.1.0", "click", - "typing-extensions", - "ruamel.yaml", - "requests", - "cloudpickle", - "lmdb", + "ruamel.yaml==0.16", + "scikit-learn==1.2.1; sys_platform == 'darwin' and platform_machine == 'arm64'", + "scikit-learn==1.0.1; sys_platform != 'darwin' or platform_machine != 'arm64'", "numpy", - "torch", - "rust_paillier", - "urllib3", - "grpcio", - "ml_metadata", + "pandas", + "transformers", + "accelerate", "beautifultable", + "requests", + "grpcio", + "protobuf", ] + +# Extra requirements extras_require = { "rabbitmq": ["pika==1.2.1"], - "pulsar": ["pulsar-client==2.10.2"], + "pulsar": [ + "pulsar-client==2.10.2; sys_platform != 'darwin'", + "pulsar-client==2.10.1; sys_platform == 'darwin'", + "urllib3==1.26.5" + ], "spark": ["pyspark"], "eggroll": [ "grpcio==1.46.3", "grpcio-tools==1.46.3", "numba==0.56.4", "protobuf==3.19.6", - "pyarrow==6.0.1", "mmh3==3.0.0", "cachetools>=3.0.0", "cloudpickle==2.1.0", @@ -41,6 +47,7 @@ "all": ["pyfate[rabbitmq,pulsar,spark,eggroll]"], } +# Long description from README.md readme_path = os.path.abspath(os.path.join(__file__, os.path.pardir, os.path.pardir, "README.md")) if os.path.exists(readme_path): with open(readme_path, "r") as f: @@ -48,6 +55,7 @@ else: long_description = "fate" +# Setup function setup( name="pyfate", version=fate.__version__, @@ -58,7 +66,7 @@ long_description=long_description, license="Apache-2.0 License", url="https://fate.fedai.org/", - packages=packages, + packages=find_packages("."), install_requires=install_requires, extras_require=extras_require, python_requires=">=3.8", diff --git a/rust/fate_utils/Cargo.toml b/rust/fate_utils/Cargo.toml new file mode 100644 index 0000000000..ab4eb3f34d --- /dev/null +++ b/rust/fate_utils/Cargo.toml @@ -0,0 +1,23 @@ +[workspace] +members = [ + "crates/*" +] + +[workspace.dependencies] +serde = { version = "1.0.137", features = ["derive", "rc"] } +rug = { version = "1.20.1", features = ["serde"] } +rand = { version = "0.8.3", features = ["getrandom"] } +rand_core = "0.5.1" +ndarray = { version = "0.15.4" } +numpy = "0.15.1" +rayon = { version = "1.5.3"} +pyo3 = { version = "0.15.2" } +bincode = "1.3.3" +libsm = "0.5" +curve25519-dalek = "3.2.1" +x25519-dalek = "1.2.0" +sha2 = "0.9.9" +iai = "0.1.1" +criterion = { version = "0.3", features = ["html_reports"] } +rand_chacha = "0.3.1" +anyhow = "1.0.75" diff --git a/rust/fate_utils/README.md b/rust/fate_utils/README.md new file mode 100644 index 0000000000..7c1ef42da0 --- /dev/null +++ b/rust/fate_utils/README.md @@ -0,0 +1,3 @@ +## fate_utils + +utils for fate, write in rust diff --git a/rust/tensor/rust_paillier/benches/base_bench.py b/rust/fate_utils/benches/base_bench.py similarity index 99% rename from rust/tensor/rust_paillier/benches/base_bench.py rename to rust/fate_utils/benches/base_bench.py index 977ac90172..0a6f65492d 100644 --- a/rust/tensor/rust_paillier/benches/base_bench.py +++ b/rust/fate_utils/benches/base_bench.py @@ -56,7 +56,7 @@ def __init__(self, a, b, c, d) -> None: self.b = b self.c = c self.d = d - self.pk, self.sk = phe.generate_paillier_keypair(n_length=1024) + self.pk, self.sk = phe.generate_paillier_keypair(n_length=2048) self.ea = np.vectorize(self.pk.encrypt)(self.a) self.eb = np.vectorize(self.pk.encrypt)(self.b) self.ec = np.vectorize(self.pk.encrypt)(self.c) @@ -105,7 +105,7 @@ def __init__(self, a, b, c, d, keygen) -> None: self.b = b self.c = c self.d = d - self.pk, self.sk = keygen(1024) + self.pk, self.sk = keygen(2048) self.ea = self.pk.encrypt_f64(self.a) self.eb = self.pk.encrypt_f64(self.b) self.ec = self.pk.encrypt_f64(self.c) diff --git a/rust/fate_utils/crates/fate_utils/Cargo.toml b/rust/fate_utils/crates/fate_utils/Cargo.toml new file mode 100644 index 0000000000..8ed1e0f82e --- /dev/null +++ b/rust/fate_utils/crates/fate_utils/Cargo.toml @@ -0,0 +1,51 @@ +[package] +name = "fate_utils" +version = "0.1.0" +edition = "2021" + +[lib] +name = "fate_utils" +crate-type = ["cdylib", "staticlib", "rlib"] +bench = false + +[dependencies] +numpy = { workspace = true } +pyo3 = { workspace = true } +ndarray = { workspace = true } +serde = { workspace = true } +bincode = { workspace = true } +rand = { workspace = true } +libsm = { workspace = true } +sha2 = { workspace = true } +curve25519-dalek = { workspace = true } +rayon = { workspace = true, optional = true } +rand_chacha = { workspace = true } +x25519-dalek = { workspace = true } +rand_core = { workspace = true } +rug = { workspace = true } +anyhow = { workspace = true } +quantile = { path = "../quantile" } +math = { path = "../math" } +fixedpoint = { path = "../fixedpoint" } +paillier = { path = "../paillier" } +fixedpoint_paillier = { path = "../fixedpoint_paillier" } +fixedpoint_ou = { path = "../fixedpoint_ou" } + + +[features] +default = ["rug", "rayon", "std", "u64_backend", "extension-module"] +rug = [] +rayon = ["dep:rayon", "ndarray/rayon"] +simd_backend = ["curve25519-dalek/simd_backend"] +std = ["curve25519-dalek/std"] +nightly = ["curve25519-dalek/nightly"] +u64_backend = ["curve25519-dalek/u64_backend"] +u32_backend = ["curve25519-dalek/u32_backend"] +fiat_u64_backend = ["curve25519-dalek/fiat_u64_backend"] +fiat_u32_backend = ["curve25519-dalek/fiat_u32_backend"] +extension-module = ["pyo3/extension-module"] + +[package.metadata.docs.rs] +# To build locally use +# RUSTDOCFLAGS="--html-in-header katex-header.html" cargo doc --no-deps --document-private-items --open +rustdoc-args = ["--html-in-header", "docs/katex-header.html"] diff --git a/rust/fate_utils/crates/fate_utils/src/hash/mod.rs b/rust/fate_utils/crates/fate_utils/src/hash/mod.rs new file mode 100644 index 0000000000..ee7e43d809 --- /dev/null +++ b/rust/fate_utils/crates/fate_utils/src/hash/mod.rs @@ -0,0 +1,12 @@ +mod sm3; +use pyo3::prelude::*; + +pub(crate) fn register(py: Python, m: &PyModule) -> PyResult<()> { + let submodule_hash = PyModule::new(py, "hash")?; + sm3::register(py, submodule_hash)?; + m.add_submodule(submodule_hash)?; + py.import("sys")? + .getattr("modules")? + .set_item("fate_utils.hash", submodule_hash)?; + Ok(()) +} diff --git a/rust/fate_utils/crates/fate_utils/src/hash/sm3.rs b/rust/fate_utils/crates/fate_utils/src/hash/sm3.rs new file mode 100644 index 0000000000..a30c20c51e --- /dev/null +++ b/rust/fate_utils/crates/fate_utils/src/hash/sm3.rs @@ -0,0 +1,17 @@ +use libsm; +use pyo3::prelude::*; +use pyo3::types::PyByteArray; +use pyo3::wrap_pyfunction; + +/// hash of bytes +#[pyfunction] +fn sm3_hash(py: Python, a: &[u8]) -> PyObject { + let mut hash = libsm::sm3::hash::Sm3Hash::new(a); + let digest: [u8; 32] = hash.get_hash(); + PyByteArray::new(py, &digest).into() +} + +pub(crate) fn register(_py: Python<'_>, m: &PyModule) -> PyResult<()> { + m.add_function(wrap_pyfunction!(sm3_hash, m)?)?; + Ok(()) +} diff --git a/rust/fate_utils/crates/fate_utils/src/histogram/indexer.rs b/rust/fate_utils/crates/fate_utils/src/histogram/indexer.rs new file mode 100644 index 0000000000..5e67183370 --- /dev/null +++ b/rust/fate_utils/crates/fate_utils/src/histogram/indexer.rs @@ -0,0 +1,254 @@ +use std::collections::HashMap; +use pyo3::prelude::*; +use rand::rngs::StdRng; +use rand::SeedableRng; +use rand::prelude::SliceRandom; + +type S = usize; + +#[pyclass] +pub struct HistogramIndexer { + node_size: S, + feature_bin_sizes: Vec, + feature_size: S, + feature_axis_stride: Vec, + node_axis_stride: S, +} + + +#[pymethods] +impl HistogramIndexer { + #[new] + fn new(node_size: S, feature_bin_sizes: Vec) -> HistogramIndexer { + let feature_size = feature_bin_sizes.len(); + let mut feature_axis_stride = vec![0]; + feature_axis_stride.extend(feature_bin_sizes.iter().scan(0, |acc, &x| { + *acc += x; + Some(*acc) + })); + let node_axis_stride = feature_bin_sizes.iter().sum(); + + HistogramIndexer { + node_size, + feature_bin_sizes: feature_bin_sizes, + feature_size, + feature_axis_stride, + node_axis_stride, + } + } + + #[inline] + fn get_position(&self, nid: S, fid: S, bid: S) -> S { + nid * self.node_axis_stride + self.feature_axis_stride[fid] + bid + } + + fn get_positions(&self, nids: Vec, bid_vec: Vec>) -> Vec> { + bid_vec.iter().zip(nids.iter()).map(|(bids, &nid)| { + bids.iter().enumerate().map(|(fid, &bid)| self.get_position(nid, fid, bid)).collect() + }).collect() + } + + fn get_reverse_position(&self, position: S) -> (S, S, S) { + let nid = position / self.node_axis_stride; + let bid = position % self.node_axis_stride; + for fid in 0..self.feature_size { + if bid < self.feature_axis_stride[fid + 1] { + return (nid, fid, bid - self.feature_axis_stride[fid]); + } + } + panic!("invalid position: {}", position); + } + + fn get_bin_num(&self, fid: S) -> S { + self.feature_bin_sizes[fid] + } + + fn get_bin_interval(&self, nid: S, fid: S) -> (S, S) { + let node_stride = nid * self.node_axis_stride; + (node_stride + self.feature_axis_stride[fid], node_stride + self.feature_axis_stride[fid + 1]) + } + + fn get_node_intervals(&self) -> Vec<(S, S)> { + (0..self.node_size).map(|nid| { + (nid * self.node_axis_stride, (nid + 1) * self.node_axis_stride) + }).collect() + } + + fn get_feature_position_ranges(&self) -> Vec<(S, S)> { + (0..self.node_size).flat_map(|nid| { + let node_stride = nid * self.node_axis_stride; + (0..self.feature_size).map(move |fid| { + (node_stride + self.feature_axis_stride[fid], node_stride + self.feature_axis_stride[fid + 1]) + }) + }).collect() + } + + fn splits_into_k(&self, k: S) -> Vec<(S, (S, S), Vec<(S, S)>)> { + let n = self.node_axis_stride; + let mut split_sizes = vec![n / k; k]; + for i in 0..n % k { + split_sizes[i] += 1; + } + let mut start = 0; + let mut splits = Vec::with_capacity(k); + for (pid, &size) in split_sizes.iter().enumerate() { + let end = start + size; + let shift = self.node_axis_stride; + let mut node_intervals = Vec::with_capacity(self.node_size); + for nid in 0..self.node_size { + node_intervals.push((start + nid * shift, end + nid * shift)); + } + splits.push((pid, (start, end), node_intervals)); + start += size; + } + splits + } + + fn total_data_size(&self) -> S { + self.node_size * self.node_axis_stride + } + + fn one_node_data_size(&self) -> S { + self.node_axis_stride + } + + fn global_flatten_bin_sizes(&self) -> Vec { + // repeat self.feature_bin_sizes for self.node_size times + let mut feature_bin_sizes = Vec::with_capacity(self.node_size * self.feature_size); + for _ in 0..self.node_size { + feature_bin_sizes.extend(self.feature_bin_sizes.iter()); + } + feature_bin_sizes + } + + fn flatten_in_node(&self) -> HistogramIndexer { + HistogramIndexer::new(self.node_size, vec![self.one_node_data_size()]) + } + + fn squeeze_bins(&self) -> HistogramIndexer { + HistogramIndexer::new(self.node_size, vec![1; self.feature_size]) + } + fn reshape(&self, feature_bin_sizes: Vec) -> HistogramIndexer { + HistogramIndexer::new(self.node_size, feature_bin_sizes) + } + fn get_shuffler(&self, seed: u64) -> Shuffler { + Shuffler::new(self.node_size, self.node_axis_stride, seed) + } + #[getter] + fn get_node_size(&self) -> S { + self.node_size + } + #[getter] + fn get_node_axis_stride(&self) -> S { + self.node_axis_stride + } + #[getter] + fn get_feature_size(&self) -> S { + self.feature_size + } + #[getter] + fn get_feature_axis_stride(&self) -> Vec { + self.feature_axis_stride.clone() + } + #[getter] + fn get_feature_bin_sizes(&self) -> Vec { + self.feature_bin_sizes.clone() + } + #[getter] + fn get_num_nodes(&self) -> S { + self.node_size + } + fn unflatten_indexes(&self) -> HashMap>> { + let mut indexes = HashMap::new(); + for nid in 0..self.node_size { + let mut feature_indexes = HashMap::new(); + for fid in 0..self.feature_size { + let (start, end) = self.get_bin_interval(nid, fid); + feature_indexes.insert(fid, (start..end).collect()); + } + indexes.insert(nid, feature_indexes); + } + indexes + } +} + +#[pyclass] +struct Shuffler { + num_node: S, + node_size: S, + perm_indexes: Vec>, +} + +#[pymethods] +impl Shuffler { + #[new] + fn new(num_node: S, node_size: S, seed: u64) -> Shuffler { + let mut perm_indexes = Vec::with_capacity(num_node); + for _ in 0..num_node { + let mut rng = StdRng::seed_from_u64(seed); + let mut perm_index = (0..node_size).collect::>(); + perm_index.shuffle(&mut rng); + perm_indexes.push(perm_index); + } + Shuffler { + num_node, + node_size, + perm_indexes, + } + } + + fn get_global_perm_index(&self) -> Vec { + let mut index = Vec::with_capacity(self.num_node * self.node_size); + for (nid, perm_index) in self.perm_indexes.iter().enumerate() { + index.extend(perm_index.iter().map(|&x| x + nid * self.node_size)); + } + index + } + + fn get_reverse_indexes(&self, step: S, indexes: Vec) -> Vec { + let mapping = self.get_shuffle_index(step, true); + indexes.iter().map(|&x| mapping[x]).collect() + } + // def get_shuffle_index(self, step, reverse=False): +// # """ +// # get chunk shuffle index +// # """ +// # stepped = torch.arange(0, self.num_node * self.node_size * step).reshape(self.num_node * self.node_size, step) +// # indexes = stepped[self.get_global_perm_index(), :].flatten() +// # if reverse: +// # indexes = torch.argsort(indexes) +// # return indexes + fn get_shuffle_index(&self, step: S, reverse: bool) -> Vec { + let mut stepped = Vec::with_capacity(self.num_node * self.node_size * step); + for i in 0..self.num_node * self.node_size { + for j in 0..step { + stepped.push(i * step + j); + } + } + let mut indexes = Vec::with_capacity(self.num_node * self.node_size * step); + for &perm_index in self.get_global_perm_index().iter() { + for j in 0..step { + indexes.push(stepped[perm_index * step + j]); + } + } + if reverse { + let mut raw_indices = (0..indexes.len()).collect::>(); + raw_indices.sort_by_key(|&i| &indexes[i]); + indexes = raw_indices + } + indexes + } + + fn reverse_index(&self, index: S) -> (S, S) { + let nid = index / self.node_size; + let bid = index % self.node_size; + (nid, self.perm_indexes[nid][bid]) + } +} + + +pub(crate) fn register(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + Ok(()) +} diff --git a/rust/fate_utils/crates/fate_utils/src/histogram/mod.rs b/rust/fate_utils/crates/fate_utils/src/histogram/mod.rs new file mode 100644 index 0000000000..c2d4e585ae --- /dev/null +++ b/rust/fate_utils/crates/fate_utils/src/histogram/mod.rs @@ -0,0 +1,13 @@ +mod indexer; + +use pyo3::prelude::*; + +pub(crate) fn register(py: Python, m: &PyModule) -> PyResult<()> { + let submodule = PyModule::new(py, "histogram")?; + indexer::register(py, submodule)?; + m.add_submodule(submodule)?; + py.import("sys")? + .getattr("modules")? + .set_item("fate_utils.histogram", submodule)?; + Ok(()) +} diff --git a/rust/fate_utils/crates/fate_utils/src/lib.rs b/rust/fate_utils/crates/fate_utils/src/lib.rs new file mode 100644 index 0000000000..9b8a4056da --- /dev/null +++ b/rust/fate_utils/crates/fate_utils/src/lib.rs @@ -0,0 +1,24 @@ +extern crate core; + +mod hash; +mod histogram; +mod psi; +mod quantile; +mod secure_aggregation_helper; +mod paillier; + +mod ou; + +use pyo3::prelude::*; + +#[pymodule] +fn fate_utils(py: Python, m: &PyModule) -> PyResult<()> { + quantile::register(py, m)?; + hash::register(py, m)?; + psi::register(py, m)?; + histogram::register(py, m)?; + paillier::register(py, m)?; + ou::register(py, m)?; + secure_aggregation_helper::register(py, m)?; + Ok(()) +} diff --git a/rust/fate_utils/crates/fate_utils/src/ou/mod.rs b/rust/fate_utils/crates/fate_utils/src/ou/mod.rs new file mode 100644 index 0000000000..e3e21b16a6 --- /dev/null +++ b/rust/fate_utils/crates/fate_utils/src/ou/mod.rs @@ -0,0 +1,12 @@ +mod ou; +use pyo3::prelude::*; + +pub(crate) fn register(py: Python, m: &PyModule) -> PyResult<()> { + let submodule = PyModule::new(py, "ou")?; + ou::register(py, submodule)?; + m.add_submodule(submodule)?; + py.import("sys")? + .getattr("modules")? + .set_item("fate_utils.ou", submodule)?; + Ok(()) +} diff --git a/rust/fate_utils/crates/fate_utils/src/ou/ou.rs b/rust/fate_utils/crates/fate_utils/src/ou/ou.rs new file mode 100644 index 0000000000..237536e8d0 --- /dev/null +++ b/rust/fate_utils/crates/fate_utils/src/ou/ou.rs @@ -0,0 +1,443 @@ +use numpy::PyReadonlyArray1; +use pyo3::exceptions::PyRuntimeError; +use pyo3::prelude::*; +use anyhow::Error as AnyhowError; + +trait ToPyErr { + fn to_py_err(self) -> PyErr; +} + +impl ToPyErr for AnyhowError { + fn to_py_err(self) -> PyErr { + PyRuntimeError::new_err(self.to_string()) + } +} + + +#[pyclass(module = "fate_utils.ou")] +#[derive(Default)] +pub struct PK(fixedpoint_ou::PK); + +#[pyclass(module = "fate_utils.ou")] +#[derive(Default)] +pub struct SK(fixedpoint_ou::SK); + +#[pyclass(module = "fate_utils.ou")] +#[derive(Default)] +pub struct Coder(fixedpoint_ou::Coder); + +#[pyclass(module = "fate_utils.ou")] +#[derive(Default)] +pub struct Ciphertext(fixedpoint_ou::Ciphertext); + +#[pyclass(module = "fate_utils.ou")] +#[derive(Default, Debug)] +pub struct CiphertextVector(fixedpoint_ou::CiphertextVector); + +#[pyclass(module = "fate_utils.ou")] +#[derive(Default, Debug)] +pub struct PlaintextVector(fixedpoint_ou::PlaintextVector); + +#[pyclass(module = "fate_utils.ou")] +#[derive(Default)] +pub struct Plaintext(fixedpoint_ou::Plaintext); + +#[pyclass] +pub struct Evaluator {} + +#[pymethods] +impl PK { + fn encrypt_encoded( + &self, + plaintext_vector: &PlaintextVector, + obfuscate: bool, + ) -> CiphertextVector { + CiphertextVector(self.0.encrypt_encoded(&plaintext_vector.0, obfuscate)) + } + fn encrypt_encoded_scalar(&self, plaintext: &Plaintext, obfuscate: bool) -> Ciphertext { + Ciphertext(self.0.encrypt_encoded_scalar(&plaintext.0, obfuscate)) + } + + #[new] + fn __new__() -> PyResult { + Ok(PK::default()) + } + + fn __getstate__(&self) -> PyResult> { + Ok(bincode::serialize(&self.0).unwrap()) + } + + fn __setstate__(&mut self, state: Vec) -> PyResult<()> { + self.0 = bincode::deserialize(&state).unwrap(); + Ok(()) + } +} + +#[pymethods] +impl SK { + fn decrypt_to_encoded(&self, data: &CiphertextVector) -> PlaintextVector { + PlaintextVector(self.0.decrypt_to_encoded(&data.0)) + } + fn decrypt_to_encoded_scalar(&self, data: &Ciphertext) -> Plaintext { + Plaintext(self.0.decrypt_to_encoded_scalar(&data.0)) + } + + #[new] + fn __new__() -> PyResult { + Ok(SK::default()) + } + + fn __getstate__(&self) -> PyResult> { + Ok(bincode::serialize(&self.0).unwrap()) + } + + fn __setstate__(&mut self, state: Vec) -> PyResult<()> { + self.0 = bincode::deserialize(&state).unwrap(); + Ok(()) + } +} + +#[pymethods] +impl Coder { + // fn encode_f64(&self, data: f64) -> Plaintext { + // Plaintext(self.0.encode_f64(data)) + // } + // fn decode_f64(&self, data: &Plaintext) -> f64 { + // self.0.decode_f64(&data.0) + // } + // fn encode_f32(&self, data: f32) -> Plaintext { + // Plaintext(self.0.encode_f32(data)) + // } + fn encode_u64(&self, data: u64) -> Plaintext { + Plaintext(self.0.encode_u64(data)) + } + fn decode_u64(&self, data: &Plaintext) -> u64 { + self.0.decode_u64(&data.0) + } + fn encode_u32(&self, data: u32) -> Plaintext { + Plaintext(self.0.encode_u32(data)) + } + fn decode_u32(&self, data: &Plaintext) -> u32 { + self.0.decode_u32(&data.0) + } + #[new] + fn __new__() -> PyResult { + Ok(Coder::default()) + } + fn __getstate__(&self) -> PyResult> { + Ok(bincode::serialize(&self.0).unwrap()) + } + + fn __setstate__(&mut self, state: Vec) -> PyResult<()> { + self.0 = bincode::deserialize(&state).unwrap(); + Ok(()) + } + + fn pack_floats(&self, float_tensor: Vec, offset_bit: usize, pack_num: usize, precision: u32) -> PlaintextVector { + let data = self.0.pack_floats(&float_tensor, offset_bit, pack_num, precision); + PlaintextVector(fixedpoint_ou::PlaintextVector { data }) + } + + fn unpack_floats(&self, packed: &PlaintextVector, offset_bit: usize, pack_num: usize, precision: u32, total_num: usize) -> Vec { + self.0.unpack_floats(&packed.0.data, offset_bit, pack_num, precision, total_num) + } + // fn encode_f64_vec(&self, data: PyReadonlyArray1) -> PlaintextVector { + // let data = data + // .as_array() + // .iter() + // .map(|x| self.0.encode_f64(*x)) + // .collect(); + // PlaintextVector(fixedpoint_ou::PlaintextVector { data }) + // } + // fn decode_f64_vec<'py>(&self, data: &PlaintextVector, py: Python<'py>) -> &'py PyArray1 { + // Array1::from( + // data.0.data + // .iter() + // .map(|x| self.0.decode_f64(x)) + // .collect::>(), + // ) + // .into_pyarray(py) + // } + // fn encode_f32_vec(&self, data: PyReadonlyArray1) -> PlaintextVector { + // let data = data + // .as_array() + // .iter() + // .map(|x| self.0.encode_f32(*x)) + // .collect(); + // PlaintextVector(fixedpoint_ou::PlaintextVector { data }) + // } + // fn decode_f32(&self, data: &Plaintext) -> f32 { + // self.0.decode_f32(&data.0) + // } + // fn decode_f32_vec<'py>(&self, data: &PlaintextVector, py: Python<'py>) -> &'py PyArray1 { + // Array1::from( + // data.0.data + // .iter() + // .map(|x| self.0.decode_f32(x)) + // .collect::>(), + // ) + // .into_pyarray(py) + // } + fn encode_u64_vec(&self, data: PyReadonlyArray1) -> PlaintextVector { + let data = data + .as_array() + .iter() + .map(|x| self.0.encode_u64(*x)) + .collect(); + PlaintextVector(fixedpoint_ou::PlaintextVector { data }) + } + fn decode_u64_vec(&self, data: &PlaintextVector) -> Vec { + data.0.data.iter().map(|x| self.0.decode_u64(x)).collect() + } + fn encode_u32_vec(&self, data: PyReadonlyArray1) -> PlaintextVector { + let data = data + .as_array() + .iter() + .map(|x| self.0.encode_u32(*x)) + .collect(); + PlaintextVector(fixedpoint_ou::PlaintextVector { data }) + } + fn decode_u32_vec(&self, data: &PlaintextVector) -> Vec { + data.0.data.iter().map(|x| self.0.decode_u32(x)).collect() + } +} + +#[pyfunction] +fn keygen(bit_length: u32) -> (SK, PK, Coder) { + let (sk, pk, coder) = fixedpoint_ou::keygen(bit_length); + (SK(sk), PK(pk), Coder(coder)) +} + +#[pymethods] +impl CiphertextVector { + #[new] + fn __new__() -> PyResult { + Ok(CiphertextVector(fixedpoint_ou::CiphertextVector { data: vec![] })) + } + + fn __getstate__(&self) -> PyResult> { + Ok(bincode::serialize(&self.0).unwrap()) + } + + fn __setstate__(&mut self, state: Vec) -> PyResult<()> { + self.0 = bincode::deserialize(&state).unwrap(); + Ok(()) + } + + fn __len__(&self) -> usize { + self.0.data.len() + } + + fn __str__(&self) -> String { + format!("{:?}", self.0) + } + + #[staticmethod] + pub fn zeros(size: usize) -> PyResult { + Ok(CiphertextVector(fixedpoint_ou::CiphertextVector::zeros(size))) + } + + pub fn pack_squeeze(&self, pack_num: usize, offset_bit: u32, pk: &PK) -> PyResult { + Ok(CiphertextVector(self.0.pack_squeeze(&pk.0, pack_num, offset_bit))) + } + + fn slice(&mut self, start: usize, size: usize) -> CiphertextVector { + CiphertextVector(self.0.slice(start, size)) + } + + fn slice_indexes(&mut self, indexes: Vec) -> PyResult { + Ok(CiphertextVector(self.0.slice_indexes(indexes))) + } + pub fn cat(&self, others: Vec>) -> PyResult { + Ok(CiphertextVector(self.0.cat(others.iter().map(|x| &x.0).collect()))) + } + fn i_shuffle(&mut self, indexes: Vec) { + self.0.i_shuffle(indexes); + } + + fn shuffle(&self, indexes: Vec) -> PyResult { + Ok(CiphertextVector(self.0.shuffle(indexes))) + } + fn intervals_slice(&mut self, intervals: Vec<(usize, usize)>) -> PyResult { + Ok(CiphertextVector(self.0.intervals_slice(intervals).map_err(|e| e.to_py_err())?)) + } + fn iadd_slice(&mut self, pk: &PK, position: usize, other: Vec>) { + self.0.iadd_slice(&pk.0, position, other.iter().map(|x| &x.0).collect()); + } + fn iadd_vec_self( + &mut self, + sa: usize, + sb: usize, + size: Option, + pk: &PK, + ) -> PyResult<()> { + self.0.iadd_vec_self(sa, sb, size, &pk.0).map_err(|e| e.to_py_err())?; + Ok(()) + } + fn isub_vec_self( + &mut self, + sa: usize, + sb: usize, + size: Option, + pk: &PK, + ) -> PyResult<()> { + self.0.isub_vec_self(sa, sb, size, &pk.0).map_err(|e| e.to_py_err())?; + Ok(()) + } + + fn iadd_vec( + &mut self, + other: &CiphertextVector, + sa: usize, + sb: usize, + size: Option, + pk: &PK, + ) -> PyResult<()> { + self.0.iadd_vec(&other.0, sa, sb, size, &pk.0).map_err(|e| e.to_py_err())?; + Ok(()) + } + + fn isub_vec( + &mut self, + other: &CiphertextVector, + sa: usize, + sb: usize, + size: Option, + pk: &PK, + ) -> PyResult<()> { + self.0.isub_vec(&other.0, sa, sb, size, &pk.0).map_err(|e| e.to_py_err())?; + Ok(()) + } + + fn iupdate(&mut self, other: &CiphertextVector, indexes: Vec>, stride: usize, pk: &PK) -> PyResult<()> { + self.0.iupdate(&other.0, indexes, stride, &pk.0).map_err(|e| e.to_py_err())?; + Ok(()) + } + fn iupdate_with_masks(&mut self, other: &CiphertextVector, indexes: Vec>, masks: Vec, stride: usize, pk: &PK) -> PyResult<()> { + self.0.iupdate_with_masks(&other.0, indexes, masks, stride, &pk.0).map_err(|e| e.to_py_err())?; + Ok(()) + } + fn iadd(&mut self, pk: &PK, other: &CiphertextVector) { + self.0.iadd(&pk.0, &other.0); + } + fn idouble(&mut self, pk: &PK) { + self.0.idouble(&pk.0); + } + fn chunking_cumsum_with_step(&mut self, pk: &PK, chunk_sizes: Vec, step: usize) { + self.0.chunking_cumsum_with_step(&pk.0, chunk_sizes, step); + } + fn intervals_sum_with_step( + &mut self, + pk: &PK, + intervals: Vec<(usize, usize)>, + step: usize, + ) -> CiphertextVector { + CiphertextVector(self.0.intervals_sum_with_step(&pk.0, intervals, step)) + } + + fn tolist(&self) -> Vec { + self.0.tolist().iter().map(|x| CiphertextVector(x.clone())).collect() + } + + fn add(&self, pk: &PK, other: &CiphertextVector) -> CiphertextVector { + CiphertextVector(self.0.add(&pk.0, &other.0)) + } + fn add_scalar(&self, pk: &PK, other: &Ciphertext) -> CiphertextVector { + CiphertextVector(self.0.add_scalar(&pk.0, &other.0)) + } + fn sub(&self, pk: &PK, other: &CiphertextVector) -> CiphertextVector { + CiphertextVector(self.0.sub(&pk.0, &other.0)) + } + fn sub_scalar(&self, pk: &PK, other: &Ciphertext) -> CiphertextVector { + CiphertextVector(self.0.sub_scalar(&pk.0, &other.0)) + } + fn rsub(&self, pk: &PK, other: &CiphertextVector) -> CiphertextVector { + CiphertextVector(self.0.rsub(&pk.0, &other.0)) + } + fn rsub_scalar(&self, pk: &PK, other: &Ciphertext) -> CiphertextVector { + CiphertextVector(self.0.rsub_scalar(&pk.0, &other.0)) + } + fn mul(&self, pk: &PK, other: &PlaintextVector) -> CiphertextVector { + CiphertextVector(self.0.mul(&pk.0, &other.0)) + } + fn mul_scalar(&self, pk: &PK, other: &Plaintext) -> CiphertextVector { + CiphertextVector(self.0.mul_scalar(&pk.0, &other.0)) + } + + fn matmul( + &self, + pk: &PK, + other: &PlaintextVector, + lshape: Vec, + rshape: Vec, + ) -> CiphertextVector { + CiphertextVector(self.0.matmul(&pk.0, &other.0, lshape, rshape)) + } + + fn rmatmul( + &self, + pk: &PK, + other: &PlaintextVector, + lshape: Vec, + rshape: Vec, + ) -> CiphertextVector { + CiphertextVector(self.0.rmatmul(&pk.0, &other.0, lshape, rshape)) + } +} + +#[pymethods] +impl PlaintextVector { + #[new] + fn __new__() -> PyResult { + Ok(PlaintextVector(fixedpoint_ou::PlaintextVector { data: vec![] })) + } + fn __getstate__(&self) -> PyResult> { + Ok(bincode::serialize(&self.0).unwrap()) + } + + fn __setstate__(&mut self, state: Vec) -> PyResult<()> { + self.0 = bincode::deserialize(&state).unwrap(); + Ok(()) + } + fn __str__(&self) -> String { + format!("{:?}", self.0) + } + fn get_stride(&mut self, index: usize, stride: usize) -> PlaintextVector { + PlaintextVector(self.0.get_stride(index, stride)) + } + fn tolist(&self) -> Vec { + self.0.tolist().iter().map(|x| Plaintext(x.clone())).collect() + } +} + +#[pymethods] +impl Evaluator { + #[staticmethod] + fn cat(vec_list: Vec<PyRef<CiphertextVector>>) -> PyResult<CiphertextVector> { + let mut data = vec![fixedpoint_ou::Ciphertext::zero(); 0]; + for vec in vec_list { + data.extend(vec.0.data.clone()); + } + Ok(CiphertextVector(fixedpoint_ou::CiphertextVector { data })) + } + #[staticmethod] + fn slice_indexes(a: &CiphertextVector, indexes: Vec<usize>) -> PyResult<CiphertextVector> { + let data = indexes + .iter() + .map(|i| a.0.data[*i].clone()) + .collect::<Vec<_>>(); + Ok(CiphertextVector(fixedpoint_ou::CiphertextVector { data })) + } +} + +pub(crate) fn register(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::<CiphertextVector>()?; + m.add_class::<PlaintextVector>()?; + m.add_class::<PK>()?; + m.add_class::<SK>()?; + m.add_class::<Coder>()?; + m.add_class::<Ciphertext>()?; + m.add_class::<Evaluator>()?; + m.add_function(wrap_pyfunction!(keygen, m)?)?; + Ok(()) +} diff --git a/rust/fate_utils/crates/fate_utils/src/paillier/mod.rs b/rust/fate_utils/crates/fate_utils/src/paillier/mod.rs new file mode 100644 index 0000000000..101655612d --- /dev/null +++ b/rust/fate_utils/crates/fate_utils/src/paillier/mod.rs @@ -0,0 +1,12 @@ +mod paillier; +use pyo3::prelude::*; + +pub(crate) fn register(py: Python, m: &PyModule) -> PyResult<()> { + let submodule = PyModule::new(py, "paillier")?; + paillier::register(py, submodule)?; + m.add_submodule(submodule)?; + py.import("sys")? + .getattr("modules")? + .set_item("fate_utils.paillier", submodule)?; + Ok(()) +} diff --git a/rust/fate_utils/crates/fate_utils/src/paillier/paillier.rs b/rust/fate_utils/crates/fate_utils/src/paillier/paillier.rs new file mode 100644 index 0000000000..4a95fa270b --- /dev/null +++ b/rust/fate_utils/crates/fate_utils/src/paillier/paillier.rs @@ -0,0 +1,444 @@ +use ndarray::prelude::*; +use numpy::{IntoPyArray, PyArray1, PyReadonlyArray1}; +use pyo3::exceptions::PyRuntimeError; +use pyo3::prelude::*; +use anyhow::Error as AnyhowError; + +trait ToPyErr { + fn to_py_err(self) -> PyErr; +} + +impl ToPyErr for AnyhowError { + fn to_py_err(self) -> PyErr { + PyRuntimeError::new_err(self.to_string()) + } +} + + +#[pyclass(module = "fate_utils.paillier")] +#[derive(Default)] +pub struct PK(fixedpoint_paillier::PK); + +#[pyclass(module = "fate_utils.paillier")] +#[derive(Default)] +pub struct SK(fixedpoint_paillier::SK); + +#[pyclass(module = "fate_utils.paillier")] +#[derive(Default)] +pub struct Coder(fixedpoint_paillier::Coder); + +#[pyclass(module = "fate_utils.paillier")] +#[derive(Default)] +pub struct Ciphertext(fixedpoint_paillier::Ciphertext); + +#[pyclass(module = "fate_utils.paillier")] +#[derive(Default, Debug)] +pub struct CiphertextVector(fixedpoint_paillier::CiphertextVector); + +#[pyclass(module = "fate_utils.paillier")] +#[derive(Default, Debug)] +pub struct PlaintextVector(fixedpoint_paillier::PlaintextVector); + +#[pyclass(module = "fate_utils.paillier")] +#[derive(Default)] +pub struct Plaintext(fixedpoint_paillier::Plaintext); + +#[pyclass] +pub struct Evaluator {} + +#[pymethods] +impl PK { + fn encrypt_encoded( + &self, + plaintext_vector: &PlaintextVector, + obfuscate: bool, + ) -> CiphertextVector { + CiphertextVector(self.0.encrypt_encoded(&plaintext_vector.0, obfuscate)) + } + fn encrypt_encoded_scalar(&self, plaintext: &Plaintext, obfuscate: bool) -> Ciphertext { + Ciphertext(self.0.encrypt_encoded_scalar(&plaintext.0, obfuscate)) + } + + #[new] + fn __new__() -> PyResult<Self> { + Ok(PK::default()) + } + + fn __getstate__(&self) -> PyResult<Vec<u8>> { + Ok(bincode::serialize(&self.0).unwrap()) + } + + fn __setstate__(&mut self, state: Vec<u8>) -> PyResult<()> { + self.0 = bincode::deserialize(&state).unwrap(); + Ok(()) + } +} + +#[pymethods] +impl SK { + fn decrypt_to_encoded(&self, data: &CiphertextVector) -> PlaintextVector { + PlaintextVector(self.0.decrypt_to_encoded(&data.0)) + } + fn decrypt_to_encoded_scalar(&self, data: &Ciphertext) -> Plaintext { + Plaintext(self.0.decrypt_to_encoded_scalar(&data.0)) + } + + #[new] + fn __new__() -> PyResult<Self> { + Ok(SK::default()) + } + + fn __getstate__(&self) -> PyResult<Vec<u8>> { + Ok(bincode::serialize(&self.0).unwrap()) + } + + fn __setstate__(&mut self, state: Vec<u8>) -> PyResult<()> { + self.0 = bincode::deserialize(&state).unwrap(); + Ok(()) + } +} + +#[pymethods] +impl Coder { + fn encode_f64(&self, data: f64) -> Plaintext { + Plaintext(self.0.encode_f64(data)) + } + fn decode_f64(&self, data: &Plaintext) -> f64 { + self.0.decode_f64(&data.0) + } + fn encode_f32(&self, data: f32) -> Plaintext { + Plaintext(self.0.encode_f32(data)) + } + fn encode_i64(&self, data: i64) -> Plaintext { + Plaintext(self.0.encode_i64(data)) + } + fn decode_i64(&self, data: &Plaintext) -> i64 { + self.0.decode_i64(&data.0) + } + fn encode_i32(&self, data: i32) -> Plaintext { + Plaintext(self.0.encode_i32(data)) + } + fn decode_i32(&self, data: &Plaintext) -> i32 { + self.0.decode_i32(&data.0) + } + #[new] + fn __new__() -> PyResult<Self> { + Ok(Coder::default()) + } + fn __getstate__(&self) -> PyResult<Vec<u8>> { + Ok(bincode::serialize(&self.0).unwrap()) + } + + fn __setstate__(&mut self, state: Vec<u8>) -> PyResult<()> { + self.0 = bincode::deserialize(&state).unwrap(); + Ok(()) + } + + fn pack_floats(&self, float_tensor: Vec<f64>, offset_bit: usize, pack_num: usize, precision: u32) -> PlaintextVector { + let data = self.0.pack_floats(&float_tensor, offset_bit, pack_num, precision); + PlaintextVector(fixedpoint_paillier::PlaintextVector { data }) + } + + fn unpack_floats(&self, packed: &PlaintextVector, offset_bit: usize, pack_num: usize, precision: u32, total_num: usize) -> Vec<f64> { + self.0.unpack_floats(&packed.0.data, offset_bit, pack_num, precision, total_num) + } + fn encode_f64_vec(&self, data: PyReadonlyArray1<f64>) -> PlaintextVector { + let data = data + .as_array() + .iter() + .map(|x| self.0.encode_f64(*x)) + .collect(); + PlaintextVector(fixedpoint_paillier::PlaintextVector { data }) + } + fn decode_f64_vec<'py>(&self, data: &PlaintextVector, py: Python<'py>) -> &'py PyArray1<f64> { + Array1::from( + data.0.data + .iter() + .map(|x| self.0.decode_f64(x)) + .collect::<Vec<f64>>(), + ) + .into_pyarray(py) + } + fn encode_f32_vec(&self, data: PyReadonlyArray1<f32>) -> PlaintextVector { + let data = data + .as_array() + .iter() + .map(|x| self.0.encode_f32(*x)) + .collect(); + PlaintextVector(fixedpoint_paillier::PlaintextVector { data }) + } + fn decode_f32(&self, data: &Plaintext) -> f32 { + self.0.decode_f32(&data.0) + } + fn decode_f32_vec<'py>(&self, data: &PlaintextVector, py: Python<'py>) -> &'py PyArray1<f32> { + Array1::from( + data.0.data + .iter() + .map(|x| self.0.decode_f32(x)) + .collect::<Vec<f32>>(), + ) + .into_pyarray(py) + } + fn encode_i64_vec(&self, data: PyReadonlyArray1<i64>) -> PlaintextVector { + let data = data + .as_array() + .iter() + .map(|x| self.0.encode_i64(*x)) + .collect(); + PlaintextVector(fixedpoint_paillier::PlaintextVector { data }) + } + fn decode_i64_vec(&self, data: &PlaintextVector) -> Vec<i64> { + data.0.data.iter().map(|x| self.0.decode_i64(x)).collect() + } + fn encode_i32_vec(&self, data: PyReadonlyArray1<i32>) -> PlaintextVector { + let data = data + .as_array() + .iter() + .map(|x| self.0.encode_i32(*x)) + .collect(); + PlaintextVector(fixedpoint_paillier::PlaintextVector { data }) + } + fn decode_i32_vec(&self, data: &PlaintextVector) -> Vec<i32> { + data.0.data.iter().map(|x| self.0.decode_i32(x)).collect() + } +} + +#[pyfunction] +fn keygen(bit_length: u32) -> (SK, PK, Coder) { + let (sk, pk, coder) = fixedpoint_paillier::keygen(bit_length); + (SK(sk), PK(pk), Coder(coder)) +} + +#[pymethods] +impl CiphertextVector { + #[new] + fn __new__() -> PyResult<Self> { + Ok(CiphertextVector(fixedpoint_paillier::CiphertextVector { data: vec![] })) + } + + fn __getstate__(&self) -> PyResult<Vec<u8>> { + Ok(bincode::serialize(&self.0).unwrap()) + } + + fn __setstate__(&mut self, state: Vec<u8>) -> PyResult<()> { + self.0 = bincode::deserialize(&state).unwrap(); + Ok(()) + } + + fn __len__(&self) -> usize { + self.0.data.len() + } + + fn __str__(&self) -> String { + format!("{:?}", self.0) + } + + #[staticmethod] + pub fn zeros(size: usize) -> PyResult<Self> { + Ok(CiphertextVector(fixedpoint_paillier::CiphertextVector::zeros(size))) + } + + pub fn pack_squeeze(&self, pack_num: usize, offset_bit: u32, pk: &PK) -> PyResult<CiphertextVector> { + Ok(CiphertextVector(self.0.pack_squeeze(&pk.0, pack_num, offset_bit))) + } + + fn slice(&mut self, start: usize, size: usize) -> CiphertextVector { + CiphertextVector(self.0.slice(start, size)) + } + + fn slice_indexes(&mut self, indexes: Vec<usize>) -> PyResult<Self> { + Ok(CiphertextVector(self.0.slice_indexes(indexes))) + } + pub fn cat(&self, others: Vec<PyRef<CiphertextVector>>) -> PyResult<Self> { + Ok(CiphertextVector(self.0.cat(others.iter().map(|x| &x.0).collect()))) + } + fn i_shuffle(&mut self, indexes: Vec<usize>) { + self.0.i_shuffle(indexes); + } + + fn shuffle(&self, indexes: Vec<usize>) -> PyResult<Self> { + Ok(CiphertextVector(self.0.shuffle(indexes))) + } + fn intervals_slice(&mut self, intervals: Vec<(usize, usize)>) -> PyResult<Self> { + Ok(CiphertextVector(self.0.intervals_slice(intervals).map_err(|e| e.to_py_err())?)) + } + fn iadd_slice(&mut self, pk: &PK, position: usize, other: Vec<PyRef<Ciphertext>>) { + self.0.iadd_slice(&pk.0, position, other.iter().map(|x| &x.0).collect()); + } + fn iadd_vec_self( + &mut self, + sa: usize, + sb: usize, + size: Option<usize>, + pk: &PK, + ) -> PyResult<()> { + self.0.iadd_vec_self(sa, sb, size, &pk.0).map_err(|e| e.to_py_err())?; + Ok(()) + } + fn isub_vec_self( + &mut self, + sa: usize, + sb: usize, + size: Option<usize>, + pk: &PK, + ) -> PyResult<()> { + self.0.isub_vec_self(sa, sb, size, &pk.0).map_err(|e| e.to_py_err())?; + Ok(()) + } + + fn iadd_vec( + &mut self, + other: &CiphertextVector, + sa: usize, + sb: usize, + size: Option<usize>, + pk: &PK, + ) -> PyResult<()> { + self.0.iadd_vec(&other.0, sa, sb, size, &pk.0).map_err(|e| e.to_py_err())?; + Ok(()) + } + + fn isub_vec( + &mut self, + other: &CiphertextVector, + sa: usize, + sb: usize, + size: Option<usize>, + pk: &PK, + ) -> PyResult<()> { + self.0.isub_vec(&other.0, sa, sb, size, &pk.0).map_err(|e| e.to_py_err())?; + Ok(()) + } + + fn iupdate(&mut self, other: &CiphertextVector, indexes: Vec<Vec<usize>>, stride: usize, pk: &PK) -> PyResult<()> { + self.0.iupdate(&other.0, indexes, stride, &pk.0).map_err(|e| e.to_py_err())?; + Ok(()) + } + fn iupdate_with_masks(&mut self, other: &CiphertextVector, indexes: Vec<Vec<usize>>, masks: Vec<bool>, stride: usize, pk: &PK) -> PyResult<()> { + self.0.iupdate_with_masks(&other.0, indexes, masks, stride, &pk.0).map_err(|e| e.to_py_err())?; + Ok(()) + } + fn iadd(&mut self, pk: &PK, other: &CiphertextVector) { + self.0.iadd(&pk.0, &other.0); + } + fn idouble(&mut self, pk: &PK) { + self.0.idouble(&pk.0); + } + fn chunking_cumsum_with_step(&mut self, pk: &PK, chunk_sizes: Vec<usize>, step: usize) { + self.0.chunking_cumsum_with_step(&pk.0, chunk_sizes, step); + } + fn intervals_sum_with_step( + &mut self, + pk: &PK, + intervals: Vec<(usize, usize)>, + step: usize, + ) -> CiphertextVector { + CiphertextVector(self.0.intervals_sum_with_step(&pk.0, intervals, step)) + } + + fn tolist(&self) -> Vec<CiphertextVector> { + self.0.tolist().iter().map(|x| CiphertextVector(x.clone())).collect() + } + + fn add(&self, pk: &PK, other: &CiphertextVector) -> CiphertextVector { + CiphertextVector(self.0.add(&pk.0, &other.0)) + } + fn add_scalar(&self, pk: &PK, other: &Ciphertext) -> CiphertextVector { + CiphertextVector(self.0.add_scalar(&pk.0, &other.0)) + } + fn sub(&self, pk: &PK, other: &CiphertextVector) -> CiphertextVector { + CiphertextVector(self.0.sub(&pk.0, &other.0)) + } + fn sub_scalar(&self, pk: &PK, other: &Ciphertext) -> CiphertextVector { + CiphertextVector(self.0.sub_scalar(&pk.0, &other.0)) + } + fn rsub(&self, pk: &PK, other: &CiphertextVector) -> CiphertextVector { + CiphertextVector(self.0.rsub(&pk.0, &other.0)) + } + fn rsub_scalar(&self, pk: &PK, other: &Ciphertext) -> CiphertextVector { + CiphertextVector(self.0.rsub_scalar(&pk.0, &other.0)) + } + fn mul(&self, pk: &PK, other: &PlaintextVector) -> CiphertextVector { + CiphertextVector(self.0.mul(&pk.0, &other.0)) + } + fn mul_scalar(&self, pk: &PK, other: &Plaintext) -> CiphertextVector { + CiphertextVector(self.0.mul_scalar(&pk.0, &other.0)) + } + + fn matmul( + &self, + pk: &PK, + other: &PlaintextVector, + lshape: Vec<usize>, + rshape: Vec<usize>, + ) -> CiphertextVector { + CiphertextVector(self.0.matmul(&pk.0, &other.0, lshape, rshape)) + } + + fn rmatmul( + &self, + pk: &PK, + other: &PlaintextVector, + lshape: Vec<usize>, + rshape: Vec<usize>, + ) -> CiphertextVector { + CiphertextVector(self.0.rmatmul(&pk.0, &other.0, lshape, rshape)) + } +} + +#[pymethods] +impl PlaintextVector { + #[new] + fn __new__() -> PyResult<Self> { + Ok(PlaintextVector(fixedpoint_paillier::PlaintextVector { data: vec![] })) + } + fn __getstate__(&self) -> PyResult<Vec<u8>> { + Ok(bincode::serialize(&self.0).unwrap()) + } + + fn __setstate__(&mut self, state: Vec<u8>) -> PyResult<()> { + self.0 = bincode::deserialize(&state).unwrap(); + Ok(()) + } + fn __str__(&self) -> String { + format!("{:?}", self.0) + } + fn get_stride(&mut self, index: usize, stride: usize) -> PlaintextVector { + PlaintextVector(self.0.get_stride(index, stride)) + } + fn tolist(&self) -> Vec<Plaintext> { + self.0.tolist().iter().map(|x| Plaintext(x.clone())).collect() + } +} + +#[pymethods] +impl Evaluator { + #[staticmethod] + fn cat(vec_list: Vec<PyRef<CiphertextVector>>) -> PyResult<CiphertextVector> { + let mut data = vec![fixedpoint_paillier::Ciphertext::zero(); 0]; + for vec in vec_list { + data.extend(vec.0.data.clone()); + } + Ok(CiphertextVector(fixedpoint_paillier::CiphertextVector { data })) + } + #[staticmethod] + fn slice_indexes(a: &CiphertextVector, indexes: Vec<usize>) -> PyResult<CiphertextVector> { + let data = indexes + .iter() + .map(|i| a.0.data[*i].clone()) + .collect::<Vec<_>>(); + Ok(CiphertextVector(fixedpoint_paillier::CiphertextVector { data })) + } +} + +pub(crate) fn register(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::<CiphertextVector>()?; + m.add_class::<PlaintextVector>()?; + m.add_class::<PK>()?; + m.add_class::<SK>()?; + m.add_class::<Coder>()?; + m.add_class::<Ciphertext>()?; + m.add_class::<Evaluator>()?; + m.add_function(wrap_pyfunction!(keygen, m)?)?; + Ok(()) +} diff --git a/rust/fate_utils/crates/fate_utils/src/psi/curve25519.rs b/rust/fate_utils/crates/fate_utils/src/psi/curve25519.rs new file mode 100644 index 0000000000..2e712e8463 --- /dev/null +++ b/rust/fate_utils/crates/fate_utils/src/psi/curve25519.rs @@ -0,0 +1,111 @@ +use curve25519_dalek::edwards::EdwardsPoint; +use curve25519_dalek::montgomery::MontgomeryPoint; +use curve25519_dalek::scalar::Scalar; +use pyo3::exceptions::PyTypeError; +use pyo3::prelude::*; +use pyo3::types::{PyBytes, PyList, PyTuple}; +use pyo3::ToPyObject; +use rand::rngs::StdRng; +use rand::{RngCore, SeedableRng}; + +#[pyclass(module = "fate_utils.psi", name = "Curve25519")] +struct Secret(Scalar); + +impl Secret { + fn new(byte32: Option<[u8; 32]>) -> Self { + Self(Scalar::from_bytes_mod_order(byte32.unwrap_or_else(|| { + let mut bytes: [u8; 32] = [0; 32]; + StdRng::from_entropy().fill_bytes(&mut bytes); + bytes + }))) + } +} + +#[pymethods] +impl Secret { + #[new] + #[args(args = "*")] + fn pynew(args: &PyTuple) -> PyResult<Self> { + match args.len() { + 0 => Ok(Secret::new(None)), + 1 => args + .get_item(0) + .unwrap() + .extract::<Option<[u8; 32]>>() + .map_err(|e| PyTypeError::new_err(e.to_string())) // convert error to pyerr + .map(Secret::new), + _ => Err(PyTypeError::new_err("accept zero or one positional args")), + } + } + pub fn get_private_key(&self, py: Python) -> PyObject { + PyBytes::new(py, self.0.as_bytes()).to_object(py) + } + pub fn __getstate__(&self, py: Python) -> PyObject { + PyBytes::new(py, self.0.as_bytes()).to_object(py) + } + pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + match state.extract::<&PyBytes>(py) { + Ok(s) => { + self.0 = + Scalar::from_bytes_mod_order(s.as_bytes().try_into().expect("invalid state")); + Ok(()) + } + Err(e) => Err(e), + } + } + #[pyo3(text_signature = "($self, bytes)")] + fn encrypt(&self, bytes: &[u8], py: Python) -> PyObject { + PyBytes::new( + py, + (EdwardsPoint::hash_from_bytes::<sha2::Sha512>(bytes).to_montgomery() * self.0) + .as_bytes(), + ) + .into() + } + fn encrypt_vec(&self, vec: Vec<&[u8]>, py: Python) -> PyResult<Py<PyList>> { + let encrypted: Vec<&PyBytes> = vec + .iter() + .map(|bytes| { + PyBytes::new( + py, + (EdwardsPoint::hash_from_bytes::<sha2::Sha512>(bytes).to_montgomery() * self.0) + .as_bytes(), + ) + }) + .collect(); + Ok(PyList::new(py, &encrypted).into_py(py)) + } + #[pyo3(text_signature = "($self, their_public)")] + fn diffie_hellman(&self, their_public: &[u8], py: Python) -> PyObject { + PyBytes::new( + py, + (MontgomeryPoint( + their_public + .try_into() + .expect("diffie_hellman accpet 32 bytes pubkey"), + ) * self.0) + .as_bytes(), + ) + .into() + } + fn diffie_hellman_vec(&self, vec: Vec<&[u8]>, py: Python) -> PyResult<Py<PyList>> { + let dh: Vec<&PyBytes> = vec + .iter() + .map(|their_public| { + let their_public_array = <[u8; 32]>::try_from(*their_public) + .expect("diffie_hellman accepts 32 bytes pubkey"); + + PyBytes::new( + py, + (MontgomeryPoint(their_public_array) * self.0).as_bytes(), + ) + }) + .collect(); + Ok(PyList::new(py, &dh).into_py(py)) + } +} + +pub(crate) fn register(_py: Python<'_>, m: &PyModule) -> PyResult<()> { + m.add_class::<Secret>()?; + Ok(()) +} diff --git a/rust/fate_utils/crates/fate_utils/src/psi/mod.rs b/rust/fate_utils/crates/fate_utils/src/psi/mod.rs new file mode 100644 index 0000000000..ec2513d24a --- /dev/null +++ b/rust/fate_utils/crates/fate_utils/src/psi/mod.rs @@ -0,0 +1,12 @@ +mod curve25519; +use pyo3::prelude::*; + +pub(crate) fn register(py: Python, m: &PyModule) -> PyResult<()> { + let submodule_psi = PyModule::new(py, "psi")?; + curve25519::register(py, submodule_psi)?; + m.add_submodule(submodule_psi)?; + py.import("sys")? + .getattr("modules")? + .set_item("fate_utils.psi", submodule_psi)?; + Ok(()) +} diff --git a/rust/fate_utils/crates/fate_utils/src/quantile.rs b/rust/fate_utils/crates/fate_utils/src/quantile.rs new file mode 100644 index 0000000000..14f8b802c1 --- /dev/null +++ b/rust/fate_utils/crates/fate_utils/src/quantile.rs @@ -0,0 +1,145 @@ +use bincode::{deserialize, serialize}; +use ndarray::prelude::*; +use ndarray::{Array, ArrayView1, Axis}; +use numpy::{IntoPyArray, PyArray2, PyReadonlyArray1, PyReadonlyArray2}; +use pyo3::exceptions::PyTypeError; +use pyo3::prelude::*; +use pyo3::types::PyBytes; +use pyo3::types::PyTuple; +use quantile::greenwald_khanna; +use serde::{Deserialize, Serialize}; + +#[derive(PartialEq, PartialOrd, Clone, Copy, Debug, Serialize, Deserialize)] +struct Ordf64(f64); +impl Eq for Ordf64 {} +impl Ord for Ordf64 { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.0.total_cmp(&other.0) + } + fn max(self, other: Self) -> Self { + Ordf64(self.0.max(other.0)) + } + fn min(self, other: Self) -> Self { + Ordf64(self.0.min(other.0)) + } +} + +#[pyclass(module = "fate_utils.quantile")] +pub struct QuantileSummaryStream(Option<greenwald_khanna::Stream<Ordf64>>); + +impl QuantileSummaryStream { + fn new(epsilon: Option<f64>) -> Self { + match epsilon { + Some(e) => Self(Some(greenwald_khanna::Stream::new(e))), + None => Self(None), + } + } +} + +#[pymethods] +impl QuantileSummaryStream { + #[new] + #[args(args = "*")] + fn __new__(args: &PyTuple) -> PyResult<Self> { + match args.len() { + 0 => Ok(QuantileSummaryStream::new(None)), + 1 => args + .get_item(0) + .unwrap() + .extract::<f64>() + .map_err(|e| PyTypeError::new_err(e.to_string())) // convert error to pyerr + .map(|epsion| QuantileSummaryStream::new(Some(epsion))), + _ => Err(PyTypeError::new_err("accept zero or one positional args")), + } + } + + pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> { + Ok(PyBytes::new(py, &serialize(&self.0).unwrap()).to_object(py)) + } + pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + match state.extract::<&PyBytes>(py) { + Ok(s) => { + self.0 = deserialize(s.as_bytes()).unwrap(); + Ok(()) + } + Err(e) => Err(e), + } + } + pub fn insert_array(&mut self, data: PyReadonlyArray1<f64>) -> PyResult<()> { + for d in data.as_array().into_iter() { + self.0.as_mut().unwrap().insert(Ordf64(*d)) + } + Ok(()) + } + pub fn queries(&self, phi: Vec<f64>) -> Vec<f64> { + phi.iter() + .map(|p| self.0.as_ref().unwrap().quantile(*p).0) + .collect() + } + pub fn merge(&self, other: &QuantileSummaryStream) -> QuantileSummaryStream { + QuantileSummaryStream(Some( + self.0.as_ref().unwrap().merge(&other.0.as_ref().unwrap()), + )) + } +} + +#[pyfunction] +fn summary_f64_ix2(data: PyReadonlyArray2<f64>, epsilon: f64) -> Vec<QuantileSummaryStream> { + let array = data.as_array(); + let mut outputs: Vec<_> = (0..array.shape()[1]) + .map(|_x| QuantileSummaryStream::new(Some(epsilon))) + .collect(); + for j in 0..array.shape()[1] { + let arr = array.index_axis(Axis(1), j); + for d in arr.into_iter() { + outputs[j].0.as_mut().unwrap().insert(Ordf64(*d)); + } + } + outputs +} + +fn quantile_f64(data: ArrayView1<f64>, q: &Vec<f64>, epsilon: f64) -> Vec<f64> { + let mut stream = greenwald_khanna::Stream::new(epsilon); + for d in data.into_iter() { + stream.insert(Ordf64(*d)) + } + println!("size is {}", stream.s()); + q.iter().map(|phi| stream.quantile(*phi).0).collect() +} + +#[pyfunction] +fn quantile_f64_ix1(data: PyReadonlyArray1<f64>, q: Vec<f64>, epsilon: f64) -> Vec<f64> { + quantile_f64(data.as_array(), &q, epsilon) +} + +#[pyfunction] +fn quantile_f64_ix2<'py>( + py: Python<'py>, + data: PyReadonlyArray2<f64>, + q: Vec<f64>, + epsilon: f64, +) -> &'py PyArray2<f64> { + let array = data.as_array(); + let mut a = Array::<f64, _>::zeros((q.len(), array.shape()[1]).f()); + for (j, mut col) in a.axis_iter_mut(Axis(1)).enumerate() { + let arr = array.index_axis(Axis(1), j); + let quantile_sub = quantile_f64(arr, &q, epsilon); + for (i, row) in col.iter_mut().enumerate() { + *row = quantile_sub[i]; + } + } + a.into_pyarray(py) +} + +pub(crate) fn register(py: Python, m: &PyModule) -> PyResult<()> { + let submodule_quantile = PyModule::new(py, "quantile")?; + submodule_quantile.add_class::<QuantileSummaryStream>()?; + m.add_submodule(submodule_quantile)?; + submodule_quantile.add_function(wrap_pyfunction!(quantile_f64_ix1, submodule_quantile)?)?; + submodule_quantile.add_function(wrap_pyfunction!(quantile_f64_ix2, submodule_quantile)?)?; + submodule_quantile.add_function(wrap_pyfunction!(summary_f64_ix2, submodule_quantile)?)?; + py.import("sys")? + .getattr("modules")? + .set_item("fate_utils.quantile", submodule_quantile)?; + Ok(()) +} diff --git a/rust/fate_utils/crates/fate_utils/src/secure_aggregation_helper/mod.rs b/rust/fate_utils/crates/fate_utils/src/secure_aggregation_helper/mod.rs new file mode 100644 index 0000000000..ef670ff97e --- /dev/null +++ b/rust/fate_utils/crates/fate_utils/src/secure_aggregation_helper/mod.rs @@ -0,0 +1,216 @@ +use std::collections::HashMap; + +use ndarray; +use ndarray::prelude::*; +use numpy::{IntoPyArray, PyArrayDyn, PyReadonlyArrayDyn}; +use pyo3::exceptions::{PyIndexError, PyTypeError}; +use pyo3::prelude::*; +use pyo3::types::PyBytes; +use rand::distributions::Uniform; +use rand::Rng; +use rand::SeedableRng; +use rand_chacha::ChaCha20Rng; +use rand_core::OsRng; +use x25519_dalek::{EphemeralSecret, PublicKey}; + +#[pyclass] +struct DiffieHellman { + private_key: Option<EphemeralSecret>, + public_key: PublicKey, +} + +#[pymethods] +impl DiffieHellman { + #[new] + fn new() -> Self { + let private_key = EphemeralSecret::new(OsRng); + let public_key = PublicKey::from(&private_key); + Self { + private_key: Some(private_key), + public_key, + } + } + + fn get_public_key(&self) -> &[u8] { + return self.public_key.as_bytes(); + } + pub fn diffie_hellman(&mut self, py: Python, other_public_key: &[u8]) -> PyResult<Py<PyBytes>> { + let private_key = match self.private_key.take() { + Some(key) => key, + None => return Err(PyTypeError::new_err("Private key not found")), + }; + + let other_public_key: [u8; 32] = match other_public_key.try_into() { + Ok(key) => key, + Err(_) => { + return Err(PyTypeError::new_err( + "slice with incorrect length, should be 32 bytes", + )) + } + }; + + let shared_secret = private_key.diffie_hellman(&PublicKey::from(other_public_key)); + Ok(PyBytes::new(py, shared_secret.as_bytes()).into()) + } +} + +#[pyclass] +struct RandomMixState { + rank: usize, + random_state: ChaCha20Rng, + index: usize, +} + +#[pyclass] +struct RandomMix { + rank: usize, + states: Vec<RandomMixState>, +} + +#[pymethods] +impl RandomMix { + #[new] + fn new(seeds: HashMap<usize, Vec<u8>>, rank: usize) -> PyResult<Self> { + let states: Result<Vec<_>, _> = seeds + .iter() + .map(|(k, v)| { + if k == &rank { + return Err(PyErr::new::<PyIndexError, _>( + "Seed should not contain the rank", + )); + } + let seed_arr = <&[u8; 32]>::try_from(&v[..]) + .map_err(|_| PyErr::new::<PyTypeError, _>("Seed should be a 32-byte array"))?; + let random_state = ChaCha20Rng::from_seed(*seed_arr); + Ok(RandomMixState { + rank: *k, + random_state, + index: 0, + }) + }) + .collect(); + match states { + Ok(states) => Ok(Self { rank, states }), + Err(e) => Err(e), + } + } + + fn mix_one( + &mut self, + py: Python, + input: PyReadonlyArrayDyn<f64>, + weight: Option<f64>, + ) -> (Py<PyArrayDyn<f64>>, Py<PyArrayDyn<f64>>) { + let (mut output_decimal_array, mut output_integer_array) = { + if let Some(w) = weight { + let input = input.as_array().map(|x| x * w); + (input.map(|x| x.fract()), input.map(|x| x.trunc())) + } else { + let input = input.as_array(); + (input.map(|x| x.fract()), input.map(|x| x.trunc())) + } + }; + let range = Uniform::new(-1e7f64, 1e7f64); + output_decimal_array + .iter_mut() + .zip(output_integer_array.iter_mut()) + .for_each(|(output_decimal, output_integer)| { + for state in self.states.iter_mut() { + let rand = state.random_state.sample(range); + state.index += 1; + if state.rank < self.rank { + *output_decimal += rand.fract(); + *output_integer += rand.trunc(); + } else { + *output_decimal -= rand.fract(); + *output_integer -= rand.trunc(); + } + } + }); + ( + output_decimal_array.into_pyarray(py).to_owned(), + output_integer_array.into_pyarray(py).to_owned(), + ) + } + fn mix( + &mut self, + py: Python, + inputs: Vec<PyReadonlyArrayDyn<f64>>, + weight: Option<f64>, + ) -> Vec<(Py<PyArrayDyn<f64>>, Py<PyArrayDyn<f64>>)> { + inputs + .into_iter() + .map(|input| self.mix_one(py, input, weight)) + .collect() + } + + fn get_index(&self, rank: usize) -> PyResult<usize> { + let state = self + .states + .iter() + .find(|state| state.rank == rank) + .ok_or(PyErr::new::<PyIndexError, _>(format!( + "Rank {} not found", + rank + )))?; + Ok(state.index) + } +} + +#[pyclass] +struct MixAggregate { + decimal_sum: Vec<ArrayD<f64>>, + integer_sum: Vec<ArrayD<f64>>, +} + +#[pymethods] +impl MixAggregate { + #[new] + fn new() -> Self { + Self { + decimal_sum: Vec::new(), + integer_sum: Vec::new(), + } + } + fn aggregate(&mut self, inputs: Vec<(PyReadonlyArrayDyn<f64>, PyReadonlyArrayDyn<f64>)>) { + inputs + .into_iter() + .enumerate() + .for_each(|(i, (decimal, integer))| { + if i >= self.decimal_sum.len() { + self.decimal_sum.push(decimal.as_array().to_owned()); + self.integer_sum.push(integer.as_array().to_owned()); + } else { + self.decimal_sum[i] += &decimal.as_array(); + self.integer_sum[i] += &integer.as_array(); + } + }); + } + fn finalize(&self, py: Python, weight: Option<f64>) -> Vec<Py<PyArrayDyn<f64>>> { + self.decimal_sum + .iter() + .zip(self.integer_sum.iter()) + .map(|(decimal, integer)| { + let mut output = decimal.clone(); + output += integer; + if let Some(w) = weight { + output /= w; + } + output.into_pyarray(py).to_owned() + }) + .collect() + } +} + +pub(crate) fn register(py: Python, m: &PyModule) -> PyResult<()> { + let submodule_secure_aggregation_helper = PyModule::new(py, "secure_aggregation_helper")?; + submodule_secure_aggregation_helper.add_class::<RandomMix>()?; + submodule_secure_aggregation_helper.add_class::<MixAggregate>()?; + submodule_secure_aggregation_helper.add_class::<DiffieHellman>()?; + m.add_submodule(submodule_secure_aggregation_helper)?; + py.import("sys")?.getattr("modules")?.set_item( + "fate_utils.secure_aggregation_helper", + submodule_secure_aggregation_helper, + )?; + Ok(()) +} diff --git a/rust/fate_utils/crates/fixedpoint/Cargo.toml b/rust/fate_utils/crates/fixedpoint/Cargo.toml new file mode 100644 index 0000000000..27bbe31f14 --- /dev/null +++ b/rust/fate_utils/crates/fixedpoint/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "fixedpoint" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +serde = { workspace = true} +rug = { workspace = true } +math = { path = "../math" } +paillier = { path = "../paillier" } diff --git a/rust/tensor/rust_paillier/src/fixedpoint/coder.rs b/rust/fate_utils/crates/fixedpoint/src/coder.rs similarity index 80% rename from rust/tensor/rust_paillier/src/fixedpoint/coder.rs rename to rust/fate_utils/crates/fixedpoint/src/coder.rs index 0dc56520bf..60cc565fba 100644 --- a/rust/tensor/rust_paillier/src/fixedpoint/coder.rs +++ b/rust/fate_utils/crates/fixedpoint/src/coder.rs @@ -1,8 +1,9 @@ +use std::ops::{AddAssign, ShlAssign, SubAssign}; use super::frexp::Frexp; use super::PT; -use crate::math::BInt; use crate::paillier; -use rug::{self, ops::Pow}; +use math::BInt; +use rug::{self, Integer, ops::Pow}; use serde::{Deserialize, Serialize}; const FLOAT_MANTISSA_BITS: u32 = 53; @@ -11,7 +12,7 @@ const BASE: u32 = 16; const MAX_INT_FRACTION: u8 = 2; /// fixedpoint encoder -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct FixedpointCoder { pub n: BInt, pub max_int: BInt, @@ -35,6 +36,31 @@ impl FixedpointCoder { exp: 0, } } + pub fn pack(&self, plaintexts: &[u64], num_shift_bit: usize) -> PT { + let significant = plaintexts.iter().fold(Integer::default(), |mut x, v| { + x.shl_assign(num_shift_bit); + x.add_assign(v); + x + }); + PT { + significant: paillier::PT(BInt(significant)), + exp: 0, + } + } + pub fn unpack(&self, encoded: &PT, num_shift_bit: usize, num: usize) -> Vec<u64> { + let mut significant = encoded.significant.0.0.clone(); + let mut mask = Integer::from(1u64 << num_shift_bit); + mask.sub_assign(1); + + let mut result = Vec::with_capacity(num); + for _ in 0..num { + let value = Integer::from(significant.clone() & mask.clone()).to_u64().unwrap(); + result.push(value); + significant >>= num_shift_bit; + } + result.reverse(); + result + } pub fn encode_i32(&self, plaintext: i32) -> PT { let significant = paillier::PT(if plaintext < 0 { BInt::from(&self.n + plaintext) @@ -115,6 +141,7 @@ pub trait CouldCode { fn encode(&self, coder: &FixedpointCoder) -> PT; fn decode(pt: &PT, coder: &FixedpointCoder) -> Self; } + impl CouldCode for f64 { fn encode(&self, coder: &FixedpointCoder) -> PT { coder.encode_f64(*self) @@ -132,6 +159,7 @@ impl CouldCode for i64 { coder.decode_i64(pt) } } + impl CouldCode for i32 { fn encode(&self, coder: &FixedpointCoder) -> PT { coder.encode_i32(*self) @@ -140,6 +168,7 @@ impl CouldCode for i32 { coder.decode_i32(pt) } } + impl CouldCode for f32 { fn encode(&self, coder: &FixedpointCoder) -> PT { coder.encode_f32(*self) diff --git a/rust/tensor/rust_paillier/src/fixedpoint/frexp.rs b/rust/fate_utils/crates/fixedpoint/src/frexp.rs similarity index 100% rename from rust/tensor/rust_paillier/src/fixedpoint/frexp.rs rename to rust/fate_utils/crates/fixedpoint/src/frexp.rs diff --git a/rust/tensor/rust_paillier/src/fixedpoint/mod.rs b/rust/fate_utils/crates/fixedpoint/src/lib.rs similarity index 83% rename from rust/tensor/rust_paillier/src/fixedpoint/mod.rs rename to rust/fate_utils/crates/fixedpoint/src/lib.rs index abd5754754..4d4bd5ef9b 100644 --- a/rust/tensor/rust_paillier/src/fixedpoint/mod.rs +++ b/rust/fate_utils/crates/fixedpoint/src/lib.rs @@ -1,15 +1,16 @@ mod coder; mod frexp; -use crate::math::BInt; -use crate::paillier; -pub(crate) use coder::CouldCode; + +pub use coder::CouldCode; pub use coder::FixedpointCoder; +use math::BInt; +use paillier; use serde::{Deserialize, Serialize}; const BASE: u32 = 16; /// fixedpoint plaintext -#[derive(Debug)] +#[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct PT { pub significant: paillier::PT, pub exp: i32, @@ -17,13 +18,13 @@ pub struct PT { /// fixedpoint ciphertext /// raw paillier ciphertext represents encryped significant -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct CT { - significant_encryped: paillier::CT, - exp: i32, + pub significant_encryped: paillier::CT, + pub exp: i32, } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct PK { pub pk: paillier::PK, pub coder: coder::FixedpointCoder, @@ -35,7 +36,8 @@ impl PK { PK { pk, coder } } } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] + +#[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct SK { pub sk: paillier::SK, pub coder: coder::FixedpointCoder, @@ -104,15 +106,38 @@ impl CT { let b = pk.encrypt(b, false); self.sub(&b, pk) } + /* + other - self + */ + pub fn rsub_pt(&self, b: &PT, pk: &PK) -> CT { + let b = pk.encrypt(b, false); + b.sub(self, pk) + } pub fn sub(&self, b: &CT, pk: &PK) -> CT { self.add(&b.neg(pk), pk) } + pub fn rsub(&self, b: &CT, pk: &PK) -> CT { + self.neg(pk).add(&b, pk) + } pub fn add_assign(&mut self, b: &CT, pk: &PK) { // FIXME *self = self.add(b, pk); } + pub fn i_double(&mut self, pk: &PK) { + self.significant_encryped.0 = self + .significant_encryped + .0 + .pow_mod_ref(&BInt::from(2), &pk.pk.ns); + } + pub fn add(&self, b: &CT, pk: &PK) -> CT { let a = self; + if a.significant_encryped.0.0 == 1 { + return b.clone(); + } + if b.significant_encryped.0.0 == 1 { + return a.clone(); + } if a.exp > b.exp { let a = &a.decrese_exp_to(b.exp, &pk.pk); CT { diff --git a/rust/fate_utils/crates/fixedpoint_ou/Cargo.toml b/rust/fate_utils/crates/fixedpoint_ou/Cargo.toml new file mode 100644 index 0000000000..da5809fb78 --- /dev/null +++ b/rust/fate_utils/crates/fixedpoint_ou/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "fixedpoint_ou" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +serde = { workspace = true} +rug = { workspace = true } +anyhow = { workspace = true } +math = { path = "../math" } +ou = { path = "../ou" } + +[dev-dependencies] +rand = { workspace = true } diff --git a/rust/fate_utils/crates/fixedpoint_ou/src/frexp.rs b/rust/fate_utils/crates/fixedpoint_ou/src/frexp.rs new file mode 100644 index 0000000000..019f28b900 --- /dev/null +++ b/rust/fate_utils/crates/fixedpoint_ou/src/frexp.rs @@ -0,0 +1,26 @@ +use std::os::raw::{c_double, c_float, c_int}; + +extern "C" { + fn frexp(x: c_double, exp: *mut c_int) -> c_double; + fn frexpf(x: c_float, exp: *mut c_int) -> c_float; +} + +pub trait Frexp: Sized { + fn frexp(self) -> (Self, i32); +} + +impl Frexp for f64 { + fn frexp(self) -> (Self, i32) { + let mut exp: c_int = 0; + let res = unsafe { frexp(self, &mut exp) }; + (res, exp) + } +} + +impl Frexp for f32 { + fn frexp(self) -> (Self, i32) { + let mut exp: c_int = 0; + let res = unsafe { frexpf(self, &mut exp) }; + (res, exp) + } +} diff --git a/rust/fate_utils/crates/fixedpoint_ou/src/lib.rs b/rust/fate_utils/crates/fixedpoint_ou/src/lib.rs new file mode 100644 index 0000000000..f57ffe43e1 --- /dev/null +++ b/rust/fate_utils/crates/fixedpoint_ou/src/lib.rs @@ -0,0 +1,865 @@ +use math::BInt; +use ou; +use anyhow::Result; +use anyhow::anyhow; +use std::ops::{AddAssign, BitAnd, Mul, ShlAssign, SubAssign}; +use rug::{self, Integer, ops::Pow, Float, Rational}; +use serde::{Deserialize, Serialize}; + +mod frexp; + +// use frexp::Frexp; + +const BASE: u32 = 16; +// const MAX_INT_FRACTION: u8 = 2; +// const FLOAT_MANTISSA_BITS: u32 = 53; +const LOG2_BASE: u32 = 4; + +#[derive(Default, Serialize, Deserialize)] +pub struct PK { + pub pk: ou::PK, + // pub max_int: BInt, +} + +impl PK { + #[inline] + pub fn encrypt(&self, plaintext: &Plaintext, obfuscate: bool) -> Ciphertext { + let exp = plaintext.exp; + let encode = self.pk.encrypt(&plaintext.significant, obfuscate); + Ciphertext { + significant_encryped: encode, + exp, + } + } +} + +#[derive(Default, Serialize, Deserialize)] +pub struct SK { + pub sk: ou::SK, +} + +impl SK { + #[inline] + pub fn decrypt(&self, ciphertext: &Ciphertext) -> Plaintext { + let exp = ciphertext.exp; + Plaintext { + significant: self.sk.decrypt(&ciphertext.significant_encryped), + exp, + } + } +} + + +/// fixedpoint encoder +#[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct Coder {} + +impl Coder { + pub fn new() -> Self { + Coder {} + } + + pub fn encode_u64(&self, plaintext: u64) -> Plaintext { + let significant = ou::PT(BInt::from(plaintext)); + Plaintext { + significant, + exp: 0, + } + } + pub fn pack_floats(&self, floats: &Vec<f64>, offset_bit: usize, pack_num: usize, precision: u32) -> Vec<Plaintext> { + let int_scale = Integer::from(2).pow(precision); + floats.chunks(pack_num).map(|data| { + let significant = data.iter().fold(Integer::default(), |mut x, v| { + x.shl_assign(offset_bit); + x.add_assign(Float::with_val(64, v).mul(&int_scale).round().to_integer().unwrap()); + x + }); + Plaintext { + significant: ou::PT(BInt(significant)), + exp: 0, + } + }) + .collect() + } + pub fn unpack_floats(&self, encoded: &[Plaintext], offset_bit: usize, pack_num: usize, precision: u32, expect_total_num: usize) -> Vec<f64> { + let int_scale = Integer::from(2).pow(precision); + let mut mask = Integer::from(1); + mask <<= offset_bit; + mask.sub_assign(1); + let mut result = Vec::with_capacity(expect_total_num); + let mut total_num = expect_total_num; + for x in encoded { + let n = std::cmp::min(total_num, pack_num); + let mut significant = x.significant.0.0.clone(); + let mut temp = Vec::with_capacity(n); + for _ in 0..n { + let value = Rational::from(((&significant).bitand(&mask), &int_scale)).to_f64(); + temp.push(value); + significant >>= offset_bit; + } + temp.reverse(); + result.extend(temp); + total_num -= n; + } + #[cfg(debug_assertions)] + assert_eq!(result.len(), expect_total_num); + + result + } + pub fn encode_u32(&self, plaintext: u32) -> Plaintext { + let significant = ou::PT( + BInt::from(plaintext) + ); + Plaintext { + significant, + exp: 0, + } + } + pub fn decode_u64(&self, encoded: &Plaintext) -> u64 { + let significant = encoded.significant.0.clone(); + let mantissa = significant; + (mantissa << (LOG2_BASE as i32 * encoded.exp)).to_i128() as u64 + } + pub fn decode_u32(&self, encoded: &Plaintext) -> u32 { + // Todo: could be improved + self.decode_u64(encoded) as u32 + } + + // pub fn encode_f64(&self, plaintext: f64) -> Plaintext { + // let bin_flt_exponent = plaintext.frexp().1; + // let bin_lsb_exponent = bin_flt_exponent - (FLOAT_MANTISSA_BITS as i32); + // let exp = (bin_lsb_exponent as f64 / LOG2_BASE as f64).floor() as i32; + // let significant = BInt( + // (plaintext * rug::Float::with_val(FLOAT_MANTISSA_BITS, BASE).pow(-exp)) + // .round() + // .to_integer() + // .unwrap(), + // ); + // if significant.abs_ref() > self.max_int { + // panic!( + // "Integer needs to be within +/- {} but got {}", + // self.max_int.0, &significant.0 + // ) + // } + // Plaintext { + // significant: ou::PT(significant), + // exp, + // } + // } + // pub fn decode_f64(&self, encoded: &Plaintext) -> f64 { + // let significant = encoded.significant.0.clone(); + // let mantissa = if significant > self.n { + // panic!("Attempted to decode corrupted number") + // } else if significant <= self.max_int { + // significant + // } else if significant >= BInt::from(&self.n - &self.max_int) { + // significant - &self.n + // } else { + // format!("Overflow detected in decrypted number: {:?}", significant); + // panic!("Overflow detected in decrypted number") + // }; + // if encoded.exp >= 0 { + // (mantissa << (LOG2_BASE as i32 * encoded.exp)).to_f64() + // } else { + // (mantissa * rug::Float::with_val(FLOAT_MANTISSA_BITS, BASE).pow(encoded.exp)).to_f64() + // } + // } + // pub fn encode_f32(&self, plaintext: f32) -> Plaintext { + // self.encode_f64(plaintext as f64) + // } + // pub fn decode_f32(&self, encoded: &Plaintext) -> f32 { + // self.decode_f64(encoded) as f32 + // } +} + + +#[derive(Default, Debug, Clone, Serialize, Deserialize)] +pub struct Ciphertext { + pub significant_encryped: ou::CT, + pub exp: i32, +} + +impl Ciphertext { + pub fn zero() -> Ciphertext { + Ciphertext { + significant_encryped: ou::CT::zero(), + exp: 0, + } + } + fn decrese_exp_to(&self, exp: i32, pk: &ou::PK) -> Ciphertext { + assert!(exp < self.exp); + let factor = BInt::from(BASE).pow((self.exp - exp) as u32); + let significant_encryped = self.significant_encryped.mul_pt(&ou::PT(factor), pk); + Ciphertext { + significant_encryped, + exp, + } + } + pub fn neg(&self, pk: &PK) -> Ciphertext { + Ciphertext { + significant_encryped: ou::CT(self.significant_encryped.0.invert_ref(&pk.pk.n)), + exp: self.exp, + } + } + pub fn add_pt(&self, b: &Plaintext, pk: &PK) -> Ciphertext { + let b = pk.encrypt(b, false); + self.add(&b, pk) + } + pub fn sub_pt(&self, b: &Plaintext, pk: &PK) -> Ciphertext { + let b = pk.encrypt(b, false); + self.sub(&b, pk) + } + /* + other - self + */ + pub fn rsub_pt(&self, b: &Plaintext, pk: &PK) -> Ciphertext { + let b = pk.encrypt(b, false); + b.sub(self, pk) + } + pub fn sub(&self, b: &Ciphertext, pk: &PK) -> Ciphertext { + self.add(&b.neg(pk), pk) + } + pub fn rsub(&self, b: &Ciphertext, pk: &PK) -> Ciphertext { + self.neg(pk).add(&b, pk) + } + pub fn add_assign(&mut self, b: &Ciphertext, pk: &PK) { + // FIXME + *self = self.add(b, pk); + } + pub fn sub_assign(&mut self, b: &Ciphertext, pk: &PK) { + // FIXME + *self = self.sub(b, pk); + } + pub fn i_double(&mut self, pk: &PK) { + self.significant_encryped.0 = self + .significant_encryped + .0 + .pow_mod_ref(&BInt::from(2), &pk.pk.n); + } + + pub fn add(&self, b: &Ciphertext, pk: &PK) -> Ciphertext { + let a = self; + if a.significant_encryped.0.0 == 1 { + return b.clone(); + } + if b.significant_encryped.0.0 == 1 { + return a.clone(); + } + if a.exp > b.exp { + let a = &a.decrese_exp_to(b.exp, &pk.pk); + Ciphertext { + significant_encryped: a + .significant_encryped + .add_ct(&b.significant_encryped, &pk.pk), + exp: b.exp, + } + } else if a.exp < b.exp { + let b = &b.decrese_exp_to(a.exp, &pk.pk); + Ciphertext { + significant_encryped: a + .significant_encryped + .add_ct(&b.significant_encryped, &pk.pk), + exp: a.exp, + } + } else { + Ciphertext { + significant_encryped: a + .significant_encryped + .add_ct(&b.significant_encryped, &pk.pk), + exp: a.exp, + } + } + } + pub fn mul(&self, b: &Plaintext, pk: &PK) -> Ciphertext { + let inside = (&self.significant_encryped.0).pow_mod_ref(&b.significant.0, &pk.pk.n); + Ciphertext { + significant_encryped: ou::CT(inside), + exp: self.exp + b.exp, + } + } +} + + +#[derive(Default, Debug, Clone, Serialize, Deserialize)] +pub struct CiphertextVector { + pub data: Vec<Ciphertext>, +} + +#[derive(Default, Debug, Clone, Serialize, Deserialize)] +pub struct Plaintext { + pub significant: ou::PT, + pub exp: i32, +} + +#[derive(Default, Debug, Serialize, Deserialize)] +pub struct PlaintextVector { + pub data: Vec<Plaintext>, +} + +impl PK { + pub fn encrypt_encoded( + &self, + plaintext: &PlaintextVector, + obfuscate: bool, + ) -> CiphertextVector { + let data = plaintext + .data + .iter() + .map(|x| Ciphertext { significant_encryped: self.pk.encrypt(&x.significant, obfuscate), exp: x.exp }) + .collect(); + CiphertextVector { data } + } + pub fn encrypt_encoded_scalar(&self, plaintext: &Plaintext, obfuscate: bool) -> Ciphertext { + Ciphertext { + significant_encryped: self.pk.encrypt(&plaintext.significant, obfuscate), + exp: plaintext.exp, + } + } +} + + +impl SK { + pub fn decrypt_to_encoded(&self, data: &CiphertextVector) -> PlaintextVector { + let data = data.data.iter().map(|x| Plaintext { + significant: + self.sk.decrypt(&x.significant_encryped), + exp: x.exp, + }).collect(); + PlaintextVector { data } + } + pub fn decrypt_to_encoded_scalar(&self, data: &Ciphertext) -> Plaintext { + Plaintext { + significant: self.sk.decrypt(&data.significant_encryped), + exp: data.exp, + } + } +} + +pub fn keygen(bit_length: u32) -> (SK, PK, Coder) { + let (sk, pk) = ou::keygen(bit_length); + let coder = Coder::new(); + // let max_int = &sk.p / MAX_INT_FRACTION; + (SK { sk }, PK { pk: pk }, coder) +} + +impl CiphertextVector { + #[inline] + fn iadd_i_j(&mut self, pk: &PK, i: usize, j: usize, size: usize) { + let mut placeholder = Ciphertext::default(); + for k in 0..size { + placeholder = std::mem::replace(&mut self.data[i + k], placeholder); + placeholder.add_assign(&self.data[j + k], &pk); + placeholder = std::mem::replace(&mut self.data[i + k], placeholder); + } + } + #[inline] + fn isub_i_j(&mut self, pk: &PK, i: usize, j: usize, size: usize) { + let mut placeholder = Ciphertext::default(); + for k in 0..size { + placeholder = std::mem::replace(&mut self.data[i + k], placeholder); + placeholder.sub_assign(&self.data[j + k], &pk); + placeholder = std::mem::replace(&mut self.data[i + k], placeholder); + } + } + pub fn zeros(size: usize) -> Self { + let data = vec![Ciphertext::zero(); size]; + CiphertextVector { data } + } + + pub fn pack_squeeze(&self, pk: &PK, pack_num: usize, shift_bit: u32) -> CiphertextVector { + let base = BInt::from(2).pow(shift_bit); + let data = self.data.chunks(pack_num).map(|x| { + let mut result = x[0].significant_encryped.0.clone(); + for y in &x[1..] { + result.pow_mod_mut(&base, &pk.pk.n); + result = result.mul(&y.significant_encryped.0) % &pk.pk.n; + } + Ciphertext { significant_encryped: ou::CT(result), exp: 0 } + }).collect(); + CiphertextVector { data } + } + + pub fn slice(&mut self, start: usize, size: usize) -> CiphertextVector { + let data = self.data[start..start + size].to_vec(); + CiphertextVector { data } + } + + pub fn slice_indexes(&mut self, indexes: Vec<usize>) -> Self { + let data = indexes + .iter() + .map(|i| self.data[*i].clone()) + .collect::<Vec<_>>(); + CiphertextVector { data } + } + + pub fn cat(&self, others: Vec<&CiphertextVector>) -> Self { + let mut data = self.data.clone(); + for other in others { + data.extend(other.data.clone()); + } + CiphertextVector { data } + } + + pub fn i_shuffle(&mut self, indexes: Vec<usize>) { + let mut visited = vec![false; self.data.len()]; + for i in 0..self.data.len() { + if visited[i] || indexes[i] == i { + continue; + } + + let mut current = i; + let mut next = indexes[current]; + while !visited[next] && next != i { + self.data.swap(current, next); + visited[current] = true; + current = next; + next = indexes[current]; + } + visited[current] = true; + } + } + + pub fn shuffle(&self, indexes: Vec<usize>) -> Self { + let data = self.data.clone(); + let mut result = CiphertextVector { data }; + result.i_shuffle(indexes); + result + } + + pub fn intervals_slice(&mut self, intervals: Vec<(usize, usize)>) -> Result<Self> { + let mut data = vec![]; + for (start, end) in intervals { + if end > self.data.len() { + return Err(anyhow!( + "end index out of range: start={}, end={}, data_size={}", + start, + end, + self.data.len() + )); + } + data.extend_from_slice(&self.data[start..end]); + } + Ok(CiphertextVector { data }) + } + + pub fn iadd_slice(&mut self, pk: &PK, position: usize, other: Vec<&Ciphertext>) { + for (i, x) in other.iter().enumerate() { + self.data[position + i] = self.data[position + i].add(&x, &pk); + } + } + + pub fn iadd_vec_self( + &mut self, + sa: usize, + sb: usize, + size: Option<usize>, + pk: &PK, + ) -> Result<()> { + if sa == sb { + if let Some(s) = size { + if sa + s > self.data.len() { + return Err(anyhow!( + "end index out of range: sa={}, s={}, data_size={}", + sa, + s, + self.data.len() + )); + } + self.data[sa..sa + s] + .iter_mut() + .for_each(|x| x.i_double(&pk)); + } else { + self.data[sa..].iter_mut().for_each(|x| x.i_double(&pk)); + } + } else if sa < sb { + // it's safe to update from left to right + if let Some(s) = size { + if sb + s > self.data.len() { + return Err(anyhow!( + "end index out of range: sb={}, s={}, data_size={}", + sb, + s, + self.data.len() + )); + } + self.iadd_i_j(&pk, sb, sa, s); + } else { + self.iadd_i_j(&pk, sb, sa, self.data.len() - sb); + } + } else { + // it's safe to update from right to left + if let Some(s) = size { + if sa + s > self.data.len() { + return Err(anyhow!( + "end index out of range: sa={}, s={}, data_size={}", + sa, + s, + self.data.len() + )); + } + self.iadd_i_j(&pk, sa, sb, s); + } else { + self.iadd_i_j(&pk, sa, sb, self.data.len() - sa); + } + } + Ok(()) + } + pub fn isub_vec_self( + &mut self, + sa: usize, + sb: usize, + size: Option<usize>, + pk: &PK, + ) -> Result<()> { + if sa == sb { + if let Some(s) = size { + if sa + s > self.data.len() { + return Err(anyhow!( + "end index out of range: sa={}, s={}, data_size={}", + sa, + s, + self.data.len() + )); + } + self.data[sa..sa + s] + .iter_mut() + .for_each(|x| *x = Ciphertext::zero()); + } else { + self.data[sa..].iter_mut().for_each(|x| *x = Ciphertext::zero()); + } + } else if sa < sb { + // it's safe to update from left to right + if let Some(s) = size { + if sb + s > self.data.len() { + return Err(anyhow!( + "end index out of range: sb={}, s={}, data_size={}", + sb, + s, + self.data.len() + )); + } + self.isub_i_j(&pk, sb, sa, s); + } else { + self.isub_i_j(&pk, sb, sa, self.data.len() - sb); + } + } else { + // it's safe to update from right to left + if let Some(s) = size { + if sa + s > self.data.len() { + return Err(anyhow!( + "end index out of range: sa={}, s={}, data_size={}", + sa, + s, + self.data.len() + )); + } + self.isub_i_j(&pk, sa, sb, s); + } else { + self.isub_i_j(&pk, sa, sb, self.data.len() - sa); + } + } + Ok(()) + } + + pub fn iadd_vec( + &mut self, + other: &CiphertextVector, + sa: usize, + sb: usize, + size: Option<usize>, + pk: &PK, + ) -> Result<()> { + match size { + Some(s) => { + let ea = sa + s; + let eb = sb + s; + if ea > self.data.len() { + return Err(anyhow!( + "end index out of range: sa={}, ea={}, data_size={}", + sa, + ea, + self.data.len() + )); + } + if eb > other.data.len() { + return Err(anyhow!( + "end index out of range: sb={}, eb={}, data_size={}", + sb, + eb, + other.data.len() + )); + } + self.data[sa..ea] + .iter_mut() + .zip(other.data[sb..eb].iter()) + .for_each(|(x, y)| { + x.add_assign(y, &pk) + }); + } + None => { + self.data[sa..] + .iter_mut() + .zip(other.data[sb..].iter()) + .for_each(|(x, y)| x.add_assign(y, &pk)); + } + }; + Ok(()) + } + + pub fn isub_vec( + &mut self, + other: &CiphertextVector, + sa: usize, + sb: usize, + size: Option<usize>, + pk: &PK, + ) -> Result<()> { + match size { + Some(s) => { + let ea = sa + s; + let eb = sb + s; + if ea > self.data.len() { + return Err(anyhow!( + "end index out of range: sa={}, ea={}, data_size={}", + sa, + ea, + self.data.len() + )); + } + if eb > other.data.len() { + return Err(anyhow!( + "end index out of range: sb={}, eb={}, data_size={}", + sb, + eb, + other.data.len() + )); + } + self.data[sa..ea] + .iter_mut() + .zip(other.data[sb..eb].iter()) + .for_each(|(x, y)| { + x.sub_assign(y, &pk) + }); + } + None => { + self.data[sa..] + .iter_mut() + .zip(other.data[sb..].iter()) + .for_each(|(x, y)| x.sub_assign(y, &pk)); + } + }; + Ok(()) + } + + pub fn iupdate(&mut self, other: &CiphertextVector, indexes: Vec<Vec<usize>>, stride: usize, pk: &PK) -> Result<()> { + for (i, x) in indexes.iter().enumerate() { + let sb = i * stride; + for pos in x.iter() { + let sa = pos * stride; + for i in 0..stride { + self.data[sa + i].add_assign(&other.data[sb + i], &pk); + } + } + } + Ok(()) + } + pub fn iupdate_with_masks(&mut self, other: &CiphertextVector, indexes: Vec<Vec<usize>>, masks: Vec<bool>, stride: usize, pk: &PK) -> Result<()> { + for (value_pos, x) in masks.iter().enumerate().filter(|(_, &mask)| mask).map(|(i, _)| i).zip(indexes.iter()) { + let sb = value_pos * stride; + for pos in x.iter() { + let sa = pos * stride; + for i in 0..stride { + self.data[sa + i].add_assign(&other.data[sb + i], &pk); + } + } + } + Ok(()) + } + + pub fn iadd(&mut self, pk: &PK, other: &CiphertextVector) { + self.data + .iter_mut() + .zip(other.data.iter()) + .for_each(|(x, y)| x.add_assign(y, &pk)); + } + + pub fn idouble(&mut self, pk: &PK) { + // TODO: fix me, remove clone + self.data + .iter_mut() + .for_each(|x| x.add_assign(&x.clone(), &pk)); + } + + pub fn chunking_cumsum_with_step(&mut self, pk: &PK, chunk_sizes: Vec<usize>, step: usize) { + let mut placeholder = Ciphertext::zero(); + let mut i = 0; + for chunk_size in chunk_sizes { + for j in step..chunk_size { + placeholder = std::mem::replace(&mut self.data[i + j], placeholder); + placeholder.add_assign(&self.data[i + j - step], &pk); + placeholder = std::mem::replace(&mut self.data[i + j], placeholder); + } + i += chunk_size; + } + } + + pub fn intervals_sum_with_step( + &mut self, + pk: &PK, + intervals: Vec<(usize, usize)>, + step: usize, + ) -> CiphertextVector { + let mut data = vec![Ciphertext::zero(); intervals.len() * step]; + for (i, (s, e)) in intervals.iter().enumerate() { + let chunk = &mut data[i * step..(i + 1) * step]; + let sub_vec = &self.data[*s..*e]; + for (val, c) in sub_vec.iter().zip((0..step).cycle()) { + chunk[c].add_assign(val, &pk); + } + } + CiphertextVector { data } + } + + pub fn tolist(&self) -> Vec<CiphertextVector> { + self.data.iter().map(|x| CiphertextVector { data: vec![x.clone()] }).collect() + } + + pub fn add(&self, pk: &PK, other: &CiphertextVector) -> CiphertextVector { + let data = self + .data + .iter() + .zip(other.data.iter()) + .map(|(x, y)| x.add(y, &pk)) + .collect(); + CiphertextVector { data } + } + + pub fn add_scalar(&self, pk: &PK, other: &Ciphertext) -> CiphertextVector { + let data = self.data.iter().map(|x| x.add(&other, &pk)).collect(); + CiphertextVector { data } + } + + pub fn sub(&self, pk: &PK, other: &CiphertextVector) -> CiphertextVector { + let data = self + .data + .iter() + .zip(other.data.iter()) + .map(|(x, y)| x.sub(y, &pk)) + .collect(); + CiphertextVector { data } + } + + pub fn sub_scalar(&self, pk: &PK, other: &Ciphertext) -> CiphertextVector { + let data = self.data.iter().map(|x| x.sub(&other, &pk)).collect(); + CiphertextVector { data } + } + + pub fn rsub(&self, pk: &PK, other: &CiphertextVector) -> CiphertextVector { + let data = self + .data + .iter() + .zip(other.data.iter()) + .map(|(x, y)| y.sub(x, &pk)) + .collect(); + CiphertextVector { data } + } + + pub fn rsub_scalar(&self, pk: &PK, other: &Ciphertext) -> CiphertextVector { + let data = self.data.iter().map(|x| other.sub(x, &pk)).collect(); + CiphertextVector { data } + } + + pub fn mul(&self, pk: &PK, other: &PlaintextVector) -> CiphertextVector { + let data = self + .data + .iter() + .zip(other.data.iter()) + .map(|(x, y)| x.mul(y, &pk)) + .collect(); + CiphertextVector { data } + } + + pub fn mul_scalar(&self, pk: &PK, other: &Plaintext) -> CiphertextVector { + let data = self + .data + .iter() + .map(|x| x.mul(&other, &pk)) + .collect(); + CiphertextVector { data } + } + + pub fn matmul( + &self, + pk: &PK, + other: &PlaintextVector, + lshape: Vec<usize>, + rshape: Vec<usize>, + ) -> CiphertextVector { + let mut data = vec![Ciphertext::zero(); lshape[0] * rshape[1]]; + for i in 0..lshape[0] { + for j in 0..rshape[1] { + for k in 0..lshape[1] { + data[i * rshape[1] + j].add_assign( + &self.data[i * lshape[1] + k].mul(&other.data[k * rshape[1] + j], &pk), + &pk, + ); + } + } + } + CiphertextVector { data } + } + + pub fn rmatmul( + &self, + pk: &PK, + other: &PlaintextVector, + lshape: Vec<usize>, + rshape: Vec<usize>, + ) -> CiphertextVector { + // rshape, lshape -> rshape[0] x lshape[1] + // other, self + // 4 x 2, 2 x 5 + // ik, kj -> ij + let mut data = vec![Ciphertext::zero(); lshape[1] * rshape[0]]; + for i in 0..rshape[0] { + // 4 + for j in 0..lshape[1] { + // 5 + for k in 0..rshape[1] { + // 2 + data[i * lshape[1] + j].add_assign( + &self.data[k * lshape[1] + j].mul(&other.data[i * rshape[1] + k], &pk), + &pk, + ); + } + } + } + CiphertextVector { data } + } +} + +impl PlaintextVector { + pub fn get_stride(&mut self, index: usize, stride: usize) -> PlaintextVector { + let start = index * stride; + let end = start + stride; + let data = self.data[start..end].to_vec(); + PlaintextVector { data } + } + pub fn tolist(&self) -> Vec<Plaintext> { + self.data + .iter() + .map(|x| x.clone()) + .collect() + } +} + +#[test] +fn test_decrypt() { + let (sk, pk, coder) = keygen(1024); + let mut data = vec![0.5, -0.5]; + let encoded = PlaintextVector { data: data.iter().map(|x| coder.encode_f64(*x)).collect() }; + let encrypted = pk.encrypt_encoded(&encoded, false); + let decrypted = sk.decrypt_to_encoded(&encrypted); + let decoded = decrypted.data.iter().map(|x| coder.decode_f64(x)).collect::<Vec<_>>(); + assert_eq!(data, decoded); +} \ No newline at end of file diff --git a/rust/fate_utils/crates/fixedpoint_paillier/Cargo.toml b/rust/fate_utils/crates/fixedpoint_paillier/Cargo.toml new file mode 100644 index 0000000000..793d836fe5 --- /dev/null +++ b/rust/fate_utils/crates/fixedpoint_paillier/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "fixedpoint_paillier" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +serde = { workspace = true} +rug = { workspace = true } +anyhow = { workspace = true } +math = { path = "../math" } +paillier = { path = "../paillier" } diff --git a/rust/fate_utils/crates/fixedpoint_paillier/src/frexp.rs b/rust/fate_utils/crates/fixedpoint_paillier/src/frexp.rs new file mode 100644 index 0000000000..019f28b900 --- /dev/null +++ b/rust/fate_utils/crates/fixedpoint_paillier/src/frexp.rs @@ -0,0 +1,26 @@ +use std::os::raw::{c_double, c_float, c_int}; + +extern "C" { + fn frexp(x: c_double, exp: *mut c_int) -> c_double; + fn frexpf(x: c_float, exp: *mut c_int) -> c_float; +} + +pub trait Frexp: Sized { + fn frexp(self) -> (Self, i32); +} + +impl Frexp for f64 { + fn frexp(self) -> (Self, i32) { + let mut exp: c_int = 0; + let res = unsafe { frexp(self, &mut exp) }; + (res, exp) + } +} + +impl Frexp for f32 { + fn frexp(self) -> (Self, i32) { + let mut exp: c_int = 0; + let res = unsafe { frexpf(self, &mut exp) }; + (res, exp) + } +} diff --git a/rust/fate_utils/crates/fixedpoint_paillier/src/lib.rs b/rust/fate_utils/crates/fixedpoint_paillier/src/lib.rs new file mode 100644 index 0000000000..ec8c9794ec --- /dev/null +++ b/rust/fate_utils/crates/fixedpoint_paillier/src/lib.rs @@ -0,0 +1,924 @@ +use math::BInt; +use paillier; +use anyhow::Result; +use anyhow::anyhow; +use std::ops::{AddAssign, BitAnd, Mul, ShlAssign, SubAssign}; +use rug::{self, Integer, ops::Pow, Float, Rational}; +use serde::{Deserialize, Serialize}; + +mod frexp; + +use frexp::Frexp; + +const BASE: u32 = 16; +const MAX_INT_FRACTION: u8 = 2; +const FLOAT_MANTISSA_BITS: u32 = 53; +const LOG2_BASE: u32 = 4; + +#[derive(Default, Serialize, Deserialize)] +pub struct PK { + pub pk: paillier::PK, + pub max_int: BInt, +} + +impl PK { + #[inline] + pub fn encrypt(&self, plaintext: &Plaintext, obfuscate: bool) -> Ciphertext { + let exp = plaintext.exp; + let encode = self.pk.encrypt(&plaintext.significant, obfuscate); + Ciphertext { + significant_encryped: encode, + exp, + } + } +} + +#[derive(Default, Serialize, Deserialize)] +pub struct SK { + pub sk: paillier::SK, +} + +impl SK { + #[inline] + pub fn decrypt(&self, ciphertext: &Ciphertext) -> Plaintext { + let exp = ciphertext.exp; + Plaintext { + significant: self.sk.decrypt(&ciphertext.significant_encryped), + exp, + } + } +} + + +/// fixedpoint encoder +#[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct Coder { + pub n: BInt, + pub max_int: BInt, +} + +impl Coder { + pub fn new(n: &BInt) -> Self { + Coder { + n: n.clone(), + max_int: n / MAX_INT_FRACTION, + } + } + + pub fn encode_i64(&self, plaintext: i64) -> Plaintext { + let significant = paillier::PT(if plaintext < 0 { + BInt::from(&self.n + plaintext) + } else { + BInt::from(plaintext) + }); + Plaintext { + significant, + exp: 0, + } + } + pub fn pack_floats(&self, floats: &Vec<f64>, offset_bit: usize, pack_num: usize, precision: u32) -> Vec<Plaintext> { + let int_scale = Integer::from(2).pow(precision); + floats.chunks(pack_num).map(|data| { + let significant = data.iter().fold(Integer::default(), |mut x, v| { + x.shl_assign(offset_bit); + x.add_assign(Float::with_val(64, v).mul(&int_scale).round().to_integer().unwrap()); + x + }); + Plaintext { + significant: paillier::PT(BInt(significant)), + exp: 0, + } + }) + .collect() + } + pub fn unpack_floats(&self, encoded: &[Plaintext], offset_bit: usize, pack_num: usize, precision: u32, expect_total_num: usize) -> Vec<f64> { + let int_scale = Integer::from(2).pow(precision); + let mut mask = Integer::from(1); + mask <<= offset_bit; + mask.sub_assign(1); + let mut result = Vec::with_capacity(expect_total_num); + let mut total_num = expect_total_num; + for x in encoded { + let n = std::cmp::min(total_num, pack_num); + let mut significant = x.significant.0.0.clone(); + let mut temp = Vec::with_capacity(n); + for _ in 0..n { + let value = Rational::from(((&significant).bitand(&mask), &int_scale)).to_f64(); + temp.push(value); + significant >>= offset_bit; + } + temp.reverse(); + result.extend(temp); + total_num -= n; + } + #[cfg(debug_assertions)] + assert_eq!(result.len(), expect_total_num); + + result + } + pub fn encode_i32(&self, plaintext: i32) -> Plaintext { + let significant = paillier::PT(if plaintext < 0 { + BInt::from(&self.n + plaintext) + } else { + BInt::from(plaintext) + }); + Plaintext { + significant, + exp: 0, + } + } + pub fn decode_i64(&self, encoded: &Plaintext) -> i64 { + let significant = encoded.significant.0.clone(); + let mantissa = if significant > self.n { + panic!("Attempted to decode corrupted number") + } else if significant <= self.max_int { + significant + } else if significant >= BInt::from(&self.n - &self.max_int) { + significant - &self.n + } else { + panic!("Overflow detected in decrypted number") + }; + (mantissa << (LOG2_BASE as i32 * encoded.exp)).to_i128() as i64 + } + pub fn decode_i32(&self, encoded: &Plaintext) -> i32 { + // Todo: could be improved + self.decode_f64(encoded) as i32 + } + + pub fn encode_f64(&self, plaintext: f64) -> Plaintext { + let bin_flt_exponent = plaintext.frexp().1; + let bin_lsb_exponent = bin_flt_exponent - (FLOAT_MANTISSA_BITS as i32); + let exp = (bin_lsb_exponent as f64 / LOG2_BASE as f64).floor() as i32; + let significant = BInt( + (plaintext * rug::Float::with_val(FLOAT_MANTISSA_BITS, BASE).pow(-exp)) + .round() + .to_integer() + .unwrap(), + ); + if significant.abs_ref() > self.max_int { + panic!( + "Integer needs to be within +/- {} but got {}", + self.max_int.0, &significant.0 + ) + } + Plaintext { + significant: paillier::PT(significant), + exp, + } + } + pub fn decode_f64(&self, encoded: &Plaintext) -> f64 { + let significant = encoded.significant.0.clone(); + let mantissa = if significant > self.n { + panic!("Attempted to decode corrupted number") + } else if significant <= self.max_int { + significant + } else if significant >= BInt::from(&self.n - &self.max_int) { + significant - &self.n + } else { + format!("Overflow detected in decrypted number: {:?}", significant); + panic!("Overflow detected in decrypted number") + }; + if encoded.exp >= 0 { + (mantissa << (LOG2_BASE as i32 * encoded.exp)).to_f64() + } else { + (mantissa * rug::Float::with_val(FLOAT_MANTISSA_BITS, BASE).pow(encoded.exp)).to_f64() + } + } + pub fn encode_f32(&self, plaintext: f32) -> Plaintext { + self.encode_f64(plaintext as f64) + } + pub fn decode_f32(&self, encoded: &Plaintext) -> f32 { + self.decode_f64(encoded) as f32 + } +} + +pub trait CouldCode { + fn encode(&self, coder: &Coder) -> Plaintext; + fn decode(plaintext: &Plaintext, coder: &Coder) -> Self; +} + +impl CouldCode for f64 { + fn encode(&self, coder: &Coder) -> Plaintext { + coder.encode_f64(*self) + } + fn decode(plaintext: &Plaintext, coder: &Coder) -> Self { + coder.decode_f64(plaintext) + } +} + +impl CouldCode for i64 { + fn encode(&self, coder: &Coder) -> Plaintext { + coder.encode_i64(*self) + } + fn decode(plaintext: &Plaintext, coder: &Coder) -> Self { + coder.decode_i64(plaintext) + } +} + +impl CouldCode for i32 { + fn encode(&self, coder: &Coder) -> Plaintext { + coder.encode_i32(*self) + } + fn decode(plaintext: &Plaintext, coder: &Coder) -> Self { + coder.decode_i32(plaintext) + } +} + +impl CouldCode for f32 { + fn encode(&self, coder: &Coder) -> Plaintext { + coder.encode_f32(*self) + } + fn decode(plaintext: &Plaintext, coder: &Coder) -> Self { + coder.decode_f32(plaintext) + } +} + + +#[derive(Default, Debug, Clone, Serialize, Deserialize)] +pub struct Ciphertext { + pub significant_encryped: paillier::CT, + pub exp: i32, +} + +impl Ciphertext { + pub fn zero() -> Ciphertext { + Ciphertext { + significant_encryped: paillier::CT::zero(), + exp: 0, + } + } + fn decrese_exp_to(&self, exp: i32, pk: &paillier::PK) -> Ciphertext { + assert!(exp < self.exp); + let factor = BInt::from(BASE).pow((self.exp - exp) as u32); + let significant_encryped = self.significant_encryped.mul_pt(&paillier::PT(factor), pk); + Ciphertext { + significant_encryped, + exp, + } + } + pub fn neg(&self, pk: &PK) -> Ciphertext { + Ciphertext { + significant_encryped: paillier::CT(self.significant_encryped.0.invert_ref(&pk.pk.ns)), + exp: self.exp, + } + } + pub fn add_pt(&self, b: &Plaintext, pk: &PK) -> Ciphertext { + let b = pk.encrypt(b, false); + self.add(&b, pk) + } + pub fn sub_pt(&self, b: &Plaintext, pk: &PK) -> Ciphertext { + let b = pk.encrypt(b, false); + self.sub(&b, pk) + } + /* + other - self + */ + pub fn rsub_pt(&self, b: &Plaintext, pk: &PK) -> Ciphertext { + let b = pk.encrypt(b, false); + b.sub(self, pk) + } + pub fn sub(&self, b: &Ciphertext, pk: &PK) -> Ciphertext { + self.add(&b.neg(pk), pk) + } + pub fn rsub(&self, b: &Ciphertext, pk: &PK) -> Ciphertext { + self.neg(pk).add(&b, pk) + } + pub fn add_assign(&mut self, b: &Ciphertext, pk: &PK) { + // FIXME + *self = self.add(b, pk); + } + pub fn sub_assign(&mut self, b: &Ciphertext, pk: &PK) { + // FIXME + *self = self.sub(b, pk); + } + pub fn i_double(&mut self, pk: &PK) { + self.significant_encryped.0 = self + .significant_encryped + .0 + .pow_mod_ref(&BInt::from(2), &pk.pk.ns); + } + + pub fn add(&self, b: &Ciphertext, pk: &PK) -> Ciphertext { + let a = self; + if a.significant_encryped.0.0 == 1 { + return b.clone(); + } + if b.significant_encryped.0.0 == 1 { + return a.clone(); + } + if a.exp > b.exp { + let a = &a.decrese_exp_to(b.exp, &pk.pk); + Ciphertext { + significant_encryped: a + .significant_encryped + .add_ct(&b.significant_encryped, &pk.pk), + exp: b.exp, + } + } else if a.exp < b.exp { + let b = &b.decrese_exp_to(a.exp, &pk.pk); + Ciphertext { + significant_encryped: a + .significant_encryped + .add_ct(&b.significant_encryped, &pk.pk), + exp: a.exp, + } + } else { + Ciphertext { + significant_encryped: a + .significant_encryped + .add_ct(&b.significant_encryped, &pk.pk), + exp: a.exp, + } + } + } + pub fn mul(&self, b: &Plaintext, pk: &PK) -> Ciphertext { + let inside = if &pk.pk.n - &pk.max_int <= b.significant.0 { + // large plaintext + let neg_c = self.significant_encryped.0.invert_ref(&pk.pk.ns); + let neg_scalar = &pk.pk.n - &b.significant.0; + neg_c.pow_mod_ref(&neg_scalar, &pk.pk.ns) + } else if b.significant.0 <= pk.max_int { + (&self.significant_encryped.0).pow_mod_ref(&b.significant.0, &pk.pk.ns) + } else { + panic!("invalid plaintext: {:?}", b) + }; + Ciphertext { + significant_encryped: paillier::CT(inside), + exp: self.exp + b.exp, + } + } +} + + +#[derive(Default, Debug, Clone, Serialize, Deserialize)] +pub struct CiphertextVector { + pub data: Vec<Ciphertext>, +} + +#[derive(Default, Debug, Clone, Serialize, Deserialize)] +pub struct Plaintext { + pub significant: paillier::PT, + pub exp: i32, +} + +#[derive(Default, Debug, Serialize, Deserialize)] +pub struct PlaintextVector { + pub data: Vec<Plaintext>, +} + +impl PK { + pub fn encrypt_encoded( + &self, + plaintext: &PlaintextVector, + obfuscate: bool, + ) -> CiphertextVector { + let data = plaintext + .data + .iter() + .map(|x| Ciphertext { significant_encryped: self.pk.encrypt(&x.significant, obfuscate), exp: x.exp }) + .collect(); + CiphertextVector { data } + } + pub fn encrypt_encoded_scalar(&self, plaintext: &Plaintext, obfuscate: bool) -> Ciphertext { + Ciphertext { + significant_encryped: self.pk.encrypt(&plaintext.significant, obfuscate), + exp: plaintext.exp, + } + } +} + + +impl SK { + pub fn decrypt_to_encoded(&self, data: &CiphertextVector) -> PlaintextVector { + let data = data.data.iter().map(|x| Plaintext { + significant: + self.sk.decrypt(&x.significant_encryped), + exp: x.exp, + }).collect(); + PlaintextVector { data } + } + pub fn decrypt_to_encoded_scalar(&self, data: &Ciphertext) -> Plaintext { + Plaintext { + significant: self.sk.decrypt(&data.significant_encryped), + exp: data.exp, + } + } +} + +pub fn keygen(bit_length: u32) -> (SK, PK, Coder) { + let (sk, pk) = paillier::keygen(bit_length); + let coder = Coder::new(&pk.n); + let max_int = &pk.n / MAX_INT_FRACTION; + (SK { sk }, PK { pk: pk, max_int: max_int }, coder) +} + +impl CiphertextVector { + #[inline] + fn iadd_i_j(&mut self, pk: &PK, i: usize, j: usize, size: usize) { + let mut placeholder = Ciphertext::default(); + for k in 0..size { + placeholder = std::mem::replace(&mut self.data[i + k], placeholder); + placeholder.add_assign(&self.data[j + k], &pk); + placeholder = std::mem::replace(&mut self.data[i + k], placeholder); + } + } + #[inline] + fn isub_i_j(&mut self, pk: &PK, i: usize, j: usize, size: usize) { + let mut placeholder = Ciphertext::default(); + for k in 0..size { + placeholder = std::mem::replace(&mut self.data[i + k], placeholder); + placeholder.sub_assign(&self.data[j + k], &pk); + placeholder = std::mem::replace(&mut self.data[i + k], placeholder); + } + } + pub fn zeros(size: usize) -> Self { + let data = vec![Ciphertext::zero(); size]; + CiphertextVector { data } + } + + pub fn pack_squeeze(&self, pk: &PK, pack_num: usize, shift_bit: u32) -> CiphertextVector { + let base = BInt::from(2).pow(shift_bit); + let data = self.data.chunks(pack_num).map(|x| { + let mut result = x[0].significant_encryped.0.clone(); + for y in &x[1..] { + result.pow_mod_mut(&base, &pk.pk.ns); + result = result.mul(&y.significant_encryped.0) % &pk.pk.ns; + } + Ciphertext { significant_encryped: paillier::CT(result), exp: 0 } + }).collect(); + CiphertextVector { data } + } + + pub fn slice(&mut self, start: usize, size: usize) -> CiphertextVector { + let data = self.data[start..start + size].to_vec(); + CiphertextVector { data } + } + + pub fn slice_indexes(&mut self, indexes: Vec<usize>) -> Self { + let data = indexes + .iter() + .map(|i| self.data[*i].clone()) + .collect::<Vec<_>>(); + CiphertextVector { data } + } + + pub fn cat(&self, others: Vec<&CiphertextVector>) -> Self { + let mut data = self.data.clone(); + for other in others { + data.extend(other.data.clone()); + } + CiphertextVector { data } + } + + pub fn i_shuffle(&mut self, indexes: Vec<usize>) { + let mut visited = vec![false; self.data.len()]; + for i in 0..self.data.len() { + if visited[i] || indexes[i] == i { + continue; + } + + let mut current = i; + let mut next = indexes[current]; + while !visited[next] && next != i { + self.data.swap(current, next); + visited[current] = true; + current = next; + next = indexes[current]; + } + visited[current] = true; + } + } + + pub fn shuffle(&self, indexes: Vec<usize>) -> Self { + let data = self.data.clone(); + let mut result = CiphertextVector { data }; + result.i_shuffle(indexes); + result + } + + pub fn intervals_slice(&mut self, intervals: Vec<(usize, usize)>) -> Result<Self> { + let mut data = vec![]; + for (start, end) in intervals { + if end > self.data.len() { + return Err(anyhow!( + "end index out of range: start={}, end={}, data_size={}", + start, + end, + self.data.len() + )); + } + data.extend_from_slice(&self.data[start..end]); + } + Ok(CiphertextVector { data }) + } + + pub fn iadd_slice(&mut self, pk: &PK, position: usize, other: Vec<&Ciphertext>) { + for (i, x) in other.iter().enumerate() { + self.data[position + i] = self.data[position + i].add(&x, &pk); + } + } + + pub fn iadd_vec_self( + &mut self, + sa: usize, + sb: usize, + size: Option<usize>, + pk: &PK, + ) -> Result<()> { + if sa == sb { + if let Some(s) = size { + if sa + s > self.data.len() { + return Err(anyhow!( + "end index out of range: sa={}, s={}, data_size={}", + sa, + s, + self.data.len() + )); + } + self.data[sa..sa + s] + .iter_mut() + .for_each(|x| x.i_double(&pk)); + } else { + self.data[sa..].iter_mut().for_each(|x| x.i_double(&pk)); + } + } else if sa < sb { + // it's safe to update from left to right + if let Some(s) = size { + if sb + s > self.data.len() { + return Err(anyhow!( + "end index out of range: sb={}, s={}, data_size={}", + sb, + s, + self.data.len() + )); + } + self.iadd_i_j(&pk, sb, sa, s); + } else { + self.iadd_i_j(&pk, sb, sa, self.data.len() - sb); + } + } else { + // it's safe to update from right to left + if let Some(s) = size { + if sa + s > self.data.len() { + return Err(anyhow!( + "end index out of range: sa={}, s={}, data_size={}", + sa, + s, + self.data.len() + )); + } + self.iadd_i_j(&pk, sa, sb, s); + } else { + self.iadd_i_j(&pk, sa, sb, self.data.len() - sa); + } + } + Ok(()) + } + pub fn isub_vec_self( + &mut self, + sa: usize, + sb: usize, + size: Option<usize>, + pk: &PK, + ) -> Result<()> { + if sa == sb { + if let Some(s) = size { + if sa + s > self.data.len() { + return Err(anyhow!( + "end index out of range: sa={}, s={}, data_size={}", + sa, + s, + self.data.len() + )); + } + self.data[sa..sa + s] + .iter_mut() + .for_each(|x| *x = Ciphertext::zero()); + } else { + self.data[sa..].iter_mut().for_each(|x| *x = Ciphertext::zero()); + } + } else if sa < sb { + // it's safe to update from left to right + if let Some(s) = size { + if sb + s > self.data.len() { + return Err(anyhow!( + "end index out of range: sb={}, s={}, data_size={}", + sb, + s, + self.data.len() + )); + } + self.isub_i_j(&pk, sb, sa, s); + } else { + self.isub_i_j(&pk, sb, sa, self.data.len() - sb); + } + } else { + // it's safe to update from right to left + if let Some(s) = size { + if sa + s > self.data.len() { + return Err(anyhow!( + "end index out of range: sa={}, s={}, data_size={}", + sa, + s, + self.data.len() + )); + } + self.isub_i_j(&pk, sa, sb, s); + } else { + self.isub_i_j(&pk, sa, sb, self.data.len() - sa); + } + } + Ok(()) + } + + pub fn iadd_vec( + &mut self, + other: &CiphertextVector, + sa: usize, + sb: usize, + size: Option<usize>, + pk: &PK, + ) -> Result<()> { + match size { + Some(s) => { + let ea = sa + s; + let eb = sb + s; + if ea > self.data.len() { + return Err(anyhow!( + "end index out of range: sa={}, ea={}, data_size={}", + sa, + ea, + self.data.len() + )); + } + if eb > other.data.len() { + return Err(anyhow!( + "end index out of range: sb={}, eb={}, data_size={}", + sb, + eb, + other.data.len() + )); + } + self.data[sa..ea] + .iter_mut() + .zip(other.data[sb..eb].iter()) + .for_each(|(x, y)| { + x.add_assign(y, &pk) + }); + } + None => { + self.data[sa..] + .iter_mut() + .zip(other.data[sb..].iter()) + .for_each(|(x, y)| x.add_assign(y, &pk)); + } + }; + Ok(()) + } + + pub fn isub_vec( + &mut self, + other: &CiphertextVector, + sa: usize, + sb: usize, + size: Option<usize>, + pk: &PK, + ) -> Result<()> { + match size { + Some(s) => { + let ea = sa + s; + let eb = sb + s; + if ea > self.data.len() { + return Err(anyhow!( + "end index out of range: sa={}, ea={}, data_size={}", + sa, + ea, + self.data.len() + )); + } + if eb > other.data.len() { + return Err(anyhow!( + "end index out of range: sb={}, eb={}, data_size={}", + sb, + eb, + other.data.len() + )); + } + self.data[sa..ea] + .iter_mut() + .zip(other.data[sb..eb].iter()) + .for_each(|(x, y)| { + x.sub_assign(y, &pk) + }); + } + None => { + self.data[sa..] + .iter_mut() + .zip(other.data[sb..].iter()) + .for_each(|(x, y)| x.sub_assign(y, &pk)); + } + }; + Ok(()) + } + + pub fn iupdate(&mut self, other: &CiphertextVector, indexes: Vec<Vec<usize>>, stride: usize, pk: &PK) -> Result<()> { + for (i, x) in indexes.iter().enumerate() { + let sb = i * stride; + for pos in x.iter() { + let sa = pos * stride; + for i in 0..stride { + self.data[sa + i].add_assign(&other.data[sb + i], &pk); + } + } + } + Ok(()) + } + pub fn iupdate_with_masks(&mut self, other: &CiphertextVector, indexes: Vec<Vec<usize>>, masks: Vec<bool>, stride: usize, pk: &PK) -> Result<()> { + for (value_pos, x) in masks.iter().enumerate().filter(|(_, &mask)| mask).map(|(i, _)| i).zip(indexes.iter()) { + let sb = value_pos * stride; + for pos in x.iter() { + let sa = pos * stride; + for i in 0..stride { + self.data[sa + i].add_assign(&other.data[sb + i], &pk); + } + } + } + Ok(()) + } + + pub fn iadd(&mut self, pk: &PK, other: &CiphertextVector) { + self.data + .iter_mut() + .zip(other.data.iter()) + .for_each(|(x, y)| x.add_assign(y, &pk)); + } + + pub fn idouble(&mut self, pk: &PK) { + // TODO: fix me, remove clone + self.data + .iter_mut() + .for_each(|x| x.add_assign(&x.clone(), &pk)); + } + + pub fn chunking_cumsum_with_step(&mut self, pk: &PK, chunk_sizes: Vec<usize>, step: usize) { + let mut placeholder = Ciphertext::zero(); + let mut i = 0; + for chunk_size in chunk_sizes { + for j in step..chunk_size { + placeholder = std::mem::replace(&mut self.data[i + j], placeholder); + placeholder.add_assign(&self.data[i + j - step], &pk); + placeholder = std::mem::replace(&mut self.data[i + j], placeholder); + } + i += chunk_size; + } + } + + pub fn intervals_sum_with_step( + &mut self, + pk: &PK, + intervals: Vec<(usize, usize)>, + step: usize, + ) -> CiphertextVector { + let mut data = vec![Ciphertext::zero(); intervals.len() * step]; + for (i, (s, e)) in intervals.iter().enumerate() { + let chunk = &mut data[i * step..(i + 1) * step]; + let sub_vec = &self.data[*s..*e]; + for (val, c) in sub_vec.iter().zip((0..step).cycle()) { + chunk[c].add_assign(val, &pk); + } + } + CiphertextVector { data } + } + + pub fn tolist(&self) -> Vec<CiphertextVector> { + self.data.iter().map(|x| CiphertextVector { data: vec![x.clone()] }).collect() + } + + pub fn add(&self, pk: &PK, other: &CiphertextVector) -> CiphertextVector { + let data = self + .data + .iter() + .zip(other.data.iter()) + .map(|(x, y)| x.add(y, &pk)) + .collect(); + CiphertextVector { data } + } + + pub fn add_scalar(&self, pk: &PK, other: &Ciphertext) -> CiphertextVector { + let data = self.data.iter().map(|x| x.add(&other, &pk)).collect(); + CiphertextVector { data } + } + + pub fn sub(&self, pk: &PK, other: &CiphertextVector) -> CiphertextVector { + let data = self + .data + .iter() + .zip(other.data.iter()) + .map(|(x, y)| x.sub(y, &pk)) + .collect(); + CiphertextVector { data } + } + + pub fn sub_scalar(&self, pk: &PK, other: &Ciphertext) -> CiphertextVector { + let data = self.data.iter().map(|x| x.sub(&other, &pk)).collect(); + CiphertextVector { data } + } + + pub fn rsub(&self, pk: &PK, other: &CiphertextVector) -> CiphertextVector { + let data = self + .data + .iter() + .zip(other.data.iter()) + .map(|(x, y)| y.sub(x, &pk)) + .collect(); + CiphertextVector { data } + } + + pub fn rsub_scalar(&self, pk: &PK, other: &Ciphertext) -> CiphertextVector { + let data = self.data.iter().map(|x| other.sub(x, &pk)).collect(); + CiphertextVector { data } + } + + pub fn mul(&self, pk: &PK, other: &PlaintextVector) -> CiphertextVector { + let data = self + .data + .iter() + .zip(other.data.iter()) + .map(|(x, y)| x.mul(y, &pk)) + .collect(); + CiphertextVector { data } + } + + pub fn mul_scalar(&self, pk: &PK, other: &Plaintext) -> CiphertextVector { + let data = self + .data + .iter() + .map(|x| x.mul(&other, &pk)) + .collect(); + CiphertextVector { data } + } + + pub fn matmul( + &self, + pk: &PK, + other: &PlaintextVector, + lshape: Vec<usize>, + rshape: Vec<usize>, + ) -> CiphertextVector { + let mut data = vec![Ciphertext::zero(); lshape[0] * rshape[1]]; + for i in 0..lshape[0] { + for j in 0..rshape[1] { + for k in 0..lshape[1] { + data[i * rshape[1] + j].add_assign( + &self.data[i * lshape[1] + k].mul(&other.data[k * rshape[1] + j], &pk), + &pk, + ); + } + } + } + CiphertextVector { data } + } + + pub fn rmatmul( + &self, + pk: &PK, + other: &PlaintextVector, + lshape: Vec<usize>, + rshape: Vec<usize>, + ) -> CiphertextVector { + // rshape, lshape -> rshape[0] x lshape[1] + // other, self + // 4 x 2, 2 x 5 + // ik, kj -> ij + let mut data = vec![Ciphertext::zero(); lshape[1] * rshape[0]]; + for i in 0..rshape[0] { + // 4 + for j in 0..lshape[1] { + // 5 + for k in 0..rshape[1] { + // 2 + data[i * lshape[1] + j].add_assign( + &self.data[k * lshape[1] + j].mul(&other.data[i * rshape[1] + k], &pk), + &pk, + ); + } + } + } + CiphertextVector { data } + } +} + +impl PlaintextVector { + pub fn get_stride(&mut self, index: usize, stride: usize) -> PlaintextVector { + let start = index * stride; + let end = start + stride; + let data = self.data[start..end].to_vec(); + PlaintextVector { data } + } + pub fn tolist(&self) -> Vec<Plaintext> { + self.data + .iter() + .map(|x| x.clone()) + .collect() + } +} \ No newline at end of file diff --git a/rust/fate_utils/crates/math/Cargo.toml b/rust/fate_utils/crates/math/Cargo.toml new file mode 100644 index 0000000000..906a58ef07 --- /dev/null +++ b/rust/fate_utils/crates/math/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "math" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +rug = { workspace = true } +rand = { workspace = true } +rand_core = { workspace = true } +serde = { workspace = true } + +[features] +default = ["rug"] +rug = [] diff --git a/rust/tensor/rust_paillier/src/math/mod.rs b/rust/fate_utils/crates/math/src/lib.rs similarity index 89% rename from rust/tensor/rust_paillier/src/math/mod.rs rename to rust/fate_utils/crates/math/src/lib.rs index 1341fb6ba6..0fc81e62a1 100644 --- a/rust/tensor/rust_paillier/src/math/mod.rs +++ b/rust/fate_utils/crates/math/src/lib.rs @@ -10,4 +10,4 @@ mod rug; pub use self::rug::BInt; #[cfg(feature = "rug")] -pub(crate) use self::rug::ONE; +pub use self::rug::ONE; diff --git a/rust/tensor/rust_paillier/src/math/rug/mod.rs b/rust/fate_utils/crates/math/src/rug/mod.rs similarity index 81% rename from rust/tensor/rust_paillier/src/math/rug/mod.rs rename to rust/fate_utils/crates/math/src/rug/mod.rs index 2db58dedc7..c0251dac81 100644 --- a/rust/tensor/rust_paillier/src/math/rug/mod.rs +++ b/rust/fate_utils/crates/math/src/rug/mod.rs @@ -1,14 +1,16 @@ mod ops; mod random; -mod serde; +// mod serde; use core::cmp::{PartialEq, PartialOrd}; use rug::Integer; use rug::{self, ops::Pow}; +use serde::{Serialize, Deserialize}; + /// newtype of rug::Integer -#[derive(PartialEq, Eq, Hash, PartialOrd, Ord, Clone, Debug)] +#[derive(Default, PartialEq, Eq, Hash, PartialOrd, Ord, Clone, Debug, Serialize, Deserialize)] pub struct BInt(pub Integer); -pub(crate) const ONE: u8 = 1u8; +pub const ONE: u8 = 1u8; impl BInt { pub fn from_str_radix(src: &str, radix: i32) -> BInt { @@ -17,8 +19,8 @@ impl BInt { pub fn significant_bits(&self) -> u32 { self.0.significant_bits() } - pub fn pow_mod(self, exp: &BInt, modulo: &BInt) -> BInt { - BInt(self.0.pow_mod(&exp.0, &modulo.0).unwrap()) + pub fn pow_mod_mut(&mut self, exp: &BInt, modulo: &BInt) { + self.0.pow_mod_mut(&exp.0, &modulo.0).unwrap(); } pub fn pow_mod_ref(&self, exp: &BInt, modulo: &BInt) -> BInt { BInt(Integer::from( diff --git a/rust/tensor/rust_paillier/src/math/rug/ops.rs b/rust/fate_utils/crates/math/src/rug/ops.rs similarity index 100% rename from rust/tensor/rust_paillier/src/math/rug/ops.rs rename to rust/fate_utils/crates/math/src/rug/ops.rs diff --git a/rust/tensor/rust_paillier/src/math/rug/random.rs b/rust/fate_utils/crates/math/src/rug/random.rs similarity index 96% rename from rust/tensor/rust_paillier/src/math/rug/random.rs rename to rust/fate_utils/crates/math/src/rug/random.rs index 1d78d00c4a..7b5d18817e 100644 --- a/rust/tensor/rust_paillier/src/math/rug/random.rs +++ b/rust/fate_utils/crates/math/src/rug/random.rs @@ -1,6 +1,7 @@ use super::BInt; use rand::rngs::StdRng; -use rand_core::{RngCore, SeedableRng}; +use rand::RngCore; +use rand::SeedableRng; use rug::rand::{RandGen, RandState}; use rug::Integer; pub(crate) struct StdRngGen(StdRng); diff --git a/rust/tensor/rust_paillier/src/math/rug/serde.rs b/rust/fate_utils/crates/math/src/rug/serde.rs similarity index 100% rename from rust/tensor/rust_paillier/src/math/rug/serde.rs rename to rust/fate_utils/crates/math/src/rug/serde.rs diff --git a/rust/fate_utils/crates/ou/Cargo.toml b/rust/fate_utils/crates/ou/Cargo.toml new file mode 100644 index 0000000000..121b7ab678 --- /dev/null +++ b/rust/fate_utils/crates/ou/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "ou" +version = "0.1.0" +edition = "2021" + +[dependencies] +math = { path = "../math" } +serde = { workspace = true } + +[dev-dependencies] +criterion = { workspace = true } +iai = { workspace = true} + +[[bench]] +name = "ou_bench" +harness = false \ No newline at end of file diff --git a/rust/fate_utils/crates/ou/benches/ou_bench.rs b/rust/fate_utils/crates/ou/benches/ou_bench.rs new file mode 100644 index 0000000000..fcccd733e4 --- /dev/null +++ b/rust/fate_utils/crates/ou/benches/ou_bench.rs @@ -0,0 +1,36 @@ +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use math::BInt; +use std::time::Duration; + +fn paillier_benchmark(c: &mut Criterion) { + let (sk, pk) = ou::keygen(1024); + let plaintext = ou::PT(BInt::from_str_radix("1234567890987654321", 10)); + let ciphertext = pk.encrypt(&plaintext, true); + let mut group = c.benchmark_group("paillier"); + + group.bench_function("keygen-1024", |b| { + b.iter(|| ou::keygen(black_box(1024))) + }); + group.bench_function("keygen-2048", |b| { + b.iter(||ou::keygen(black_box(1024))) + }); + group.bench_function("encrypt", |b| { + b.iter(|| black_box(&pk).encrypt(black_box(&plaintext), true)) + }); + group.bench_function("decrypt", |b| { + b.iter(|| black_box(&sk).decrypt(black_box(&ciphertext))) + }); + group.bench_function("add ciphertext", |b| { + b.iter(|| black_box(&ciphertext).add_ct(black_box(&ciphertext), black_box(&pk))) + }); + group.bench_function("mul plaintext", |b| { + b.iter(|| black_box(&ciphertext).mul_pt(black_box(&plaintext), black_box(&pk))) + }); +} + +criterion_group! { + name = benches; + config = Criterion::default().measurement_time(Duration::from_secs(10)); + targets = paillier_benchmark +} +criterion_main!(benches); diff --git a/rust/fate_utils/crates/ou/src/lib.rs b/rust/fate_utils/crates/ou/src/lib.rs new file mode 100644 index 0000000000..4437910fbf --- /dev/null +++ b/rust/fate_utils/crates/ou/src/lib.rs @@ -0,0 +1,169 @@ +use math::{BInt, ONE}; +use serde::{Deserialize, Serialize}; +use std::fmt::{Display, Formatter}; +use std::ops::AddAssign; + +#[derive(Clone, Deserialize, Serialize, Debug, PartialEq)] +pub struct CT(pub BInt); //ciphertext + +impl Display for CT { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "CT") + } +} + +impl Default for CT { + fn default() -> Self { + todo!() + } +} + +impl<'b> AddAssign<&'b CT> for CT { + fn add_assign(&mut self, _rhs: &'b CT) { + todo!() + } +} + +impl CT { + pub fn zero() -> CT { + CT(BInt::from(ONE)) + } + pub fn add_ct(&self, ct: &CT, pk: &PK) -> CT { + CT(&self.0 * &ct.0 % &pk.n) + } + pub fn i_double(&mut self, pk: &PK) { + self.0.pow_mod_mut(&BInt::from(2), &pk.n); + } + pub fn mul_pt(&self, b: &PT, pk: &PK) -> CT { + CT(self.0.pow_mod_ref(&b.0, &pk.n)) + } +} + +#[derive(Default, Clone, Deserialize, Serialize, Debug, PartialEq)] +pub struct PT(pub BInt); // plaintest + +#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct PK { + pub n: BInt, + pub g: BInt, + pub h: BInt, +} + +#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct SK { + pub p: BInt, + pub q: BInt, + pub g: BInt, + // pub n: BInt, + // // n = p * q + p_minus_one: BInt, + // q_minus_one: BInt, + // ps: BInt, + // qs: BInt, + // p_invert: BInt, + // // p^{-1} mod q + // hp: BInt, + // hq: BInt, +} + +/// generate Okamoto–Uchiyama cryptosystem with providing bit length +pub fn keygen(bit_length: u32) -> (SK, PK) { + let prime_bit_size = bit_length / 3; + let (mut p, mut q, mut n, mut g): (BInt, BInt, BInt, BInt); + loop { + p = BInt::gen_prime(prime_bit_size); + q = BInt::gen_prime(bit_length - 2 * prime_bit_size); + n = &p * &p * &q; + if p != q && n.significant_bits() == bit_length { + break; + } + } + let p2 = &p * &p; + let p_minus_one = &p - 1; + let n_minus_one = &n - 1; + loop { + g = BInt::gen_positive_integer(&n_minus_one) + 1; + if g.pow_mod_ref(&p_minus_one, &p2).ne(&BInt::from(1u8)) { + break; + } + } + let h = g.pow_mod_ref(&n, &n); + (SK::new(p, p_minus_one, q, g.clone()), PK::new(n, g, h)) +} + +impl PK { + fn new(n: BInt, g: BInt, h: BInt) -> PK { + PK { n, g, h } + } + /// encrypt plaintext + /// + /// ```math + /// g^plaintext \cdot h^r \pmod{n} + /// ``` + pub fn encrypt(&self, plaintext: &PT, _obfuscate: bool) -> CT { + let r = BInt::gen_positive_integer(&self.n); + let c = self.g.pow_mod_ref(&plaintext.0, &self.n) * self.h.pow_mod_ref(&r, &self.n); + CT(c) + } +} + +impl SK { + fn new(p: BInt, p_minus_one: BInt, q: BInt, g: BInt) -> SK { + assert!(p != q, "p == q"); + SK { + p, + q, + g, + p_minus_one, + } + } + /// decrypt ciphertext + /// + pub fn decrypt(&self, c: &CT) -> PT { + let ps = &self.p * &self.p; + let dp = SK::h_function(&c.0, &self.p, &self.p_minus_one, &ps); + let dq = SK::h_function(&self.g, &self.p, &self.p_minus_one, &ps); + let mut m = (dp * dq.invert(&self.p)) % &self.p; + // TODO: any better way to do this? + if m < BInt::from(0) { + m.0.add_assign(&self.p.0) + } + PT(m) + } + #[inline] + fn h_function(c: &BInt, p: &BInt, p_1: &BInt, ps: &BInt) -> BInt { + let x = c.pow_mod_ref(p_1, ps) - ONE; + (x / p) % p + } +} + +#[test] +fn keygen_even_size() { + keygen(1024); +} + +#[test] +#[should_panic] +fn keygen_odd_size() { + keygen(1023); +} + +#[test] +fn test_decrypt() { + let (private, public) = keygen(1024); + let plaintext = PT(BInt::from(25519u32)); + let ciphertext = public.encrypt(&plaintext, true); + let decrypted = private.decrypt(&ciphertext); + assert_eq!(plaintext, decrypted) +} +#[test] +fn test_add() { + let (private, public) = keygen(1024); + let plaintext1 = PT(BInt::from(25519u32)); + let plaintext2 = PT(BInt::from(12345u32)); + let ciphertext1 = public.encrypt(&plaintext1, true); + let ciphertext2 = public.encrypt(&plaintext2, true); + let ciphertext3 = ciphertext1.add_ct(&ciphertext2, &public); + let decrypted = private.decrypt(&ciphertext3); + assert_eq!(PT(BInt::from(25519u32 + 12345u32)), decrypted) +} diff --git a/rust/fate_utils/crates/paillier/Cargo.toml b/rust/fate_utils/crates/paillier/Cargo.toml new file mode 100644 index 0000000000..1183c9fea3 --- /dev/null +++ b/rust/fate_utils/crates/paillier/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "paillier" +version = "0.1.0" +edition = "2021" + +[dependencies] +math = { path = "../math" } +serde = { workspace = true } + +[dev-dependencies] +criterion = { workspace = true } +iai = { workspace = true} + +[[bench]] +name = "iai_bench" +harness = false + +[[bench]] +name = "paillier_bench" +harness = false \ No newline at end of file diff --git a/rust/tensor/rust_paillier/benches/iai_bench.rs b/rust/fate_utils/crates/paillier/benches/iai_bench.rs similarity index 78% rename from rust/tensor/rust_paillier/benches/iai_bench.rs rename to rust/fate_utils/crates/paillier/benches/iai_bench.rs index f73c4168ba..b9aaaae47a 100644 --- a/rust/tensor/rust_paillier/benches/iai_bench.rs +++ b/rust/fate_utils/crates/paillier/benches/iai_bench.rs @@ -1,5 +1,4 @@ -use fate_tensor::math::BInt; -use fate_tensor::paillier; +use math::BInt; fn encrypt() { let (_sk, pk) = paillier::keygen(1024); diff --git a/rust/tensor/rust_paillier/benches/paillier_bench.rs b/rust/fate_utils/crates/paillier/benches/paillier_bench.rs similarity index 96% rename from rust/tensor/rust_paillier/benches/paillier_bench.rs rename to rust/fate_utils/crates/paillier/benches/paillier_bench.rs index dd0926a2a7..a7c862ed76 100644 --- a/rust/tensor/rust_paillier/benches/paillier_bench.rs +++ b/rust/fate_utils/crates/paillier/benches/paillier_bench.rs @@ -1,6 +1,5 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use fate_tensor::math::BInt; -use fate_tensor::paillier; +use math::BInt; use std::time::Duration; fn paillier_benchmark(c: &mut Criterion) { diff --git a/rust/tensor/rust_paillier/src/paillier/mod.rs b/rust/fate_utils/crates/paillier/src/lib.rs similarity index 78% rename from rust/tensor/rust_paillier/src/paillier/mod.rs rename to rust/fate_utils/crates/paillier/src/lib.rs index 2d448add88..98ba5b40fc 100644 --- a/rust/tensor/rust_paillier/src/paillier/mod.rs +++ b/rust/fate_utils/crates/paillier/src/lib.rs @@ -1,9 +1,29 @@ -use crate::math::{BInt, ONE}; +use math::{BInt, ONE}; use serde::{Deserialize, Serialize}; +use std::fmt::{Display, Formatter}; +use std::ops::AddAssign; #[derive(Clone, Deserialize, Serialize, Debug, PartialEq)] pub struct CT(pub BInt); //ciphertext +impl Display for CT { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "CT") + } +} + +impl Default for CT { + fn default() -> Self { + todo!() + } +} + +impl<'b> AddAssign<&'b CT> for CT { + fn add_assign(&mut self, _rhs: &'b CT) { + todo!() + } +} + impl CT { pub fn zero() -> CT { CT(BInt::from(ONE)) @@ -15,34 +35,42 @@ impl CT { pub fn add_ct(&self, ct: &CT, pk: &PK) -> CT { CT(&self.0 * &ct.0 % &pk.ns) } + pub fn i_double(&mut self, pk: &PK) { + self.0.pow_mod_mut(&BInt::from(2), &pk.ns); + } pub fn mul_pt(&self, b: &PT, pk: &PK) -> CT { CT(self.0.pow_mod_ref(&b.0, &pk.ns)) } } -#[derive(Clone, Deserialize, Serialize, Debug, PartialEq)] + +#[derive(Default, Clone, Deserialize, Serialize, Debug, PartialEq)] pub struct PT(pub BInt); // plaintest -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)] pub struct PK { pub n: BInt, pub ns: BInt, // n * n } -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] + +#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)] pub struct SK { p: BInt, q: BInt, - pub n: BInt, // n = p * q + pub n: BInt, + // n = p * q p_minus_one: BInt, q_minus_one: BInt, ps: BInt, qs: BInt, - p_invert: BInt, // p^{-1} mod q + p_invert: BInt, + // p^{-1} mod q hp: BInt, hq: BInt, } + /// generate paillier keypairs with providing bit lenght pub fn keygen(bit_lenght: u32) -> (SK, PK) { - assert!(bit_lenght % 2 == 0); + assert_eq!(bit_lenght % 2, 0); let prime_bit_size = bit_lenght / 2; // generate prime p and q such that num_bit(p) * num_bit(q) == bit_length // and p != q @@ -57,13 +85,16 @@ pub fn keygen(bit_lenght: u32) -> (SK, PK) { } (SK::new(p, q, n.clone()), PK::new(n)) } + impl PK { fn new(n: BInt) -> PK { let ns = &n * &n; PK { n, ns } } fn random_rn(&self) -> BInt { - BInt::gen_positive_integer(&self.n).pow_mod(&self.n, &self.ns) + let mut r = BInt::gen_positive_integer(&self.n); + r.pow_mod_mut(&self.n, &self.ns); + r } /// encrypt plaintext /// @@ -132,7 +163,12 @@ impl SK { pub fn decrypt(&self, c: &CT) -> PT { let dp = SK::h_function(&c.0, &self.p, &self.p_minus_one, &self.ps, &self.hp); let dq = SK::h_function(&c.0, &self.q, &self.q_minus_one, &self.qs, &self.hq); - PT((((dq - &dp) * &self.p_invert) % &self.q) * &self.p + &dp) + let mut o = (((dq - &dp) * &self.p_invert) % &self.q) * &self.p + &dp; + // TODO: any better way to do this? + if o < BInt::from(0) { + o.0.add_assign(&self.n.0) + } + PT(o) } #[inline] fn h_function(c: &BInt, p: &BInt, p_1: &BInt, ps: &BInt, hp: &BInt) -> BInt { diff --git a/rust/fate_utils/crates/quantile/Cargo.toml b/rust/fate_utils/crates/quantile/Cargo.toml new file mode 100644 index 0000000000..d9f714fea7 --- /dev/null +++ b/rust/fate_utils/crates/quantile/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "quantile" +version = "0.1.0" +edition = "2021" + +[dependencies] +serde = { workspace = true } + +[dev-dependencies.quickcheck] +version = "0.5" \ No newline at end of file diff --git a/rust/fate_utils/crates/quantile/src/greenwald_khanna.rs b/rust/fate_utils/crates/quantile/src/greenwald_khanna.rs new file mode 100644 index 0000000000..4d19977556 --- /dev/null +++ b/rust/fate_utils/crates/quantile/src/greenwald_khanna.rs @@ -0,0 +1,455 @@ +//! Greenwald Khanna calculates epsilon-approximate quantiles. +//! If the desired quantile is phi, the epsilon-approximate +//! quantile is any element in the range of elements that rank +//! between `lbound((phi-epsilon) x N)` and `lbound((phi+epsilon) x N)` +//! +//! terminology from the paper: +//! +//! * S: set of observations +//! * n: number of observations in S +//! * v[i]: observation i in S +//! * r: rank of observation in S from 1 to n. +//! * `r_min(v[i])`: lower bound on rank r of v[i] +//! * `r_max(v[i])`: upper bound on rank r of v[i] +//! * `g[i] = r_min(v[i]) - r_min(v[i - 1])` +//! * `delta[i] = r_max(v[i]) - r_min(v[i])` +//! * `t[i] = tuple(v[i], g[i], delta[i])` +//! * phi: quantile as a real number in the range [0,1] +//! * r: ubound(phi * n) +//! +//! identities: +//! +//! * `r_min(v[i]) = forall j<=i sum of g[j]` +//! * `r_max(v[i]) = ( forall j<=i sum of g[j] ) + delta[i]` +//! * g[i] + delta[i] - 1 is an upper bound on the total number of observations +//! * between v[i] and v[i-1] +//! * sum of g[i] = n +//! +//! results: +//! +//! * `max_i(g[i] + delta[i]) <= 2 * epsilon * n` +//! * a tuple is full if g[i] + delta[i] = floor(2 * epsilon * n) +//! +//! `@inproceedings{Greenwald:2001:SOC:375663.375670, +//! author = {Greenwald, Michael and Khanna, Sanjeev}, +//! title = {Space-efficient Online Computation of Quantile Summaries}, +//! booktitle = {Proceedings of the 2001 ACM SIGMOD International +//! Conference +//! on Management of Data}, +//! series = {SIGMOD '01}, +//! year = {2001}, +//! isbn = {1-58113-332-4}, +//! location = {Santa Barbara, California, USA}, +//! pages = {58--66}, +//! numpages = {9}, +//! url = {http://doi.acm.org/10.1145/375663.375670}, +//! doi = {10.1145/375663.375670}, +//! acmid = {375670}, +//! publisher = {ACM}, +//! address = {New York, NY, USA}, +//! }` +//! +//! # Examples +//! +//! ``` +//! use quantile::greenwald_khanna::*; +//! +//! let epsilon = 0.01; +//! +//! let mut stream = Stream::new(epsilon); +//! +//! let n = 1001; +//! for i in 1..n { +//! stream.insert(i); +//! } +//! let in_range = |phi: f64, value: u32| { +//! let lower = ((phi - epsilon) * (n as f64)) as u32; +//! let upper = ((phi + epsilon) * (n as f64)) as u32; +//! (epsilon > phi || lower <= value) && value <= upper +//! }; +//! assert!(in_range(0f64, *stream.quantile(0f64))); +//! assert!(in_range(0.1f64, *stream.quantile(0.1f64))); +//! assert!(in_range(0.2f64, *stream.quantile(0.2f64))); +//! assert!(in_range(0.3f64, *stream.quantile(0.3f64))); +//! assert!(in_range(0.4f64, *stream.quantile(0.4f64))); +//! assert!(in_range(1f64, *stream.quantile(1f64))); +//! ``` + +use serde::{Deserialize, Serialize}; +use std::cmp; +/// Locates the proper position of v in a vector vs +/// such that when v is inserted at position i, +/// it is less then the element at i+1 if any, +/// and greater than or equal to the element at i-1 if any. +pub fn find_insert_pos<T>(vs: &[T], v: &T) -> usize +where + T: Ord, +{ + if vs.len() <= 10 { + return find_insert_pos_linear(vs, v); + } + + let middle = vs.len() / 2; + let pivot = &vs[middle]; + + if v < pivot { + find_insert_pos(&vs[0..middle], v) + } else { + middle + find_insert_pos(&vs[middle..], v) + } +} + +/// Locates the proper position of v in a vector vs +/// such that when v is inserted at position i, +/// it is less then the element at i+1 if any, +/// and greater than or equal to the element at i-1 if any. +/// Works by scanning the slice from start to end. +pub fn find_insert_pos_linear<T>(vs: &[T], v: &T) -> usize +where + T: Ord, +{ + for (i, vi) in vs.iter().enumerate() { + if v < vi { + return i; + } + } + + vs.len() +} + +/// 3-tuple of a value v[i], g[i] and delta[i]. +#[derive(Eq, Ord, Debug, Clone, Copy, Serialize, Deserialize)] +pub struct Tuple<T> +where + T: Ord, +{ + /// v[i], an observation in the set of observations + pub v: T, + + /// the difference between the rank lowerbounds of t[i] and t[i-1] + /// g = r_min(v[i]) - r_min(v[i - 1]) + pub g: usize, + + /// the difference betweeh the rank upper and lower bounds for this tuple + pub delta: usize, +} + +impl<T> Tuple<T> +where + T: Ord, +{ + /// Creates a new instance of a Tuple + pub fn new(v: T, g: usize, delta: usize) -> Tuple<T> { + Tuple { + v: v, + g: g, + delta: delta, + } + } +} + +impl<T> PartialEq for Tuple<T> +where + T: Ord, +{ + fn eq(&self, other: &Self) -> bool { + self.v == other.v + } +} + +impl<T> PartialOrd for Tuple<T> +where + T: Ord, +{ + fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> { + self.v.partial_cmp(&other.v) + } +} + +/// The summary S of the observations seen so far. +#[derive(Debug, Serialize, Deserialize)] +pub struct Stream<T> +where + T: Ord, +{ + /// An ordered sequence of the selected observations + summary: Vec<Tuple<T>>, + + /// The error factor + epsilon: f64, + + /// The number of observations + n: usize, +} + +impl<T: Ord + Copy> Stream<T> { + pub fn merge(&self, other: &Stream<T>) -> Stream<T> { + assert!(self.epsilon == other.epsilon); + let mut summary: Vec<Tuple<T>> = vec![]; + let epsilon = self.epsilon; + let n = self.n + other.n; + let additional_self_delta = (2f64 * self.epsilon * self.n as f64).floor() as usize; + let additional_other_delta = (2f64 * other.epsilon * other.n as f64).floor() as usize; + let mut self_idx = 0; + let mut other_idx = 0; + while self_idx < self.summary.len() && other_idx < other.summary.len() { + let self_summary = self.summary[self_idx]; + let other_summary = other.summary[other_idx]; + let (next_summary, additional_delta) = if self_summary.v < other_summary.v { + self_idx += 1; + ( + self_summary, + if other_idx > 0 { + additional_self_delta + } else { + 0 + }, + ) + } else { + other_idx += 1; + ( + other_summary, + if self_idx > 0 { + additional_other_delta + } else { + 0 + }, + ) + }; + summary.push(Tuple { + delta: next_summary.delta + additional_delta, + ..next_summary + }); + } + while self_idx < self.summary.len() { + summary.push(self.summary[self_idx]); + self_idx += 1; + } + while other_idx < other.summary.len() { + summary.push(other.summary[other_idx]); + other_idx += 1; + } + Stream { + epsilon, + n, + summary, + } + } +} + +impl<T> Stream<T> +where + T: Ord + Copy, +{ + /// Creates a new instance of a Stream + pub fn new(epsilon: f64) -> Stream<T> { + Stream { + summary: vec![], + epsilon: epsilon, + n: 0, + } + } + + /// Locates the correct position in the summary data set + /// for the observation v, and inserts a new tuple (v,1,floor(2en)) + /// If v is the new minimum or maximum, then instead insert + /// tuple (v,1,0). + pub fn insert(&mut self, v: T) { + let mut t = Tuple::new(v, 1, 0); + + let pos = find_insert_pos(&self.summary, &t); + + if pos != 0 && pos != self.summary.len() { + t.delta = (2f64 * self.epsilon * (self.n as f64)).floor() as usize; + } + + self.summary.insert(pos, t); + + self.n += 1; + + if self.should_compress() { + self.compress(); + } + } + + /// Compute the epsilon-approximate phi-quantile + /// from the summary data structure. + pub fn quantile(&self, phi: f64) -> &T { + assert!(self.summary.len() >= 1); + assert!(phi >= 0f64 && phi <= 1f64); + + let r = (phi * self.n as f64).floor() as usize; + let en = (self.epsilon * self.n as f64) as usize; + + let first = &self.summary[0]; + + let mut prev = &first.v; + let mut prev_rmin = first.g; + + for t in self.summary.iter().skip(1) { + let rmax = prev_rmin + t.g + t.delta; + + if rmax > r + en { + return prev; + } + + prev_rmin += t.g; + prev = &t.v; + } + + prev + } + + fn should_compress(&self) -> bool { + let period = (1f64 / (2f64 * self.epsilon)).floor() as usize; + + self.n % period == 0 + } + + fn compress(&mut self) { + let s = self.s(); + for i in (1..(s - 1)).rev() { + if self.can_delete(i) { + self.delete(i); + } + } + } + + fn can_delete(&self, i: usize) -> bool { + assert!(self.summary.len() >= 2); + assert!(i < self.summary.len() - 1); + + let t = &self.summary[i]; + let tnext = &self.summary[i + 1]; + let p = self.p(); + + let safety_property = t.g + tnext.g + tnext.delta < p; + + let optimal = Self::band(t.delta, p) <= Self::band(tnext.delta, p); + + safety_property && optimal + } + + /// Remove the ith tuple from the summary. + /// Panics if i is not in the range [0,summary.len() - 1) + /// Only permitted if g[i] + g[i+1] + delta[i+1] < 2 * epsilon * n + fn delete(&mut self, i: usize) { + assert!(self.summary.len() >= 2); + assert!(i < self.summary.len() - 1); + + let t = self.summary.remove(i); + let tnext = &mut self.summary[i]; + + tnext.g += t.g; + } + + /// Compute which band a delta lies in. + fn band(delta: usize, p: usize) -> usize { + assert!(p >= delta); + + let diff = p - delta + 1; + + (diff as f64).log(2f64).floor() as usize + } + + /// Calculate p = 2epsilon * n + pub fn p(&self) -> usize { + (2f64 * self.epsilon * (self.n as f64)).floor() as usize + } + + /// The number of observations inserted into the stream. + pub fn n(&self) -> usize { + self.n + } + + /// Indication of the space usage of the summary data structure + /// Returns the number of tuples in the summary + /// data structure. + pub fn s(&self) -> usize { + self.summary.len() + } +} + +#[cfg(test)] +mod test { + use super::*; + use std::ops::Range; + + #[test] + fn test_find_insert_pos() { + let mut vs = vec![]; + for v in 0..10 { + vs.push(v); + } + + for v in 0..10 { + assert_eq!(find_insert_pos_linear(&vs, &v), v + 1); + } + } + + fn get_quantile_for_range(r: &Range<u32>, phi: f64) -> u32 { + (phi * ((r.end - 1) - r.start) as f64).floor() as u32 + r.start + } + + fn get_quantile_bounds_for_range(r: Range<u32>, phi: f64, epsilon: f64) -> (u32, u32) { + let lower = get_quantile_for_range(&r, (phi - epsilon).max(0f64)); + let upper = get_quantile_for_range(&r, phi + epsilon); + + (lower, upper) + } + + fn quantile_in_bounds(r: Range<u32>, s: &Stream<u32>, phi: f64, epsilon: f64) -> bool { + let approx_quantile = *s.quantile(phi); + let (lower, upper) = get_quantile_bounds_for_range(r, phi, epsilon); + + // println!("approx_quantile={} lower={} upper={} phi={} epsilon={}", + // approx_quantile, lower, upper, phi, epsilon); + + approx_quantile >= lower && approx_quantile <= upper + } + + #[test] + fn test_basics() { + let epsilon = 0.01; + + let mut stream = Stream::new(epsilon); + + for i in 1..1001 { + stream.insert(i); + } + + for phi in 0..100 { + assert!(quantile_in_bounds( + 1..1001, + &stream, + (phi as f64) / 100f64, + epsilon + )); + } + } + + quickcheck! { + fn find_insert_pos_log_equals_find_insert_pos_linear(vs: Vec<i32>) -> bool { + let mut vs = vs; + vs.sort(); + + for v in -100..100 { + if find_insert_pos(&vs, &v) != find_insert_pos_linear(&vs, &v) { + return false; + } + } + + true + } + + fn test_gk(vs: Vec<u32>) -> bool { + let mut s = Stream::new(0.25); + + for v in vs { + s.insert(v); + } + + true + } + } +} diff --git a/rust/fate_utils/crates/quantile/src/lib.rs b/rust/fate_utils/crates/quantile/src/lib.rs new file mode 100644 index 0000000000..9179c75da5 --- /dev/null +++ b/rust/fate_utils/crates/quantile/src/lib.rs @@ -0,0 +1,5 @@ +pub mod greenwald_khanna; + +#[cfg(test)] +#[macro_use] +extern crate quickcheck; diff --git a/rust/tensor/rust_paillier/pyproject.toml b/rust/fate_utils/pyproject.toml similarity index 65% rename from rust/tensor/rust_paillier/pyproject.toml rename to rust/fate_utils/pyproject.toml index fa7c44f1ae..99505a865c 100644 --- a/rust/tensor/rust_paillier/pyproject.toml +++ b/rust/fate_utils/pyproject.toml @@ -3,12 +3,17 @@ requires = ["maturin>=0.12,<0.13"] build-backend = "maturin" [project] -name = "rust_paillier" +name = "fate_utils" requires-python = ">=3.6" -description = "paillier tensor implemented using rust" -long_description_content_type = "text/markdown" +readme = "README.md" classifiers = [ "Programming Language :: Rust", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] + +[tool.maturin] +bindings = "pyo3" +manifest-path = "crates/fate_utils/Cargo.toml" +python-source = "python" +strip = true diff --git a/rust/fate_utils/python/fate_utils/__init__.py b/rust/fate_utils/python/fate_utils/__init__.py new file mode 100644 index 0000000000..f3f446d79a --- /dev/null +++ b/rust/fate_utils/python/fate_utils/__init__.py @@ -0,0 +1 @@ +from .fate_utils import * diff --git a/rust/tensor/rust_paillier/rust_paillier/__init__.pyi b/rust/fate_utils/python/fate_utils/__init__.pyi similarity index 100% rename from rust/tensor/rust_paillier/rust_paillier/__init__.pyi rename to rust/fate_utils/python/fate_utils/__init__.pyi diff --git a/rust/fate_utils/python/fate_utils/histogram/__init__.py b/rust/fate_utils/python/fate_utils/histogram/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/rust/fate_utils/python/fate_utils/histogram/__init__.pyi b/rust/fate_utils/python/fate_utils/histogram/__init__.pyi new file mode 100644 index 0000000000..de373962da --- /dev/null +++ b/rust/fate_utils/python/fate_utils/histogram/__init__.pyi @@ -0,0 +1,33 @@ +from typing import Tuple, List, Dict, Union + +class HistogramIndexer: + def __new__(node_size: int, feature_bin_sizes: List[int]) -> "HistogramIndexer": ... + def get_position(self, nid: int, fid: int, bid: int) -> int: ... + def get_positions(self, nids: List[int], bid_vec: List[List[int]]) -> List[List[int]]: ... + def get_reverse_position(self, position: int) -> Tuple[int, int, int]: ... + def get_bin_num(self, fid: int) -> int: ... + def get_bin_interval(self, nid: int, fid: int) -> Tuple[int, int]: ... + def get_node_intervals(self) -> List[Tuple[int, int]]: ... + def get_feature_position_ranges(self) -> List[Tuple[int, int]]: ... + def splits_into_k(self, k: int) -> List[Tuple[int, Tuple[int, int], List[Tuple[int, int]]]]: ... + def total_data_size(self) -> int: ... + def one_node_data_size(self) -> int: ... + def global_flatten_bin_sizes(self) -> List[int]: ... + def flatten_in_node(self) -> "HistogramIndexer": ... + def squeeze_bins(self) -> "HistogramIndexer": ... + def reshape(self, feature_bin_sizes: List[int]) -> "HistogramIndexer": ... + def get_shuffler(self, seed: int) -> "Shuffler": ... + def get_node_size(self) -> int: ... + def get_node_axis_stride(self) -> int: ... + def get_feature_size(self) -> int: ... + def get_feature_axis_stride(self) -> List[int]: ... + def get_feature_bin_sizes(self) -> List[int]: ... + def get_num_nodes(self) -> int: ... + def unflatten_indexes(self) -> Dict[int, Dict[int, List[int]]]: ... + +class Shuffler: + def __new__(num_node: int, node_size: int, seed: int) -> "Shuffler": ... + def get_global_perm_index(self) -> List[int]: ... + def get_reverse_indexes(self, step: int, indexes: List[int]) -> List[int]: ... + def get_shuffle_index(self, step: int, reverse: bool) -> List[int]: ... + def reverse_index(self, index: int) -> Tuple[int, int]: ... diff --git a/rust/fate_utils/python/fate_utils/paillier/__init__.py b/rust/fate_utils/python/fate_utils/paillier/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/rust/fate_utils/python/fate_utils/paillier/__init__.pyi b/rust/fate_utils/python/fate_utils/paillier/__init__.pyi new file mode 100644 index 0000000000..08ad659fd1 --- /dev/null +++ b/rust/fate_utils/python/fate_utils/paillier/__init__.pyi @@ -0,0 +1,57 @@ +from typing import Any, List, Tuple, Union, Optional +import numpy as np +from numpy import ndarray as Array + +class PK: + def __init__(self) -> None: ... + def encrypt_encoded(self, fixedpoint: FixedpointVector, obfuscate: bool) -> FixedpointPaillierVector: ... + def encrypt_encoded_scalar(self, fixedpoint: FixedpointEncoded, obfuscate: bool) -> PyCT: ... + def __new__(self) -> "PK": ... + def __getstate__(self) -> List[bytes]: ... + def __setstate__(self, state: List[bytes]) -> None: ... + +class SK: + def __init__(self) -> None: ... + def decrypt_to_encoded(self, data: FixedpointPaillierVector) -> FixedpointVector: ... + def decrypt_to_encoded_scalar(self, data: PyCT) -> FixedpointEncoded: ... + def __new__(self) -> "SK": ... + def __getstate__(self) -> List[bytes]: ... + def __setstate__(self, state: List[bytes]) -> None: ... + +class Coders: + def __init__(self) -> None: ... + def encode_f64(self, data: float) -> FixedpointEncoded: ... + def encode_f64_vec(self, data: Array[float]) -> FixedpointVector: ... + def decode_f64(self, data: FixedpointEncoded) -> float: ... + def decode_f64_vec(self, data: FixedpointVector) -> Array[float]: ... + def encode_f32(self, data: float) -> FixedpointEncoded: ... + def encode_f32_vec(self, data: Array[float]) -> FixedpointVector: ... + def decode_f32(self, data: FixedpointEncoded) -> float: ... + def decode_f32_vec(self, data: FixedpointVector) -> Array[float]: ... + def encode_i64(self, data: int) -> FixedpointEncoded: ... + def encode_i64_vec(self, data: Array[int]) -> FixedpointVector: ... + def decode_i64(self, data: FixedpointEncoded) -> int: ... + def decode_i64_vec(self, data: FixedpointVector) -> List[int]: ... + def encode_i32(self, data: int) -> FixedpointEncoded: ... + def encode_i32_vec(self, data: Array[int]) -> FixedpointVector: ... + def decode_i32(self, data: FixedpointEncoded) -> int: ... + def decode_i32_vec(self, data: FixedpointVector) -> List[int]: ... + def __getstate__(self) -> List[bytes]: ... + def __setstate__(self, state: List[bytes]) -> None: ... + +class FixedpointPaillierVector: + def __init__(self) -> None: ... + def zeros(self, size: int) -> "FixedpointPaillierVector": ... + # Other methods... + +class FixedpointVector: + def __init__(self) -> None: ... + # Other methods... + +class PyCT: + ct: Any + +class FixedpointEncoded: + data: Any + +def keygen(bit_length: int) -> Tuple[SK, PK, Coders]: ... diff --git a/rust/fate_utils/python/fate_utils/par/__init__.py b/rust/fate_utils/python/fate_utils/par/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/rust/tensor/rust_paillier/rust_paillier/par/__init__.pyi b/rust/fate_utils/python/fate_utils/par/__init__.pyi similarity index 100% rename from rust/tensor/rust_paillier/rust_paillier/par/__init__.pyi rename to rust/fate_utils/python/fate_utils/par/__init__.pyi diff --git a/rust/fate_utils/python/fate_utils/quantile.pyi b/rust/fate_utils/python/fate_utils/quantile.pyi new file mode 100644 index 0000000000..7b18ce3116 --- /dev/null +++ b/rust/fate_utils/python/fate_utils/quantile.pyi @@ -0,0 +1,14 @@ +from typing import List, Optional +import numpy as np + +class QuantileSummaryStream: + def __init__(self, epsilon: Optional[float] = None): ... + def __getstate__(self) -> bytes: ... + def __setstate__(self, state: bytes) -> None: ... + def insert_array(self, data: np.ndarray) -> None: ... + def queries(self, phi: List[float]) -> List[float]: ... + def merge(self, other: "QuantileSummaryStream") -> "QuantileSummaryStream": ... + +def summary_f64_ix2(data: np.ndarray, epsilon: float) -> List[QuantileSummaryStream]: ... +def quantile_f64_ix1(data: np.ndarray, q: List[float], epsilon: float) -> List[float]: ... +def quantile_f64_ix2(data: np.ndarray, q: List[float], epsilon: float) -> np.ndarray: ... diff --git a/rust/fate_utils/python/fate_utils/secure_aggregation_helper/__init__.py b/rust/fate_utils/python/fate_utils/secure_aggregation_helper/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/rust/fate_utils/python/fate_utils/secure_aggregation_helper/__init__.pyi b/rust/fate_utils/python/fate_utils/secure_aggregation_helper/__init__.pyi new file mode 100644 index 0000000000..1cea9ad48a --- /dev/null +++ b/rust/fate_utils/python/fate_utils/secure_aggregation_helper/__init__.pyi @@ -0,0 +1,18 @@ +from typing import List, Dict, Optional, Tuple +import numpy as np + +class DiffieHellman: + def __init__(self) -> None: ... + def get_public_key(self) -> bytes: ... + def diffie_hellman(self, other_public_key: bytes) -> bytes: ... + +class RandomMix: + def __init__(self, seeds: Dict[int, bytes], rank: int) -> None: ... + def mix_one(self, input: np.ndarray, weight: Optional[float] = None) -> Tuple[np.ndarray, np.ndarray]: ... + def mix(self, inputs: List[np.ndarray], weight: Optional[float] = None) -> List[Tuple[np.ndarray, np.ndarray]]: ... + def get_index(self, rank: int) -> int: ... + +class MixAggregate: + def __init__(self) -> None: ... + def aggregate(self, inputs: List[Tuple[np.ndarray, np.ndarray]]) -> None: ... + def finalize(self, weight: Optional[float] = None) -> List[np.ndarray]: ... diff --git a/rust/tensor/rust_paillier/tests/test_base.py b/rust/fate_utils/tests/test_base.py similarity index 98% rename from rust/tensor/rust_paillier/tests/test_base.py rename to rust/fate_utils/tests/test_base.py index 533b7f42dd..8a1efab268 100644 --- a/rust/tensor/rust_paillier/tests/test_base.py +++ b/rust/fate_utils/tests/test_base.py @@ -9,7 +9,7 @@ def get_suites(): suites = [] - packages = ["rust_paillier", "rust_paillier.par"] + packages = ["fate_utils.tensor"] for package in packages: module = importlib.import_module(package) suites.append(Suite(module.keygen)) diff --git a/rust/fate_utils/tests/test_psi_bench.py b/rust/fate_utils/tests/test_psi_bench.py new file mode 100644 index 0000000000..77ca8a551d --- /dev/null +++ b/rust/fate_utils/tests/test_psi_bench.py @@ -0,0 +1,46 @@ +import random +import hashlib +from fate_utils.psi import Curve25519 + + +def ecdh(k, m): + return k.encrypt(m) + + +def dh(k, e): + return k.diffie_hellman(e) + + +def sha256(value): + return hashlib.sha256(bytes(value, encoding="utf-8")).digest() + + +def test_ecdh_encrypt_bench(benchmark): + k = Curve25519() + m = random.SystemRandom().getrandbits(256).to_bytes(32, "little") + result = benchmark(ecdh, k, m) + + +def test_ecdh_dh_bench(benchmark): + k = Curve25519() + m = random.SystemRandom().getrandbits(256).to_bytes(32, "little") + e = k.encrypt(m) + result = benchmark(dh, k, e) + + +def test_sha256_bench(benchmark): + m = "1000000000" + result = benchmark(sha256, m) + + +def test_ecdh_encrypt_vec_bench(benchmark): + k = Curve25519() + m = [random.SystemRandom().getrandbits(256).to_bytes(32, "little") for _ in range(10000)] + result = benchmark(k.encrypt_vec, m) + + +def test_ecdh_dh_vec_bench(benchmark): + k = Curve25519() + m = [random.SystemRandom().getrandbits(256).to_bytes(32, "little") for _ in range(10000)] + e = k.encrypt_vec(m) + result = benchmark(k.diffie_hellman_vec, e) diff --git a/rust/fate_utils/tests/test_psi_ecdh.py b/rust/fate_utils/tests/test_psi_ecdh.py new file mode 100644 index 0000000000..b4904a1e7a --- /dev/null +++ b/rust/fate_utils/tests/test_psi_ecdh.py @@ -0,0 +1,33 @@ +from fate_utils.psi import Curve25519 +import pickle +import unittest +import random + + +class TestStringMethods(unittest.TestCase): + def test_ecdh(self): + k1 = Curve25519() + k2 = Curve25519() + m = random.SystemRandom().getrandbits(33 * 8).to_bytes(33, "little") + self.assertEqual(k2.diffie_hellman(k1.encrypt(m)), k1.diffie_hellman(k2.encrypt(m))) + + def test_ecdh_vec(self): + k1 = Curve25519() + k2 = Curve25519() + m = [random.SystemRandom().getrandbits(33 * 8).to_bytes(33, "little") for _ in range(100)] + s1 = k1.encrypt_vec(m) + s12 = k2.diffie_hellman_vec(s1) + s2 = k2.encrypt_vec(m) + s21 = k1.diffie_hellman_vec(s2) + self.assertEqual(s12, s21) + + def test_pickle(self): + k1 = Curve25519() + m = random.SystemRandom().getrandbits(33 * 8).to_bytes(33, "little") + pickled = pickle.dumps(k1) + k2 = pickle.loads(pickled) + self.assertEqual(k1.encrypt(m), k2.encrypt(m)) + + +if __name__ == "__main__": + unittest.main() diff --git a/rust/fate_utils/tests/test_sm3_hash.py b/rust/fate_utils/tests/test_sm3_hash.py new file mode 100644 index 0000000000..70de2afe0a --- /dev/null +++ b/rust/fate_utils/tests/test_sm3_hash.py @@ -0,0 +1,18 @@ +import unittest +from fate_utils.hash import sm3_hash + + +class TestCorrect(unittest.TestCase): + def test_hash_1(self): + data = b"abc" + expected = "66c7f0f462eeedd9d1f2d46bdc10e4e24167c4875cf2f7a2297da02b8f4ba8e0" + self.assertEqual(sm3_hash(data).hex(), expected) + + def test_hash_2(self): + data = b"abcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd" + expected = "debe9ff92275b8a138604889c18e5a4d6fdb70e5387e5765293dcba39c0c5732" + self.assertEqual(sm3_hash(data).hex(), expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/rust/tensor/rust_paillier/Cargo.lock b/rust/tensor/rust_paillier/Cargo.lock deleted file mode 100644 index e9be0c7c77..0000000000 --- a/rust/tensor/rust_paillier/Cargo.lock +++ /dev/null @@ -1,948 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 3 - -[[package]] -name = "atty" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" -dependencies = [ - "hermit-abi", - "libc", - "winapi", -] - -[[package]] -name = "autocfg" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" - -[[package]] -name = "az" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f771a5d1f5503f7f4279a30f3643d3421ba149848b89ecaaec0ea2acf04a5ac4" - -[[package]] -name = "bincode" -version = "1.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" -dependencies = [ - "serde", -] - -[[package]] -name = "bitflags" -version = "1.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" - -[[package]] -name = "bstr" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba3569f383e8f1598449f1a423e72e99569137b47740b1da11ef19af3d5c3223" -dependencies = [ - "lazy_static", - "memchr", - "regex-automata", - "serde", -] - -[[package]] -name = "bumpalo" -version = "3.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37ccbd214614c6783386c1af30caf03192f17891059cecc394b4fb119e363de3" - -[[package]] -name = "cast" -version = "0.2.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c24dab4283a142afa2fdca129b80ad2c6284e073930f964c3a1293c225ee39a" -dependencies = [ - "rustc_version", -] - -[[package]] -name = "cfg-if" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822" - -[[package]] -name = "cfg-if" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" - -[[package]] -name = "clap" -version = "2.34.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c" -dependencies = [ - "bitflags", - "textwrap", - "unicode-width", -] - -[[package]] -name = "criterion" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1604dafd25fba2fe2d5895a9da139f8dc9b319a5fe5354ca137cbbce4e178d10" -dependencies = [ - "atty", - "cast", - "clap", - "criterion-plot", - "csv", - "itertools", - "lazy_static", - "num-traits", - "oorandom", - "plotters", - "rayon", - "regex", - "serde", - "serde_cbor", - "serde_derive", - "serde_json", - "tinytemplate", - "walkdir", -] - -[[package]] -name = "criterion-plot" -version = "0.4.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d00996de9f2f7559f7f4dc286073197f83e92256a59ed395f9aac01fe717da57" -dependencies = [ - "cast", - "itertools", -] - -[[package]] -name = "crossbeam-channel" -version = "0.5.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c02a4d71819009c192cf4872265391563fd6a84c81ff2c0f2a7026ca4c1d85c" -dependencies = [ - "cfg-if 1.0.0", - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-deque" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6455c0ca19f0d2fbf751b908d5c55c1f5cbc65e03c4225427254b46890bdde1e" -dependencies = [ - "cfg-if 1.0.0", - "crossbeam-epoch", - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-epoch" -version = "0.9.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07db9d94cbd326813772c968ccd25999e5f8ae22f4f8d1b11effa37ef6ce281d" -dependencies = [ - "autocfg", - "cfg-if 1.0.0", - "crossbeam-utils", - "memoffset", - "once_cell", - "scopeguard", -] - -[[package]] -name = "crossbeam-utils" -version = "0.8.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d82ee10ce34d7bc12c2122495e7593a9c41347ecdd64185af4ecf72cb1a7f83" -dependencies = [ - "cfg-if 1.0.0", - "once_cell", -] - -[[package]] -name = "csv" -version = "1.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22813a6dc45b335f9bade10bf7271dc477e81113e89eb251a0bc2a8a81c536e1" -dependencies = [ - "bstr", - "csv-core", - "itoa 0.4.8", - "ryu", - "serde", -] - -[[package]] -name = "csv-core" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b2466559f260f48ad25fe6317b3c8dac77b5bdb5763ac7d9d6103530663bc90" -dependencies = [ - "memchr", -] - -[[package]] -name = "either" -version = "1.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457" - -[[package]] -name = "getrandom" -version = "0.2.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4eb1a864a501629691edf6c15a593b7a51eebaa1e8468e9ddc623de7c9b58ec6" -dependencies = [ - "cfg-if 1.0.0", - "libc", - "wasi", -] - -[[package]] -name = "gmp-mpfr-sys" -version = "1.4.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3d00b0ef965511028498a1668c4a6ef9b0b2501a4a5ab26fb8156408869306e" -dependencies = [ - "libc", - "winapi", -] - -[[package]] -name = "half" -version = "1.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" - -[[package]] -name = "hermit-abi" -version = "0.1.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" -dependencies = [ - "libc", -] - -[[package]] -name = "iai" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71a816c97c42258aa5834d07590b718b4c9a598944cd39a52dc25b351185d678" - -[[package]] -name = "indoc" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47741a8bc60fb26eb8d6e0238bbb26d8575ff623fdc97b1a2c00c050b9684ed8" -dependencies = [ - "indoc-impl", - "proc-macro-hack", -] - -[[package]] -name = "indoc-impl" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce046d161f000fffde5f432a0d034d0341dc152643b2598ed5bfce44c4f3a8f0" -dependencies = [ - "proc-macro-hack", - "proc-macro2", - "quote", - "syn", - "unindent", -] - -[[package]] -name = "instant" -version = "0.1.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" -dependencies = [ - "cfg-if 1.0.0", -] - -[[package]] -name = "itertools" -version = "0.10.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9a9d19fa1e79b6215ff29b9d6880b706147f16e9b1dbb1e4e5947b5b02bc5e3" -dependencies = [ - "either", -] - -[[package]] -name = "itoa" -version = "0.4.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b71991ff56294aa922b450139ee08b3bfc70982c6b2c7562771375cf73542dd4" - -[[package]] -name = "itoa" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "112c678d4050afce233f4f2852bb2eb519230b3cf12f33585275537d7e41578d" - -[[package]] -name = "js-sys" -version = "0.3.58" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3fac17f7123a73ca62df411b1bf727ccc805daa070338fda671c86dac1bdc27" -dependencies = [ - "wasm-bindgen", -] - -[[package]] -name = "lazy_static" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" - -[[package]] -name = "libc" -version = "0.2.126" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "349d5a591cd28b49e1d1037471617a32ddcda5731b99419008085f72d5a53836" - -[[package]] -name = "lock_api" -version = "0.4.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "327fa5b6a6940e4699ec49a9beae1ea4845c6bab9314e4f84ac68742139d8c53" -dependencies = [ - "autocfg", - "scopeguard", -] - -[[package]] -name = "log" -version = "0.4.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e" -dependencies = [ - "cfg-if 1.0.0", -] - -[[package]] -name = "matrixmultiply" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "add85d4dd35074e6fedc608f8c8f513a3548619a9024b751949ef0e8e45a4d84" -dependencies = [ - "rawpointer", -] - -[[package]] -name = "memchr" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" - -[[package]] -name = "memoffset" -version = "0.6.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aa361d4faea93603064a027415f07bd8e1d5c88c9fbf68bf56a285428fd79ce" -dependencies = [ - "autocfg", -] - -[[package]] -name = "ndarray" -version = "0.15.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dec23e6762830658d2b3d385a75aa212af2f67a4586d4442907144f3bb6a1ca8" -dependencies = [ - "matrixmultiply", - "num-complex", - "num-integer", - "num-traits", - "rawpointer", - "rayon", -] - -[[package]] -name = "num-complex" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ae39348c8bc5fbd7f40c727a9925f03517afd2ab27d46702108b6a7e5414c19" -dependencies = [ - "num-traits", -] - -[[package]] -name = "num-integer" -version = "0.1.45" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" -dependencies = [ - "autocfg", - "num-traits", -] - -[[package]] -name = "num-traits" -version = "0.2.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" -dependencies = [ - "autocfg", -] - -[[package]] -name = "num_cpus" -version = "1.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19e64526ebdee182341572e50e9ad03965aa510cd94427a4549448f285e957a1" -dependencies = [ - "hermit-abi", - "libc", -] - -[[package]] -name = "numpy" -version = "0.15.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f3a190dd1aa88ee0de91e59e970d5b85cfa079a9ff6531b69f811ccd0c2a6e1" -dependencies = [ - "cfg-if 0.1.10", - "libc", - "ndarray", - "num-complex", - "num-traits", - "pyo3", -] - -[[package]] -name = "once_cell" -version = "1.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7709cef83f0c1f58f666e746a08b21e0085f7440fa6a29cc194d68aac97a4225" - -[[package]] -name = "oorandom" -version = "11.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" - -[[package]] -name = "parking_lot" -version = "0.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" -dependencies = [ - "instant", - "lock_api", - "parking_lot_core", -] - -[[package]] -name = "parking_lot_core" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d76e8e1493bcac0d2766c42737f34458f1c8c50c0d23bcb24ea953affb273216" -dependencies = [ - "cfg-if 1.0.0", - "instant", - "libc", - "redox_syscall", - "smallvec", - "winapi", -] - -[[package]] -name = "paste" -version = "0.1.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45ca20c77d80be666aef2b45486da86238fabe33e38306bd3118fe4af33fa880" -dependencies = [ - "paste-impl", - "proc-macro-hack", -] - -[[package]] -name = "paste-impl" -version = "0.1.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d95a7db200b97ef370c8e6de0088252f7e0dfff7d047a28528e47456c0fc98b6" -dependencies = [ - "proc-macro-hack", -] - -[[package]] -name = "plotters" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32a3fd9ec30b9749ce28cd91f255d569591cdf937fe280c312143e3c4bad6f2a" -dependencies = [ - "num-traits", - "plotters-backend", - "plotters-svg", - "wasm-bindgen", - "web-sys", -] - -[[package]] -name = "plotters-backend" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d88417318da0eaf0fdcdb51a0ee6c3bed624333bff8f946733049380be67ac1c" - -[[package]] -name = "plotters-svg" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "521fa9638fa597e1dc53e9412a4f9cefb01187ee1f7413076f9e6749e2885ba9" -dependencies = [ - "plotters-backend", -] - -[[package]] -name = "ppv-lite86" -version = "0.2.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" - -[[package]] -name = "proc-macro-hack" -version = "0.5.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbf0c48bc1d91375ae5c3cd81e3722dff1abcf81a30960240640d223f59fe0e5" - -[[package]] -name = "proc-macro2" -version = "1.0.40" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd96a1e8ed2596c337f8eae5f24924ec83f5ad5ab21ea8e455d3566c69fbcaf7" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "pyo3" -version = "0.15.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d41d50a7271e08c7c8a54cd24af5d62f73ee3a6f6a314215281ebdec421d5752" -dependencies = [ - "cfg-if 1.0.0", - "indoc", - "libc", - "parking_lot", - "paste", - "pyo3-build-config", - "pyo3-macros", - "unindent", -] - -[[package]] -name = "pyo3-build-config" -version = "0.15.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "779239fc40b8e18bc8416d3a37d280ca9b9fb04bda54b98037bb6748595c2410" -dependencies = [ - "once_cell", -] - -[[package]] -name = "pyo3-macros" -version = "0.15.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00b247e8c664be87998d8628e86f282c25066165f1f8dda66100c48202fdb93a" -dependencies = [ - "pyo3-macros-backend", - "quote", - "syn", -] - -[[package]] -name = "pyo3-macros-backend" -version = "0.15.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a8c2812c412e00e641d99eeb79dd478317d981d938aa60325dfa7157b607095" -dependencies = [ - "proc-macro2", - "pyo3-build-config", - "quote", - "syn", -] - -[[package]] -name = "quote" -version = "1.0.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3bcdf212e9776fbcb2d23ab029360416bb1706b1aea2d1a5ba002727cbcab804" -dependencies = [ - "proc-macro2", -] - -[[package]] -name = "rand" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" -dependencies = [ - "libc", - "rand_chacha", - "rand_core", -] - -[[package]] -name = "rand_chacha" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" -dependencies = [ - "ppv-lite86", - "rand_core", -] - -[[package]] -name = "rand_core" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d34f1408f55294453790c48b2f1ebbb1c5b4b7563eb1f418bcfcfdbb06ebb4e7" -dependencies = [ - "getrandom", -] - -[[package]] -name = "rawpointer" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" - -[[package]] -name = "rayon" -version = "1.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd99e5772ead8baa5215278c9b15bf92087709e9c1b2d1f97cdb5a183c933a7d" -dependencies = [ - "autocfg", - "crossbeam-deque", - "either", - "rayon-core", -] - -[[package]] -name = "rayon-core" -version = "1.9.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "258bcdb5ac6dad48491bb2992db6b7cf74878b0384908af124823d118c99683f" -dependencies = [ - "crossbeam-channel", - "crossbeam-deque", - "crossbeam-utils", - "num_cpus", -] - -[[package]] -name = "redox_syscall" -version = "0.2.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62f25bc4c7e55e0b0b7a1d43fb893f4fa1361d0abe38b9ce4f323c2adfe6ef42" -dependencies = [ - "bitflags", -] - -[[package]] -name = "regex" -version = "1.5.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d83f127d94bdbcda4c8cc2e50f6f84f4b611f69c902699ca385a39c3a75f9ff1" -dependencies = [ - "regex-syntax", -] - -[[package]] -name = "regex-automata" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" - -[[package]] -name = "regex-syntax" -version = "0.6.26" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49b3de9ec5dc0a3417da371aab17d729997c15010e7fd24ff707773a33bddb64" - -[[package]] -name = "rug" -version = "1.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f829d980ca193fa33fdd1decaebe72ec07cf2d8afdd0be60b3e5391f18014c91" -dependencies = [ - "az", - "gmp-mpfr-sys", - "libc", -] - -[[package]] -name = "rust_paillier" -version = "0.1.0" -dependencies = [ - "bincode", - "criterion", - "iai", - "ndarray", - "numpy", - "pyo3", - "rand", - "rand_core", - "rayon", - "rug", - "serde", -] - -[[package]] -name = "rustc_version" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" -dependencies = [ - "semver", -] - -[[package]] -name = "ryu" -version = "1.0.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3f6f92acf49d1b98f7a81226834412ada05458b7364277387724a237f062695" - -[[package]] -name = "same-file" -version = "1.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" -dependencies = [ - "winapi-util", -] - -[[package]] -name = "scopeguard" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" - -[[package]] -name = "semver" -version = "1.0.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a41d061efea015927ac527063765e73601444cdc344ba855bc7bd44578b25e1c" - -[[package]] -name = "serde" -version = "1.0.137" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61ea8d54c77f8315140a05f4c7237403bf38b72704d031543aa1d16abbf517d1" -dependencies = [ - "serde_derive", -] - -[[package]] -name = "serde_cbor" -version = "0.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5" -dependencies = [ - "half", - "serde", -] - -[[package]] -name = "serde_derive" -version = "1.0.137" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f26faba0c3959972377d3b2d306ee9f71faee9714294e41bb777f83f88578be" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "serde_json" -version = "1.0.81" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b7ce2b32a1aed03c558dc61a5cd328f15aff2dbc17daad8fb8af04d2100e15c" -dependencies = [ - "itoa 1.0.2", - "ryu", - "serde", -] - -[[package]] -name = "smallvec" -version = "1.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc88c725d61fc6c3132893370cac4a0200e3fedf5da8331c570664b1987f5ca2" - -[[package]] -name = "syn" -version = "1.0.98" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c50aef8a904de4c23c788f104b7dddc7d6f79c647c7c8ce4cc8f73eb0ca773dd" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "textwrap" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" -dependencies = [ - "unicode-width", -] - -[[package]] -name = "tinytemplate" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" -dependencies = [ - "serde", - "serde_json", -] - -[[package]] -name = "unicode-ident" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5bd2fe26506023ed7b5e1e315add59d6f584c621d037f9368fea9cfb988f368c" - -[[package]] -name = "unicode-width" -version = "0.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ed742d4ea2bd1176e236172c8429aaf54486e7ac098db29ffe6529e0ce50973" - -[[package]] -name = "unindent" -version = "0.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52fee519a3e570f7df377a06a1a7775cdbfb7aa460be7e08de2b1f0e69973a44" - -[[package]] -name = "walkdir" -version = "2.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "808cf2735cd4b6866113f648b791c6adc5714537bc222d9347bb203386ffda56" -dependencies = [ - "same-file", - "winapi", - "winapi-util", -] - -[[package]] -name = "wasi" -version = "0.11.0+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" - -[[package]] -name = "wasm-bindgen" -version = "0.2.81" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c53b543413a17a202f4be280a7e5c62a1c69345f5de525ee64f8cfdbc954994" -dependencies = [ - "cfg-if 1.0.0", - "wasm-bindgen-macro", -] - -[[package]] -name = "wasm-bindgen-backend" -version = "0.2.81" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5491a68ab4500fa6b4d726bd67408630c3dbe9c4fe7bda16d5c82a1fd8c7340a" -dependencies = [ - "bumpalo", - "lazy_static", - "log", - "proc-macro2", - "quote", - "syn", - "wasm-bindgen-shared", -] - -[[package]] -name = "wasm-bindgen-macro" -version = "0.2.81" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c441e177922bc58f1e12c022624b6216378e5febc2f0533e41ba443d505b80aa" -dependencies = [ - "quote", - "wasm-bindgen-macro-support", -] - -[[package]] -name = "wasm-bindgen-macro-support" -version = "0.2.81" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d94ac45fcf608c1f45ef53e748d35660f168490c10b23704c7779ab8f5c3048" -dependencies = [ - "proc-macro2", - "quote", - "syn", - "wasm-bindgen-backend", - "wasm-bindgen-shared", -] - -[[package]] -name = "wasm-bindgen-shared" -version = "0.2.81" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a89911bd99e5f3659ec4acf9c4d93b0a90fe4a2a11f15328472058edc5261be" - -[[package]] -name = "web-sys" -version = "0.3.58" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fed94beee57daf8dd7d51f2b15dc2bcde92d7a72304cdf662a4371008b71b90" -dependencies = [ - "js-sys", - "wasm-bindgen", -] - -[[package]] -name = "winapi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" -dependencies = [ - "winapi-i686-pc-windows-gnu", - "winapi-x86_64-pc-windows-gnu", -] - -[[package]] -name = "winapi-i686-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" - -[[package]] -name = "winapi-util" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178" -dependencies = [ - "winapi", -] - -[[package]] -name = "winapi-x86_64-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" diff --git a/rust/tensor/rust_paillier/Cargo.toml b/rust/tensor/rust_paillier/Cargo.toml deleted file mode 100644 index f16f4cefaf..0000000000 --- a/rust/tensor/rust_paillier/Cargo.toml +++ /dev/null @@ -1,43 +0,0 @@ -[package] -name = "rust_paillier" -version = "0.1.0" -edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html -[lib] -name = "rust_paillier" -crate-type = ["cdylib", "staticlib", "rlib"] -bench = false - -[dependencies] -numpy = "0.15.1" -pyo3 = { version = "0.15.2", features = ["extension-module"] } -rug = "1.16.0" -rand = { version = "0.8.3", features = ["getrandom"] } -rand_core = "0.6.3" -ndarray = { version = "0.15.4" } -serde = { version = "1.0.137", features = ["derive", "rc"] } -bincode = "1.3.3" -rayon = { version = "1.5.3", optional = true} - -[features] -default = ["rug", "rayon"] -rug = [] -rayon = ["dep:rayon", "ndarray/rayon"] - -[dev-dependencies] -criterion = { version = "0.3", features = ["html_reports"] } -iai = "0.1.1" - -[[bench]] -name = "paillier_bench" -harness = false - -[[bench]] -name = "iai_bench" -harness = false - -[package.metadata.docs.rs] -# To build locally use -# RUSTDOCFLAGS="--html-in-header katex-header.html" cargo doc --no-deps --document-private-items --open -rustdoc-args = ["--html-in-header", "docs/katex-header.html"] diff --git a/rust/tensor/rust_paillier/README.md b/rust/tensor/rust_paillier/README.md deleted file mode 100644 index 5b02790feb..0000000000 --- a/rust/tensor/rust_paillier/README.md +++ /dev/null @@ -1,3 +0,0 @@ -## rust_paillier - -paillier tensor implemented using rust. diff --git a/rust/tensor/rust_paillier/rust_paillier/__init__.py b/rust/tensor/rust_paillier/rust_paillier/__init__.py deleted file mode 100644 index 86fa586d02..0000000000 --- a/rust/tensor/rust_paillier/rust_paillier/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .rust_paillier import * diff --git a/rust/tensor/rust_paillier/src/block/matmul.rs b/rust/tensor/rust_paillier/src/block/matmul.rs deleted file mode 100644 index 2a269e943d..0000000000 --- a/rust/tensor/rust_paillier/src/block/matmul.rs +++ /dev/null @@ -1,298 +0,0 @@ -use super::{fixedpoint, Cipherblock, CouldCode}; -use ndarray::{ArrayView1, ArrayView2}; -use rayon::prelude::*; - -/// help function for generic matrix multiply -/// -/// assuming we want to multiply matrix A(with shape m x s) with B(with shape s, n): -/// 1. the output matrix C has shape (m, n) -/// 2. C(i,j) = \sum_k A(i,k)B(k,j) -/// 3. the function F(i, k, j, v): v += A(i,k)B(k,j) -fn matmul_apply<F>(m: usize, s: usize, n: usize, func: F) -> Vec<fixedpoint::CT> -where - F: Fn(usize, usize, usize, &mut fixedpoint::CT) -> (), // (i, k, j, v) -> () -{ - let mut data: Vec<fixedpoint::CT> = vec![fixedpoint::CT::zero(); m * n]; - (0..m) - .flat_map(|i| (0..n).map(move |j| (i, j))) - .into_iter() - .zip(data.iter_mut()) - .for_each(|((i, j), v)| (0..s).for_each(|k| func(i, k, j, v))); - data -} - -/// parallel version of help function for generic matrix multiply -/// -/// assuming we want to multiply matrix A(with shape m x s) with B(with shape s, n): -/// 1. the output matrix C has shape (m, n) -/// 2. C(i,j) = \sum_k A(i,k)B(k,j) -/// 3. the function F(i, k, j, v): v += A(i,k)B(k,j) -fn matmul_apply_par<F>(m: usize, s: usize, n: usize, func: F) -> Vec<fixedpoint::CT> -where - F: Fn(usize, usize, usize, &mut fixedpoint::CT) -> () + Sync, // (i, k, j, v) -> () -{ - let mut data: Vec<fixedpoint::CT> = vec![fixedpoint::CT::zero(); m * n]; - let output_indexes = (0..m) - .flat_map(|i| (0..n).map(move |j| (i, j))) - .collect::<Vec<(usize, usize)>>(); - output_indexes - .into_par_iter() - .zip(data.par_iter_mut()) - .for_each(|((i, j), v)| (0..s).for_each(|k| func(i, k, j, v))); - data -} - -/// v += lhs[i, k]rhs[k, j] -#[inline] -fn matmul_ops_cipherblock_plaintext_ix2<T: CouldCode>( - v: &mut fixedpoint::CT, - i: usize, - k: usize, - j: usize, - lhs: &Cipherblock, - rhs: ArrayView2<T>, -) { - let l = &lhs[(i, k)]; // lhs[i, k] - let r = rhs[(k, j)].encode(&lhs.pk.coder); // rhs[k, j] - v.add_assign(&l.mul(&r, &lhs.pk), &lhs.pk); // v += Self[i, k]Other[k, j] -} - -/// v += lhs[i, k]rhs[k] -#[inline] -fn matmul_ops_cipherblock_plaintext_ix1<T: CouldCode>( - v: &mut fixedpoint::CT, - i: usize, - k: usize, - lhs: &Cipherblock, - rhs: ArrayView1<T>, -) { - let l = &lhs[(i, k)]; // lhs[i, k] - let r = rhs[k].encode(&lhs.pk.coder); // rhs[k] - v.add_assign(&l.mul(&r, &lhs.pk), &lhs.pk); // v += Self[i, k]Other[k] -} - -/// v += lhs[i, k]rhs[k, j] -#[inline] -fn rmatmul_ops_cipherblock_plaintext_ix2<T: CouldCode>( - v: &mut fixedpoint::CT, - i: usize, - k: usize, - j: usize, - lhs: ArrayView2<T>, - rhs: &Cipherblock, -) { - let l = lhs[(i, k)].encode(&rhs.pk.coder); // lhs[i, k] - let r = &rhs[(k, j)]; // rhs[k, j] - v.add_assign(&r.mul(&l, &rhs.pk), &rhs.pk); // v += Self[i, k]Other[k, j] -} - -/// v += lhs[k]rhs[k, j] -#[inline] -fn rmatmul_ops_cipherblock_plaintext_ix1<T: CouldCode>( - v: &mut fixedpoint::CT, - k: usize, - j: usize, - lhs: ArrayView1<T>, - rhs: &Cipherblock, -) { - let l = lhs[k].encode(&rhs.pk.coder); // lhs[k] - let r = &rhs[(k, j)]; // rhs[k, j] - v.add_assign(&r.mul(&l, &rhs.pk), &rhs.pk); // v += Self[k]Other[k, j] -} - -fn checked_shape_cipherblock_matmul_plaintext_ix2<T>( - lhs: &Cipherblock, - rhs: ArrayView2<T>, -) -> (usize, usize, usize) { - if lhs.shape.len() != 2 || lhs.shape[1] != rhs.dim().0 { - panic!("dot shape error: ({:?}) x ({:?})", lhs.shape, rhs.dim()); - } - (lhs.shape[0], lhs.shape[1], rhs.dim().1) -} -fn checked_shape_cipherblock_matmul_plaintext_ix1<T>( - lhs: &Cipherblock, - rhs: ArrayView1<T>, -) -> (usize, usize, usize) { - if lhs.shape.len() != 2 || lhs.shape[1] != rhs.dim() { - panic!("dot shape error: ({:?}) x ({:?})", lhs.shape, rhs.dim()); - } - (lhs.shape[0], lhs.shape[1], 1) -} -fn checked_shape_cipherblock_rmatmul_plaintext_ix1<T>( - lhs: ArrayView1<T>, - rhs: &Cipherblock, -) -> (usize, usize, usize) { - if rhs.shape.len() != 2 || rhs.shape[0] != lhs.dim() { - panic!("dot shape error: ({:?}) x ({:?})", lhs.dim(), rhs.shape); - } - (1, rhs.shape[0], rhs.shape[1]) -} -fn checked_shape_cipherblock_rmatmul_plaintext_ix2<T>( - lhs: ArrayView2<T>, - rhs: &Cipherblock, -) -> (usize, usize, usize) { - if rhs.shape.len() != 2 || rhs.shape[0] != lhs.dim().1 { - panic!("dot shape error: ({:?}) x ({:?})", lhs.dim(), rhs.shape); - } - (lhs.dim().0, rhs.shape[0], rhs.shape[1]) -} - -pub fn cipherblock_matmul_plaintext_ix1<T: CouldCode>( - lhs: &Cipherblock, - rhs: ArrayView1<T>, -) -> Cipherblock { - // (m x s) x (s x n) - let (m, s, n) = checked_shape_cipherblock_matmul_plaintext_ix1(lhs, rhs); - let data = matmul_apply(m, s, n, |i, k, _, v| { - matmul_ops_cipherblock_plaintext_ix1(v, i, k, lhs, rhs); - }); - Cipherblock { - pk: lhs.pk.clone(), - data, - shape: vec![m, n], - } -} - -pub fn cipherblock_matmul_plaintext_ix2<T: CouldCode>( - lhs: &Cipherblock, - rhs: ArrayView2<T>, -) -> Cipherblock { - // (m x s) x (s x n) - let (m, s, n) = checked_shape_cipherblock_matmul_plaintext_ix2(lhs, rhs); - let data = matmul_apply(m, s, n, |i, k, j, v| { - matmul_ops_cipherblock_plaintext_ix2(v, i, k, j, lhs, rhs); - }); - Cipherblock { - pk: lhs.pk.clone(), - data, - shape: vec![m, n], - } -} -pub fn cipherblock_rmatmul_plaintext_ix1<T: CouldCode>( - lhs: ArrayView1<T>, - rhs: &Cipherblock, -) -> Cipherblock { - // (m x s) x (s x n) - let (m, s, n) = checked_shape_cipherblock_rmatmul_plaintext_ix1(lhs, rhs); - let data = matmul_apply(m, s, n, |_, k, j, v| { - rmatmul_ops_cipherblock_plaintext_ix1(v, k, j, lhs, rhs); - }); - Cipherblock { - pk: rhs.pk.clone(), - data, - shape: vec![m, n], - } -} - -pub fn cipherblock_rmatmul_plaintext_ix2<T: CouldCode>( - lhs: ArrayView2<T>, - rhs: &Cipherblock, -) -> Cipherblock { - // (m x s) x (s x n) - let (m, s, n) = checked_shape_cipherblock_rmatmul_plaintext_ix2(lhs, rhs); - let data = matmul_apply(m, s, n, |i, k, j, v| { - rmatmul_ops_cipherblock_plaintext_ix2(v, i, k, j, lhs, rhs); - }); - Cipherblock { - pk: rhs.pk.clone(), - data, - shape: vec![m, n], - } -} - -pub fn cipherblock_matmul_plaintext_ix1_par<T: CouldCode + Sync>( - lhs: &Cipherblock, - rhs: ArrayView1<T>, -) -> Cipherblock { - // (m x s) x (s x n) - let (m, s, n) = checked_shape_cipherblock_matmul_plaintext_ix1(lhs, rhs); - let data = matmul_apply_par(m, s, n, |i, k, _, v| { - matmul_ops_cipherblock_plaintext_ix1(v, i, k, lhs, rhs); - }); - Cipherblock { - pk: lhs.pk.clone(), - data, - shape: vec![m, n], - } -} -pub fn cipherblock_matmul_plaintext_ix2_par<T: CouldCode + Sync>( - lhs: &Cipherblock, - rhs: ArrayView2<T>, -) -> Cipherblock { - // (m x s) x (s x n) - let (m, s, n) = checked_shape_cipherblock_matmul_plaintext_ix2(lhs, rhs); - let data = matmul_apply_par(m, s, n, |i, k, j, v| { - matmul_ops_cipherblock_plaintext_ix2(v, i, k, j, lhs, rhs); - }); - Cipherblock { - pk: lhs.pk.clone(), - data, - shape: vec![m, n], - } -} -pub fn cipherblock_rmatmul_plaintext_ix1_par<T: CouldCode + Sync>( - lhs: ArrayView1<T>, - rhs: &Cipherblock, -) -> Cipherblock { - // (m x s) x (s x n) - let (m, s, n) = checked_shape_cipherblock_rmatmul_plaintext_ix1(lhs, rhs); - let data = matmul_apply_par(m, s, n, |_, k, j, v| { - rmatmul_ops_cipherblock_plaintext_ix1(v, k, j, lhs, rhs); - }); - Cipherblock { - pk: rhs.pk.clone(), - data, - shape: vec![m, n], - } -} - -pub fn cipherblock_rmatmul_plaintext_ix2_par<T: CouldCode + Sync>( - lhs: ArrayView2<T>, - rhs: &Cipherblock, -) -> Cipherblock { - // (m x s) x (s x n) - let (m, s, n) = checked_shape_cipherblock_rmatmul_plaintext_ix2(lhs, rhs); - let data = matmul_apply_par(m, s, n, |i, k, j, v| { - rmatmul_ops_cipherblock_plaintext_ix2(v, i, k, j, lhs, rhs); - }); - Cipherblock { - pk: rhs.pk.clone(), - data, - shape: vec![m, n], - } -} - -impl Cipherblock { - pub fn matmul_plaintext_ix1<T: CouldCode>(&self, rhs: ArrayView1<T>) -> Cipherblock { - cipherblock_matmul_plaintext_ix1(self, rhs) - } - pub fn rmatmul_plaintext_ix1<T: CouldCode>(&self, lhs: ArrayView1<T>) -> Cipherblock { - cipherblock_rmatmul_plaintext_ix1(lhs, self) - } - pub fn matmul_plaintext_ix2<T: CouldCode>(&self, rhs: ArrayView2<T>) -> Cipherblock { - cipherblock_matmul_plaintext_ix2(self, rhs) - } - pub fn rmatmul_plaintext_ix2<T: CouldCode>(&self, lhs: ArrayView2<T>) -> Cipherblock { - cipherblock_rmatmul_plaintext_ix2(lhs, self) - } - - // par - pub fn matmul_plaintext_ix1_par<T: CouldCode + Sync>(&self, rhs: ArrayView1<T>) -> Cipherblock { - cipherblock_matmul_plaintext_ix1_par(self, rhs) - } - pub fn rmatmul_plaintext_ix1_par<T: CouldCode + Sync>( - &self, - lhs: ArrayView1<T>, - ) -> Cipherblock { - cipherblock_rmatmul_plaintext_ix1_par(lhs, self) - } - pub fn matmul_plaintext_ix2_par<T: CouldCode + Sync>(&self, rhs: ArrayView2<T>) -> Cipherblock { - cipherblock_matmul_plaintext_ix2_par(self, rhs) - } - pub fn rmatmul_plaintext_ix2_par<T: CouldCode + Sync>( - &self, - lhs: ArrayView2<T>, - ) -> Cipherblock { - cipherblock_rmatmul_plaintext_ix2_par(lhs, self) - } -} diff --git a/rust/tensor/rust_paillier/src/block/mod.rs b/rust/tensor/rust_paillier/src/block/mod.rs deleted file mode 100644 index 303e995979..0000000000 --- a/rust/tensor/rust_paillier/src/block/mod.rs +++ /dev/null @@ -1,220 +0,0 @@ -use std::ops::Index; - -use super::fixedpoint; -use super::fixedpoint::CouldCode; -use ndarray::{ArrayD, ArrayViewD}; -use rayon::prelude::*; -use serde::{Deserialize, Serialize}; -mod matmul; - -#[derive(Clone, Serialize, Deserialize)] -pub struct Cipherblock { - pub pk: fixedpoint::PK, - pub data: Vec<fixedpoint::CT>, - pub shape: Vec<usize>, -} - -impl Index<(usize, usize)> for Cipherblock { - type Output = fixedpoint::CT; - - #[inline] - fn index(&self, index: (usize, usize)) -> &Self::Output { - &self.data[index.0 * self.shape[1] + index.1] - } -} -impl Cipherblock { - pub fn map<F>(&self, func: F) -> Cipherblock - where - F: Fn(&fixedpoint::CT) -> fixedpoint::CT, - { - Cipherblock { - pk: self.pk.clone(), - data: self.data.iter().map(func).collect(), - shape: self.shape.clone(), - } - } - pub fn agg<F, T>(&self, init: T, f: F) -> T - where - F: Fn(T, &fixedpoint::CT) -> T, - { - self.data.iter().fold(init, f) - } - pub fn binary_cipherblock_cipherblock<F>( - lhs: &Cipherblock, - rhs: &Cipherblock, - func: F, - ) -> Cipherblock - where - F: Fn(&fixedpoint::CT, &fixedpoint::CT, &fixedpoint::PK) -> fixedpoint::CT, - { - assert_eq!(lhs.shape, rhs.shape); - assert_eq!(lhs.pk, rhs.pk); - let lhs_iter = lhs.data.iter(); - let rhs_iter = rhs.data.iter(); - let data: Vec<fixedpoint::CT> = lhs_iter - .zip(rhs_iter) - .map(|(l, r)| func(l, r, &lhs.pk)) - .collect(); - Cipherblock { - pk: lhs.pk.clone(), - data, - shape: lhs.shape.clone(), - } - } - pub fn binary_cipherblock_plaintext<F, T>( - lhs: &Cipherblock, - rhs: ArrayViewD<T>, - func: F, - ) -> Cipherblock - where - F: Fn(&fixedpoint::CT, &fixedpoint::PT, &fixedpoint::PK) -> fixedpoint::CT, - T: CouldCode, - { - assert_eq!(lhs.shape, rhs.shape().to_vec()); - let lhs_iter = lhs.data.iter(); - let rhs_iter = rhs.iter(); - let data: Vec<fixedpoint::CT> = lhs_iter - .zip(rhs_iter) - .map(|(l, r)| func(l, &r.encode(&lhs.pk.coder), &lhs.pk)) - .collect(); - Cipherblock { - pk: lhs.pk.clone(), - data, - shape: lhs.shape.clone(), - } - } -} -impl fixedpoint::PK { - pub fn encrypt_array<T>(&self, array: ArrayViewD<T>) -> Cipherblock - where - T: CouldCode, - { - let shape = array.shape().to_vec(); - let data: Vec<fixedpoint::CT> = array - .iter() - .map(|e| self.encrypt(&e.encode(&self.coder), true)) - .collect(); - Cipherblock { - pk: self.clone(), - data, - shape, - } - } -} - -impl fixedpoint::SK { - pub fn decrypt_array<T>(&self, array: &Cipherblock) -> ArrayD<T> - where - T: CouldCode, - { - let shape = array.shape.as_slice(); - let data = array - .data - .iter() - .map(|e| T::decode(&self.decrypt(e), &self.coder)) - .collect(); - ArrayD::from_shape_vec(shape, data).unwrap() - } -} - -impl Cipherblock { - pub fn agg_par<F, T, ID, OP>(&self, identity: ID, f: F, op: OP) -> T - where - F: Fn(T, &fixedpoint::CT) -> T + Send + Sync, - ID: Fn() -> T + Send + Sync, - OP: Fn(T, T) -> T + Send + Sync, - T: Send, - { - self.data - .par_iter() - .fold(&identity, f) - .reduce(&identity, op) - } - pub fn map_par<F>(&self, func: F) -> Cipherblock - where - F: Fn(&fixedpoint::CT) -> fixedpoint::CT + Sync + Send, - { - Cipherblock { - pk: self.pk.clone(), - data: self.data.par_iter().map(func).collect(), - shape: self.shape.clone(), - } - } - pub fn binary_cipherblock_cipherblock_par<F>( - lhs: &Cipherblock, - rhs: &Cipherblock, - func: F, - ) -> Cipherblock - where - F: Fn(&fixedpoint::CT, &fixedpoint::CT, &fixedpoint::PK) -> fixedpoint::CT + Sync, - { - assert_eq!(lhs.shape, rhs.shape); - assert_eq!(lhs.pk, rhs.pk); - let lhs_iter = lhs.data.par_iter(); - let rhs_iter = rhs.data.par_iter(); - let data: Vec<fixedpoint::CT> = lhs_iter - .zip(rhs_iter) - .map(|(l, r)| func(l, r, &lhs.pk)) - .collect(); - Cipherblock { - pk: lhs.pk.clone(), - data, - shape: lhs.shape.clone(), - } - } - pub fn binary_cipherblock_plaintext_par<F, T>( - lhs: &Cipherblock, - rhs: ArrayViewD<T>, - func: F, - ) -> Cipherblock - where - F: Fn(&fixedpoint::CT, &fixedpoint::PT, &fixedpoint::PK) -> fixedpoint::CT + Sync, - T: CouldCode + Sync + Send, - { - assert_eq!(lhs.shape, rhs.shape().to_vec()); - let lhs_iter = lhs.data.par_iter(); - let rhs_iter = rhs.as_slice().unwrap().into_par_iter(); - let data: Vec<fixedpoint::CT> = lhs_iter - .zip(rhs_iter) - .map(|(l, r)| func(l, &r.encode(&lhs.pk.coder), &lhs.pk)) - .collect(); - Cipherblock { - pk: lhs.pk.clone(), - data, - shape: lhs.shape.clone(), - } - } -} - -impl fixedpoint::PK { - pub fn encrypt_array_par<T>(&self, array: ArrayViewD<T>) -> Cipherblock - where - T: CouldCode + Send + Sync, - { - let shape = array.shape().to_vec(); - let data: Vec<fixedpoint::CT> = array - .into_par_iter() - .map(|e| self.encrypt(&e.encode(&self.coder), true)) - .collect(); - Cipherblock { - pk: self.clone(), - data, - shape, - } - } -} - -impl fixedpoint::SK { - pub fn decrypt_array_par<T>(&self, array: &Cipherblock) -> ArrayD<T> - where - T: CouldCode + Send, - { - let shape = array.shape.as_slice(); - let data = array - .data - .par_iter() - .map(|e| T::decode(&self.decrypt(e), &self.coder)) - .collect(); - ArrayD::from_shape_vec(shape, data).unwrap() - } -} diff --git a/rust/tensor/rust_paillier/src/cb.rs b/rust/tensor/rust_paillier/src/cb.rs deleted file mode 100644 index 56f12eef54..0000000000 --- a/rust/tensor/rust_paillier/src/cb.rs +++ /dev/null @@ -1,156 +0,0 @@ -use super::{block, fixedpoint, fixedpoint::CouldCode, Cipherblock, PK, SK}; -use ndarray::{ArrayD, ArrayView1, ArrayView2, ArrayViewD}; - -fn operation_with_arrayview_dyn<F, T>( - this: &Cipherblock, - other: ArrayViewD<T>, - func: F, -) -> Cipherblock -where - F: Fn(&block::Cipherblock, ArrayViewD<T>) -> block::Cipherblock, -{ - Cipherblock::new(func(this.unwrap(), other)) -} - -fn operation_with_cipherblock<F>(this: &Cipherblock, other: &Cipherblock, func: F) -> Cipherblock -where - F: Fn(&block::Cipherblock, &block::Cipherblock) -> block::Cipherblock, -{ - let a = this.unwrap(); - let b = other.unwrap(); - Cipherblock::new(func(a, b)) -} - -fn operation_with_scalar<F, T>(this: &Cipherblock, other: T, func: F) -> Cipherblock -where - F: Fn(&block::Cipherblock, T) -> block::Cipherblock, -{ - Cipherblock::new(func(this.unwrap(), other)) -} - -macro_rules! impl_ops_cipher_scalar { - ($name:ident,$fn:expr) => { - pub fn $name(&self, other: &fixedpoint::CT) -> Cipherblock { - operation_with_scalar(self, other, |lhs, rhs| { - block::Cipherblock::map(lhs, |c| $fn(c, rhs, &lhs.pk)) - }) - } - }; -} -macro_rules! impl_ops_plaintext_scalar { - ($name:ident,$fn:expr) => { - pub fn $name<T>(&self, other: T) -> Cipherblock - where - T: CouldCode, - { - operation_with_scalar(self, other, |lhs, rhs| { - block::Cipherblock::map(lhs, |c| $fn(c, &rhs.encode(&lhs.pk.coder), &lhs.pk)) - }) - } - }; -} -macro_rules! impl_ops_cipher { - ($name:ident,$fn:expr) => { - pub fn $name(&self, other: &Cipherblock) -> Cipherblock { - operation_with_cipherblock(self, other, |lhs, rhs| { - block::Cipherblock::binary_cipherblock_cipherblock(lhs, rhs, $fn) - }) - } - }; -} -macro_rules! impl_ops_plain { - ($name:ident,$fn:expr) => { - pub fn $name<T>(&self, other: ArrayViewD<T>) -> Cipherblock - where - T: fixedpoint::CouldCode, - { - operation_with_arrayview_dyn(self, other, |lhs, rhs| { - block::Cipherblock::binary_cipherblock_plaintext(lhs, rhs, $fn) - }) - } - }; -} -macro_rules! impl_ops_matmul { - ($name:ident, $fn:expr, $oty:ident) => { - pub fn $name<T: CouldCode + Sync>(&self, other: $oty<T>) -> Cipherblock { - Cipherblock::new($fn(self.unwrap(), other)) - } - }; -} -impl Cipherblock { - fn new(cb: block::Cipherblock) -> Cipherblock { - Cipherblock(Some(cb)) - } - fn unwrap(&self) -> &block::Cipherblock { - self.0.as_ref().unwrap() - } - impl_ops_cipher!(add_cb, fixedpoint::CT::add); - impl_ops_plain!(add_plaintext, fixedpoint::CT::add_pt); - impl_ops_cipher_scalar!(add_cipher_scalar, fixedpoint::CT::add); - impl_ops_plaintext_scalar!(add_plaintext_scalar, fixedpoint::CT::add_pt); - impl_ops_cipher!(sub_cb, fixedpoint::CT::sub); - impl_ops_plain!(sub_plaintext, fixedpoint::CT::sub_pt); - impl_ops_cipher_scalar!(sub_cipher_scalar, fixedpoint::CT::sub); - impl_ops_plaintext_scalar!(sub_plaintext_scalar, fixedpoint::CT::sub_pt); - impl_ops_plain!(mul_plaintext, fixedpoint::CT::mul); - impl_ops_plaintext_scalar!(mul_plaintext_scalar, fixedpoint::CT::mul); - - // matmul - impl_ops_matmul!( - matmul_plaintext_ix1, - block::Cipherblock::matmul_plaintext_ix1, - ArrayView1 - ); - impl_ops_matmul!( - rmatmul_plaintext_ix1, - block::Cipherblock::rmatmul_plaintext_ix1, - ArrayView1 - ); - impl_ops_matmul!( - matmul_plaintext_ix2, - block::Cipherblock::matmul_plaintext_ix2, - ArrayView2 - ); - impl_ops_matmul!( - rmatmul_plaintext_ix2, - block::Cipherblock::rmatmul_plaintext_ix2, - ArrayView2 - ); -} - -impl Cipherblock { - pub fn sum_cb(&self) -> Cipherblock { - let cb = self.unwrap(); - let sum = cb.agg(fixedpoint::CT::zero(), |s, c| s.add(c, &cb.pk)); - Cipherblock::new(block::Cipherblock { - pk: cb.pk.clone(), - data: vec![sum], - shape: vec![1], - }) - } - pub fn mean_cb(&self) -> Cipherblock { - let cb = self.unwrap(); - let (s, n) = cb.agg((fixedpoint::CT::zero(), 0usize), |s, c| { - (s.0.add(c, &cb.pk), s.1 + 1) - }); - let mean = s.mul(&(1.0f64 / (n as f64)).encode(&cb.pk.coder), &cb.pk); - Cipherblock::new(block::Cipherblock { - pk: cb.pk.clone(), - data: vec![mean], - shape: vec![1], - }) - } -} - -impl SK { - pub fn decrypt_array<T: CouldCode + numpy::Element>(&self, a: &Cipherblock) -> ArrayD<T> { - let array = a.0.as_ref().unwrap(); - self.as_ref().decrypt_array(array) - } -} - -impl PK { - pub fn encrypt_array<T: CouldCode>(&self, array: ArrayViewD<T>) -> Cipherblock { - Cipherblock::new(self.as_ref().encrypt_array(array)) - } -} diff --git a/rust/tensor/rust_paillier/src/lib.rs b/rust/tensor/rust_paillier/src/lib.rs deleted file mode 100644 index dfeb5116c8..0000000000 --- a/rust/tensor/rust_paillier/src/lib.rs +++ /dev/null @@ -1,316 +0,0 @@ -pub mod block; -pub mod cb; -pub mod fixedpoint; -pub mod math; -pub mod paillier; -mod par; - -use bincode::{deserialize, serialize}; -use numpy::{IntoPyArray, PyArrayDyn, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArrayDyn}; -use pyo3::exceptions::PyTypeError; -use pyo3::prelude::*; -use pyo3::types::PyBytes; - -/// cipherblock contains ciphertexts and pubkey -/// -/// we need `new` method with zero argument (Option::None) -/// for unpickle to work. -#[pyclass(module = "rust_paillier")] -pub struct Cipherblock(Option<block::Cipherblock>); - -#[pyclass(module = "rust_paillier")] -pub struct PK { - pk: Option<fixedpoint::PK>, -} -impl PK { - fn new(pk: fixedpoint::PK) -> Self { - Self { pk: Some(pk) } - } - fn as_ref(&self) -> &fixedpoint::PK { - self.pk.as_ref().unwrap() - } -} - -#[pyclass(module = "rust_paillier")] -pub struct SK { - sk: Option<fixedpoint::SK>, -} - -impl SK { - fn new(sk: fixedpoint::SK) -> Self { - Self { sk: Some(sk) } - } - fn as_ref(&self) -> &fixedpoint::SK { - self.sk.as_ref().unwrap() - } -} - -#[pyfunction] -fn keygen(bit_size: u32) -> (PK, SK) { - let (sk, pk) = fixedpoint::keygen(bit_size); - (PK::new(pk), SK::new(sk)) -} - -/// public key for paillier system used to encrypt arrays -/// -/// Notes: we could not use Generics Types or rule macro here, sad. -#[pymethods] -impl PK { - #[new] - fn __new__() -> Self { - Self { pk: None } - } - pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> { - Ok(PyBytes::new(py, &serialize(self.as_ref()).unwrap()).to_object(py)) - } - pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { - match state.extract::<&PyBytes>(py) { - Ok(s) => { - self.pk = Some(deserialize(s.as_bytes()).unwrap()); - Ok(()) - } - Err(e) => Err(e), - } - } - pub fn __richcmp__(&self, other: &PK, cmp: pyo3::basic::CompareOp) -> PyResult<bool> { - match cmp { - pyo3::basic::CompareOp::Eq => Ok(self.as_ref() == other.as_ref()), - _ => Err(PyTypeError::new_err( - "not supported between instances PK and PK", - )), - } - } - fn encrypt_f64(&self, a: PyReadonlyArrayDyn<f64>) -> Cipherblock { - self.encrypt_array(a.as_array()) - } - fn encrypt_f32(&self, a: PyReadonlyArrayDyn<f32>) -> Cipherblock { - self.encrypt_array(a.as_array()) - } - fn encrypt_i64(&self, a: PyReadonlyArrayDyn<i64>) -> Cipherblock { - self.encrypt_array(a.as_array()) - } - fn encrypt_i32(&self, a: PyReadonlyArrayDyn<i32>) -> Cipherblock { - self.encrypt_array(a.as_array()) - } -} - -/// secret key for paillier system used to encrypt arrays -/// -/// Notes: we could not use Generics Types or rule macro here, sad. -#[pymethods] -impl SK { - #[new] - fn __new__() -> Self { - Self { sk: None } - } - pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> { - Ok(PyBytes::new(py, &serialize(self.as_ref()).unwrap()).to_object(py)) - } - pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { - match state.extract::<&PyBytes>(py) { - Ok(s) => { - self.sk = Some(deserialize(s.as_bytes()).unwrap()); - Ok(()) - } - Err(e) => Err(e), - } - } - pub fn __richcmp__(&self, other: &SK, cmp: pyo3::basic::CompareOp) -> PyResult<bool> { - match cmp { - pyo3::basic::CompareOp::Eq => Ok(self.as_ref() == other.as_ref()), - _ => Err(PyTypeError::new_err( - "not supported between instances PK and PK", - )), - } - } - fn decrypt_f64<'py>(&self, a: &Cipherblock, py: Python<'py>) -> &'py PyArrayDyn<f64> { - self.decrypt_array(a).into_pyarray(py) - } - fn decrypt_f32<'py>(&self, a: &Cipherblock, py: Python<'py>) -> &'py PyArrayDyn<f32> { - self.decrypt_array(a).into_pyarray(py) - } - fn decrypt_i64<'py>(&self, a: &Cipherblock, py: Python<'py>) -> &'py PyArrayDyn<i64> { - self.decrypt_array(a).into_pyarray(py) - } - fn decrypt_i32<'py>(&self, a: &Cipherblock, py: Python<'py>) -> &'py PyArrayDyn<i32> { - self.decrypt_array(a).into_pyarray(py) - } -} - -/// methods for cipherblock -/// -/// Notes: we could not use Generics Types or rule macro here, sad. -#[pymethods] -impl Cipherblock { - #[new] - fn __new__() -> Self { - Cipherblock(None) - } - pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> { - Ok(PyBytes::new(py, &serialize(&self.0).unwrap()).to_object(py)) - } - pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { - match state.extract::<&PyBytes>(py) { - Ok(s) => { - self.0 = deserialize(s.as_bytes()).unwrap(); - Ok(()) - } - Err(e) => Err(e), - } - } - #[getter] - pub fn shape(&self) -> Vec<usize> { - self.0.as_ref().map(|cb| cb.shape.clone()).unwrap() - } - // add - pub fn add_cipherblock(&self, other: &Cipherblock) -> Cipherblock { - self.add_cb(other) - } - pub fn add_plaintext_f64(&self, other: PyReadonlyArrayDyn<f64>) -> Cipherblock { - self.add_plaintext(other.as_array()) - } - pub fn add_plaintext_f32(&self, other: PyReadonlyArrayDyn<f32>) -> Cipherblock { - self.add_plaintext(other.as_array()) - } - pub fn add_plaintext_i64(&self, other: PyReadonlyArrayDyn<i64>) -> Cipherblock { - self.add_plaintext(other.as_array()) - } - pub fn add_plaintext_i32(&self, other: PyReadonlyArrayDyn<i32>) -> Cipherblock { - self.add_plaintext(other.as_array()) - } - pub fn add_plaintext_scalar_f64(&self, other: f64) -> Cipherblock { - self.add_plaintext_scalar(other) - } - pub fn add_plaintext_scalar_f32(&self, other: f32) -> Cipherblock { - self.add_plaintext_scalar(other) - } - pub fn add_plaintext_scalar_i64(&self, other: i64) -> Cipherblock { - self.add_plaintext_scalar(other) - } - pub fn add_plaintext_scalar_i32(&self, other: i32) -> Cipherblock { - self.add_plaintext_scalar(other) - } - - // sub - pub fn sub_cipherblock(&self, other: &Cipherblock) -> Cipherblock { - self.sub_cb(other) - } - pub fn sub_plaintext_f64(&self, other: PyReadonlyArrayDyn<f64>) -> Cipherblock { - self.sub_plaintext(other.as_array()) - } - pub fn sub_plaintext_f32(&self, other: PyReadonlyArrayDyn<f32>) -> Cipherblock { - self.sub_plaintext(other.as_array()) - } - pub fn sub_plaintext_i64(&self, other: PyReadonlyArrayDyn<i64>) -> Cipherblock { - self.sub_plaintext(other.as_array()) - } - pub fn sub_plaintext_i32(&self, other: PyReadonlyArrayDyn<i32>) -> Cipherblock { - self.sub_plaintext(other.as_array()) - } - pub fn sub_plaintext_scalar_f64(&self, other: f64) -> Cipherblock { - self.sub_plaintext_scalar(other) - } - pub fn sub_plaintext_scalar_f32(&self, other: f32) -> Cipherblock { - self.sub_plaintext_scalar(other) - } - pub fn sub_plaintext_scalar_i64(&self, other: i64) -> Cipherblock { - self.sub_plaintext_scalar(other) - } - pub fn sub_plaintext_scalar_i32(&self, other: i32) -> Cipherblock { - self.sub_plaintext_scalar(other) - } - - // mul - pub fn mul_plaintext_f64(&self, other: PyReadonlyArrayDyn<f64>) -> Cipherblock { - self.mul_plaintext(other.as_array()) - } - pub fn mul_plaintext_f32(&self, other: PyReadonlyArrayDyn<f32>) -> Cipherblock { - self.mul_plaintext(other.as_array()) - } - pub fn mul_plaintext_i64(&self, other: PyReadonlyArrayDyn<i64>) -> Cipherblock { - self.mul_plaintext(other.as_array()) - } - pub fn mul_plaintext_i32(&self, other: PyReadonlyArrayDyn<i32>) -> Cipherblock { - self.mul_plaintext(other.as_array()) - } - pub fn mul_plaintext_scalar_f64(&self, other: f64) -> Cipherblock { - self.mul_plaintext_scalar(other) - } - pub fn mul_plaintext_scalar_f32(&self, other: f32) -> Cipherblock { - self.mul_plaintext_scalar(other) - } - pub fn mul_plaintext_scalar_i64(&self, other: i64) -> Cipherblock { - self.mul_plaintext_scalar(other) - } - pub fn mul_plaintext_scalar_i32(&self, other: i32) -> Cipherblock { - self.mul_plaintext_scalar(other) - } - - // matmul - pub fn matmul_plaintext_ix2_f64(&self, other: PyReadonlyArray2<f64>) -> Cipherblock { - self.matmul_plaintext_ix2(other.as_array()) - } - pub fn matmul_plaintext_ix2_f32(&self, other: PyReadonlyArray2<f32>) -> Cipherblock { - self.matmul_plaintext_ix2(other.as_array()) - } - pub fn matmul_plaintext_ix2_i64(&self, other: PyReadonlyArray2<i64>) -> Cipherblock { - self.matmul_plaintext_ix2(other.as_array()) - } - pub fn matmul_plaintext_ix2_i32(&self, other: PyReadonlyArray2<i32>) -> Cipherblock { - self.matmul_plaintext_ix2(other.as_array()) - } - pub fn rmatmul_plaintext_ix2_f64(&self, other: PyReadonlyArray2<f64>) -> Cipherblock { - self.rmatmul_plaintext_ix2(other.as_array()) - } - pub fn rmatmul_plaintext_ix2_f32(&self, other: PyReadonlyArray2<f32>) -> Cipherblock { - self.rmatmul_plaintext_ix2(other.as_array()) - } - pub fn rmatmul_plaintext_ix2_i64(&self, other: PyReadonlyArray2<i64>) -> Cipherblock { - self.rmatmul_plaintext_ix2(other.as_array()) - } - pub fn rmatmul_plaintext_ix2_i32(&self, other: PyReadonlyArray2<i32>) -> Cipherblock { - self.rmatmul_plaintext_ix2(other.as_array()) - } - pub fn matmul_plaintext_ix1_f64(&self, other: PyReadonlyArray1<f64>) -> Cipherblock { - self.matmul_plaintext_ix1(other.as_array()) - } - pub fn matmul_plaintext_ix1_f32(&self, other: PyReadonlyArray1<f32>) -> Cipherblock { - self.matmul_plaintext_ix1(other.as_array()) - } - pub fn matmul_plaintext_ix1_i64(&self, other: PyReadonlyArray1<i64>) -> Cipherblock { - self.matmul_plaintext_ix1(other.as_array()) - } - pub fn matmul_plaintext_ix1_i32(&self, other: PyReadonlyArray1<i32>) -> Cipherblock { - self.matmul_plaintext_ix1(other.as_array()) - } - pub fn rmatmul_plaintext_ix1_f64(&self, other: PyReadonlyArray1<f64>) -> Cipherblock { - self.rmatmul_plaintext_ix1(other.as_array()) - } - pub fn rmatmul_plaintext_ix1_f32(&self, other: PyReadonlyArray1<f32>) -> Cipherblock { - self.rmatmul_plaintext_ix1(other.as_array()) - } - pub fn rmatmul_plaintext_ix1_i64(&self, other: PyReadonlyArray1<i64>) -> Cipherblock { - self.rmatmul_plaintext_ix1(other.as_array()) - } - pub fn rmatmul_plaintext_ix1_i32(&self, other: PyReadonlyArray1<i32>) -> Cipherblock { - self.rmatmul_plaintext_ix1(other.as_array()) - } - - // agg - pub fn sum(&self) -> Cipherblock { - self.sum_cb() - } - pub fn mean(&self) -> Cipherblock { - self.sum_cb() - } -} -#[pymodule] -fn rust_paillier(_py: Python, m: &PyModule) -> PyResult<()> { - m.add_class::<Cipherblock>()?; - m.add_class::<PK>()?; - m.add_class::<SK>()?; - m.add_function(wrap_pyfunction!(keygen, m)?)?; - - par::register(_py, m)?; - Ok(()) -} diff --git a/rust/tensor/rust_paillier/src/par/cb.rs b/rust/tensor/rust_paillier/src/par/cb.rs deleted file mode 100644 index 38a8193a5d..0000000000 --- a/rust/tensor/rust_paillier/src/par/cb.rs +++ /dev/null @@ -1,164 +0,0 @@ -use super::{block, fixedpoint, fixedpoint::CouldCode, Cipherblock, PK, SK}; -use ndarray::{ArrayD, ArrayView1, ArrayView2, ArrayViewD}; - -fn operation_with_arrayview_dyn<F, T>( - this: &Cipherblock, - other: ArrayViewD<T>, - func: F, -) -> Cipherblock -where - F: Fn(&block::Cipherblock, ArrayViewD<T>) -> block::Cipherblock, -{ - Cipherblock::new(func(this.unwrap(), other)) -} - -fn operation_with_cipherblock<F>(this: &Cipherblock, other: &Cipherblock, func: F) -> Cipherblock -where - F: Fn(&block::Cipherblock, &block::Cipherblock) -> block::Cipherblock, -{ - let a = this.unwrap(); - let b = other.unwrap(); - Cipherblock::new(func(a, b)) -} - -fn operation_with_scalar<F, T>(this: &Cipherblock, other: T, func: F) -> Cipherblock -where - F: Fn(&block::Cipherblock, T) -> block::Cipherblock, -{ - Cipherblock::new(func(this.unwrap(), other)) -} - -macro_rules! impl_ops_cipher_scalar { - ($name:ident,$fn:expr) => { - pub fn $name(&self, other: &fixedpoint::CT) -> Cipherblock { - operation_with_scalar(self, other, |lhs, rhs| { - block::Cipherblock::map_par(lhs, |c| $fn(c, rhs, &lhs.pk)) - }) - } - }; -} -macro_rules! impl_ops_plaintext_scalar { - ($name:ident,$fn:expr) => { - pub fn $name<T>(&self, other: T) -> Cipherblock - where - T: CouldCode + Sync, - { - operation_with_scalar(self, other, |lhs, rhs| { - block::Cipherblock::map_par(lhs, |c| $fn(c, &rhs.encode(&lhs.pk.coder), &lhs.pk)) - }) - } - }; -} -macro_rules! impl_ops_cipher { - ($name:ident,$fn:expr) => { - pub fn $name(&self, other: &Cipherblock) -> Cipherblock { - operation_with_cipherblock(self, other, |lhs, rhs| { - block::Cipherblock::binary_cipherblock_cipherblock_par(lhs, rhs, $fn) - }) - } - }; -} -macro_rules! impl_ops_plain { - ($name:ident,$fn:expr) => { - pub fn $name<T>(&self, other: ArrayViewD<T>) -> Cipherblock - where - T: fixedpoint::CouldCode + Sync + Send, - { - operation_with_arrayview_dyn(self, other, |lhs, rhs| { - block::Cipherblock::binary_cipherblock_plaintext_par(lhs, rhs, $fn) - }) - } - }; -} -macro_rules! impl_ops_matmul { - ($name:ident, $fn:expr, $oty:ident) => { - pub fn $name<T: CouldCode + Sync>(&self, other: $oty<T>) -> Cipherblock { - Cipherblock::new($fn(self.unwrap(), other)) - } - }; -} -impl Cipherblock { - fn new(cb: block::Cipherblock) -> Cipherblock { - Cipherblock(Some(cb)) - } - fn unwrap(&self) -> &block::Cipherblock { - self.0.as_ref().unwrap() - } - impl_ops_cipher!(add_cb, fixedpoint::CT::add); - impl_ops_plain!(add_plaintext, fixedpoint::CT::add_pt); - impl_ops_cipher_scalar!(add_cipher_scalar, fixedpoint::CT::add); - impl_ops_plaintext_scalar!(add_plaintext_scalar, fixedpoint::CT::add_pt); - - impl_ops_cipher!(sub_cb, fixedpoint::CT::sub); - impl_ops_plain!(sub_plaintext, fixedpoint::CT::sub_pt); - impl_ops_cipher_scalar!(sub_cipher_scalar, fixedpoint::CT::add); - impl_ops_plaintext_scalar!(sub_plaintext_scalar, fixedpoint::CT::sub_pt); - - impl_ops_plain!(mul_plaintext, fixedpoint::CT::mul); - impl_ops_plaintext_scalar!(mul_plaintext_scalar, fixedpoint::CT::mul); - - // matmul - impl_ops_matmul!( - matmul_plaintext_ix1, - block::Cipherblock::matmul_plaintext_ix1_par, - ArrayView1 - ); - impl_ops_matmul!( - rmatmul_plaintext_ix1, - block::Cipherblock::rmatmul_plaintext_ix1_par, - ArrayView1 - ); - impl_ops_matmul!( - matmul_plaintext_ix2, - block::Cipherblock::matmul_plaintext_ix2_par, - ArrayView2 - ); - impl_ops_matmul!( - rmatmul_plaintext_ix2, - block::Cipherblock::rmatmul_plaintext_ix2_par, - ArrayView2 - ); -} - -impl Cipherblock { - pub fn sum_cb(&self) -> Cipherblock { - let cb = self.unwrap(); - let sum = cb.agg_par( - fixedpoint::CT::zero, - |s, c| s.add(c, &cb.pk), - |s1, s2| s1.add(&s2, &cb.pk), - ); - Cipherblock::new(block::Cipherblock { - pk: cb.pk.clone(), - data: vec![sum], - shape: vec![1], - }) - } - pub fn mean_cb(&self) -> Cipherblock { - let cb = self.unwrap(); - let (s, n) = cb.agg_par( - || (fixedpoint::CT::zero(), 0usize), - |s, c| (s.0.add(c, &cb.pk), s.1 + 1), - |s1, s2| (s1.0.add(&s2.0, &cb.pk), s1.1 + s2.1), - ); - let mean = s.mul(&(1.0f64 / (n as f64)).encode(&cb.pk.coder), &cb.pk); - Cipherblock::new(block::Cipherblock { - pk: cb.pk.clone(), - data: vec![mean], - shape: vec![1], - }) - } -} - -impl SK { - pub fn decrypt_array<T: CouldCode + numpy::Element>(&self, a: &Cipherblock) -> ArrayD<T> { - let array = a.0.as_ref().unwrap(); - self.as_ref().decrypt_array_par(array) - } -} - -impl PK { - pub fn encrypt_array<T: CouldCode + Sync + Send>(&self, array: ArrayViewD<T>) -> Cipherblock { - Cipherblock::new(self.as_ref().encrypt_array_par(array)) - } -} diff --git a/rust/tensor/rust_paillier/src/par/mod.rs b/rust/tensor/rust_paillier/src/par/mod.rs deleted file mode 100644 index 3c85d1bbd2..0000000000 --- a/rust/tensor/rust_paillier/src/par/mod.rs +++ /dev/null @@ -1,311 +0,0 @@ -use crate::block; -use crate::fixedpoint; -use bincode::{deserialize, serialize}; -use numpy::{IntoPyArray, PyArrayDyn, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArrayDyn}; -use pyo3::exceptions::PyTypeError; -use pyo3::prelude::*; -use pyo3::types::PyBytes; - -mod cb; - -#[pyclass(module = "rust_paillier.par")] -pub struct Cipherblock(Option<block::Cipherblock>); - -#[pyclass(module = "rust_paillier.par")] -pub struct PK { - pk: Option<fixedpoint::PK>, -} -impl PK { - fn new(pk: fixedpoint::PK) -> Self { - Self { pk: Some(pk) } - } - fn as_ref(&self) -> &fixedpoint::PK { - self.pk.as_ref().unwrap() - } -} - -#[pyclass(module = "rust_paillier.par")] -pub struct SK { - sk: Option<fixedpoint::SK>, -} - -impl SK { - fn new(sk: fixedpoint::SK) -> Self { - Self { sk: Some(sk) } - } - fn as_ref(&self) -> &fixedpoint::SK { - self.sk.as_ref().unwrap() - } -} - -#[pyfunction] -fn keygen(bit_size: u32) -> (PK, SK) { - let (sk, pk) = fixedpoint::keygen(bit_size); - (PK::new(pk), SK::new(sk)) -} - -#[pyfunction] -fn set_num_threads(num_threads: usize) { - rayon::ThreadPoolBuilder::new() - .num_threads(num_threads) - .build_global() - .unwrap(); -} - -#[pymethods] -impl PK { - #[new] - fn __new__() -> Self { - Self { pk: None } - } - pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> { - Ok(PyBytes::new(py, &serialize(self.as_ref()).unwrap()).to_object(py)) - } - pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { - match state.extract::<&PyBytes>(py) { - Ok(s) => { - self.pk = Some(deserialize(s.as_bytes()).unwrap()); - Ok(()) - } - Err(e) => Err(e), - } - } - pub fn __richcmp__(&self, other: &PK, cmp: pyo3::basic::CompareOp) -> PyResult<bool> { - match cmp { - pyo3::basic::CompareOp::Eq => Ok(self.as_ref() == other.as_ref()), - _ => Err(PyTypeError::new_err( - "not supported between instances PK and PK", - )), - } - } - fn encrypt_f64(&self, a: PyReadonlyArrayDyn<f64>) -> Cipherblock { - self.encrypt_array(a.as_array()) - } - fn encrypt_f32(&self, a: PyReadonlyArrayDyn<f32>) -> Cipherblock { - self.encrypt_array(a.as_array()) - } - fn encrypt_i64(&self, a: PyReadonlyArrayDyn<i64>) -> Cipherblock { - self.encrypt_array(a.as_array()) - } - fn encrypt_i32(&self, a: PyReadonlyArrayDyn<i32>) -> Cipherblock { - self.encrypt_array(a.as_array()) - } -} - -#[pymethods] -impl SK { - #[new] - fn __new__() -> Self { - Self { sk: None } - } - pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> { - Ok(PyBytes::new(py, &serialize(self.as_ref()).unwrap()).to_object(py)) - } - pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { - match state.extract::<&PyBytes>(py) { - Ok(s) => { - self.sk = Some(deserialize(s.as_bytes()).unwrap()); - Ok(()) - } - Err(e) => Err(e), - } - } - pub fn __richcmp__(&self, other: &SK, cmp: pyo3::basic::CompareOp) -> PyResult<bool> { - match cmp { - pyo3::basic::CompareOp::Eq => Ok(self.as_ref() == other.as_ref()), - _ => Err(PyTypeError::new_err( - "not supported between instances PK and PK", - )), - } - } - fn decrypt_f64<'py>(&self, a: &Cipherblock, py: Python<'py>) -> &'py PyArrayDyn<f64> { - self.decrypt_array(a).into_pyarray(py) - } - fn decrypt_f32<'py>(&self, a: &Cipherblock, py: Python<'py>) -> &'py PyArrayDyn<f32> { - self.decrypt_array(a).into_pyarray(py) - } - fn decrypt_i64<'py>(&self, a: &Cipherblock, py: Python<'py>) -> &'py PyArrayDyn<i64> { - self.decrypt_array(a).into_pyarray(py) - } - fn decrypt_i32<'py>(&self, a: &Cipherblock, py: Python<'py>) -> &'py PyArrayDyn<i32> { - self.decrypt_array(a).into_pyarray(py) - } -} - -#[pymethods] -impl Cipherblock { - #[new] - fn __new__() -> Self { - Cipherblock(None) - } - pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> { - Ok(PyBytes::new(py, &serialize(&self.0).unwrap()).to_object(py)) - } - pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { - match state.extract::<&PyBytes>(py) { - Ok(s) => { - self.0 = deserialize(s.as_bytes()).unwrap(); - Ok(()) - } - Err(e) => Err(e), - } - } - #[getter] - pub fn shape(&self) -> Vec<usize> { - self.0.as_ref().map(|cb| cb.shape.clone()).unwrap() - } - // add - pub fn add_cipherblock(&self, other: &Cipherblock) -> Cipherblock { - self.add_cb(other) - } - pub fn add_plaintext_f64(&self, other: PyReadonlyArrayDyn<f64>) -> Cipherblock { - self.add_plaintext(other.as_array()) - } - pub fn add_plaintext_f32(&self, other: PyReadonlyArrayDyn<f32>) -> Cipherblock { - self.add_plaintext(other.as_array()) - } - pub fn add_plaintext_i64(&self, other: PyReadonlyArrayDyn<i64>) -> Cipherblock { - self.add_plaintext(other.as_array()) - } - pub fn add_plaintext_i32(&self, other: PyReadonlyArrayDyn<i32>) -> Cipherblock { - self.add_plaintext(other.as_array()) - } - pub fn add_plaintext_scalar_f64(&self, other: f64) -> Cipherblock { - self.add_plaintext_scalar(other) - } - pub fn add_plaintext_scalar_f32(&self, other: f32) -> Cipherblock { - self.add_plaintext_scalar(other) - } - pub fn add_plaintext_scalar_i64(&self, other: i64) -> Cipherblock { - self.add_plaintext_scalar(other) - } - pub fn add_plaintext_scalar_i32(&self, other: i32) -> Cipherblock { - self.add_plaintext_scalar(other) - } - - // sub - pub fn sub_cipherblock(&self, other: &Cipherblock) -> Cipherblock { - self.sub_cb(other) - } - pub fn sub_plaintext_f64(&self, other: PyReadonlyArrayDyn<f64>) -> Cipherblock { - self.sub_plaintext(other.as_array()) - } - pub fn sub_plaintext_f32(&self, other: PyReadonlyArrayDyn<f32>) -> Cipherblock { - self.sub_plaintext(other.as_array()) - } - pub fn sub_plaintext_i64(&self, other: PyReadonlyArrayDyn<i64>) -> Cipherblock { - self.sub_plaintext(other.as_array()) - } - pub fn sub_plaintext_i32(&self, other: PyReadonlyArrayDyn<i32>) -> Cipherblock { - self.sub_plaintext(other.as_array()) - } - pub fn sub_plaintext_scalar_f64(&self, other: f64) -> Cipherblock { - self.sub_plaintext_scalar(other) - } - pub fn sub_plaintext_scalar_f32(&self, other: f32) -> Cipherblock { - self.sub_plaintext_scalar(other) - } - pub fn sub_plaintext_scalar_i64(&self, other: i64) -> Cipherblock { - self.sub_plaintext_scalar(other) - } - pub fn sub_plaintext_scalar_i32(&self, other: i32) -> Cipherblock { - self.sub_plaintext_scalar(other) - } - - // mul - pub fn mul_plaintext_f64(&self, other: PyReadonlyArrayDyn<f64>) -> Cipherblock { - self.mul_plaintext(other.as_array()) - } - pub fn mul_plaintext_f32(&self, other: PyReadonlyArrayDyn<f32>) -> Cipherblock { - self.mul_plaintext(other.as_array()) - } - pub fn mul_plaintext_i64(&self, other: PyReadonlyArrayDyn<i64>) -> Cipherblock { - self.mul_plaintext(other.as_array()) - } - pub fn mul_plaintext_i32(&self, other: PyReadonlyArrayDyn<i32>) -> Cipherblock { - self.mul_plaintext(other.as_array()) - } - pub fn mul_plaintext_scalar_f64(&self, other: f64) -> Cipherblock { - self.mul_plaintext_scalar(other) - } - pub fn mul_plaintext_scalar_f32(&self, other: f32) -> Cipherblock { - self.mul_plaintext_scalar(other) - } - pub fn mul_plaintext_scalar_i64(&self, other: i64) -> Cipherblock { - self.mul_plaintext_scalar(other) - } - pub fn mul_plaintext_scalar_i32(&self, other: i32) -> Cipherblock { - self.mul_plaintext_scalar(other) - } - - // matmul - pub fn matmul_plaintext_ix2_f64(&self, other: PyReadonlyArray2<f64>) -> Cipherblock { - self.matmul_plaintext_ix2(other.as_array()) - } - pub fn matmul_plaintext_ix2_f32(&self, other: PyReadonlyArray2<f32>) -> Cipherblock { - self.matmul_plaintext_ix2(other.as_array()) - } - pub fn matmul_plaintext_ix2_i64(&self, other: PyReadonlyArray2<i64>) -> Cipherblock { - self.matmul_plaintext_ix2(other.as_array()) - } - pub fn matmul_plaintext_ix2_i32(&self, other: PyReadonlyArray2<i32>) -> Cipherblock { - self.matmul_plaintext_ix2(other.as_array()) - } - pub fn rmatmul_plaintext_ix2_f64(&self, other: PyReadonlyArray2<f64>) -> Cipherblock { - self.rmatmul_plaintext_ix2(other.as_array()) - } - pub fn rmatmul_plaintext_ix2_f32(&self, other: PyReadonlyArray2<f32>) -> Cipherblock { - self.rmatmul_plaintext_ix2(other.as_array()) - } - pub fn rmatmul_plaintext_ix2_i64(&self, other: PyReadonlyArray2<i64>) -> Cipherblock { - self.rmatmul_plaintext_ix2(other.as_array()) - } - pub fn rmatmul_plaintext_ix2_i32(&self, other: PyReadonlyArray2<i32>) -> Cipherblock { - self.rmatmul_plaintext_ix2(other.as_array()) - } - pub fn matmul_plaintext_ix1_f64(&self, other: PyReadonlyArray1<f64>) -> Cipherblock { - self.matmul_plaintext_ix1(other.as_array()) - } - pub fn matmul_plaintext_ix1_f32(&self, other: PyReadonlyArray1<f32>) -> Cipherblock { - self.matmul_plaintext_ix1(other.as_array()) - } - pub fn matmul_plaintext_ix1_i64(&self, other: PyReadonlyArray1<i64>) -> Cipherblock { - self.matmul_plaintext_ix1(other.as_array()) - } - pub fn matmul_plaintext_ix1_i32(&self, other: PyReadonlyArray1<i32>) -> Cipherblock { - self.matmul_plaintext_ix1(other.as_array()) - } - pub fn rmatmul_plaintext_ix1_f64(&self, other: PyReadonlyArray1<f64>) -> Cipherblock { - self.rmatmul_plaintext_ix1(other.as_array()) - } - pub fn rmatmul_plaintext_ix1_f32(&self, other: PyReadonlyArray1<f32>) -> Cipherblock { - self.rmatmul_plaintext_ix1(other.as_array()) - } - pub fn rmatmul_plaintext_ix1_i64(&self, other: PyReadonlyArray1<i64>) -> Cipherblock { - self.rmatmul_plaintext_ix1(other.as_array()) - } - pub fn rmatmul_plaintext_ix1_i32(&self, other: PyReadonlyArray1<i32>) -> Cipherblock { - self.rmatmul_plaintext_ix1(other.as_array()) - } - // agg - pub fn sum(&self) -> Cipherblock { - self.sum_cb() - } - pub fn mean(&self) -> Cipherblock { - self.sum_cb() - } -} - -pub(crate) fn register(py: Python, m: &PyModule) -> PyResult<()> { - let submodule_par = PyModule::new(py, "par")?; - submodule_par.add_function(wrap_pyfunction!(keygen, submodule_par)?)?; - submodule_par.add_function(wrap_pyfunction!(set_num_threads, submodule_par)?)?; - submodule_par.add_class::<Cipherblock>()?; - submodule_par.add_class::<PK>()?; - submodule_par.add_class::<SK>()?; - m.add_submodule(submodule_par)?; - py.import("sys")? - .getattr("modules")? - .set_item("rust_paillier.par", submodule_par)?; - Ok(()) -} diff --git a/schemas/components/feature_scale.yaml b/schemas/components/feature_scale.yaml deleted file mode 100644 index b37b2512a6..0000000000 --- a/schemas/components/feature_scale.yaml +++ /dev/null @@ -1,67 +0,0 @@ -component: - name: feature_scale - description: '' - provider: fate - version: 2.0.0.alpha - labels: [] - roles: - - guest - - host - input_definitions: - parameters: - method: - type: str - default: standard - optional: false - artifacts: - train_data: - type: dataset - optional: false - stages: - - train - roles: - - guest - - host - input_model: - type: model - optional: false - stages: - - predict - roles: - - guest - - host - test_data: - type: dataset - optional: false - stages: - - predict - roles: - - guest - - host - output_definitions: - artifacts: - train_output_data: - type: dataset - optional: false - stages: - - train - roles: - - guest - - host - output_model: - type: model - optional: false - stages: - - train - roles: - - guest - - host - test_output_data: - type: dataset - optional: false - stages: - - predict - roles: - - guest - - host -schema_version: v1 diff --git a/schemas/components/intersection.yaml b/schemas/components/intersection.yaml deleted file mode 100644 index d62a5e4389..0000000000 --- a/schemas/components/intersection.yaml +++ /dev/null @@ -1,35 +0,0 @@ -component: - name: intersection - description: '' - provider: fate - version: 2.0.0.alpha - labels: [] - roles: - - guest - - host - input_definitions: - parameters: - method: - type: str - default: raw - optional: true - artifacts: - input_data: - type: dataset - optional: false - stages: - - default - roles: - - guest - - host - output_definitions: - artifacts: - output_data: - type: dataset - optional: false - stages: - - default - roles: - - guest - - host -schema_version: v1 diff --git a/schemas/components/lr.yaml b/schemas/components/lr.yaml deleted file mode 100644 index 792a55626f..0000000000 --- a/schemas/components/lr.yaml +++ /dev/null @@ -1,108 +0,0 @@ -component: - name: hetero_lr - description: '' - provider: fate - version: 2.0.0-alpha - labels: [] - roles: - - guest - - host - - arbiter - input_definitions: - parameters: - learning_rate: - type: ConFloat - default: 0.1 - optional: true - description: learning rate - type_meta: - gt: 0.0 - max_iter: - type: ConInt - default: 100 - optional: true - description: max iteration num - type_meta: - gt: 0 - batch_size: - type: ConInt - default: 100 - optional: true - description: batch size, value less or equals to 0 means full batch - type_meta: {} - artifacts: - validate_data: - type: dataset - optional: true - stages: - - train - roles: - - guest - - host - description: validation data - train_data: - type: dataset - optional: false - stages: - - train - roles: - - guest - - host - description: training data - test_data: - type: dataset - optional: false - stages: - - predict - roles: - - guest - - host - description: '' - input_model: - type: model - optional: false - stages: - - predict - roles: - - guest - - host - description: '' - output_definitions: - artifacts: - train_output_data: - type: dataset - optional: false - stages: - - train - roles: - - guest - - host - description: '' - test_output_data: - type: dataset - optional: false - stages: - - predict - roles: - - guest - - host - description: '' - output_model: - type: model - optional: false - stages: - - train - roles: - - guest - - host - description: '' - train_output_metric: - type: loss - optional: false - stages: - - train - roles: - - arbiter - description: '' -schema_version: v1 - diff --git a/schemas/components/reader.yaml b/schemas/components/reader.yaml deleted file mode 100644 index bc5334594b..0000000000 --- a/schemas/components/reader.yaml +++ /dev/null @@ -1,51 +0,0 @@ -component: - name: reader - description: '' - provider: fate - version: 2.0.0.alpha - labels: [] - roles: - - guest - - host - input_definitions: - parameters: - path: - type: str - default: - optional: false - format: - type: str - default: csv - optional: false - id_name: - type: str - default: id - optional: true - delimiter: - type: str - default: ',' - optional: true - label_name: - type: str - default: - optional: true - label_type: - type: str - default: float32 - optional: true - dtype: - type: str - default: float32 - optional: true - artifacts: {} - output_definitions: - artifacts: - output_data: - type: dataset - optional: false - stages: - - default - roles: - - guest - - host -schema_version: v1 diff --git a/schemas/jobs/training_dag.yaml b/schemas/jobs/training_dag.yaml deleted file mode 100644 index 3185785a96..0000000000 --- a/schemas/jobs/training_dag.yaml +++ /dev/null @@ -1,124 +0,0 @@ -dag: - parties: - - party_id: ['9999'] - role: guest - - party_id: ['9999'] - role: host - - party_id: ['9999'] - role: arbiter - party_tasks: - guest_9999: - parties: - - party_id: ['9999'] - role: guest - tasks: - reader_0: - inputs: - parameters: {delimiter: ',', dtype: float32, format: csv, id_name: id, - label_name: y, label_type: float32, path: 'file://${abs_path_of_data_guest}'} - host_9999: - parties: - - party_id: ['9999'] - role: host - tasks: - reader_0: - inputs: - parameters: {delimiter: ',', dtype: float32, format: csv, id_name: id, - label_name: null, path: 'file://${abs_path_of_data_host}'} - stage: train - tasks: - evaluation_0: - component_ref: evaluation - dependent_tasks: [lr_0] - inputs: - artifacts: - input_data: - task_output_artifact: - output_artifact_key: train_output_data - producer_task: lr_0 - roles: [guest] - parties: - - party_id: ['9999'] - role: guest - stage: default - feature_scale_0: - component_ref: feature_scale - dependent_tasks: [intersection_0] - inputs: - artifacts: - train_data: - task_output_artifact: {output_artifact_key: output_data, producer_task: intersection_0} - parameters: {method: standard} - parties: - - party_id: ['9999'] - role: guest - - party_id: ['9999'] - role: host - feature_scale_1: - component_ref: feature_scale - dependent_tasks: [intersection_1, feature_scale_0] - inputs: - artifacts: - input_model: - task_output_artifact: {output_artifact_key: output_model, producer_task: feature_scale_0} - test_data: - task_output_artifact: {output_artifact_key: output_data, producer_task: intersection_1} - parties: - - party_id: ['9999'] - role: guest - - party_id: ['9999'] - role: host - stage: predict - intersection_0: - component_ref: intersection - dependent_tasks: [reader_0] - inputs: - artifacts: - input_data: - task_output_artifact: {output_artifact_key: output_data, producer_task: reader_0} - parameters: {method: raw} - parties: - - party_id: ['9999'] - role: guest - - party_id: ['9999'] - role: host - stage: default - intersection_1: - component_ref: intersection - dependent_tasks: [reader_0] - inputs: - artifacts: - input_data: - task_output_artifact: {output_artifact_key: output_data, producer_task: reader_0} - parameters: {method: raw} - parties: - - party_id: ['9999'] - role: guest - - party_id: ['9999'] - role: host - stage: default - lr_0: - component_ref: hetero_lr - dependent_tasks: [feature_scale_0, feature_scale_1] - inputs: - artifacts: - train_data: - task_output_artifact: - output_artifact_key: train_output_data - producer_task: feature_scale_0 - roles: [guest, host] - validate_data: - task_output_artifact: - output_artifact_key: test_output_data - producer_task: feature_scale_1 - roles: [guest, host] - parameters: {batch_size: 100, learning_rate: 0.01, max_iter: 1} - reader_0: - component_ref: reader - parties: - - party_id: ['9999'] - role: guest - - party_id: ['9999'] - role: host - stage: default -schema_version: 2.0.0.alpha diff --git a/setup.py b/setup.py deleted file mode 100644 index e5a0d9b483..0000000000 --- a/setup.py +++ /dev/null @@ -1 +0,0 @@ -#!/usr/bin/env python3