diff --git a/jointContribution/AI_Climate_Diseases/ERA5_land(1).ipynb b/jointContribution/AI_Climate_Diseases/ERA5_land(1).ipynb new file mode 100644 index 0000000000..f2e4bd63f9 --- /dev/null +++ b/jointContribution/AI_Climate_Diseases/ERA5_land(1).ipynb @@ -0,0 +1,1608 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "77f152bb34f64833ad4f9e10337992f6": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_79cf5f4c819246149f0b3b7cfaf03b5b", + "IPY_MODEL_c142db7f1d6c4222b3811342b46b0c30", + "IPY_MODEL_485aa27dcfbb4cb9a61da98a8ad277c1" + ], + "layout": "IPY_MODEL_652a7f72be404e2580fb4592348330cb" + } + }, + "79cf5f4c819246149f0b3b7cfaf03b5b": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_76fd9553926040f5a108afe18be359b8", + "placeholder": "​", + "style": "IPY_MODEL_492b0421e1cc4ee2b02d24910fbbd367", + "value": "e5dc627d8a097bec79906e75846db42e.zip:  90%" + } + }, + "c142db7f1d6c4222b3811342b46b0c30": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_4065f75702ea4b30b25b3e0adc6d1cf4", + "max": 4672516, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_8d4b92424bde4065b92126463c3a88cf", + "value": 4672516 + } + }, + "485aa27dcfbb4cb9a61da98a8ad277c1": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_da3c858f56284927907e863b3f7bff15", + "placeholder": "​", + "style": "IPY_MODEL_17a5e81074e042208e7c41a6fe72d705", + "value": " 4.00M/4.46M [00:01<00:00, 4.23MB/s]" + } + }, + "652a7f72be404e2580fb4592348330cb": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": null + } + }, + "76fd9553926040f5a108afe18be359b8": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "492b0421e1cc4ee2b02d24910fbbd367": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "4065f75702ea4b30b25b3e0adc6d1cf4": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "8d4b92424bde4065b92126463c3a88cf": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "da3c858f56284927907e863b3f7bff15": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "17a5e81074e042208e7c41a6fe72d705": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "7638bda1ed3e4c66a4a47c3ff6a8edea": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_b30de5c069a4404ab2f48af8feef79cc", + "IPY_MODEL_23894ce319a141d19310e4c02124bd38", + "IPY_MODEL_2bf99ea9822d41d39a8fbd324bb8b3e5" + ], + "layout": "IPY_MODEL_3a5456e2d86a48be853f4552be673819" + } + }, + "b30de5c069a4404ab2f48af8feef79cc": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2f8e3c66060b4ef1b57aa2ea97259b7b", + "placeholder": "​", + "style": "IPY_MODEL_daf5ce1cd8e44d47aa090b2693b83d10", + "value": "76d9dacc7921e0bd06d9383480d75e62.zip:  89%" + } + }, + "23894ce319a141d19310e4c02124bd38": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_701ad51f4d414a3fbaa9c9eb5e448785", + "max": 4688001, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_5b81063a9d2f46d69a3d138704717817", + "value": 4688001 + } + }, + "2bf99ea9822d41d39a8fbd324bb8b3e5": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_87ca252dfca34361ab50ea1809856703", + "placeholder": "​", + "style": "IPY_MODEL_db01f8756847498482428b5460105a6b", + "value": " 4.00M/4.47M [00:01<00:00, 4.38MB/s]" + } + }, + "3a5456e2d86a48be853f4552be673819": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": null + } + }, + "2f8e3c66060b4ef1b57aa2ea97259b7b": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "daf5ce1cd8e44d47aa090b2693b83d10": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "701ad51f4d414a3fbaa9c9eb5e448785": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "5b81063a9d2f46d69a3d138704717817": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "87ca252dfca34361ab50ea1809856703": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "db01f8756847498482428b5460105a6b": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + } + } + } + }, + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "3tsuE-BfSlWm", + "outputId": "d1bd225c-8951-4b46-aa4d-29288258c515" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting cdsapi\n", + " Downloading cdsapi-0.7.6-py2.py3-none-any.whl.metadata (3.0 kB)\n", + "Collecting ecmwf-datastores-client (from cdsapi)\n", + " Downloading ecmwf_datastores_client-0.4.0-py3-none-any.whl.metadata (21 kB)\n", + "Requirement already satisfied: requests>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from cdsapi) (2.32.4)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (from cdsapi) (4.67.1)\n", + "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests>=2.5.0->cdsapi) (3.4.3)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests>=2.5.0->cdsapi) (3.10)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests>=2.5.0->cdsapi) (2.5.0)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests>=2.5.0->cdsapi) (2025.8.3)\n", + "Requirement already satisfied: attrs in /usr/local/lib/python3.12/dist-packages (from ecmwf-datastores-client->cdsapi) (25.3.0)\n", + "Collecting multiurl>=0.3.7 (from ecmwf-datastores-client->cdsapi)\n", + " Downloading multiurl-0.3.7-py3-none-any.whl.metadata (2.8 kB)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.12/dist-packages (from ecmwf-datastores-client->cdsapi) (4.15.0)\n", + "Requirement already satisfied: pytz in /usr/local/lib/python3.12/dist-packages (from multiurl>=0.3.7->ecmwf-datastores-client->cdsapi) (2025.2)\n", + "Requirement already satisfied: python-dateutil in /usr/local/lib/python3.12/dist-packages (from multiurl>=0.3.7->ecmwf-datastores-client->cdsapi) (2.9.0.post0)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil->multiurl>=0.3.7->ecmwf-datastores-client->cdsapi) (1.17.0)\n", + "Downloading cdsapi-0.7.6-py2.py3-none-any.whl (12 kB)\n", + "Downloading ecmwf_datastores_client-0.4.0-py3-none-any.whl (29 kB)\n", + "Downloading multiurl-0.3.7-py3-none-any.whl (21 kB)\n", + "Installing collected packages: multiurl, ecmwf-datastores-client, cdsapi\n", + "Successfully installed cdsapi-0.7.6 ecmwf-datastores-client-0.4.0 multiurl-0.3.7\n" + ] + } + ], + "source": [ + "!pip install cdsapi" + ] + }, + { + "cell_type": "code", + "source": [ + "\n", + "import os, getpass, textwrap, pathlib\n", + "\n", + "\n", + "cfg = textwrap.dedent(f\"\"\"\\\n", + "url: https://cds.climate.copernicus.eu/api\n", + "key: 55a51e6d-554d-46e6-8743-c8f5f4a98f9b\n", + "\"\"\")\n", + "\n", + "path = pathlib.Path(\"~/.cdsapirc\").expanduser()\n", + "path.write_text(cfg)\n", + "# 收紧权限(Linux 600)\n", + "!chmod 600 ~/.cdsapirc\n", + "\n", + "print(\"~/.cdsapirc 写入完成\")\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "mGHbxM7JWnn-", + "outputId": "d7bb263e-bd87-44e6-e29c-c879b263b8ce" + }, + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "~/.cdsapirc 写入完成\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "import cdsapi\n", + "c = cdsapi.Client()\n", + "print(\"CDS client OK\")\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "gpHQ393HXAC2", + "outputId": "9468f689-9e15-4a20-e75c-d8ac19012343" + }, + "execution_count": 3, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "2025-09-28 10:16:07,307 INFO [2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "INFO:ecmwf.datastores.legacy_client:[2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "2025-09-28 10:16:07,309 INFO [2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "INFO:ecmwf.datastores.legacy_client:[2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "CDS client OK\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# download_era5land_chunked_full_parallel_by_var.py —— 在你的基础上:按“变量”并发\n", + "# -*- coding: utf-8 -*-\n", + "import argparse\n", + "import time\n", + "import random\n", + "from pathlib import Path\n", + "from typing import List, Tuple\n", + "import sys\n", + "from concurrent.futures import ThreadPoolExecutor, as_completed\n", + "\n", + "import cdsapi\n", + "\n", + "DATASET = \"reanalysis-era5-land\"\n", + "\n", + "# ---------- 时间维度 ----------\n", + "ALL_DAYS: List[str] = [f\"{d:02d}\" for d in range(1, 32)]\n", + "ALL_HOURS: List[str] = [f\"{h:02d}:00\" for h in range(0, 24)]\n", + "ALL_MONTHS: List[str] = [f\"{m:02d}\" for m in range(1, 13)]\n", + "\n", + "# ---------- 重试设置 ----------\n", + "MAX_RETRIES = 8\n", + "BASE_SLEEP = 10 # seconds\n", + "\n", + "# ---------- 变量全集(你的清单 + 注释项全部纳入) ----------\n", + "VARIABLES: List[str] = [\n", + " # 2m/skin/soil/lake temps\n", + " \"2m_dewpoint_temperature\",\n", + " \"2m_temperature\",\n", + " \"skin_temperature\",\n", + " \"soil_temperature_level_1\",\n", + " \"soil_temperature_level_2\",\n", + " \"soil_temperature_level_3\",\n", + " \"soil_temperature_level_4\",\n", + " \"lake_bottom_temperature\",\n", + " \"lake_ice_depth\",\n", + " \"lake_ice_temperature\",\n", + " \"lake_mix_layer_depth\",\n", + " \"lake_mix_layer_temperature\",\n", + " \"lake_shape_factor\",\n", + " # 你原注释掉的项(已启用)\n", + " \"lake_total_layer_temperature\",\n", + " \"snow_albedo\",\n", + " \"snow_cover\",\n", + " \"snow_density\",\n", + " \"snow_depth\",\n", + " \"snow_depth_water_equivalent\",\n", + " \"snowfall\",\n", + " \"snowmelt\",\n", + " \"temperature_of_snow_layer\",\n", + " \"forecast_albedo\",\n", + " \"surface_latent_heat_flux\",\n", + " \"surface_net_solar_radiation\",\n", + " \"surface_net_thermal_radiation\",\n", + " \"surface_sensible_heat_flux\",\n", + " \"surface_solar_radiation_downwards\",\n", + " \"surface_thermal_radiation_downwards\",\n", + " \"evaporation_from_bare_soil\",\n", + " \"evaporation_from_open_water_surfaces_excluding_oceans\",\n", + " \"evaporation_from_the_top_of_canopy\",\n", + " \"evaporation_from_vegetation_transpiration\",\n", + " \"potential_evaporation\",\n", + " \"runoff\",\n", + " \"snow_evaporation\",\n", + " \"sub_surface_runoff\",\n", + " \"surface_runoff\",\n", + " \"total_evaporation\",\n", + " \"10m_u_component_of_wind\",\n", + " \"10m_v_component_of_wind\",\n", + " \"surface_pressure\",\n", + " \"total_precipitation\",\n", + " \"leaf_area_index_high_vegetation\",\n", + " \"leaf_area_index_low_vegetation\",\n", + " \"high_vegetation_cover\",\n", + " \"glacier_mask\",\n", + " \"lake_cover\",\n", + " \"low_vegetation_cover\",\n", + " \"lake_total_depth\",\n", + " \"land_sea_mask\",\n", + " \"soil_type\",\n", + " \"type_of_high_vegetation\",\n", + " \"type_of_low_vegetation\",\n", + "]\n", + "\n", + "def build_request(\n", + " variable: str,\n", + " year: str,\n", + " month: str,\n", + " area_box: Tuple[float, float, float, float],\n", + " fmt: str,\n", + ") -> dict:\n", + " north, west, south, east = area_box\n", + " req = {\n", + " \"variable\": variable,\n", + " \"year\": year,\n", + " \"month\": month,\n", + " \"day\": ALL_DAYS,\n", + " \"time\": ALL_HOURS,\n", + " \"area\": [north, west, south, east], # N W S E\n", + " \"format\": fmt, # \"grib\" | \"netcdf\"\n", + " \"download_format\": \"zip\",\n", + " # \"product_type\": \"reanalysis\",\n", + " }\n", + " return req\n", + "\n", + "def safe_retrieve(client: cdsapi.Client, dataset: str, request: dict, target_path: Path):\n", + " \"\"\"下载单个分块(含指数退避 + 轻度抖动),返回 True/False\"\"\"\n", + " # 轻度抖动,降低“羊群效应”\n", + " time.sleep(random.uniform(0.3, 1.0))\n", + " attempt = 0\n", + " while True:\n", + " try:\n", + " client.retrieve(dataset, request).download(str(target_path))\n", + " return True\n", + " except Exception as e:\n", + " attempt += 1\n", + " msg = str(e).lower()\n", + " # 明确不可恢复的错误(变量无效/不可用/无数据)直接跳过\n", + " unrecoverable_signals = [\n", + " \"unavailable\",\n", + " \"not available\",\n", + " \"invalid\",\n", + " \"does not match\",\n", + " \"no data\",\n", + " \"bad request\",\n", + " \"cannot be found\",\n", + " ]\n", + " if any(s in msg for s in unrecoverable_signals):\n", + " print(f\"[ERROR] Unrecoverable for {target_path.name}: {e}\")\n", + " return False\n", + "\n", + " if attempt > MAX_RETRIES:\n", + " print(f\"[ERROR] Max retries exceeded for {target_path.name}: {e}\")\n", + " return False\n", + "\n", + " sleep_s = BASE_SLEEP * (2 ** (attempt - 1)) * random.uniform(0.85, 1.15)\n", + " print(f\"[WARN] Download failed (attempt {attempt}/{MAX_RETRIES}): {e}\")\n", + " print(f\" Sleeping {sleep_s:.0f}s then retrying...\")\n", + " time.sleep(sleep_s)\n", + "\n", + "def parse_args_with_defaults():\n", + " parser = argparse.ArgumentParser(\n", + " description=\"ERA5-Land downloader (split by variable × year × month), parallel by VARIABLE\"\n", + " )\n", + " # —— 给出默认值,不再强制要求 —— #\n", + " parser.add_argument(\"--out_dir\", type=str, default=\"./era5land\",\n", + " help=\"输出根目录(默认 ./era5land)\")\n", + " parser.add_argument(\"--bbox\", nargs=4, type=float,\n", + " default=[60.86, -6.23, 49.86, 1.75],\n", + " metavar=(\"NORTH\", \"WEST\", \"SOUTH\", \"EAST\"),\n", + " help=\"经纬度范围:N W S E(默认 60.86 -6.23 49.86 1.75)\")\n", + " parser.add_argument(\"--format\", default=\"grib\", choices=[\"grib\", \"netcdf\"],\n", + " help=\"文件格式(默认 grib)\")\n", + " parser.add_argument(\"--years\", nargs=\"+\",\n", + " default=[str(y) for y in range(1997, 2023)], # 1997–2022\n", + " help=\"年份列表(默认 1997..2022)\")\n", + " parser.add_argument(\"--months\", nargs=\"+\", default=ALL_MONTHS,\n", + " help=\"月份列表(默认 01..12)\")\n", + " parser.add_argument(\"--variables\", nargs=\"+\", default=VARIABLES,\n", + " help=\"变量名列表(默认为脚本内置全集)\")\n", + " parser.add_argument(\"--skip_existing\", action=\"store_true\",\n", + " help=\"若目标文件已存在则跳过\")\n", + " parser.add_argument(\"--max_workers\", type=int, default=3,\n", + " help=\"并发的变量数(建议 2–3)\")\n", + " # 如果在 Notebook 中直接运行,且没有传任何参数,也能用默认值\n", + " try:\n", + " return parser.parse_args([])\n", + " except SystemExit:\n", + " # 在某些环境 parse_args([]) 会触发 SystemExit,退回到标准方式\n", + " return parser.parse_args()\n", + "\n", + "def download_one_variable(var: str, args) -> tuple[str, int, int]:\n", + " \"\"\"在一个线程内:顺序下载某个变量的所有 年×月,返回 (var, ok, fail)\"\"\"\n", + " client = cdsapi.Client() # 每个线程各自的 client\n", + " ok = 0\n", + " fail = 0\n", + "\n", + " out_root = Path(args.out_dir)\n", + "\n", + " for year in args.years:\n", + " for month in args.months:\n", + " subdir = out_root / var / str(year)\n", + " subdir.mkdir(parents=True, exist_ok=True)\n", + "\n", + " suffix = \"grib\" if args.format == \"grib\" else \"nc\"\n", + " target_name = f\"{DATASET}_{var}_{year}-{month}.{suffix}.zip\"\n", + " target_path = subdir / target_name\n", + "\n", + " if args.skip_existing and target_path.exists():\n", + " # 已存在直接视为成功,便于断点续跑\n", + " # 你也可以换成校验 zip 完整性的逻辑\n", + " continue_ok = True\n", + " if continue_ok:\n", + " ok += 1\n", + " continue\n", + "\n", + " req = build_request(\n", + " variable=var,\n", + " year=str(year),\n", + " month=f\"{int(month):02d}\",\n", + " area_box=tuple(args.bbox),\n", + " fmt=args.format,\n", + " )\n", + "\n", + " print(f\"[INFO][{var}] Downloading {year}-{month} -> {target_path}\")\n", + " success = safe_retrieve(client, DATASET, req, target_path)\n", + " if success:\n", + " ok += 1\n", + " else:\n", + " fail += 1\n", + "\n", + " return var, ok, fail\n", + "\n", + "def main():\n", + " # 在 Notebook/Colab 里,这里会采用默认值;命令行下可用参数覆盖\n", + " if \"ipykernel\" in sys.modules or \"google.colab\" in sys.modules:\n", + " args = parse_args_with_defaults()\n", + " else:\n", + " args = parse_args_with_defaults()\n", + "\n", + " variables = list(args.variables)\n", + " if not variables:\n", + " print(\"[WARN] 未提供变量列表,使用内置 VARIABLES。\")\n", + " variables = VARIABLES\n", + "\n", + " # 并发数量不超过变量数\n", + " max_workers = max(1, min(args.max_workers, len(variables)))\n", + "\n", + " print(f\"[INFO] Variables: {len(variables)} | Years: {len(args.years)} | Months: {len(args.months)}\")\n", + " print(f\"[INFO] Parallel by VARIABLE with max_workers = {max_workers}\")\n", + " start = time.time()\n", + "\n", + " total_ok = 0\n", + " total_fail = 0\n", + " results = []\n", + "\n", + " with ThreadPoolExecutor(max_workers=max_workers) as ex:\n", + " futures = {ex.submit(download_one_variable, var, args): var for var in variables}\n", + " for fut in as_completed(futures):\n", + " var, ok, fail = fut.result()\n", + " results.append((var, ok, fail))\n", + " total_ok += ok\n", + " total_fail += fail\n", + " print(f\"[DONE][{var}] Ok={ok}, Fail={fail}\")\n", + "\n", + " elapsed = time.time() - start\n", + " print(\"\\n================ SUMMARY ================\")\n", + " for var, ok, fail in sorted(results):\n", + " print(f\"{var:40s} Ok={ok:4d} Fail={fail:3d}\")\n", + " print(f\"-----------------------------------------\")\n", + " print(f\"TOTAL Ok={total_ok} Fail={total_fail} | Elapsed: {elapsed/60:.1f} min\")\n", + " print(\"=========================================\\n\")\n", + "\n", + "if __name__ == \"__main__\":\n", + " main()\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 854, + "referenced_widgets": [ + "77f152bb34f64833ad4f9e10337992f6", + "79cf5f4c819246149f0b3b7cfaf03b5b", + "c142db7f1d6c4222b3811342b46b0c30", + "485aa27dcfbb4cb9a61da98a8ad277c1", + "652a7f72be404e2580fb4592348330cb", + "76fd9553926040f5a108afe18be359b8", + "492b0421e1cc4ee2b02d24910fbbd367", + "4065f75702ea4b30b25b3e0adc6d1cf4", + "8d4b92424bde4065b92126463c3a88cf", + "da3c858f56284927907e863b3f7bff15", + "17a5e81074e042208e7c41a6fe72d705", + "7638bda1ed3e4c66a4a47c3ff6a8edea", + "b30de5c069a4404ab2f48af8feef79cc", + "23894ce319a141d19310e4c02124bd38", + "2bf99ea9822d41d39a8fbd324bb8b3e5", + "3a5456e2d86a48be853f4552be673819", + "2f8e3c66060b4ef1b57aa2ea97259b7b", + "daf5ce1cd8e44d47aa090b2693b83d10", + "701ad51f4d414a3fbaa9c9eb5e448785", + "5b81063a9d2f46d69a3d138704717817", + "87ca252dfca34361ab50ea1809856703", + "db01f8756847498482428b5460105a6b" + ] + }, + "id": "2acnD1HPz3l5", + "outputId": "45f323f6-0ec7-4c0f-910c-dd0ed415e62e" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[INFO] Variables: 54 | Years: 26 | Months: 12\n", + "[INFO] Parallel by VARIABLE with max_workers = 3\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "2025-09-28 10:12:52,047 INFO [2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "INFO:ecmwf.datastores.legacy_client:[2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "2025-09-28 10:12:52,049 INFO [2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "INFO:ecmwf.datastores.legacy_client:[2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "2025-09-28 10:12:52,245 INFO [2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "INFO:ecmwf.datastores.legacy_client:[2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "2025-09-28 10:12:52,247 INFO [2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "INFO:ecmwf.datastores.legacy_client:[2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "2025-09-28 10:12:52,250 INFO [2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "INFO:ecmwf.datastores.legacy_client:[2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[INFO][skin_temperature] Downloading 1997-01 -> era5land/skin_temperature/1997/reanalysis-era5-land_skin_temperature_1997-01.grib.zip\n", + "[INFO][2m_temperature] Downloading 1997-01 -> era5land/2m_temperature/1997/reanalysis-era5-land_2m_temperature_1997-01.grib.zip\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "2025-09-28 10:12:52,255 INFO [2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "INFO:ecmwf.datastores.legacy_client:[2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[INFO][2m_dewpoint_temperature] Downloading 1997-01 -> era5land/2m_dewpoint_temperature/1997/reanalysis-era5-land_2m_dewpoint_temperature_1997-01.grib.zip\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "2025-09-28 10:12:53,041 INFO Request ID is aebbba26-5aa4-4aca-b888-df4bc5037557\n", + "INFO:ecmwf.datastores.legacy_client:Request ID is aebbba26-5aa4-4aca-b888-df4bc5037557\n", + "2025-09-28 10:12:53,083 INFO Request ID is 1136a560-fa8e-4f6c-863b-c2eb22028995\n", + "INFO:ecmwf.datastores.legacy_client:Request ID is 1136a560-fa8e-4f6c-863b-c2eb22028995\n", + "2025-09-28 10:12:53,255 INFO status has been updated to accepted\n", + "INFO:ecmwf.datastores.legacy_client:status has been updated to accepted\n", + "2025-09-28 10:12:53,327 INFO status has been updated to accepted\n", + "INFO:ecmwf.datastores.legacy_client:status has been updated to accepted\n", + "2025-09-28 10:12:53,394 INFO Request ID is 3ccd2bcf-f2c5-4875-bbfb-33c845fd886b\n", + "INFO:ecmwf.datastores.legacy_client:Request ID is 3ccd2bcf-f2c5-4875-bbfb-33c845fd886b\n", + "2025-09-28 10:12:53,563 INFO status has been updated to accepted\n", + "INFO:ecmwf.datastores.legacy_client:status has been updated to accepted\n", + "2025-09-28 10:13:07,327 INFO status has been updated to successful\n", + "INFO:ecmwf.datastores.legacy_client:status has been updated to successful\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "e5dc627d8a097bec79906e75846db42e.zip: 0%| | 0.00/4.46M [00:00 era5land/2m_dewpoint_temperature/1997/reanalysis-era5-land_2m_dewpoint_temperature_1997-02.grib.zip\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "2025-09-28 10:13:11,084 INFO Request ID is c96d1a54-2476-40f3-ba1c-e360b6529ef7\n", + "INFO:ecmwf.datastores.legacy_client:Request ID is c96d1a54-2476-40f3-ba1c-e360b6529ef7\n", + "2025-09-28 10:13:11,231 INFO status has been updated to accepted\n", + "INFO:ecmwf.datastores.legacy_client:status has been updated to accepted\n", + "2025-09-28 10:13:15,343 INFO status has been updated to running\n", + "INFO:ecmwf.datastores.legacy_client:status has been updated to running\n", + "2025-09-28 10:14:48,504 INFO status has been updated to successful\n", + "INFO:ecmwf.datastores.legacy_client:status has been updated to successful\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "76d9dacc7921e0bd06d9383480d75e62.zip: 0%| | 0.00/4.47M [00:00 era5land/2m_temperature/1997/reanalysis-era5-land_2m_temperature_1997-02.grib.zip\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "2025-09-28 10:14:51,964 INFO Request ID is f4c1333f-a60a-47f8-84fc-e9e2d4555026\n", + "INFO:ecmwf.datastores.legacy_client:Request ID is f4c1333f-a60a-47f8-84fc-e9e2d4555026\n", + "2025-09-28 10:14:52,108 INFO status has been updated to accepted\n", + "INFO:ecmwf.datastores.legacy_client:status has been updated to accepted\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# download_era5land_chunked_full_month_parallel.py —— 按“月份”并发(每年最多12并发)\n", + "# -*- coding: utf-8 -*-\n", + "import argparse\n", + "import time\n", + "import random\n", + "from pathlib import Path\n", + "from typing import List, Tuple\n", + "import sys\n", + "from concurrent.futures import ThreadPoolExecutor, as_completed\n", + "\n", + "import cdsapi\n", + "\n", + "DATASET = \"reanalysis-era5-land\"\n", + "\n", + "# ---------- 时间维度 ----------\n", + "ALL_DAYS: List[str] = [f\"{d:02d}\" for d in range(1, 32)]\n", + "ALL_HOURS: List[str] = [f\"{h:02d}:00\" for h in range(0, 24)]\n", + "ALL_MONTHS: List[str] = [f\"{m:02d}\" for m in range(1, 13)]\n", + "\n", + "# ---------- 重试设置 ----------\n", + "MAX_RETRIES = 8\n", + "BASE_SLEEP = 10 # seconds\n", + "\n", + "# ---------- 变量全集(你的清单 + 注释项全部纳入) ----------\n", + "VARIABLES: List[str] = [\n", + " # 2m/skin/soil/lake temps\n", + " \"2m_dewpoint_temperature\",\n", + " \"2m_temperature\",\n", + " \"skin_temperature\",\n", + " \"soil_temperature_level_1\",\n", + " \"soil_temperature_level_2\",\n", + " \"soil_temperature_level_3\",\n", + " \"soil_temperature_level_4\",\n", + " \"lake_bottom_temperature\",\n", + " \"lake_ice_depth\",\n", + " \"lake_ice_temperature\",\n", + " \"lake_mix_layer_depth\",\n", + " \"lake_mix_layer_temperature\",\n", + " \"lake_shape_factor\",\n", + " # 你原注释掉的项(已启用)\n", + " \"lake_total_layer_temperature\",\n", + " \"snow_albedo\",\n", + " \"snow_cover\",\n", + " \"snow_density\",\n", + " \"snow_depth\",\n", + " \"snow_depth_water_equivalent\",\n", + " \"snowfall\",\n", + " \"snowmelt\",\n", + " \"temperature_of_snow_layer\",\n", + " \"forecast_albedo\",\n", + " \"surface_latent_heat_flux\",\n", + " \"surface_net_solar_radiation\",\n", + " \"surface_net_thermal_radiation\",\n", + " \"surface_sensible_heat_flux\",\n", + " \"surface_solar_radiation_downwards\",\n", + " \"surface_thermal_radiation_downwards\",\n", + " \"evaporation_from_bare_soil\",\n", + " \"evaporation_from_open_water_surfaces_excluding_oceans\",\n", + " \"evaporation_from_the_top_of_canopy\",\n", + " \"evaporation_from_vegetation_transpiration\",\n", + " \"potential_evaporation\",\n", + " \"runoff\",\n", + " \"snow_evaporation\",\n", + " \"sub_surface_runoff\",\n", + " \"surface_runoff\",\n", + " \"total_evaporation\",\n", + " \"10m_u_component_of_wind\",\n", + " \"10m_v_component_of_wind\",\n", + " \"surface_pressure\",\n", + " \"total_precipitation\",\n", + " \"leaf_area_index_high_vegetation\",\n", + " \"leaf_area_index_low_vegetation\",\n", + " \"high_vegetation_cover\",\n", + " \"glacier_mask\",\n", + " \"lake_cover\",\n", + " \"low_vegetation_cover\",\n", + " \"lake_total_depth\",\n", + " \"land_sea_mask\",\n", + " \"soil_type\",\n", + " \"type_of_high_vegetation\",\n", + " \"type_of_low_vegetation\",\n", + "]\n", + "\n", + "def build_request(\n", + " variable: str,\n", + " year: str,\n", + " month: str,\n", + " area_box: Tuple[float, float, float, float],\n", + " fmt: str,\n", + ") -> dict:\n", + " north, west, south, east = area_box\n", + " req = {\n", + " \"variable\": variable,\n", + " \"year\": year,\n", + " \"month\": month,\n", + " \"day\": ALL_DAYS,\n", + " \"time\": ALL_HOURS,\n", + " \"area\": [north, west, south, east], # N W S E\n", + " \"format\": fmt, # \"grib\" | \"netcdf\"\n", + " \"download_format\": \"zip\",\n", + " # \"product_type\": \"reanalysis\",\n", + " }\n", + " return req\n", + "\n", + "def safe_retrieve(client: cdsapi.Client, dataset: str, request: dict, target_path: Path):\n", + " attempt = 0\n", + " # 轻微抖动,错峰请求\n", + " time.sleep(random.uniform(0.3, 1.0))\n", + " while True:\n", + " try:\n", + " client.retrieve(dataset, request).download(str(target_path))\n", + " return True\n", + " except Exception as e:\n", + " attempt += 1\n", + " msg = str(e).lower()\n", + " # 明确不可恢复的错误(变量无效/不可用/无数据)直接跳过\n", + " unrecoverable_signals = [\n", + " \"unavailable\",\n", + " \"not available\",\n", + " \"invalid\",\n", + " \"does not match\",\n", + " \"no data\",\n", + " \"bad request\",\n", + " \"cannot be found\",\n", + " ]\n", + " if any(s in msg for s in unrecoverable_signals):\n", + " print(f\"[ERROR] Unrecoverable for {target_path.name}: {e}\")\n", + " return False\n", + "\n", + " if attempt > MAX_RETRIES:\n", + " print(f\"[ERROR] Max retries exceeded for {target_path.name}: {e}\")\n", + " return False\n", + "\n", + " sleep_s = BASE_SLEEP * (2 ** (attempt - 1)) * random.uniform(0.85, 1.15)\n", + " print(f\"[WARN] Download failed (attempt {attempt}/{MAX_RETRIES}): {e}\")\n", + " print(f\" Sleeping {sleep_s:.0f}s then retrying...\")\n", + " time.sleep(sleep_s)\n", + "\n", + "def parse_args_with_defaults():\n", + " parser = argparse.ArgumentParser(\n", + " description=\"ERA5-Land downloader (split by variable × year × month) — month-level parallelism\"\n", + " )\n", + " # —— 给出默认值,不再强制要求 —— #\n", + " parser.add_argument(\"--out_dir\", type=str, default=\"./era5land\",\n", + " help=\"输出根目录(默认 ./era5land)\")\n", + " parser.add_argument(\"--bbox\", nargs=4, type=float,\n", + " default=[60.86, -6.23, 49.86, 1.75],\n", + " metavar=(\"NORTH\", \"WEST\", \"SOUTH\", \"EAST\"),\n", + " help=\"经纬度范围:N W S E(默认 60.86 -6.23 49.86 1.75)\")\n", + " parser.add_argument(\"--format\", default=\"grib\", choices=[\"grib\", \"netcdf\"],\n", + " help=\"文件格式(默认 grib)\")\n", + " parser.add_argument(\"--years\", nargs=\"+\",\n", + " default=[str(y) for y in range(1997, 2023)], # 1997–2022\n", + " help=\"年份列表(默认 1997..2022)\")\n", + " parser.add_argument(\"--months\", nargs=\"+\", default=ALL_MONTHS,\n", + " help=\"月份列表(默认 01..12)\")\n", + " parser.add_argument(\"--variables\", nargs=\"+\", default=VARIABLES,\n", + " help=\"变量名列表(默认为脚本内置全集)\")\n", + " parser.add_argument(\"--skip_existing\", action=\"store_true\",\n", + " help=\"若目标文件已存在则跳过\")\n", + " parser.add_argument(\"--month_workers\", type=int, default=12,\n", + " help=\"每个 年×变量 的月份并发数(默认 12)\")\n", + " # 如果在 Notebook 中直接运行,且没有传任何参数,也能用默认值\n", + " try:\n", + " return parser.parse_args([])\n", + " except SystemExit:\n", + " # 在某些环境 parse_args([]) 会触发 SystemExit,退回到标准方式\n", + " return parser.parse_args()\n", + "\n", + "def download_one_month(var: str, year: str, month: str, args) -> bool:\n", + " \"\"\"并发任务:下载单个 变量×年×月 分块\"\"\"\n", + " subdir = Path(args.out_dir) / var / str(year)\n", + " subdir.mkdir(parents=True, exist_ok=True)\n", + "\n", + " suffix = \"grib\" if args.format == \"grib\" else \"nc\"\n", + " target_name = f\"{DATASET}_{var}_{year}-{month}.{suffix}.zip\"\n", + " target_path = subdir / target_name\n", + "\n", + " if args.skip_existing and target_path.exists():\n", + " # 已存在直接视为成功(简易断点续跑)\n", + " return True\n", + "\n", + " req = build_request(\n", + " variable=var,\n", + " year=str(year),\n", + " month=f\"{int(month):02d}\",\n", + " area_box=tuple(args.bbox),\n", + " fmt=args.format,\n", + " )\n", + "\n", + " print(f\"[INFO][{var}][{year}] Downloading month={month} -> {target_path}\")\n", + " # 为了线程安全,这里每个任务各自实例化 client\n", + " client = cdsapi.Client()\n", + " return safe_retrieve(client, DATASET, req, target_path)\n", + "\n", + "def main():\n", + " # 在 Notebook/Colab 里,这里会采用默认值;命令行下可用参数覆盖\n", + " if \"ipykernel\" in sys.modules or \"google.colab\" in sys.modules:\n", + " args = parse_args_with_defaults()\n", + " else:\n", + " args = parse_args_with_defaults()\n", + "\n", + " total_ok = 0\n", + " total_fail = 0\n", + "\n", + " for var in args.variables:\n", + " for year in args.years:\n", + " months = list(args.months)\n", + " max_workers = max(1, min(args.month_workers, len(months)))\n", + " print(f\"\\n[GROUP] var={var} year={year} | months={months} | month_workers={max_workers}\")\n", + "\n", + " futures = []\n", + " with ThreadPoolExecutor(max_workers=max_workers) as ex:\n", + " for month in months:\n", + " futures.append(ex.submit(download_one_month, var, year, month, args))\n", + " for fut in as_completed(futures):\n", + " ok = fut.result()\n", + " if ok: total_ok += 1\n", + " else: total_fail += 1\n", + "\n", + " print(f\"\\n[DONE] Finished. Success: {total_ok}, Failed: {total_fail}\")\n", + "\n", + "if __name__ == \"__main__\":\n", + " main()\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "EQN9ggBX0U6D", + "outputId": "58899f9d-5504-4a22-8b80-65b4f9f3faab" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "[GROUP] var=2m_dewpoint_temperature year=1997 | months=['01', '02', '03', '04', '05', '06', '07', '08', '09', '10', '11', '12'] | month_workers=12\n", + "[INFO][2m_dewpoint_temperature][1997] Downloading month=01 -> era5land/2m_dewpoint_temperature/1997/reanalysis-era5-land_2m_dewpoint_temperature_1997-01.grib.zip\n", + "[INFO][2m_dewpoint_temperature][1997] Downloading month=02 -> era5land/2m_dewpoint_temperature/1997/reanalysis-era5-land_2m_dewpoint_temperature_1997-02.grib.zip\n", + "[INFO][2m_dewpoint_temperature][1997] Downloading month=03 -> era5land/2m_dewpoint_temperature/1997/reanalysis-era5-land_2m_dewpoint_temperature_1997-03.grib.zip\n", + "[INFO][2m_dewpoint_temperature][1997] Downloading month=05 -> era5land/2m_dewpoint_temperature/1997/reanalysis-era5-land_2m_dewpoint_temperature_1997-05.grib.zip\n", + "[INFO][2m_dewpoint_temperature][1997] Downloading month=04 -> era5land/2m_dewpoint_temperature/1997/reanalysis-era5-land_2m_dewpoint_temperature_1997-04.grib.zip\n", + "[INFO][2m_dewpoint_temperature][1997] Downloading month=06 -> era5land/2m_dewpoint_temperature/1997/reanalysis-era5-land_2m_dewpoint_temperature_1997-06.grib.zip\n", + "[INFO][2m_dewpoint_temperature][1997] Downloading month=07 -> era5land/2m_dewpoint_temperature/1997/reanalysis-era5-land_2m_dewpoint_temperature_1997-07.grib.zip\n", + "[INFO][2m_dewpoint_temperature][1997] Downloading month=09 -> era5land/2m_dewpoint_temperature/1997/reanalysis-era5-land_2m_dewpoint_temperature_1997-09.grib.zip\n", + "[INFO][2m_dewpoint_temperature][1997] Downloading month=08 -> era5land/2m_dewpoint_temperature/1997/reanalysis-era5-land_2m_dewpoint_temperature_1997-08.grib.zip\n", + "[INFO][2m_dewpoint_temperature][1997] Downloading month=10 -> era5land/2m_dewpoint_temperature/1997/reanalysis-era5-land_2m_dewpoint_temperature_1997-10.grib.zip\n", + "[INFO][2m_dewpoint_temperature][1997] Downloading month=11 -> era5land/2m_dewpoint_temperature/1997/reanalysis-era5-land_2m_dewpoint_temperature_1997-11.grib.zip\n", + "[INFO][2m_dewpoint_temperature][1997] Downloading month=12 -> era5land/2m_dewpoint_temperature/1997/reanalysis-era5-land_2m_dewpoint_temperature_1997-12.grib.zip\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "2025-09-28 10:16:14,916 INFO [2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "INFO:ecmwf.datastores.legacy_client:[2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "2025-09-28 10:16:14,917 INFO [2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "2025-09-28 10:16:14,919 INFO [2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "INFO:ecmwf.datastores.legacy_client:[2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "INFO:ecmwf.datastores.legacy_client:[2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "2025-09-28 10:16:14,920 INFO [2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "INFO:ecmwf.datastores.legacy_client:[2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "2025-09-28 10:16:14,926 INFO [2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "INFO:ecmwf.datastores.legacy_client:[2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "2025-09-28 10:16:14,933 INFO [2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "2025-09-28 10:16:14,935 INFO [2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "2025-09-28 10:16:14,939 INFO [2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "INFO:ecmwf.datastores.legacy_client:[2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "2025-09-28 10:16:14,942 INFO [2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "INFO:ecmwf.datastores.legacy_client:[2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "2025-09-28 10:16:14,935 INFO [2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "INFO:ecmwf.datastores.legacy_client:[2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "2025-09-28 10:16:14,936 INFO [2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "INFO:ecmwf.datastores.legacy_client:[2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "2025-09-28 10:16:14,936 INFO [2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "2025-09-28 10:16:14,945 INFO [2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "INFO:ecmwf.datastores.legacy_client:[2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "2025-09-28 10:16:14,953 INFO [2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "INFO:ecmwf.datastores.legacy_client:[2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "2025-09-28 10:16:14,948 INFO [2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "INFO:ecmwf.datastores.legacy_client:[2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "2025-09-28 10:16:14,956 INFO [2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "2025-09-28 10:16:14,949 INFO [2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "INFO:ecmwf.datastores.legacy_client:[2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "2025-09-28 10:16:14,963 INFO [2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "INFO:ecmwf.datastores.legacy_client:[2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "INFO:ecmwf.datastores.legacy_client:[2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "2025-09-28 10:16:14,964 INFO [2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "INFO:ecmwf.datastores.legacy_client:[2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "2025-09-28 10:16:14,970 INFO [2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "INFO:ecmwf.datastores.legacy_client:[2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "INFO:ecmwf.datastores.legacy_client:[2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "2025-09-28 10:16:14,973 INFO [2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "2025-09-28 10:16:14,975 INFO [2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "INFO:ecmwf.datastores.legacy_client:[2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "2025-09-28 10:16:14,974 INFO [2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "INFO:ecmwf.datastores.legacy_client:[2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.\n", + "2025-09-28 10:16:14,979 INFO [2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "INFO:ecmwf.datastores.legacy_client:[2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "INFO:ecmwf.datastores.legacy_client:[2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "INFO:ecmwf.datastores.legacy_client:[2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "INFO:ecmwf.datastores.legacy_client:[2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.\n", + "2025-09-28 10:16:15,818 INFO Request ID is 067dc222-585a-4744-a0ac-4570fda17814\n", + "INFO:ecmwf.datastores.legacy_client:Request ID is 067dc222-585a-4744-a0ac-4570fda17814\n", + "2025-09-28 10:16:15,831 INFO Request ID is 5f47f2e0-1aa4-47ae-829f-b05b53c69363\n", + "INFO:ecmwf.datastores.legacy_client:Request ID is 5f47f2e0-1aa4-47ae-829f-b05b53c69363\n", + "2025-09-28 10:16:15,848 INFO Request ID is a9be8a75-74e5-49d4-9441-f4f653fc3f6b\n", + "INFO:ecmwf.datastores.legacy_client:Request ID is a9be8a75-74e5-49d4-9441-f4f653fc3f6b\n", + "2025-09-28 10:16:15,857 INFO Request ID is ec12ac64-6a87-475e-9853-1a5c7bef1419\n", + "INFO:ecmwf.datastores.legacy_client:Request ID is ec12ac64-6a87-475e-9853-1a5c7bef1419\n", + "2025-09-28 10:16:15,941 INFO Request ID is ae343a40-f746-4efb-97e7-9e4384ab6aa2\n", + "INFO:ecmwf.datastores.legacy_client:Request ID is ae343a40-f746-4efb-97e7-9e4384ab6aa2\n", + "2025-09-28 10:16:15,979 INFO status has been updated to accepted\n", + "INFO:ecmwf.datastores.legacy_client:status has been updated to accepted\n", + "2025-09-28 10:16:15,982 INFO status has been updated to accepted\n", + "INFO:ecmwf.datastores.legacy_client:status has been updated to accepted\n", + "2025-09-28 10:16:15,991 INFO status has been updated to accepted\n", + "INFO:ecmwf.datastores.legacy_client:status has been updated to accepted\n", + "2025-09-28 10:16:15,991 INFO status has been updated to accepted\n", + "INFO:ecmwf.datastores.legacy_client:status has been updated to accepted\n", + "2025-09-28 10:16:15,993 INFO Request ID is 85e7d622-6f8b-4f36-be77-c728883e832b\n", + "INFO:ecmwf.datastores.legacy_client:Request ID is 85e7d622-6f8b-4f36-be77-c728883e832b\n", + "2025-09-28 10:16:16,060 INFO Request ID is ac540eec-458d-4517-b2aa-7fd0db3a88c7\n", + "INFO:ecmwf.datastores.legacy_client:Request ID is ac540eec-458d-4517-b2aa-7fd0db3a88c7\n", + "2025-09-28 10:16:16,079 INFO status has been updated to accepted\n", + "INFO:ecmwf.datastores.legacy_client:status has been updated to accepted\n", + "2025-09-28 10:16:16,097 INFO Request ID is 5c27fe51-87b7-4304-9a1c-29dbe883e1af\n", + "INFO:ecmwf.datastores.legacy_client:Request ID is 5c27fe51-87b7-4304-9a1c-29dbe883e1af\n", + "2025-09-28 10:16:16,098 INFO Request ID is 9101737b-5c7b-4c01-b771-ca3bb6180867\n", + "INFO:ecmwf.datastores.legacy_client:Request ID is 9101737b-5c7b-4c01-b771-ca3bb6180867\n", + "2025-09-28 10:16:16,132 INFO status has been updated to accepted\n", + "INFO:ecmwf.datastores.legacy_client:status has been updated to accepted\n", + "2025-09-28 10:16:16,228 INFO status has been updated to accepted\n", + "INFO:ecmwf.datastores.legacy_client:status has been updated to accepted\n", + "2025-09-28 10:16:16,234 INFO status has been updated to accepted\n", + "INFO:ecmwf.datastores.legacy_client:status has been updated to accepted\n", + "2025-09-28 10:16:16,288 INFO status has been updated to accepted\n", + "INFO:ecmwf.datastores.legacy_client:status has been updated to accepted\n", + "WARNING:multiurl.http:Recovering from HTTP error [429 Too Many Requests], attempt 1 of 500\n", + "WARNING:multiurl.http:Retrying in 120 seconds\n", + "2025-09-28 10:16:16,332 INFO Request ID is f150c332-0ddd-47bb-b448-fb25fa21d8b7\n", + "INFO:ecmwf.datastores.legacy_client:Request ID is f150c332-0ddd-47bb-b448-fb25fa21d8b7\n", + "2025-09-28 10:16:16,354 INFO Request ID is e8268ff6-f4b4-4aac-927a-5137a23ac5ad\n", + "INFO:ecmwf.datastores.legacy_client:Request ID is e8268ff6-f4b4-4aac-927a-5137a23ac5ad\n", + "2025-09-28 10:16:16,465 INFO status has been updated to accepted\n", + "INFO:ecmwf.datastores.legacy_client:status has been updated to accepted\n", + "2025-09-28 10:16:16,498 INFO status has been updated to accepted\n", + "INFO:ecmwf.datastores.legacy_client:status has been updated to accepted\n" + ] + } + ] + } + ] +} \ No newline at end of file diff --git a/jointContribution/AI_Climate_Diseases/imputer.ipynb b/jointContribution/AI_Climate_Diseases/imputer.ipynb new file mode 100644 index 0000000000..3a4a3cc565 --- /dev/null +++ b/jointContribution/AI_Climate_Diseases/imputer.ipynb @@ -0,0 +1,629 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "rlAsSGECAmNA", + "outputId": "76610133-ea9d-4566-b7fe-bc3340db872d" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[OK] Ownsend_Deprivation_Index: {'column': 'Ownsend_Deprivation_Index', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=3.03671', 'metric_secondary': 'r2=0.6825', 'fallback': 'none'}\n", + "[OK] Number_of_Self-Reported_Cancers: {'column': 'Number_of_Self-Reported_Cancers', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.01536', 'metric_secondary': 'r2=0.8008', 'fallback': 'none'}\n", + "[OK] Operations: {'column': 'Operations', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=1.65087', 'metric_secondary': 'r2=0.3075', 'fallback': 'none'}\n", + "[OK] Number_of_Treatments/Medications: {'column': 'Number_of_Treatments/Medications', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=2.40922', 'metric_secondary': 'r2=0.6692', 'fallback': 'none'}\n", + "[OK] Number_of_Self-Reported_Non-Cancer_Illnesses: {'column': 'Number_of_Self-Reported_Non-Cancer_Illnesses', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=1.53121', 'metric_secondary': 'r2=0.5638', 'fallback': 'none'}\n", + "[OK] Aidememoire_Completed: {'column': 'Aidememoire_Completed', 'type': 'regression', 'trained': False, 'metric_primary': 'mse=0.15893', 'metric_secondary': 'r2=0.0430', 'fallback': 'median'}\n", + "[OK] Sexual_History: {'column': 'Sexual_History', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.05077', 'metric_secondary': 'r2=0.4014', 'fallback': 'none'}\n", + "[OK] Added_Salt: {'column': 'Added_Salt', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.22884', 'metric_secondary': 'r2=0.0754', 'fallback': 'none'}\n", + "[OK] Handedness: {'column': 'Handedness', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.09320', 'metric_secondary': 'r2=0.0545', 'fallback': 'none'}\n", + "[OK] Current_Tobacco_Smoking: {'column': 'Current_Tobacco_Smoking', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00001', 'metric_secondary': 'r2=0.9999', 'fallback': 'none'}\n", + "[OK] Accommodation_Type: {'column': 'Accommodation_Type', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.05787', 'metric_secondary': 'r2=0.4082', 'fallback': 'none'}\n", + "[OK] Alcohol_Intake_Frequency: {'column': 'Alcohol_Intake_Frequency', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00034', 'metric_secondary': 'r2=0.9956', 'fallback': 'none'}\n", + "[OK] Milk_Type: {'column': 'Milk_Type', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.19374', 'metric_secondary': 'r2=0.1529', 'fallback': 'none'}\n", + "[OK] Insomnia: {'column': 'Insomnia', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.15658', 'metric_secondary': 'r2=0.1454', 'fallback': 'none'}\n", + "[OK] Glasses_Wear: {'column': 'Glasses_Wear', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.08702', 'metric_secondary': 'r2=0.1445', 'fallback': 'none'}\n", + "[OK] Alcohol_Status: {'column': 'Alcohol_Status', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00001', 'metric_secondary': 'r2=0.9999', 'fallback': 'none'}\n", + "[OK] Pacemaker: {'column': 'Pacemaker', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00194', 'metric_secondary': 'r2=0.3678', 'fallback': 'none'}\n", + "[OK] Computer_Gaming: {'column': 'Computer_Gaming', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.14589', 'metric_secondary': 'r2=0.1162', 'fallback': 'none'}\n", + "[OK] Birth_Country: {'column': 'Birth_Country', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.08208', 'metric_secondary': 'r2=0.5226', 'fallback': 'none'}\n", + "[OK] Day_Napping: {'column': 'Day_Napping', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.20236', 'metric_secondary': 'r2=0.1787', 'fallback': 'none'}\n", + "[OK] Hair_Color: {'column': 'Hair_Color', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.22422', 'metric_secondary': 'r2=0.0561', 'fallback': 'none'}\n", + "[OK] Poultry: {'column': 'Poultry', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.01650', 'metric_secondary': 'r2=0.6699', 'fallback': 'none'}\n", + "[OK] Waist_Circumference: {'column': 'Waist_Circumference', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=21.89398', 'metric_secondary': 'r2=0.8808', 'fallback': 'none'}\n", + "[OK] Tea: {'column': 'Tea', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=6.09476', 'metric_secondary': 'r2=0.2554', 'fallback': 'none'}\n", + "[OK] Hip_Circumference: {'column': 'Hip_Circumference', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=13.19761', 'metric_secondary': 'r2=0.8465', 'fallback': 'none'}\n", + "[OK] Processed_Meat_Intake: {'column': 'Processed_Meat_Intake', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.04308', 'metric_secondary': 'r2=0.5001', 'fallback': 'none'}\n", + "[OK] Coffee: {'column': 'Coffee', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=3.11790', 'metric_secondary': 'r2=0.2826', 'fallback': 'none'}\n", + "[OK] Diet_Change: {'column': 'Diet_Change', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.20395', 'metric_secondary': 'r2=0.1449', 'fallback': 'none'}\n", + "[OK] Adopted: {'column': 'Adopted', 'type': 'regression', 'trained': False, 'metric_primary': 'mse=0.01467', 'metric_secondary': 'r2=-0.0032', 'fallback': 'median'}\n", + "[OK] Standing_Height: {'column': 'Standing_Height', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.92470', 'metric_secondary': 'r2=0.9892', 'fallback': 'none'}\n", + "[OK] Diabetes_Diagnosis: {'column': 'Diabetes_Diagnosis', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.01807', 'metric_secondary': 'r2=0.6474', 'fallback': 'none'}\n", + "[OK] Eye_Problems: {'column': 'Eye_Problems', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.11966', 'metric_secondary': 'r2=0.0556', 'fallback': 'none'}\n", + "[OK] Falls: {'column': 'Falls', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.14412', 'metric_secondary': 'r2=0.0938', 'fallback': 'none'}\n", + "[OK] Current_Residence_Duration: {'column': 'Current_Residence_Duration', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=111.30535', 'metric_secondary': 'r2=0.2432', 'fallback': 'none'}\n", + "[OK] Cancer_Diagnosis: {'column': 'Cancer_Diagnosis', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.01289', 'metric_secondary': 'r2=0.8208', 'fallback': 'none'}\n", + "[OK] Weight: {'column': 'Weight', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.38324', 'metric_secondary': 'r2=0.9985', 'fallback': 'none'}\n", + "[OK] Ethnicity: {'column': 'Ethnicity', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.05108', 'metric_secondary': 'r2=0.4994', 'fallback': 'none'}\n", + "[OK] Smoked: {'column': 'Smoked', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.19598', 'metric_secondary': 'r2=0.1850', 'fallback': 'none'}\n", + "[OK] Smoking_Status: {'column': 'Smoking_Status', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00003', 'metric_secondary': 'r2=0.9996', 'fallback': 'none'}\n", + "[OK] Other_Prescription_Medications: {'column': 'Other_Prescription_Medications', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.11891', 'metric_secondary': 'r2=0.5227', 'fallback': 'none'}\n", + "[OK] BMI: {'column': 'BMI', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.06996', 'metric_secondary': 'r2=0.9969', 'fallback': 'none'}\n", + "[OK] Cereal: {'column': 'Cereal', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=6.16983', 'metric_secondary': 'r2=0.2266', 'fallback': 'none'}\n", + "[OK] Fresh_Fruit: {'column': 'Fresh_Fruit', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=2.10727', 'metric_secondary': 'r2=0.2082', 'fallback': 'none'}\n", + "[OK] Hand_Grip_Strength_(Right): {'column': 'Hand_Grip_Strength_(Right)', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=20.45542', 'metric_secondary': 'r2=0.8401', 'fallback': 'none'}\n", + "[OK] Beef_Intake: {'column': 'Beef_Intake', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.04843', 'metric_secondary': 'r2=0.5084', 'fallback': 'none'}\n", + "[OK] Hand_Grip_Strength_(Left): {'column': 'Hand_Grip_Strength_(Left)', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=20.41612', 'metric_secondary': 'r2=0.8406', 'fallback': 'none'}\n", + "[OK] Overall_Health_Rating: {'column': 'Overall_Health_Rating', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.02892', 'metric_secondary': 'r2=0.3373', 'fallback': 'none'}\n", + "[OK] Psychiatrist_Visits: {'column': 'Psychiatrist_Visits', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.07183', 'metric_secondary': 'r2=0.2951', 'fallback': 'none'}\n", + "[OK] Non_Oily_Fish: {'column': 'Non_Oily_Fish', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.02499', 'metric_secondary': 'r2=0.4516', 'fallback': 'none'}\n", + "[OK] Fractures: {'column': 'Fractures', 'type': 'regression', 'trained': False, 'metric_primary': 'mse=0.08456', 'metric_secondary': 'r2=0.0343', 'fallback': 'median'}\n", + "[OK] Daytime_Dozing: {'column': 'Daytime_Dozing', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.15195', 'metric_secondary': 'r2=0.1684', 'fallback': 'none'}\n", + "[OK] Spirometry_Contraindications: {'column': 'Spirometry_Contraindications', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00493', 'metric_secondary': 'r2=0.9302', 'fallback': 'none'}\n", + "[OK] Oily_Fish: {'column': 'Oily_Fish', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.06648', 'metric_secondary': 'r2=0.3064', 'fallback': 'none'}\n", + "[OK] Sleep_Duration: {'column': 'Sleep_Duration', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=1.14679', 'metric_secondary': 'r2=0.0885', 'fallback': 'none'}\n", + "[OK] Pork: {'column': 'Pork', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.08622', 'metric_secondary': 'r2=0.4008', 'fallback': 'none'}\n", + "[OK] Lamb_Mutton: {'column': 'Lamb_Mutton', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.09012', 'metric_secondary': 'r2=0.3838', 'fallback': 'none'}\n", + "[OK] Incorrect_Matches: {'column': 'Incorrect_Matches', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=3.19447', 'metric_secondary': 'r2=0.1079', 'fallback': 'none'}\n", + "[OK] Willingness_for_Cognitive_Tests: {'column': 'Willingness_for_Cognitive_Tests', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00715', 'metric_secondary': 'r2=0.6377', 'fallback': 'none'}\n", + "[OK] Water_Intake: {'column': 'Water_Intake', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=4.02436', 'metric_secondary': 'r2=0.2531', 'fallback': 'none'}\n", + "[OK] Water_300m: {'column': 'Water_300m', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=82.21683', 'metric_secondary': 'r2=0.8698', 'fallback': 'none'}\n", + "[OK] Natural_Environment_(1000m_Buffer): {'column': 'Natural_Environment_(1000m_Buffer)', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=20.10810', 'metric_secondary': 'r2=0.9693', 'fallback': 'none'}\n", + "[OK] Natural_Environment_300m: {'column': 'Natural_Environment_300m', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=24.21351', 'metric_secondary': 'r2=0.9687', 'fallback': 'none'}\n", + "[OK] GP_Visits_for_Mental_Health: {'column': 'GP_Visits_for_Mental_Health', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.13707', 'metric_secondary': 'r2=0.3873', 'fallback': 'none'}\n", + "[OK] Home_Area_Population_Density: {'column': 'Home_Area_Population_Density', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.01900', 'metric_secondary': 'r2=0.8828', 'fallback': 'none'}\n", + "[OK] TV_Time: {'column': 'TV_Time', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=2.19738', 'metric_secondary': 'r2=0.2620', 'fallback': 'none'}\n", + "[OK] Disability/Mobility_Allowance: {'column': 'Disability/Mobility_Allowance', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.03385', 'metric_secondary': 'r2=0.4068', 'fallback': 'none'}\n", + "[OK] Wake_Up_Ease: {'column': 'Wake_Up_Ease', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.03346', 'metric_secondary': 'r2=0.1082', 'fallback': 'none'}\n", + "[OK] Spread_Type: {'column': 'Spread_Type', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.08664', 'metric_secondary': 'r2=0.1012', 'fallback': 'none'}\n", + "[OK] Match_Identification_Time: {'column': 'Match_Identification_Time', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=11765.66602', 'metric_secondary': 'r2=0.1531', 'fallback': 'none'}\n", + "[OK] UV_Protection: {'column': 'UV_Protection', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.07551', 'metric_secondary': 'r2=0.1674', 'fallback': 'none'}\n", + "[OK] Seated_Height: {'column': 'Seated_Height', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=9.22653', 'metric_secondary': 'r2=0.8201', 'fallback': 'none'}\n", + "[OK] Seating_Box_Height: {'column': 'Seating_Box_Height', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00028', 'metric_secondary': 'r2=0.8925', 'fallback': 'none'}\n", + "[OK] Sitting_Height: {'column': 'Sitting_Height', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=1.85534', 'metric_secondary': 'r2=0.9227', 'fallback': 'none'}\n", + "[OK] Hot_Drink_Temperature: {'column': 'Hot_Drink_Temperature', 'type': 'regression', 'trained': False, 'metric_primary': 'mse=0.12885', 'metric_secondary': 'r2=0.0284', 'fallback': 'median'}\n", + "[OK] Chest_Pain_Discomfort: {'column': 'Chest_Pain_Discomfort', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.11280', 'metric_secondary': 'r2=0.1704', 'fallback': 'none'}\n", + "[OK] Dried_Fruit: {'column': 'Dried_Fruit', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=2.92822', 'metric_secondary': 'r2=0.0881', 'fallback': 'none'}\n", + "[OK] Diet_Variation: {'column': 'Diet_Variation', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.20759', 'metric_secondary': 'r2=0.0766', 'fallback': 'none'}\n", + "[OK] Major_Road_Traffic: {'column': 'Major_Road_Traffic', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00004', 'metric_secondary': 'r2=0.8855', 'fallback': 'none'}\n", + "[OK] Nearest_Road_Distance: {'column': 'Nearest_Road_Distance', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=49187815.57922', 'metric_secondary': 'r2=0.8913', 'fallback': 'none'}\n", + "[OK] PM10_Air_2007: {'column': 'PM10_Air_2007', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00837', 'metric_secondary': 'r2=0.9996', 'fallback': 'none'}\n", + "[OK] Nearest_Road_Traffic: {'column': 'Nearest_Road_Traffic', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00144', 'metric_secondary': 'r2=0.7221', 'fallback': 'none'}\n", + "[OK] NO2_Air_Pollution_(2005): {'column': 'NO2_Air_Pollution_(2005)', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.45550', 'metric_secondary': 'r2=0.9955', 'fallback': 'none'}\n", + "[OK] NO2_Air_Pollution_(2006): {'column': 'NO2_Air_Pollution_(2006)', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.37226', 'metric_secondary': 'r2=0.9956', 'fallback': 'none'}\n", + "[OK] PM2.5_to_10_Air_2010: {'column': 'PM2.5_to_10_Air_2010', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=4921075.06406', 'metric_secondary': 'r2=0.7997', 'fallback': 'none'}\n", + "[OK] Major_Road_Distance: {'column': 'Major_Road_Distance', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=89375062275.11403', 'metric_secondary': 'r2=0.9296', 'fallback': 'none'}\n", + "[OK] NO2_Air_Pollution_(2010): {'column': 'NO2_Air_Pollution_(2010)', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.68771', 'metric_secondary': 'r2=0.9879', 'fallback': 'none'}\n", + "[OK] Road_Traffic_Load: {'column': 'Road_Traffic_Load', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=558.05541', 'metric_secondary': 'r2=0.9083', 'fallback': 'none'}\n", + "[OK] Trunk_Fat_Percentage: {'column': 'Trunk_Fat_Percentage', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=11.20265', 'metric_secondary': 'r2=0.9531', 'fallback': 'none'}\n", + "[OK] Evening_Noise_Level: {'column': 'Evening_Noise_Level', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00841', 'metric_secondary': 'r2=0.9996', 'fallback': 'none'}\n", + "[OK] Nighttime_Noise_Level: {'column': 'Nighttime_Noise_Level', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00860', 'metric_secondary': 'r2=0.9995', 'fallback': 'none'}\n", + "[OK] Proximity_to_Major_Road: {'column': 'Proximity_to_Major_Road', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00491', 'metric_secondary': 'r2=0.9277', 'fallback': 'none'}\n", + "[OK] Hour-16_Noise_Level: {'column': 'Hour-16_Noise_Level', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00856', 'metric_secondary': 'r2=0.9995', 'fallback': 'none'}\n", + "[OK] Daytime_Noise_Level: {'column': 'Daytime_Noise_Level', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00831', 'metric_secondary': 'r2=0.9996', 'fallback': 'none'}\n", + "[OK] Road_Length: {'column': 'Road_Length', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.47892', 'metric_secondary': 'r2=0.9958', 'fallback': 'none'}\n", + "[OK] Cooked_Vegetables: {'column': 'Cooked_Vegetables', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=3.27094', 'metric_secondary': 'r2=0.1751', 'fallback': 'none'}\n", + "[OK] Salad_Intake: {'column': 'Salad_Intake', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=3.73412', 'metric_secondary': 'r2=0.2083', 'fallback': 'none'}\n", + "[OK] Social_Visits: {'column': 'Social_Visits', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.01657', 'metric_secondary': 'r2=0.0511', 'fallback': 'none'}\n", + "[OK] Recent_Stress_Events: {'column': 'Recent_Stress_Events', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.22579', 'metric_secondary': 'r2=0.0891', 'fallback': 'none'}\n", + "[OK] Phone_Use_Duration: {'column': 'Phone_Use_Duration', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.11437', 'metric_secondary': 'r2=0.1156', 'fallback': 'none'}\n", + "[OK] Skin_Color: {'column': 'Skin_Color', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.19061', 'metric_secondary': 'r2=0.1207', 'fallback': 'none'}\n", + "[OK] NO2_Air_2007: {'column': 'NO2_Air_2007', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.47331', 'metric_secondary': 'r2=0.9427', 'fallback': 'none'}\n", + "[OK] Computer_Time: {'column': 'Computer_Time', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=1.83010', 'metric_secondary': 'r2=0.1338', 'fallback': 'none'}\n", + "[OK] Bowel_Screening: {'column': 'Bowel_Screening', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.17538', 'metric_secondary': 'r2=0.1949', 'fallback': 'none'}\n", + "[OK] Weight_Change_(1_Year): {'column': 'Weight_Change_(1_Year)', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.21421', 'metric_secondary': 'r2=0.1316', 'fallback': 'none'}\n", + "[OK] Loneliness/Isolation: {'column': 'Loneliness/Isolation', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.10941', 'metric_secondary': 'r2=0.2875', 'fallback': 'none'}\n", + "[OK] Driving_Time: {'column': 'Driving_Time', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=1.24548', 'metric_secondary': 'r2=0.2583', 'fallback': 'none'}\n", + "[OK] Whole_Body_Water_Mass: {'column': 'Whole_Body_Water_Mass', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.02345', 'metric_secondary': 'r2=0.9997', 'fallback': 'none'}\n", + "[OK] Basal_Metabolic_Rate: {'column': 'Basal_Metabolic_Rate', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=469.91901', 'metric_secondary': 'r2=0.9997', 'fallback': 'none'}\n", + "[OK] Right_Leg_Impedance: {'column': 'Right_Leg_Impedance', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=21.39709', 'metric_secondary': 'r2=0.9839', 'fallback': 'none'}\n", + "[OK] Left_Leg_Impedance: {'column': 'Left_Leg_Impedance', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=18.76451', 'metric_secondary': 'r2=0.9853', 'fallback': 'none'}\n", + "[OK] Whole_Body_Fat-Free_Mass: {'column': 'Whole_Body_Fat-Free_Mass', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.02022', 'metric_secondary': 'r2=0.9998', 'fallback': 'none'}\n", + "[OK] Right_Leg_Fat_Percentage: {'column': 'Right_Leg_Fat_Percentage', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.07534', 'metric_secondary': 'r2=0.9993', 'fallback': 'none'}\n", + "[OK] Left_Arm_Impedance: {'column': 'Left_Arm_Impedance', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=30.00962', 'metric_secondary': 'r2=0.9907', 'fallback': 'none'}\n", + "[OK] Right_Leg_Fat_Mass: {'column': 'Right_Leg_Fat_Mass', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00680', 'metric_secondary': 'r2=0.9981', 'fallback': 'none'}\n", + "[OK] Right_Leg_Fat_Free_Mass: {'column': 'Right_Leg_Fat_Free_Mass', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00192', 'metric_secondary': 'r2=0.9995', 'fallback': 'none'}\n", + "[OK] Right_Leg_Predicted_Mass: {'column': 'Right_Leg_Predicted_Mass', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00188', 'metric_secondary': 'r2=0.9995', 'fallback': 'none'}\n", + "[OK] Whole_Body_Impedance: {'column': 'Whole_Body_Impedance', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=157.82494', 'metric_secondary': 'r2=0.9803', 'fallback': 'none'}\n", + "[OK] Left_Leg_Fat_Percentage: {'column': 'Left_Leg_Fat_Percentage', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.07565', 'metric_secondary': 'r2=0.9993', 'fallback': 'none'}\n", + "[OK] Right_Arm_Impedance: {'column': 'Right_Arm_Impedance', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=85.94955', 'metric_secondary': 'r2=0.9723', 'fallback': 'none'}\n", + "[OK] Left_Leg_Fat_Mass: {'column': 'Left_Leg_Fat_Mass', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00534', 'metric_secondary': 'r2=0.9985', 'fallback': 'none'}\n", + "[OK] Left_Leg_Fat-Free_Mass: {'column': 'Left_Leg_Fat-Free_Mass', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00176', 'metric_secondary': 'r2=0.9996', 'fallback': 'none'}\n", + "[OK] Left_Leg_Predicted_Mass: {'column': 'Left_Leg_Predicted_Mass', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00150', 'metric_secondary': 'r2=0.9996', 'fallback': 'none'}\n", + "[OK] Right_Arm_Fat_Percentage: {'column': 'Right_Arm_Fat_Percentage', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.16739', 'metric_secondary': 'r2=0.9984', 'fallback': 'none'}\n", + "[OK] Right_Arm_Fat_Mass: {'column': 'Right_Arm_Fat_Mass', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00189', 'metric_secondary': 'r2=0.9954', 'fallback': 'none'}\n", + "[OK] Right_Arm_Fat-Free_Mass: {'column': 'Right_Arm_Fat-Free_Mass', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00138', 'metric_secondary': 'r2=0.9980', 'fallback': 'none'}\n", + "[OK] Right_Arm_Predicted_Mass: {'column': 'Right_Arm_Predicted_Mass', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00122', 'metric_secondary': 'r2=0.9980', 'fallback': 'none'}\n", + "[OK] Left_Arm_Fat_Percentage: {'column': 'Left_Arm_Fat_Percentage', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.15265', 'metric_secondary': 'r2=0.9986', 'fallback': 'none'}\n", + "[OK] Left_Arm_Fat_Mass: {'column': 'Left_Arm_Fat_Mass', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00200', 'metric_secondary': 'r2=0.9961', 'fallback': 'none'}\n", + "[OK] Left_Arm_Fat-Free_Mass: {'column': 'Left_Arm_Fat-Free_Mass', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00142', 'metric_secondary': 'r2=0.9980', 'fallback': 'none'}\n", + "[OK] Height_At_10: {'column': 'Height_At_10', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.11467', 'metric_secondary': 'r2=0.2967', 'fallback': 'none'}\n", + "[OK] Left_Arm_Predicted_Mass: {'column': 'Left_Arm_Predicted_Mass', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00128', 'metric_secondary': 'r2=0.9980', 'fallback': 'none'}\n", + "[OK] Body_Fat_Percentage: {'column': 'Body_Fat_Percentage', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.04620', 'metric_secondary': 'r2=0.9994', 'fallback': 'none'}\n", + "[OK] Reticulocyte_Percentage: {'column': 'Reticulocyte_Percentage', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.07625', 'metric_secondary': 'r2=0.9988', 'fallback': 'none'}\n", + "[OK] Trunk_Fat_Mass: {'column': 'Trunk_Fat_Mass', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.03267', 'metric_secondary': 'r2=0.9988', 'fallback': 'none'}\n", + "[OK] Trunk_Fat-Free_Mass: {'column': 'Trunk_Fat-Free_Mass', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.01580', 'metric_secondary': 'r2=0.9996', 'fallback': 'none'}\n", + "[OK] Trunk_Predicted_Mass: {'column': 'Trunk_Predicted_Mass', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00868', 'metric_secondary': 'r2=0.9997', 'fallback': 'none'}\n", + "[OK] Miserableness: {'column': 'Miserableness', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.14343', 'metric_secondary': 'r2=0.4137', 'fallback': 'none'}\n", + "[OK] Solarium_Use: {'column': 'Solarium_Use', 'type': 'regression', 'trained': False, 'metric_primary': 'mse=18.99546', 'metric_secondary': 'r2=-0.0101', 'fallback': 'median'}\n", + "[OK] Weekly_Walking_Days: {'column': 'Weekly_Walking_Days', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00580', 'metric_secondary': 'r2=0.7195', 'fallback': 'none'}\n", + "[OK] Other_Serious_Medical_Conditions: {'column': 'Other_Serious_Medical_Conditions', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.11604', 'metric_secondary': 'r2=0.2884', 'fallback': 'none'}\n", + "[OK] Bread: {'column': 'Bread', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=60.18592', 'metric_secondary': 'r2=0.2012', 'fallback': 'none'}\n", + "[OK] Whole_Body_Fat_Mass: {'column': 'Whole_Body_Fat_Mass', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.06344', 'metric_secondary': 'r2=0.9993', 'fallback': 'none'}\n", + "[OK] Body_Size_At_10: {'column': 'Body_Size_At_10', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.19814', 'metric_secondary': 'r2=0.1080', 'fallback': 'none'}\n", + "[OK] Wheezing: {'column': 'Wheezing', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.11852', 'metric_secondary': 'r2=0.2935', 'fallback': 'none'}\n", + "[OK] Fed-Up_Feelings: {'column': 'Fed-Up_Feelings', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.13092', 'metric_secondary': 'r2=0.4578', 'fallback': 'none'}\n", + "[OK] Longstanding_Illness/Disability: {'column': 'Longstanding_Illness/Disability', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.12172', 'metric_secondary': 'r2=0.4460', 'fallback': 'none'}\n", + "[OK] Mood_Swings: {'column': 'Mood_Swings', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.13213', 'metric_secondary': 'r2=0.4663', 'fallback': 'none'}\n", + "[OK] Nervous_Feelings: {'column': 'Nervous_Feelings', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.10714', 'metric_secondary': 'r2=0.4019', 'fallback': 'none'}\n", + "[OK] Worrier/Anxious_Feelings: {'column': 'Worrier/Anxious_Feelings', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.15546', 'metric_secondary': 'r2=0.3665', 'fallback': 'none'}\n", + "[OK] Guilty_Feelings: {'column': 'Guilty_Feelings', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.14841', 'metric_secondary': 'r2=0.2760', 'fallback': 'none'}\n", + "[OK] Sensitivity_to_Hurt_Feelings: {'column': 'Sensitivity_to_Hurt_Feelings', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.17205', 'metric_secondary': 'r2=0.3037', 'fallback': 'none'}\n", + "[OK] Tiredness/Lethargy_Frequency: {'column': 'Tiredness/Lethargy_Frequency', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.17287', 'metric_secondary': 'r2=0.3053', 'fallback': 'none'}\n", + "[OK] Enzymatic_In_Urine: {'column': 'Enzymatic_In_Urine', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=12236521.12770', 'metric_secondary': 'r2=0.6405', 'fallback': 'none'}\n", + "[OK] Tanning_Ease: {'column': 'Tanning_Ease', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.13398', 'metric_secondary': 'r2=0.0575', 'fallback': 'none'}\n", + "[OK] Able_to_Confide: {'column': 'Able_to_Confide', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.11438', 'metric_secondary': 'r2=0.0904', 'fallback': 'none'}\n", + "[OK] Sodium_Urine: {'column': 'Sodium_Urine', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=1092.01791', 'metric_secondary': 'r2=0.4494', 'fallback': 'none'}\n", + "[OK] Potassium_Urine: {'column': 'Potassium_Urine', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=462.58147', 'metric_secondary': 'r2=0.5949', 'fallback': 'none'}\n", + "[OK] Unenthusiasm/Disinterest_Frequency: {'column': 'Unenthusiasm/Disinterest_Frequency', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.09237', 'metric_secondary': 'r2=0.4486', 'fallback': 'none'}\n", + "[OK] Tense/Highly_Strung: {'column': 'Tense/Highly_Strung', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.09239', 'metric_secondary': 'r2=0.3666', 'fallback': 'none'}\n", + "[OK] Risk-Taking_Behavior: {'column': 'Risk-Taking_Behavior', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.16802', 'metric_secondary': 'r2=0.1462', 'fallback': 'none'}\n", + "[OK] Suffering_from_Nerves: {'column': 'Suffering_from_Nerves', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.10314', 'metric_secondary': 'r2=0.3780', 'fallback': 'none'}\n", + "[OK] Tenseness/Restlessness_Frequency: {'column': 'Tenseness/Restlessness_Frequency', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.11827', 'metric_secondary': 'r2=0.3879', 'fallback': 'none'}\n", + "[OK] Post-Embarrassment_Worry: {'column': 'Post-Embarrassment_Worry', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.16919', 'metric_secondary': 'r2=0.3221', 'fallback': 'none'}\n", + "[OK] Depressed_Mood_Frequency: {'column': 'Depressed_Mood_Frequency', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.08407', 'metric_secondary': 'r2=0.5371', 'fallback': 'none'}\n", + "[OK] WBC_Count: {'column': 'WBC_Count', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00073', 'metric_secondary': 'r2=0.9958', 'fallback': 'none'}\n", + "[OK] RBC_Count: {'column': 'RBC_Count', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.03964', 'metric_secondary': 'r2=0.9969', 'fallback': 'none'}\n", + "[OK] Corpuscular_Haemoglobin_Concentration: {'column': 'Corpuscular_Haemoglobin_Concentration', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.51932', 'metric_secondary': 'r2=0.4514', 'fallback': 'none'}\n", + "[OK] Haemoglobin_Concentration: {'column': 'Haemoglobin_Concentration', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00387', 'metric_secondary': 'r2=0.9975', 'fallback': 'none'}\n", + "[OK] Mean_Corpuscular_Haemoglobin: {'column': 'Mean_Corpuscular_Haemoglobin', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.13281', 'metric_secondary': 'r2=0.9651', 'fallback': 'none'}\n", + "[OK] Haematocrit: {'column': 'Haematocrit', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.27617', 'metric_secondary': 'r2=0.9871', 'fallback': 'none'}\n", + "[OK] Mean_Corpuscular_Haemoglobin_Concentration: {'column': 'Mean_Corpuscular_Haemoglobin_Concentration', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.11281', 'metric_secondary': 'r2=0.9101', 'fallback': 'none'}\n", + "[OK] Blood_Cell_Erythrocyte_Distribution_Width: {'column': 'Blood_Cell_Erythrocyte_Distribution_Width', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=31.01486', 'metric_secondary': 'r2=0.9913', 'fallback': 'none'}\n", + "[OK] Platelet_Count: {'column': 'Platelet_Count', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00003', 'metric_secondary': 'r2=0.9858', 'fallback': 'none'}\n", + "[OK] Platelet_Crit: {'column': 'Platelet_Crit', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.01272', 'metric_secondary': 'r2=0.9890', 'fallback': 'none'}\n", + "[OK] Platelet_Thrombocyte_Volume: {'column': 'Platelet_Thrombocyte_Volume', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.19092', 'metric_secondary': 'r2=0.3061', 'fallback': 'none'}\n", + "[OK] Environment_Score: {'column': 'Environment_Score', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.34719', 'metric_secondary': 'r2=0.9291', 'fallback': 'none'}\n", + "[OK] Irritability: {'column': 'Irritability', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.14319', 'metric_secondary': 'r2=0.2874', 'fallback': 'none'}\n", + "[OK] Hearing_Difficulty: {'column': 'Hearing_Difficulty', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.17787', 'metric_secondary': 'r2=0.0649', 'fallback': 'none'}\n", + "[OK] Red_Blood_Cell_(Count): {'column': 'Red_Blood_Cell_(Count)', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.16398', 'metric_secondary': 'r2=0.9785', 'fallback': 'none'}\n", + "[OK] Eosinophill_Percentage: {'column': 'Eosinophill_Percentage', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.02639', 'metric_secondary': 'r2=0.9219', 'fallback': 'none'}\n", + "[OK] Lymphocyte_Percentage: {'column': 'Lymphocyte_Percentage', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.17163', 'metric_secondary': 'r2=0.9969', 'fallback': 'none'}\n", + "[OK] Neutrophill_Percentage: {'column': 'Neutrophill_Percentage', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.07561', 'metric_secondary': 'r2=0.9781', 'fallback': 'none'}\n", + "[OK] Monocyte_Percentage: {'column': 'Monocyte_Percentage', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.30818', 'metric_secondary': 'r2=0.9958', 'fallback': 'none'}\n", + "[OK] Eosinophill_Count: {'column': 'Eosinophill_Count', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00057', 'metric_secondary': 'r2=0.9710', 'fallback': 'none'}\n", + "[OK] Platelet_Width: {'column': 'Platelet_Width', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.30995', 'metric_secondary': 'r2=0.7349', 'fallback': 'none'}\n", + "[OK] Neutrophill_Count: {'column': 'Neutrophill_Count', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00027', 'metric_secondary': 'r2=0.8899', 'fallback': 'none'}\n", + "[OK] Lymphocyte_Count: {'column': 'Lymphocyte_Count', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00223', 'metric_secondary': 'r2=0.9527', 'fallback': 'none'}\n", + "[OK] Monocyte_Count: {'column': 'Monocyte_Count', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.04233', 'metric_secondary': 'r2=0.9799', 'fallback': 'none'}\n", + "[OK] Basophill_Count: {'column': 'Basophill_Count', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00009', 'metric_secondary': 'r2=0.8739', 'fallback': 'none'}\n", + "[OK] Nucleated_Red_Blood_Cell_Percentage: {'column': 'Nucleated_Red_Blood_Cell_Percentage', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.01141', 'metric_secondary': 'r2=0.9213', 'fallback': 'none'}\n", + "[OK] Weekly_Moderate_Activity_Days: {'column': 'Weekly_Moderate_Activity_Days', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.01736', 'metric_secondary': 'r2=0.8440', 'fallback': 'none'}\n", + "[OK] Weekly_Vigorous_Activity_Days: {'column': 'Weekly_Vigorous_Activity_Days', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.04184', 'metric_secondary': 'r2=0.8211', 'fallback': 'none'}\n", + "[OK] Diastolic_BP: {'column': 'Diastolic_BP', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=38.04323', 'metric_secondary': 'r2=0.6319', 'fallback': 'none'}\n", + "[OK] Pulse_Rate: {'column': 'Pulse_Rate', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=91.16690', 'metric_secondary': 'r2=0.2794', 'fallback': 'none'}\n", + "[OK] Systolic_BP: {'column': 'Systolic_BP', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=133.03961', 'metric_secondary': 'r2=0.6182', 'fallback': 'none'}\n", + "[OK] Basophill_Percentage: {'column': 'Basophill_Percentage', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.09672', 'metric_secondary': 'r2=0.9015', 'fallback': 'none'}\n", + "[OK] Reticulocyte_Count: {'column': 'Reticulocyte_Count', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00004', 'metric_secondary': 'r2=0.9773', 'fallback': 'none'}\n", + "[OK] High_Light_Scatter_Reticulocyte_Count: {'column': 'High_Light_Scatter_Reticulocyte_Count', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00000', 'metric_secondary': 'r2=0.9821', 'fallback': 'none'}\n", + "[OK] Sphered_Cell_Volume: {'column': 'Sphered_Cell_Volume', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=7.33223', 'metric_secondary': 'r2=0.7418', 'fallback': 'none'}\n", + "[OK] Light_Scatter_Reticulocyte_Percentage: {'column': 'Light_Scatter_Reticulocyte_Percentage', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.12540', 'metric_secondary': 'r2=0.3288', 'fallback': 'none'}\n", + "[OK] Immature_Fraction: {'column': 'Immature_Fraction', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00002', 'metric_secondary': 'r2=0.9941', 'fallback': 'none'}\n", + "[OK] Mean_Reticulocyte_Volume: {'column': 'Mean_Reticulocyte_Volume', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=19.21475', 'metric_secondary': 'r2=0.6895', 'fallback': 'none'}\n", + "[OK] Alkaline_Phosphatase: {'column': 'Alkaline_Phosphatase', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=507.44511', 'metric_secondary': 'r2=0.2811', 'fallback': 'none'}\n", + "[OK] Cholesterol: {'column': 'Cholesterol', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.02822', 'metric_secondary': 'r2=0.9782', 'fallback': 'none'}\n", + "[OK] Cystatin_C: {'column': 'Cystatin_C', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.01332', 'metric_secondary': 'r2=0.6313', 'fallback': 'none'}\n", + "[OK] Alanine_Aminotransferase: {'column': 'Alanine_Aminotransferase', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=70.46189', 'metric_secondary': 'r2=0.6845', 'fallback': 'none'}\n", + "[OK] Gamma_Glutamyltransferase: {'column': 'Gamma_Glutamyltransferase', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=923.75392', 'metric_secondary': 'r2=0.4866', 'fallback': 'none'}\n", + "[OK] Creatinine_Creatinine: {'column': 'Creatinine_Creatinine', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=112.35011', 'metric_secondary': 'r2=0.6777', 'fallback': 'none'}\n", + "[OK] Urea_Urea: {'column': 'Urea_Urea', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=1.15129', 'metric_secondary': 'r2=0.4171', 'fallback': 'none'}\n", + "[OK] Triglycerides_Triglycerides: {'column': 'Triglycerides_Triglycerides', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.27655', 'metric_secondary': 'r2=0.7399', 'fallback': 'none'}\n", + "[OK] Urate_Urate: {'column': 'Urate_Urate', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=2878.89471', 'metric_secondary': 'r2=0.5535', 'fallback': 'none'}\n", + "[OK] LDL_Cholesterol: {'column': 'LDL_Cholesterol', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.01279', 'metric_secondary': 'r2=0.9833', 'fallback': 'none'}\n", + "[OK] C_Protein: {'column': 'C_Protein', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=12.72532', 'metric_secondary': 'r2=0.2974', 'fallback': 'none'}\n", + "[OK] Summer_Outdoors_Time: {'column': 'Summer_Outdoors_Time', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=2.91764', 'metric_secondary': 'r2=0.5099', 'fallback': 'none'}\n", + "[OK] Winter_Outdoors_Time: {'column': 'Winter_Outdoors_Time', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=1.55315', 'metric_secondary': 'r2=0.5726', 'fallback': 'none'}\n", + "[OK] Aspartate_Aminotransferase: {'column': 'Aspartate_Aminotransferase', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=42.55151', 'metric_secondary': 'r2=0.6339', 'fallback': 'none'}\n", + "[OK] Total_Bilirubin: {'column': 'Total_Bilirubin', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=1.91629', 'metric_secondary': 'r2=0.9018', 'fallback': 'none'}\n", + "[OK] Apolipoprotein_B: {'column': 'Apolipoprotein_B', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00312', 'metric_secondary': 'r2=0.9448', 'fallback': 'none'}\n", + "[OK] Igf_1: {'column': 'Igf_1', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=24.56654', 'metric_secondary': 'r2=0.2554', 'fallback': 'none'}\n", + "[OK] Haemoglobin_(Hba1C): {'column': 'Haemoglobin_(Hba1C)', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=16.14193', 'metric_secondary': 'r2=0.6392', 'fallback': 'none'}\n", + "[OK] Snoring: {'column': 'Snoring', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.20626', 'metric_secondary': 'r2=0.1223', 'fallback': 'none'}\n", + "[OK] NOx_Air_2010: {'column': 'NOx_Air_2010', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.20355', 'metric_secondary': 'r2=0.9439', 'fallback': 'none'}\n", + "[OK] PM2.5_Absorbance_2010: {'column': 'PM2.5_Absorbance_2010', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.01840', 'metric_secondary': 'r2=0.9772', 'fallback': 'none'}\n", + "[OK] PM10_Air_2010: {'column': 'PM10_Air_2010', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.02771', 'metric_secondary': 'r2=0.9750', 'fallback': 'none'}\n", + "[OK] PM2.5_Air_2010: {'column': 'PM2.5_Air_2010', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00660', 'metric_secondary': 'r2=0.9099', 'fallback': 'none'}\n", + "[OK] Caffeine_Drink: {'column': 'Caffeine_Drink', 'type': 'regression', 'trained': False, 'metric_primary': 'mse=0.02196', 'metric_secondary': 'r2=0.0244', 'fallback': 'median'}\n", + "[OK] Inhaler_Use: {'column': 'Inhaler_Use', 'type': 'regression', 'trained': False, 'metric_primary': 'mse=0.00735', 'metric_secondary': 'r2=0.0481', 'fallback': 'median'}\n", + "[OK] Facial_Ageing: {'column': 'Facial_Ageing', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.18129', 'metric_secondary': 'r2=0.0603', 'fallback': 'none'}\n", + "[OK] FVC: {'column': 'FVC', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.08923', 'metric_secondary': 'r2=0.9174', 'fallback': 'none'}\n", + "[OK] PEF: {'column': 'PEF', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=4913.93520', 'metric_secondary': 'r2=0.7044', 'fallback': 'none'}\n", + "[OK] FEV1: {'column': 'FEV1', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.02576', 'metric_secondary': 'r2=0.9588', 'fallback': 'none'}\n", + "[OK] Over_Speed_Driving: {'column': 'Over_Speed_Driving', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.19525', 'metric_secondary': 'r2=0.2133', 'fallback': 'none'}\n", + "[OK] Chronotype: {'column': 'Chronotype', 'type': 'regression', 'trained': False, 'metric_primary': 'mse=0.22260', 'metric_secondary': 'r2=0.0284', 'fallback': 'median'}\n", + "[OK] Garden_1000m: {'column': 'Garden_1000m', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.81392', 'metric_secondary': 'r2=0.8715', 'fallback': 'none'}\n", + "[OK] Water_1000m: {'column': 'Water_1000m', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=15.67364', 'metric_secondary': 'r2=0.9709', 'fallback': 'none'}\n", + "[OK] Garden_300m: {'column': 'Garden_300m', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=1.14739', 'metric_secondary': 'r2=0.8690', 'fallback': 'none'}\n", + "[OK] Noise_Level: {'column': 'Noise_Level', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=4.63175', 'metric_secondary': 'r2=0.9901', 'fallback': 'none'}\n", + "[OK] Greenspace_300m: {'column': 'Greenspace_300m', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=12.93868', 'metric_secondary': 'r2=0.9407', 'fallback': 'none'}\n", + "[OK] Greenspace_1000m: {'column': 'Greenspace_1000m', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=5.34329', 'metric_secondary': 'r2=0.9586', 'fallback': 'none'}\n", + "[OK] First_Sexual_Intercourse: {'column': 'First_Sexual_Intercourse', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=11.86485', 'metric_secondary': 'r2=0.1890', 'fallback': 'none'}\n", + "[OK] Crime_Score: {'column': 'Crime_Score', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=26.45542', 'metric_secondary': 'r2=0.8875', 'fallback': 'none'}\n", + "[OK] Education_Score: {'column': 'Education_Score', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=14.24937', 'metric_secondary': 'r2=0.8606', 'fallback': 'none'}\n", + "[OK] Housing_Score: {'column': 'Housing_Score', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.07317', 'metric_secondary': 'r2=0.8805', 'fallback': 'none'}\n", + "[OK] Income_Score: {'column': 'Income_Score', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00006', 'metric_secondary': 'r2=0.9832', 'fallback': 'none'}\n", + "[OK] Coast_Distance: {'column': 'Coast_Distance', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00019', 'metric_secondary': 'r2=0.9801', 'fallback': 'none'}\n", + "[OK] Index_of_Multiple_Deprivation: {'column': 'Index_of_Multiple_Deprivation', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.52451', 'metric_secondary': 'r2=0.9973', 'fallback': 'none'}\n", + "[OK] Employment_Score: {'column': 'Employment_Score', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=15.87962', 'metric_secondary': 'r2=0.9384', 'fallback': 'none'}\n", + "[OK] Albumin_Albumin: {'column': 'Albumin_Albumin', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=3.54718', 'metric_secondary': 'r2=0.4850', 'fallback': 'none'}\n", + "[OK] Calcium_Calcium: {'column': 'Calcium_Calcium', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00511', 'metric_secondary': 'r2=0.4311', 'fallback': 'none'}\n", + "[OK] HDL_Cholesterol: {'column': 'HDL_Cholesterol', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00749', 'metric_secondary': 'r2=0.9493', 'fallback': 'none'}\n", + "[OK] Total_Protein: {'column': 'Total_Protein', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=9.22529', 'metric_secondary': 'r2=0.4502', 'fallback': 'none'}\n", + "[OK] Glucose_Glucose: {'column': 'Glucose_Glucose', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.66192', 'metric_secondary': 'r2=0.6126', 'fallback': 'none'}\n", + "[OK] Phosphate_Phosphate: {'column': 'Phosphate_Phosphate', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.01922', 'metric_secondary': 'r2=0.2563', 'fallback': 'none'}\n", + "[OK] Apolipoprotein_A: {'column': 'Apolipoprotein_A', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00739', 'metric_secondary': 'r2=0.8998', 'fallback': 'none'}\n", + "[OK] Shbg: {'column': 'Shbg', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=360.79332', 'metric_secondary': 'r2=0.5249', 'fallback': 'none'}\n", + "[OK] Testosterone: {'column': 'Testosterone', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=4.26155', 'metric_secondary': 'r2=0.8839', 'fallback': 'none'}\n", + "[OK] Direct_Bilirubin: {'column': 'Direct_Bilirubin', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.05971', 'metric_secondary': 'r2=0.9132', 'fallback': 'none'}\n", + "[OK] MET_Moderate: {'column': 'MET_Moderate', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=5086.36089', 'metric_secondary': 'r2=0.9965', 'fallback': 'none'}\n", + "[OK] Activity_Days: {'column': 'Activity_Days', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.29166', 'metric_secondary': 'r2=0.9875', 'fallback': 'none'}\n", + "[OK] Activity_Recommendation: {'column': 'Activity_Recommendation', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00002', 'metric_secondary': 'r2=0.9999', 'fallback': 'none'}\n", + "[OK] MET_Walking: {'column': 'MET_Walking', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=5936.64890', 'metric_secondary': 'r2=0.9948', 'fallback': 'none'}\n", + "[OK] MET_Vigorous: {'column': 'MET_Vigorous', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=6285.74163', 'metric_secondary': 'r2=0.9954', 'fallback': 'none'}\n", + "[OK] IPAQ_Activity_Group: {'column': 'IPAQ_Activity_Group', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00048', 'metric_secondary': 'r2=0.9968', 'fallback': 'none'}\n", + "[OK] Activity_Minutes: {'column': 'Activity_Minutes', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=104.87300', 'metric_secondary': 'r2=0.9896', 'fallback': 'none'}\n", + "[OK] Walking_Recommendation_Compliance: {'column': 'Walking_Recommendation_Compliance', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00087', 'metric_secondary': 'r2=0.9941', 'fallback': 'none'}\n", + "[OK] Total_MET_Minutes_per_Week: {'column': 'Total_MET_Minutes_per_Week', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=6335.64322', 'metric_secondary': 'r2=0.9991', 'fallback': 'none'}\n", + "[OK] Breastfed_Baby: {'column': 'Breastfed_Baby', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.18390', 'metric_secondary': 'r2=0.0789', 'fallback': 'none'}\n", + "[OK] FEV1_Z_Score: {'column': 'FEV1_Z_Score', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00375', 'metric_secondary': 'r2=0.9970', 'fallback': 'none'}\n", + "[OK] FEV1_FVC_Ratio: {'column': 'FEV1_FVC_Ratio', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00927', 'metric_secondary': 'r2=0.9884', 'fallback': 'none'}\n", + "[OK] FVC_Z_Score: {'column': 'FVC_Z_Score', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00535', 'metric_secondary': 'r2=0.9952', 'fallback': 'none'}\n", + "[OK] Spirometry_Quality: {'column': 'Spirometry_Quality', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.08186', 'metric_secondary': 'r2=0.0990', 'fallback': 'none'}\n", + "[OK] Sunburns: {'column': 'Sunburns', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.21410', 'metric_secondary': 'r2=0.1374', 'fallback': 'none'}\n", + "[OK] Lipoprotein_A: {'column': 'Lipoprotein_A', 'type': 'regression', 'trained': False, 'metric_primary': 'mse=2356.15944', 'metric_secondary': 'r2=0.0215', 'fallback': 'median'}\n", + "[OK] FEV1_(Best_Measure): {'column': 'FEV1_(Best_Measure)', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.00642', 'metric_secondary': 'r2=0.9895', 'fallback': 'none'}\n", + "[OK] FVC_(Best_Measure): {'column': 'FVC_(Best_Measure)', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.01083', 'metric_secondary': 'r2=0.9889', 'fallback': 'none'}\n", + "[OK] Email_Access: {'column': 'Email_Access', 'type': 'regression', 'trained': False, 'metric_primary': 'mse=0.06467', 'metric_secondary': 'r2=0.0258', 'fallback': 'median'}\n", + "[OK] Bowel_Open_Min: {'column': 'Bowel_Open_Min', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=7295.04911', 'metric_secondary': 'r2=0.2187', 'fallback': 'none'}\n", + "[OK] Bowel_Open_Max: {'column': 'Bowel_Open_Max', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=1907.53442', 'metric_secondary': 'r2=0.4941', 'fallback': 'none'}\n", + "[OK] Bowel_Open_Average: {'column': 'Bowel_Open_Average', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=4068.44912', 'metric_secondary': 'r2=0.3667', 'fallback': 'none'}\n", + "[OK] Headache: {'column': 'Headache', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.18840', 'metric_secondary': 'r2=0.2088', 'fallback': 'none'}\n", + "[OK] Breath_Shortness: {'column': 'Breath_Shortness', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.16005', 'metric_secondary': 'r2=0.2642', 'fallback': 'none'}\n", + "[OK] Heart_Pounding: {'column': 'Heart_Pounding', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.17364', 'metric_secondary': 'r2=0.1805', 'fallback': 'none'}\n", + "[OK] Back_Pain: {'column': 'Back_Pain', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.20411', 'metric_secondary': 'r2=0.1495', 'fallback': 'none'}\n", + "[OK] Limb_Joint_Pain: {'column': 'Limb_Joint_Pain', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.16754', 'metric_secondary': 'r2=0.1635', 'fallback': 'none'}\n", + "[OK] Tiredness: {'column': 'Tiredness', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.15552', 'metric_secondary': 'r2=0.3162', 'fallback': 'none'}\n", + "[OK] Sleep_Trouble: {'column': 'Sleep_Trouble', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.16397', 'metric_secondary': 'r2=0.2664', 'fallback': 'none'}\n", + "[OK] Dizziness: {'column': 'Dizziness', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.16016', 'metric_secondary': 'r2=0.1864', 'fallback': 'none'}\n", + "[OK] Nausea: {'column': 'Nausea', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.10358', 'metric_secondary': 'r2=0.1875', 'fallback': 'none'}\n", + "[OK] Chest_Pain: {'column': 'Chest_Pain', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.09515', 'metric_secondary': 'r2=0.2064', 'fallback': 'none'}\n", + "[OK] Fainting: {'column': 'Fainting', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.02884', 'metric_secondary': 'r2=0.0934', 'fallback': 'none'}\n", + "[OK] Loose_Stools_Frequency: {'column': 'Loose_Stools_Frequency', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.19391', 'metric_secondary': 'r2=0.1692', 'fallback': 'none'}\n", + "[OK] Hard_Stools_Frequency: {'column': 'Hard_Stools_Frequency', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.20094', 'metric_secondary': 'r2=0.1928', 'fallback': 'none'}\n", + "[OK] Urinary_Frequency: {'column': 'Urinary_Frequency', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.22031', 'metric_secondary': 'r2=0.1083', 'fallback': 'none'}\n", + "[OK] Abdomen_Pain_Frequency: {'column': 'Abdomen_Pain_Frequency', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.13654', 'metric_secondary': 'r2=0.4499', 'fallback': 'none'}\n", + "[OK] Recent_Abdominal_Pain: {'column': 'Recent_Abdominal_Pain', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.11318', 'metric_secondary': 'r2=0.4575', 'fallback': 'none'}\n", + "[OK] Abdominal_Distension: {'column': 'Abdominal_Distension', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.12470', 'metric_secondary': 'r2=0.3120', 'fallback': 'none'}\n", + "[OK] Bowel_Satisfaction: {'column': 'Bowel_Satisfaction', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.11267', 'metric_secondary': 'r2=0.2667', 'fallback': 'none'}\n", + "[OK] Bowel_Interference: {'column': 'Bowel_Interference', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.15387', 'metric_secondary': 'r2=0.3840', 'fallback': 'none'}\n", + "[OK] Coeliac_Disease: {'column': 'Coeliac_Disease', 'type': 'regression', 'trained': False, 'metric_primary': 'mse=0.01618', 'metric_secondary': 'r2=0.0477', 'fallback': 'median'}\n", + "[OK] IBS: {'column': 'IBS', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.08246', 'metric_secondary': 'r2=0.2479', 'fallback': 'none'}\n", + "[OK] Caesarian_Born: {'column': 'Caesarian_Born', 'type': 'regression', 'trained': False, 'metric_primary': 'mse=0.02584', 'metric_secondary': 'r2=0.0132', 'fallback': 'median'}\n", + "[OK] Sensitive_Stomach: {'column': 'Sensitive_Stomach', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.12573', 'metric_secondary': 'r2=0.2929', 'fallback': 'none'}\n", + "[OK] Childhood_Antibiotics: {'column': 'Childhood_Antibiotics', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.11160', 'metric_secondary': 'r2=0.0798', 'fallback': 'none'}\n", + "[OK] Alcohol-Related_Injury: {'column': 'Alcohol-Related_Injury', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.03446', 'metric_secondary': 'r2=0.0933', 'fallback': 'none'}\n", + "[OK] Alcohol_Drinking_Frequency: {'column': 'Alcohol_Drinking_Frequency', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.04356', 'metric_secondary': 'r2=0.4509', 'fallback': 'none'}\n", + "[OK] Serious_Accident: {'column': 'Serious_Accident', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.08273', 'metric_secondary': 'r2=0.0628', 'fallback': 'none'}\n", + "[OK] Combat_War_Exposure: {'column': 'Combat_War_Exposure', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.03397', 'metric_secondary': 'r2=0.0647', 'fallback': 'none'}\n", + "[OK] Appetite_Changes: {'column': 'Appetite_Changes', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.10221', 'metric_secondary': 'r2=0.3431', 'fallback': 'none'}\n", + "[OK] Stressful_Thoughts: {'column': 'Stressful_Thoughts', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.12213', 'metric_secondary': 'r2=0.3859', 'fallback': 'none'}\n", + "[OK] Concentration_Issues: {'column': 'Concentration_Issues', 'type': 'regression', 'trained': True, 'metric_primary': 'mse=0.09228', 'metric_secondary': 'r2=0.3689', 'fallback': 'none'}\n" + ] + } + ], + "source": [ + "# =======================\n", + "# Notebook 一体化:XGBoost 逐列机器学习插补(兼容旧版 xgboost,无 early_stopping_rounds)\n", + "# =======================\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "import os\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.preprocessing import LabelEncoder\n", + "from sklearn.metrics import accuracy_score, f1_score, mean_squared_error, r2_score\n", + "\n", + "# xgboost 基本导入\n", + "try:\n", + " import xgboost as xgb\n", + " from xgboost import XGBClassifier, XGBRegressor\n", + "except Exception as e:\n", + " raise RuntimeError(\"需要已安装 xgboost。请先在该环境安装:pip install xgboost\") from e\n", + "\n", + "# ============ Config(按需修改) ============\n", + "INPUT_CSV = \"/content/drive/MyDrive/demo_200000.csv\"\n", + "OUTPUT_CSV = \"demo_200000_imputed.csv\"\n", + "REPORT_CSV = \"demo_200000_impute_report.csv\"\n", + "\n", + "# 不参与作为特征/目标的列(如ID/标签)\n", + "EXCLUDE_COLS = [\"ID\", \"DR\",\"DR_time\",\"AMD_time\",\"AMD\",\"glaucoma_time\",\"glaucoma\",\"cataract_time\",\"cataract\"]\n", + "\n", + "# 类别压帽(最多保留前N个高频类别,其余合并为 OTHER)\n", + "MAX_CATEGORIES = 50\n", + "\n", + "# 训练与评估\n", + "TEST_SIZE = 0.2\n", + "RANDOM_STATE = 42\n", + "N_ESTIMATORS = 300 # 如果耗时长,可先降到 200\n", + "LEARNING_RATE = 0.05\n", + "MAX_DEPTH = 6\n", + "SUBSAMPLE = 0.8\n", + "COLSAMPLE_BYTREE = 0.8\n", + "\n", + "# 早停轮数(老版本不支持 fit 参数,我们用回调;若回调也不可用就自动跳过)\n", + "EARLY_STOPPING_ROUNDS = 50\n", + "\n", + "# 分类任务:最低可接受准确率(低于则回退到众数)\n", + "EVAL_ACC_THRESHOLD = 0.65\n", + "# 回归任务:模型MSE需 <= baseline*MSE_RATIO 才接受(即至少优于均值/中位数约5%)\n", + "EVAL_MSE_RATIO = 0.95\n", + "\n", + "# XGBoost tree_method(可切到 \"gpu_hist\" 如果你的环境支持 GPU)\n", + "XGB_TREE_METHOD = \"hist\" # \"hist\" | \"approx\" | \"auto\" | \"gpu_hist\"\n", + "\n", + "\n", + "# ============ Helpers ============\n", + "def is_numeric_series(s: pd.Series) -> bool:\n", + " return pd.api.types.is_integer_dtype(s) or pd.api.types.is_float_dtype(s)\n", + "\n", + "def cap_categories(series: pd.Series, max_categories: int = 50):\n", + " \"\"\"保留前N高频类别,其余 -> 'OTHER'\"\"\"\n", + " vc = series.value_counts(dropna=False)\n", + " top = set(vc.head(max_categories).index.tolist())\n", + " return series.apply(lambda x: x if x in top else \"OTHER\")\n", + "\n", + "def one_hot_fit_transform(df: pd.DataFrame, categorical_cols, max_categories: int):\n", + " \"\"\"拟合并独热编码,返回:编码后DF、meta(每列类别集合)、最终列名列表\"\"\"\n", + " df = df.copy()\n", + " meta = {}\n", + " for col in categorical_cols:\n", + " s = df[col].astype(str).fillna(\"UNKNOWN\")\n", + " s = cap_categories(s, max_categories=max_categories)\n", + " df[col] = s\n", + " meta[col] = sorted(df[col].unique().tolist())\n", + " dummied = pd.get_dummies(df, columns=categorical_cols, dummy_na=False)\n", + " return dummied, meta, dummied.columns.tolist()\n", + "\n", + "def one_hot_transform_with_meta(df: pd.DataFrame, categorical_cols, meta, all_cols):\n", + " \"\"\"用拟合阶段的 meta 做独热,并对齐列\"\"\"\n", + " df = df.copy()\n", + " for col in categorical_cols:\n", + " s = df[col].astype(str).fillna(\"UNKNOWN\")\n", + " df[col] = s.apply(lambda x: x if x in meta[col] else \"OTHER\")\n", + " dummied = pd.get_dummies(df, columns=categorical_cols, dummy_na=False)\n", + " for c in all_cols:\n", + " if c not in dummied.columns:\n", + " dummied[c] = 0\n", + " dummied = dummied[all_cols]\n", + " return dummied\n", + "\n", + "def evaluate_classifier(y_true, y_pred):\n", + " acc = accuracy_score(y_true, y_pred)\n", + " f1m = f1_score(y_true, y_pred, average=\"macro\")\n", + " return {\"accuracy\": acc, \"f1_macro\": f1m}\n", + "\n", + "def evaluate_regressor(y_true, y_pred):\n", + " mse = mean_squared_error(y_true, y_pred)\n", + " r2 = r2_score(y_true, y_pred)\n", + " return {\"mse\": mse, \"r2\": r2}\n", + "\n", + "def _fit_with_optional_early_stopping(model, X_tr, y_tr, X_va, y_va):\n", + " \"\"\"\n", + " 兼容不同 xgboost 版本的早停:\n", + " - 优先使用 xgboost.callback.EarlyStopping\n", + " - 如果不可用,直接不做早停\n", + " \"\"\"\n", + " callbacks = []\n", + " eval_set = [(X_va, y_va)]\n", + "\n", + " # 优先用官方回调(老版本也支持)\n", + " try:\n", + " cb = xgb.callback.EarlyStopping(\n", + " rounds=EARLY_STOPPING_ROUNDS,\n", + " save_best=True,\n", + " maximize=False # 回归/分类默认都是最小化损失\n", + " )\n", + " callbacks.append(cb)\n", + " model.fit(X_tr, y_tr, eval_set=eval_set, callbacks=callbacks, verbose=False)\n", + " return model\n", + " except Exception:\n", + " # 无法使用回调则直接无早停训练\n", + " model.fit(X_tr, y_tr, eval_set=eval_set, verbose=False)\n", + " return model\n", + "\n", + "def xgb_impute_column(\n", + " df: pd.DataFrame,\n", + " target_col: str,\n", + " exclude_cols: list,\n", + " max_categories: int = 50,\n", + " test_size: float = 0.2,\n", + " random_state: int = 42,\n", + " n_estimators: int = 300,\n", + " learning_rate: float = 0.05,\n", + " max_depth: int = 6,\n", + " subsample: float = 0.8,\n", + " colsample_bytree: float = 0.8,\n", + " eval_acc_threshold: float = 0.65,\n", + " eval_mse_ratio: float = 0.95,\n", + " tree_method: str = \"hist\",\n", + "):\n", + " \"\"\"对单列进行 XGB 插补:返回填补后的 Series 与一条报告 dict\"\"\"\n", + " y = df[target_col]\n", + " notnull_mask = y.notna()\n", + " null_mask = ~notnull_mask\n", + " if null_mask.sum() == 0:\n", + " return y, {\"column\": target_col, \"type\": \"skip_no_missing\", \"trained\": False,\n", + " \"metric_primary\": None, \"metric_secondary\": None, \"fallback\": \"none\"}\n", + "\n", + " feature_cols = [c for c in df.columns if c != target_col and c not in exclude_cols]\n", + " X = df[feature_cols].copy()\n", + "\n", + " # 特征侧预填(不改原 df)\n", + " num_cols = [c for c in feature_cols if is_numeric_series(X[c])]\n", + " cat_cols = [c for c in feature_cols if not is_numeric_series(X[c])]\n", + " for c in num_cols:\n", + " X[c] = pd.to_numeric(X[c], errors=\"coerce\").fillna(X[c].median())\n", + " for c in cat_cols:\n", + " X[c] = X[c].astype(str).fillna(\"UNKNOWN\")\n", + "\n", + " X_train_full = X.loc[notnull_mask].copy()\n", + " y_train_full = y.loc[notnull_mask].copy()\n", + " X_null = X.loc[null_mask].copy()\n", + "\n", + " X_train_oh, meta, all_cols = one_hot_fit_transform(X_train_full, cat_cols, max_categories=max_categories)\n", + " X_null_oh = one_hot_transform_with_meta(X_null, cat_cols, meta, all_cols)\n", + "\n", + " # 回归还是分类由目标列类型决定\n", + " if is_numeric_series(y_train_full):\n", + " task = \"regression\"\n", + " y_vec = pd.to_numeric(y_train_full, errors=\"coerce\")\n", + " valid = y_vec.notna()\n", + " X_train_oh = X_train_oh.loc[valid]; y_vec = y_vec.loc[valid]\n", + "\n", + " X_tr, X_te, y_tr, y_te = train_test_split(X_train_oh, y_vec, test_size=test_size, random_state=random_state)\n", + " model = XGBRegressor(\n", + " n_estimators=n_estimators, learning_rate=learning_rate, max_depth=max_depth,\n", + " subsample=subsample, colsample_bytree=colsample_bytree, objective=\"reg:squarederror\",\n", + " tree_method=tree_method, random_state=random_state, n_jobs=-1\n", + " )\n", + " model = _fit_with_optional_early_stopping(model, X_tr, y_tr, X_te, y_te)\n", + " y_pred = model.predict(X_te)\n", + " m = evaluate_regressor(y_te, y_pred)\n", + "\n", + " # 基线:均值/中位数\n", + " mse_model = m[\"mse\"]\n", + " mse_mean = mean_squared_error(y_te, np.full_like(y_te, y_tr.mean(), dtype=float))\n", + " mse_median= mean_squared_error(y_te, np.full_like(y_te, float(np.median(y_tr)), dtype=float))\n", + " best_bl = min(mse_mean, mse_median)\n", + "\n", + " if mse_model <= eval_mse_ratio * best_bl:\n", + " y_null_pred = model.predict(X_null_oh)\n", + " filled = y.copy(); filled.loc[null_mask] = y_null_pred\n", + " rep = {\"column\": target_col, \"type\": task, \"trained\": True,\n", + " \"metric_primary\": f\"mse={m['mse']:.5f}\", \"metric_secondary\": f\"r2={m['r2']:.4f}\",\n", + " \"fallback\": \"none\"}\n", + " return filled, rep\n", + " else:\n", + " filled = y.fillna(y.median())\n", + " rep = {\"column\": target_col, \"type\": task, \"trained\": False,\n", + " \"metric_primary\": f\"mse={m['mse']:.5f}\", \"metric_secondary\": f\"r2={m['r2']:.4f}\",\n", + " \"fallback\": \"median\"}\n", + " return filled, rep\n", + "\n", + " else:\n", + " task = \"classification\"\n", + " y_str = y_train_full.astype(str)\n", + " enc = LabelEncoder(); y_enc = enc.fit_transform(y_str)\n", + "\n", + " X_tr, X_te, y_tr, y_te = train_test_split(\n", + " X_train_oh, y_enc, test_size=test_size, random_state=random_state, stratify=y_enc\n", + " )\n", + " clf = XGBClassifier(\n", + " n_estimators=n_estimators, learning_rate=learning_rate, max_depth=max_depth,\n", + " subsample=subsample, colsample_bytree=colsample_bytree,\n", + " objective=\"multi:softprob\" if len(np.unique(y_enc))>2 else \"binary:logistic\",\n", + " num_class=len(np.unique(y_enc)) if len(np.unique(y_enc))>2 else None,\n", + " tree_method=tree_method, random_state=random_state, n_jobs=-1, use_label_encoder=False\n", + " )\n", + " clf = _fit_with_optional_early_stopping(clf, X_tr, y_tr, X_te, y_te)\n", + " y_pred = clf.predict(X_te)\n", + " m = evaluate_classifier(y_te, y_pred)\n", + "\n", + " if m[\"accuracy\"] >= eval_acc_threshold:\n", + " y_null_pred_enc = clf.predict(X_null_oh)\n", + " y_null_pred = enc.inverse_transform(y_null_pred_enc)\n", + " filled = y.copy(); filled.loc[null_mask] = y_null_pred\n", + " rep = {\"column\": target_col, \"type\": task, \"trained\": True,\n", + " \"metric_primary\": f\"acc={m['accuracy']:.4f}\", \"metric_secondary\": f\"f1_macro={m['f1_macro']:.4f}\",\n", + " \"fallback\": \"none\"}\n", + " return filled, rep\n", + " else:\n", + " mode_val = y_str.mode().iloc[0] if y_str.mode().shape[0] else \"UNKNOWN\"\n", + " filled = y.fillna(mode_val)\n", + " rep = {\"column\": target_col, \"type\": task, \"trained\": False,\n", + " \"metric_primary\": f\"acc={m['accuracy']:.4f}\", \"metric_secondary\": f\"f1_macro={m['f1_macro']:.4f}\",\n", + " \"fallback\": f\"mode({mode_val})\"}\n", + " return filled, rep\n", + "\n", + "\n", + "def impute_dataframe(df: pd.DataFrame, exclude_cols=None, max_categories=50):\n", + " dfw = df.copy()\n", + " exclude_cols = exclude_cols or []\n", + "\n", + " miss_counts = dfw.isna().sum()\n", + " miss_cols = miss_counts[miss_counts > 0].sort_values(ascending=True).index.tolist()\n", + "\n", + " reports = []\n", + " for col in miss_cols:\n", + " if col in exclude_cols:\n", + " reports.append({\"column\": col, \"type\": \"skipped_excluded\", \"trained\": False,\n", + " \"metric_primary\": None, \"metric_secondary\": None, \"fallback\": \"none\"})\n", + " continue\n", + "\n", + " filled_col, rep = xgb_impute_column(\n", + " df=dfw, target_col=col, exclude_cols=exclude_cols,\n", + " max_categories=max_categories, test_size=TEST_SIZE, random_state=RANDOM_STATE,\n", + " n_estimators=N_ESTIMATORS, learning_rate=LEARNING_RATE, max_depth=MAX_DEPTH,\n", + " subsample=SUBSAMPLE, colsample_bytree=COLSAMPLE_BYTREE,\n", + " eval_acc_threshold=EVAL_ACC_THRESHOLD, eval_mse_ratio=EVAL_MSE_RATIO,\n", + " tree_method=XGB_TREE_METHOD\n", + " )\n", + " dfw[col] = filled_col\n", + " reports.append(rep)\n", + " print(f\"[OK] {col}: {rep}\")\n", + "\n", + " report_df = pd.DataFrame(reports, columns=[\"column\",\"type\",\"trained\",\"metric_primary\",\"metric_secondary\",\"fallback\"])\n", + " return dfw, report_df\n", + "\n", + "\n", + "# ============ RUN ============\n", + "df_in = pd.read_csv(INPUT_CSV)\n", + "imputed_df, report_df = impute_dataframe(df_in, exclude_cols=EXCLUDE_COLS, max_categories=MAX_CATEGORIES)\n", + "\n", + "imputed_df.to_csv(OUTPUT_CSV, index=False)\n", + "report_df.to_csv(REPORT_CSV, index=False)\n", + "\n", + "print(\"插补完成。\")\n", + "print(\"Imputed CSV ->\", OUTPUT_CSV)\n", + "print(\"Report CSV ->\", REPORT_CSV)\n", + "\n", + "# 预览\n", + "imputed_df.head(), report_df.head(20)\n" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/jointContribution/AI_Climate_Diseases/paddle_test_example.ipynb b/jointContribution/AI_Climate_Diseases/paddle_test_example.ipynb new file mode 100644 index 0000000000..1543d69f30 --- /dev/null +++ b/jointContribution/AI_Climate_Diseases/paddle_test_example.ipynb @@ -0,0 +1,5735 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "gpuType": "A100" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "code", + "source": [ + "!pip -q install paddlepaddle -i https://pypi.tuna.tsinghua.edu.cn/simple\n", + "\n", + "# 验证\n", + "import paddle\n", + "paddle.utils.run_check()\n", + "print(\"Paddle version:\", paddle.__version__)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "awQrZ0Q-NF7S", + "outputId": "c8c84f16-eccb-4a9d-af47-60a1188792ad" + }, + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m189.0/189.0 MB\u001b[0m \u001b[31m9.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m65.5/65.5 kB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.12/dist-packages/paddle/utils/cpp_extension/extension_utils.py:718: UserWarning: No ccache found. Please be aware that recompiling all source files may be required. You can download and install ccache from: https://github.com/ccache/ccache/blob/master/doc/INSTALL.md\n", + " warnings.warn(warning_message)\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Running verify PaddlePaddle program ... \n", + "PaddlePaddle works well on 1 CPU.\n", + "PaddlePaddle is installed successfully! Let's start deep learning with PaddlePaddle now.\n", + "Paddle version: 3.2.0\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.12/dist-packages/paddle/pir/math_op_patch.py:219: UserWarning: Value do not have 'place' interface for pir graph mode, try not to use it. None will be returned.\n", + " warnings.warn(\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "import math\n", + "from typing import Literal, Optional\n", + "\n", + "import paddle\n", + "import paddle.nn as nn\n", + "import paddle.nn.functional as F\n", + "\n", + "# ---------- 通用初始化 ------------------------------------------------------\n", + "def init_rsqrt_uniform_(w: paddle.Tensor) -> paddle.Tensor:\n", + " bound = 1.0 / math.sqrt(w.shape[-1])\n", + " noise = paddle.uniform(w.shape, min=-bound, max=bound, dtype=w.dtype)\n", + " w.set_value(noise)\n", + " return w\n", + "\n", + "def init_random_signs_(w: paddle.Tensor) -> paddle.Tensor:\n", + " # 0/1 伯努利 -> *2 -1 => {-1, +1}\n", + " with paddle.no_grad():\n", + " p = paddle.full(w.shape, 0.5, dtype='float32')\n", + " s = paddle.bernoulli(p) * 2.0 - 1.0\n", + " s = paddle.cast(s, w.dtype)\n", + " w.set_value(s)\n", + " return w\n", + "\n", + "# ---------- 基础层 ----------------------------------------------------------\n", + "class NLinear(nn.Layer):\n", + " \"\"\"PackedEnsemble: K 份 Linear 打包 → 输入 (B,K,D), 权重布局 (K, I, O)\"\"\"\n", + " def __init__(self, k: int, in_f: int, out_f: int, bias: bool = True):\n", + " super().__init__()\n", + " self.k = k\n", + " self.in_f = in_f\n", + " self.out_f = out_f\n", + " # 按 Paddle 线性层布局 [I, O]\n", + " self.weight = self.create_parameter(shape=[k, in_f, out_f])\n", + " self.bias_e = self.create_parameter(shape=[k, out_f]) if bias else None\n", + " self.reset_parameters()\n", + "\n", + " def reset_parameters(self):\n", + " init_rsqrt_uniform_(self.weight)\n", + " if self.bias_e is not None:\n", + " init_rsqrt_uniform_(self.bias_e)\n", + "\n", + " def forward(self, x): # x: (B,K,D=I)\n", + " # 转成 (K,B,D) 与 batched matmul 对齐\n", + " xk = paddle.transpose(x, [1, 0, 2]) # (K,B,I)\n", + " # (K,B,I) @ (K,I,O) = (K,B,O)\n", + " yk = paddle.bmm(xk, self.weight) # (K,B,O)\n", + " y = paddle.transpose(yk, [1, 0, 2]) # (B,K,O)\n", + " if self.bias_e is not None:\n", + " y = y + self.bias_e # 广播到 (B,K,O)\n", + " return y\n", + "\n", + "class ScaleEnsemble(nn.Layer):\n", + " \"\"\"Mini-Ensemble:每层一个 rank-1 缩放向量\"\"\"\n", + " def __init__(self, k: int, d: int, init='ones'):\n", + " super().__init__()\n", + " self.k = k\n", + " self.d = d\n", + " self.init = init\n", + " self.weight = self.create_parameter(shape=[k, d])\n", + " self.reset_parameters()\n", + "\n", + " def reset_parameters(self):\n", + " if self.init == 'ones':\n", + " self.weight.set_value(paddle.ones_like(self.weight))\n", + " else:\n", + " init_random_signs_(self.weight)\n", + "\n", + " def forward(self, x): # (B,K,D)\n", + " return x * self.weight # 广播到 (B,K,D)\n", + "\n", + "class LinearBE(nn.Layer):\n", + " \"\"\"\n", + " BatchEnsemble Linear(Paddle 布局):\n", + " 权重 W: [I, O];前向 y_e = ((x * r_e) @ W) * s_e + b_e\n", + " 输入: x (B,K,I)\n", + " 输出: y (B,K,O)\n", + " \"\"\"\n", + " def __init__(self, in_f: int, out_f: int, k: int,\n", + " scale_init='ones', bias: bool = True):\n", + " super().__init__()\n", + " self.k = k\n", + " self.in_f = in_f\n", + " self.out_f = out_f\n", + " # 显式属性名,避免冲突;按 Paddle 线性层布局 [I, O]\n", + " self.weight = self.create_parameter(shape=[in_f, out_f])\n", + " self.r = self.create_parameter(shape=[k, in_f])\n", + " self.s = self.create_parameter(shape=[k, out_f])\n", + " self.use_bias = bias\n", + " self.bias_e = self.create_parameter(shape=[k, out_f]) if bias else None\n", + " self.scale_init = scale_init\n", + " self.reset_parameters()\n", + "\n", + " def reset_parameters(self):\n", + " init_rsqrt_uniform_(self.weight)\n", + " if self.scale_init == 'ones':\n", + " self.r.set_value(paddle.ones_like(self.r))\n", + " self.s.set_value(paddle.ones_like(self.s))\n", + " else:\n", + " init_random_signs_(self.r)\n", + " init_random_signs_(self.s)\n", + " if self.use_bias:\n", + " init_rsqrt_uniform_(self.bias_e)\n", + "\n", + " def forward(self, x): # (B,K,I)\n", + " xr = x * self.r # (B,K,I)\n", + " # (B,K,I) @ (I,O) = (B,K,O)\n", + " y = paddle.matmul(xr, self.weight) # (B,K,O)\n", + " y = y * self.s # (B,K,O)\n", + " if self.use_bias:\n", + " y = y + self.bias_e\n", + " return y\n", + "\n", + "# ---------- Backbone MLP -----------------------------------------------------\n", + "class MLPBlock(nn.Layer):\n", + " def __init__(self, d_in, d_hid, dropout, act='ReLU'):\n", + " super().__init__()\n", + " Act = getattr(nn, act)\n", + " self.net = nn.Sequential(\n", + " nn.Linear(d_in, d_hid), # Paddle: weight [d_in, d_hid]\n", + " Act(),\n", + " nn.Dropout(dropout),\n", + " )\n", + "\n", + " def forward(self, x):\n", + " # 允许 (B,K,D) 或 (B,D);Linear 会在最后一维上工作\n", + " return self.net(x)\n", + "\n", + "class BackboneMLP(nn.Layer):\n", + " def __init__(self, n_blocks: int, d_in: int, d_hidden: int, dropout: float):\n", + " super().__init__()\n", + " blocks = []\n", + " for i in range(n_blocks):\n", + " blocks.append(\n", + " MLPBlock(d_in if i == 0 else d_hidden, d_hidden, dropout)\n", + " )\n", + " self.blocks = nn.LayerList(blocks)\n", + "\n", + " def forward(self, x):\n", + " for blk in self.blocks:\n", + " x = blk(x)\n", + " return x\n", + "\n", + "# ---------- 工具:递归替换 Linear 为 BE / Packed ---------------------------\n", + "def _get_parent_by_path(root: nn.Layer, path_list):\n", + " \"\"\"根据命名路径拿到父层(最后一个名是子层名)\"\"\"\n", + " cur = root\n", + " for p in path_list:\n", + " if hasattr(cur, p):\n", + " cur = getattr(cur, p)\n", + " else:\n", + " sub_layers = getattr(cur, \"_sub_layers\", None)\n", + " if sub_layers is None or p not in sub_layers:\n", + " raise AttributeError(f\"Cannot locate sublayer '{p}' under '{type(cur).__name__}'\")\n", + " cur = sub_layers[p]\n", + " return cur\n", + "\n", + "def _replace_linear(module: nn.Layer, k: int, mode: Literal['be', 'packed']):\n", + " \"\"\"\n", + " 遍历 module 的子层,把 nn.Linear 替换为 LinearBE 或 NLinear\n", + " 注意:Paddle Linear 的 weight 形状为 [in_features, out_features]\n", + " \"\"\"\n", + " to_replace = []\n", + "\n", + " for full_name, layer in module.named_sublayers(include_self=False):\n", + " if isinstance(layer, nn.Linear):\n", + " parts = full_name.split('.')\n", + " parent_path, child_name = parts[:-1], parts[-1]\n", + " parent = _get_parent_by_path(module, parent_path) if parent_path else module\n", + "\n", + " in_f = layer.weight.shape[0] # I\n", + " out_f = layer.weight.shape[1] # O\n", + "\n", + " if mode == 'be':\n", + " new_layer = LinearBE(in_f, out_f, k)\n", + " with paddle.no_grad():\n", + " # 拷贝共享主权重([I,O])与偏置([O])\n", + " assert list(new_layer.weight.shape) == list(layer.weight.shape), \\\n", + " f\"weight shape mismatch: {new_layer.weight.shape} vs {layer.weight.shape}\"\n", + " new_layer.weight.set_value(layer.weight.clone())\n", + " if layer.bias is not None and new_layer.bias_e is not None:\n", + " b = layer.bias.reshape([1, -1]).tile([k, 1]) # (K, O)\n", + " assert list(new_layer.bias_e.shape) == list(b.shape), \\\n", + " f\"bias shape mismatch: {new_layer.bias_e.shape} vs {b.shape}\"\n", + " new_layer.bias_e.set_value(b)\n", + " else: # 'packed'\n", + " new_layer = NLinear(k, in_f, out_f, bias=layer.bias is not None)\n", + " with paddle.no_grad():\n", + " # 每个 pack 共享同一权重初值: 原 (I,O) -> (K,I,O)\n", + " w = layer.weight.unsqueeze(0).tile([k, 1, 1]) # (K,I,O)\n", + " assert list(new_layer.weight.shape) == list(w.shape), \\\n", + " f\"packed weight shape mismatch: {new_layer.weight.shape} vs {w.shape}\"\n", + " new_layer.weight.set_value(w)\n", + " if layer.bias is not None and new_layer.bias_e is not None:\n", + " b = layer.bias.unsqueeze(0).tile([k, 1]) # (K,O)\n", + " assert list(new_layer.bias_e.shape) == list(b.shape), \\\n", + " f\"packed bias shape mismatch: {new_layer.bias_e.shape} vs {b.shape}\"\n", + " new_layer.bias_e.set_value(b)\n", + "\n", + " to_replace.append((parent, child_name, new_layer))\n", + "\n", + " # 正式替换\n", + " for parent, child_name, new_layer in to_replace:\n", + " if hasattr(parent, child_name):\n", + " setattr(parent, child_name, new_layer)\n", + " else:\n", + " sub_layers = getattr(parent, \"_sub_layers\", None)\n", + " if sub_layers is None or child_name not in sub_layers:\n", + " raise AttributeError(f\"Cannot set sublayer '{child_name}' under '{type(parent).__name__}'\")\n", + " parent._sub_layers[child_name] = new_layer\n", + "\n", + "# ---------- TabM 特征提取器 --------------------------------------------------\n", + "class TabMFeatureExtractor(nn.Layer):\n", + " \"\"\"\n", + " arch_type: 'plain' | 'tabm' | 'tabm-mini' | 'tabm-packed'\n", + " 返回:\n", + " - reduce=True → (B,H)\n", + " - reduce=False → (B,K,H)\n", + " \"\"\"\n", + " def __init__(self,\n", + " num_features: int,\n", + " arch_type: Literal['plain', 'tabm', 'tabm-mini', 'tabm-packed']='tabm',\n", + " k: int = 32,\n", + " backbone_cfg: Optional[dict] = None,\n", + " reduce: bool = True):\n", + " super().__init__()\n", + " if arch_type == 'plain':\n", + " k = 1\n", + " self.k = k\n", + " self.reduce = reduce\n", + " cfg = backbone_cfg or dict(n_blocks=3, d_hidden=512, dropout=0.1)\n", + " self.backbone = BackboneMLP(**cfg, d_in=num_features)\n", + "\n", + " # --- 插入 Ensemble 逻辑 ---\n", + " if arch_type == 'tabm':\n", + " _replace_linear(self.backbone, k, mode='be')\n", + " self.min_adapter = None\n", + " elif arch_type == 'tabm-mini':\n", + " self.min_adapter = ScaleEnsemble(k, num_features, init='random-signs')\n", + " elif arch_type == 'tabm-packed':\n", + " _replace_linear(self.backbone, k, mode='packed')\n", + " self.min_adapter = None\n", + " else: # plain\n", + " self.min_adapter = None\n", + "\n", + " def forward(self, x_num: paddle.Tensor):\n", + " \"\"\"\n", + " x_num : (B, num_features) – 已完成数值化/标准化\n", + " \"\"\"\n", + " if self.k > 1:\n", + " x = x_num.unsqueeze(1).tile([1, self.k, 1]) # (B,K,D)\n", + " else:\n", + " x = x_num.unsqueeze(1) # (B,1,D)\n", + "\n", + " if self.min_adapter is not None:\n", + " x = self.min_adapter(x) # (B,K,D)\n", + "\n", + " features = self.backbone(x) # (B,K,H)\n", + " if self.reduce:\n", + " return features.mean(axis=1) # (B,H)\n", + " return features # (B,K,H)\n", + "\n", + "# ---------------- Quick check ----------------\n", + "if __name__ == '__main__':\n", + " paddle.seed(123)\n", + " B, D = 8, 30\n", + " x = paddle.randn([B, D])\n", + " # 1) 标准 TabM(BatchEnsemble 替换)\n", + " fe1 = TabMFeatureExtractor(D, arch_type='tabm', k=16, reduce=True)\n", + " out1 = fe1(x)\n", + " print('TabM-BE features:', list(out1.shape)) # (B, H)\n", + "\n", + " # 2) tabm-mini(只做 rank-1 缩放)\n", + " fe2 = TabMFeatureExtractor(D, arch_type='tabm-mini', k=16, reduce=False)\n", + " out2 = fe2(x)\n", + " print('TabM-mini features:', list(out2.shape)) # (B, K, H)\n", + "\n", + " # 3) tabm-packed(Packed NLinear)\n", + " fe3 = TabMFeatureExtractor(D, arch_type='tabm-packed', k=8, reduce=True)\n", + " out3 = fe3(x)\n", + " print('TabM-packed features:', list(out3.shape)) # (B, H)\n", + "\n", + " # 4) plain(无集成基线)\n", + " fe4 = TabMFeatureExtractor(D, arch_type='plain', k=1, reduce=True)\n", + " out4 = fe4(x)\n", + " print('Plain features:', list(out4.shape)) # (B, H)\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "TCF5hoiWF-N4", + "outputId": "df4ea66c-1716-4357-fa58-fb5f8f57d183" + }, + "execution_count": 8, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "TabM-BE features: [8, 512]\n", + "TabM-mini features: [8, 16, 512]\n", + "TabM-packed features: [8, 512]\n", + "Plain features: [8, 512]\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# -*- coding: utf-8 -*-\n", + "import math\n", + "from typing import Optional, Literal, Tuple\n", + "import numpy as np\n", + "import paddle\n", + "import paddle.nn as nn\n", + "import paddle.nn.functional as F\n", + "from paddle.io import Dataset, DataLoader\n", + "\n", + "\n", + "# ====== 数据集(示例:合成数据)========================================\n", + "class ToyMultiLabelDataset(Dataset):\n", + " \"\"\"\n", + " 返回:\n", + " x_num: float32, 形状 (D,)\n", + " y: float32, 形状 (4,) —— 多标签 0/1\n", + " \"\"\"\n", + " def __init__(self, n: int, d: int, seed: int = 123):\n", + " super().__init__()\n", + " rng = np.random.default_rng(seed)\n", + " self.X = rng.normal(size=(n, d)).astype('float32')\n", + " # 随机生成 4 个线性规则 + 噪声,得到多标签\n", + " W = rng.normal(size=(d, 4))\n", + " logits = self.X @ W + rng.normal(scale=0.5, size=(n, 4))\n", + " probs = 1.0 / (1.0 + np.exp(-logits))\n", + " self.Y = (probs > 0.5).astype('float32')\n", + "\n", + " def __getitem__(self, idx: int):\n", + " return self.X[idx], self.Y[idx]\n", + "\n", + " def __len__(self) -> int:\n", + " return len(self.X)\n", + "\n", + "# ====== 模型:特征抽取 + 多标签头 ======================================\n", + "class MultiLabelClassifier(nn.Layer):\n", + " def __init__(self, num_features: int, num_labels: int = 4,\n", + " arch_type: str = 'tabm', k: int = 16,\n", + " backbone_cfg: Optional[dict] = None):\n", + " super().__init__()\n", + " self.fe = TabMFeatureExtractor(\n", + " num_features=num_features,\n", + " arch_type=arch_type,\n", + " k=k,\n", + " backbone_cfg=backbone_cfg,\n", + " reduce=True\n", + " )\n", + " # 推断隐藏维度(若你的 TabM 有属性可读,直接使用;否则手动传入)\n", + " d_hidden = getattr(self.fe, \"d_hidden\", (backbone_cfg or dict(d_hidden=512))[\"d_hidden\"])\n", + " self.head = nn.Linear(d_hidden, num_labels)\n", + "\n", + " def forward(self, x_num: paddle.Tensor) -> paddle.Tensor:\n", + " # x_num: (B, D)\n", + " h = self.fe(x_num) # (B, H)\n", + " logits = self.head(h) # (B, 4)\n", + " return logits\n", + "\n", + "# ====== 评价指标:F1、AP 等 ==============================================\n", + "def f1_per_class(y_true: np.ndarray, y_pred: np.ndarray, eps: float = 1e-9) -> Tuple[np.ndarray, float, float]:\n", + " \"\"\"\n", + " y_true: (N, C) 0/1\n", + " y_pred: (N, C) 0/1\n", + " 返回: per_class F1, macro-F1, micro-F1\n", + " \"\"\"\n", + " assert y_true.shape == y_pred.shape\n", + " N, C = y_true.shape\n", + " f1_c = np.zeros(C, dtype=np.float32)\n", + "\n", + " # per-class\n", + " for c in range(C):\n", + " yt = y_true[:, c]\n", + " yp = y_pred[:, c]\n", + " tp = np.sum((yt == 1) & (yp == 1))\n", + " fp = np.sum((yt == 0) & (yp == 1))\n", + " fn = np.sum((yt == 1) & (yp == 0))\n", + " prec = tp / (tp + fp + eps)\n", + " rec = tp / (tp + fn + eps)\n", + " f1_c[c] = 2 * prec * rec / (prec + rec + eps)\n", + "\n", + " macro_f1 = float(np.mean(f1_c))\n", + "\n", + " # micro\n", + " tp = np.sum((y_true == 1) & (y_pred == 1))\n", + " fp = np.sum((y_true == 0) & (y_pred == 1))\n", + " fn = np.sum((y_true == 1) & (y_pred == 0))\n", + " prec = tp / (tp + fp + 1e-9)\n", + " rec = tp / (tp + fn + 1e-9)\n", + " micro_f1 = 2 * prec * rec / (prec + rec + 1e-9)\n", + " return f1_c, macro_f1, float(micro_f1)\n", + "\n", + "def average_precision_micro(y_true: np.ndarray, y_prob: np.ndarray, num_thresholds: int = 101) -> float:\n", + " \"\"\"\n", + " 简易版 micro-AP(AUCPR):在 0~1 阈值上扫一遍,近似计算 PR 曲线下面积\n", + " \"\"\"\n", + " thresholds = np.linspace(0.0, 1.0, num_thresholds)\n", + " precision, recall = [], []\n", + " for t in thresholds:\n", + " y_pred = (y_prob >= t).astype(np.float32)\n", + " tp = np.sum((y_true == 1) & (y_pred == 1))\n", + " fp = np.sum((y_true == 0) & (y_pred == 1))\n", + " fn = np.sum((y_true == 1) & (y_pred == 0))\n", + " p = tp / (tp + fp + 1e-9)\n", + " r = tp / (tp + fn + 1e-9)\n", + " precision.append(p); recall.append(r)\n", + " # 按 recall 升序进行梯形积分\n", + " order = np.argsort(recall)\n", + " recall = np.array(recall)[order]\n", + " precision = np.array(precision)[order]\n", + " auc_pr = np.trapz(precision, recall)\n", + " return float(auc_pr)\n", + "\n", + "# ====== 训练/验证循环 =====================================================\n", + "def train_one_epoch(model, loader, optimizer,\n", + " pos_weight: Optional[paddle.Tensor] = None,\n", + " clip_grad_norm: Optional[float] = None,\n", + " device: str = 'gpu' if paddle.is_compiled_with_cuda() else 'cpu'):\n", + " model.train()\n", + " total_loss = 0.0\n", + " total_batches = 0\n", + " for x, y in loader:\n", + " x = x.astype('float32')\n", + " y = y.astype('float32')\n", + " logits = model(x)\n", + " # BCE with logits(支持 pos_weight)\n", + " if pos_weight is not None:\n", + " loss = F.binary_cross_entropy_with_logits(logits, y, pos_weight=pos_weight)\n", + " else:\n", + " loss = F.binary_cross_entropy_with_logits(logits, y)\n", + " loss.backward()\n", + " if clip_grad_norm is not None:\n", + " nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad_norm)\n", + " optimizer.step()\n", + " optimizer.clear_grad()\n", + " total_loss += float(loss)\n", + " total_batches += 1\n", + " return total_loss / max(1, total_batches)\n", + "\n", + "@paddle.no_grad()\n", + "def evaluate(model, loader, threshold: float = 0.5):\n", + " model.eval()\n", + " ys, ps = [], []\n", + " total_loss, total_batches = 0.0, 0\n", + " for x, y in loader:\n", + " x = x.astype('float32'); y = y.astype('float32')\n", + " logits = model(x) # (B,4)\n", + " loss = F.binary_cross_entropy_with_logits(logits, y)\n", + " prob = F.sigmoid(logits).numpy() # (B,4)\n", + " ys.append(y.numpy())\n", + " ps.append(prob)\n", + " total_loss += float(loss)\n", + " total_batches += 1\n", + " y_true = np.concatenate(ys, axis=0)\n", + " y_prob = np.concatenate(ps, axis=0)\n", + " y_pred = (y_prob >= threshold).astype(np.float32)\n", + "\n", + " per_f1, macro_f1, micro_f1 = f1_per_class(y_true, y_pred)\n", + " ap_micro = average_precision_micro(y_true, y_prob)\n", + " avg_loss = total_loss / max(1, total_batches)\n", + " metrics = {\n", + " \"loss\": avg_loss,\n", + " \"macro_f1\": macro_f1,\n", + " \"micro_f1\": micro_f1,\n", + " \"per_class_f1\": per_f1.tolist(),\n", + " \"micro_AP\": ap_micro\n", + " }\n", + " return metrics\n", + "\n", + "# ====== 主函数:跑通一个最小示例 ===========================================\n", + "if __name__ == \"__main__\":\n", + " paddle.seed(2025)\n", + " # 配置\n", + " D = 30 # 数值特征维度\n", + " C = 4 # 多标签数\n", + " N_train, N_val = 5000, 1000\n", + " batch_size = 128\n", + " epochs = 5\n", + " lr = 3e-4\n", + "\n", + " # 数据\n", + " train_ds = ToyMultiLabelDataset(N_train, D, seed=42)\n", + " val_ds = ToyMultiLabelDataset(N_val, D, seed=233)\n", + " train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=False)\n", + " val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, drop_last=False)\n", + "\n", + " # 类别不平衡(可选):按训练集估计每个标签的正例比例,构造 pos_weight\n", + " y_train = np.vstack([y for _, y in train_ds])\n", + " pos_ratio = np.clip(y_train.mean(axis=0), 1e-3, 1-1e-3) # (4,)\n", + " # 经典做法:pos_weight = (N_neg / N_pos) = (1-p)/p\n", + " pos_weight_np = (1.0 - pos_ratio) / pos_ratio\n", + " pos_weight = paddle.to_tensor(pos_weight_np.astype('float32')) # (4,)\n", + "\n", + " # 模型\n", + " backbone_cfg = dict(n_blocks=3, d_hidden=512, dropout=0.1)\n", + " model = MultiLabelClassifier(num_features=D, num_labels=C,\n", + " arch_type='tabm', k=16,\n", + " backbone_cfg=backbone_cfg)\n", + "\n", + " optimizer = paddle.optimizer.Adam(learning_rate=lr, parameters=model.parameters())\n", + "\n", + " # 训练\n", + " best_macro_f1, best_state = -1.0, None\n", + " for ep in range(1, epochs + 1):\n", + " train_loss = train_one_epoch(model, train_loader, optimizer,\n", + " pos_weight=pos_weight, clip_grad_norm=1.0)\n", + " val_metrics = evaluate(model, val_loader, threshold=0.5)\n", + " print(f\"[Epoch {ep:02d}] train_loss={train_loss:.4f} | \"\n", + " f\"val_loss={val_metrics['loss']:.4f} | \"\n", + " f\"macro_f1={val_metrics['macro_f1']:.4f} | \"\n", + " f\"micro_f1={val_metrics['micro_f1']:.4f} | \"\n", + " f\"per_class_f1={val_metrics['per_class_f1']} | \"\n", + " f\"micro_AP={val_metrics['micro_AP']:.4f}\")\n", + " # 记录最佳\n", + " if val_metrics[\"macro_f1\"] > best_macro_f1:\n", + " best_macro_f1 = val_metrics[\"macro_f1\"]\n", + " best_state = {k: v.clone() for k, v in model.state_dict().items()}\n", + "\n", + " if best_state is not None:\n", + " model.set_state_dict(best_state)\n", + " print(f\"Loaded best state with macro_f1={best_macro_f1:.4f}\")\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "qOJBZjSyGpSA", + "outputId": "c23c05d7-0f63-406f-fbcc-66ea0c3065fc" + }, + "execution_count": 9, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/tmp/ipython-input-1539235308.py:108: DeprecationWarning: `trapz` is deprecated. Use `trapezoid` instead, or one of the numerical integration functions in `scipy.integrate`.\n", + " auc_pr = np.trapz(precision, recall)\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[Epoch 01] train_loss=0.4427 | val_loss=1.5854 | macro_f1=0.4728 | micro_f1=0.4734 | per_class_f1=[0.5263158082962036, 0.5373737215995789, 0.4471057951450348, 0.38046795129776] | micro_AP=0.4584\n", + "[Epoch 02] train_loss=0.1538 | val_loss=2.9504 | macro_f1=0.4818 | micro_f1=0.4837 | per_class_f1=[0.553903341293335, 0.5406504273414612, 0.44742268323898315, 0.38532111048698425] | micro_AP=0.4769\n", + "[Epoch 03] train_loss=0.1057 | val_loss=3.7486 | macro_f1=0.4919 | micro_f1=0.4941 | per_class_f1=[0.5516605377197266, 0.568965494632721, 0.4606299102306366, 0.38624873757362366] | micro_AP=0.4722\n", + "[Epoch 04] train_loss=0.0917 | val_loss=4.2100 | macro_f1=0.4762 | micro_f1=0.4774 | per_class_f1=[0.5183752179145813, 0.557729959487915, 0.43551796674728394, 0.3932472765445709] | micro_AP=0.4652\n", + "[Epoch 05] train_loss=0.0732 | val_loss=4.7243 | macro_f1=0.4810 | micro_f1=0.4820 | per_class_f1=[0.5422138571739197, 0.5443425178527832, 0.4589178264141083, 0.3785425126552582] | micro_AP=0.4507\n", + "Loaded best state with macro_f1=0.4919\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# -*- coding: utf-8 -*-\n", + "import math\n", + "from typing import Optional, Literal, Tuple\n", + "import numpy as np\n", + "import paddle\n", + "import paddle.nn as nn\n", + "import paddle.nn.functional as F\n", + "from paddle.io import Dataset, DataLoader\n", + "from paddle.vision.models import resnet18\n", + "\n", + "# ====================== 工具:正弦位置编码 ======================\n", + "class SinusoidalPositionalEncoding(nn.Layer):\n", + " def __init__(self, d_model: int, max_len: int = 2048):\n", + " super().__init__()\n", + " pe = np.zeros((max_len, d_model), dtype=\"float32\")\n", + " position = np.arange(0, max_len, dtype=\"float32\")[:, None]\n", + " div_term = np.exp(np.arange(0, d_model, 2, dtype=\"float32\") * (-math.log(10000.0) / d_model))\n", + " pe[:, 0::2] = np.sin(position * div_term)\n", + " pe[:, 1::2] = np.cos(position * div_term)\n", + " self.register_buffer(\"pe\", paddle.to_tensor(pe), persistable=False)\n", + "\n", + " def forward(self, x): # x: (B, T, D)\n", + " T = x.shape[1]\n", + " return x + self.pe[:T, :]\n", + "\n", + "# ====================== 简化版 TabM(占位,可换成你的实现) ======================\n", + "class TabMFeatureExtractor(nn.Layer):\n", + " \"\"\"占位实现:MLP → (B, H)。可直接替换为你修好的 TabM。\"\"\"\n", + " def __init__(self, num_features: int, d_hidden: int = 512, dropout: float = 0.1):\n", + " super().__init__()\n", + " self.net = nn.Sequential(\n", + " nn.Linear(num_features, d_hidden),\n", + " nn.ReLU(),\n", + " nn.Dropout(dropout),\n", + " nn.Linear(d_hidden, d_hidden),\n", + " nn.ReLU(),\n", + " )\n", + " self.d_hidden = d_hidden\n", + "\n", + " def forward(self, x_num: paddle.Tensor): # (B, 424)\n", + " return self.net(x_num) # (B, H)\n", + "\n", + "# ====================== ResNet18 特征抽取(逐帧) ======================\n", + "class ResNet18FrameEncoder(nn.Layer):\n", + " \"\"\"将 ResNet18 改为 20 通道输入;输出每帧 512 维特征。\"\"\"\n", + " def __init__(self, in_channels: int = 20):\n", + " super().__init__()\n", + " self.backbone = resnet18(pretrained=False)\n", + " # 改首层卷积为 20 通道\n", + " self.backbone.conv1 = nn.Conv2D(in_channels, 64, kernel_size=7, stride=2, padding=3, bias_attr=False)\n", + " # 去掉分类头 fc,保留到 avgpool\n", + " self.avgpool = self.backbone.avgpool # AdaptiveAvgPool2D(1)\n", + " # 记录下游维度\n", + " self.out_dim = 512\n", + "\n", + " def forward(self, x): # x: (B*T, C=20, H=20, W=20)\n", + " m = self.backbone\n", + " x = m.conv1(x); x = m.bn1(x); x = F.relu(x); x = m.maxpool(x)\n", + " x = m.layer1(x); x = m.layer2(x); x = m.layer3(x); x = m.layer4(x)\n", + " x = self.avgpool(x) # (B*T, 512, 1, 1)\n", + " x = paddle.flatten(x, 1) # (B*T, 512)\n", + " return x\n", + "\n", + "# ====================== 时序 Transformer 编码器 ======================\n", + "class TemporalTransformer(nn.Layer):\n", + " def __init__(self, d_model=512, nhead=8, num_layers=4, dim_feedforward=1024, dropout=0.1, max_len=1024):\n", + " super().__init__()\n", + " enc_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,\n", + " dim_feedforward=dim_feedforward,\n", + " dropout=dropout, activation='relu')\n", + " self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)\n", + " self.pos = SinusoidalPositionalEncoding(d_model, max_len=max_len)\n", + "\n", + " def forward(self, x): # x: (B, T, D)\n", + " x = self.pos(x)\n", + " # Paddle 的 Transformer 期望 (T, B, D)\n", + " x = paddle.transpose(x, [1, 0, 2]) # (T,B,D)\n", + " z = self.encoder(x) # (T,B,D)\n", + " z = paddle.transpose(z, [1, 0, 2]) # (B,T,D)\n", + " return z\n", + "\n", + "# ====================== 多头注意力(支持 q from A, kv from B) ======================\n", + "class MultiHeadCrossAttention(nn.Layer):\n", + " def __init__(self, d_model: int, nhead: int = 8, dropout: float = 0.1):\n", + " super().__init__()\n", + " assert d_model % nhead == 0\n", + " self.d_model = d_model\n", + " self.nhead = nhead\n", + " self.d_head = d_model // nhead\n", + " self.Wq = nn.Linear(d_model, d_model)\n", + " self.Wk = nn.Linear(d_model, d_model)\n", + " self.Wv = nn.Linear(d_model, d_model)\n", + " self.proj = nn.Linear(d_model, d_model)\n", + " self.drop = nn.Dropout(dropout)\n", + " self.ln = nn.LayerNorm(d_model)\n", + "\n", + " def forward(self, q, kv):\n", + " \"\"\"\n", + " q: (B, Nq, D)\n", + " kv: (B, Nk, D)\n", + " return: (B, Nq, D) # 残差 + LN\n", + " \"\"\"\n", + " B, Nq, D = q.shape\n", + " Nk = kv.shape[1]\n", + "\n", + " q_lin = self.Wq(q) # (B,Nq,D)\n", + " k_lin = self.Wk(kv) # (B,Nk,D)\n", + " v_lin = self.Wv(kv) # (B,Nk,D)\n", + "\n", + " def split_heads(t): # (B,N,Heads,dh)\n", + " return t.reshape([B, -1, self.nhead, self.d_head]).transpose([0, 2, 1, 3])\n", + "\n", + " qh = split_heads(q_lin) # (B,H,Nq,dh)\n", + " kh = split_heads(k_lin) # (B,H,Nk,dh)\n", + " vh = split_heads(v_lin) # (B,H,Nk,dh)\n", + "\n", + " scores = paddle.matmul(qh, kh, transpose_y=True) / math.sqrt(self.d_head) # (B,H,Nq,Nk)\n", + " attn = F.softmax(scores, axis=-1)\n", + " ctx = paddle.matmul(attn, vh) # (B,H,Nq,dh)\n", + "\n", + " ctx = ctx.transpose([0, 2, 1, 3]).reshape([B, Nq, D]) # (B,Nq,D)\n", + " out = self.proj(ctx)\n", + " out = self.drop(out)\n", + " # 残差 + LN\n", + " return self.ln(out + q)\n", + "\n", + "# ====================== 融合头(双向 Cross-Attn) ======================\n", + "class BiModalCrossFusion(nn.Layer):\n", + " \"\"\"\n", + " 输入:\n", + " video_seq: (B, T, D) —— Transformer 后的视频序列\n", + " tabm_tok: (B, D) —— TabM token\n", + " 过程:\n", + " v_token = mean(video_seq)\n", + " v' = CrossAttn(q=v_token[1], kv=tabm_token[1])\n", + " t' = CrossAttn(q=tabm_token[1], kv=video_seq[T])\n", + " fuse = concat([v', t']) → MLP\n", + " \"\"\"\n", + " def __init__(self, d_model=512, nhead=8, dropout=0.1, fuse_hidden=512):\n", + " super().__init__()\n", + " self.ca_v_from_t = MultiHeadCrossAttention(d_model, nhead, dropout)\n", + " self.ca_t_from_v = MultiHeadCrossAttention(d_model, nhead, dropout)\n", + " self.fuse = nn.Sequential(\n", + " nn.Linear(2 * d_model, fuse_hidden),\n", + " nn.ReLU(),\n", + " nn.Dropout(dropout),\n", + " )\n", + " self.out_dim = fuse_hidden\n", + "\n", + " def forward(self, video_seq, tabm_tok):\n", + " B, T, D = video_seq.shape\n", + " # 池化出视频 token\n", + " v_tok = video_seq.mean(axis=1, keepdim=True) # (B,1,D)\n", + " t_tok = tabm_tok.unsqueeze(1) # (B,1,D)\n", + "\n", + " v_upd = self.ca_v_from_t(v_tok, t_tok) # (B,1,D)\n", + " t_upd = self.ca_t_from_v(t_tok, video_seq) # (B,1,D)\n", + "\n", + " fused = paddle.concat([v_upd, t_upd], axis=-1) # (B,1,2D)\n", + " fused = fused.squeeze(1) # (B,2D)\n", + " return self.fuse(fused) # (B, F)\n", + "\n", + "# ====================== 总模型 ======================\n", + "class TwoModalMultiLabelModel(nn.Layer):\n", + " def __init__(self,\n", + " # 视频模态\n", + " vid_channels=20, vid_h=20, vid_w=20, vid_frames=36,\n", + " # 结构化模态\n", + " vec_dim=424,\n", + " # 维度与结构\n", + " d_model=512, nhead=8, n_trans_layers=4, trans_ff=1024,\n", + " tabm_hidden=512, dropout=0.1, num_labels=4):\n", + " super().__init__()\n", + " # A: 逐帧 ResNet18\n", + " self.frame_encoder = ResNet18FrameEncoder(in_channels=vid_channels) # (B*T,512)\n", + " # A: 时序 Transformer\n", + " self.temporal = TemporalTransformer(d_model=d_model,\n", + " nhead=nhead,\n", + " num_layers=n_trans_layers,\n", + " dim_feedforward=trans_ff,\n", + " dropout=dropout,\n", + " max_len=vid_frames)\n", + " # B: TabM(或替换为你的 TabM)\n", + " self.tabm = TabMFeatureExtractor(vec_dim, d_hidden=tabm_hidden, dropout=dropout)\n", + " self.tabm_proj = nn.Linear(tabm_hidden, d_model) # 对齐到 d_model\n", + " # 融合:双向 Cross-Attention\n", + " self.fusion = BiModalCrossFusion(d_model=d_model, nhead=nhead, dropout=dropout, fuse_hidden=d_model)\n", + " # 分类头\n", + " self.head = nn.Linear(self.fusion.out_dim, num_labels)\n", + "\n", + " def forward(self, x_video, x_vec):\n", + " \"\"\"\n", + " x_video: (B, T, C=20, H=20, W=20)\n", + " x_vec: (B, 424)\n", + " \"\"\"\n", + " B, T, C, H, W = x_video.shape\n", + " # ---- A: 逐帧 ResNet ----\n", + " xvt = x_video.reshape([B * T, C, H, W]) # (B*T, C, H, W)\n", + " f_frame = self.frame_encoder(xvt) # (B*T, 512)\n", + " f_seq = f_frame.reshape([B, T, -1]) # (B, T, 512)\n", + " # ---- A: 时序 Transformer ----\n", + " z_vid = self.temporal(f_seq) # (B, T, 512)\n", + " # ---- B: TabM 特征 ----\n", + " z_tabm = self.tabm(x_vec) # (B, H_tabm)\n", + " z_tabm = self.tabm_proj(z_tabm) # (B, 512)\n", + " # ---- Cross-Attention 融合 ----\n", + " fused = self.fusion(z_vid, z_tabm) # (B, 512)\n", + " # ---- 分类 ----\n", + " logits = self.head(fused) # (B, 4)\n", + " return logits\n", + "\n", + "# ====================== 指标与训练循环(与前一致) ======================\n", + "def f1_per_class(y_true: np.ndarray, y_pred: np.ndarray, eps: float = 1e-9) -> Tuple[np.ndarray, float, float]:\n", + " assert y_true.shape == y_pred.shape\n", + " N, C = y_true.shape\n", + " f1_c = np.zeros(C, dtype=np.float32)\n", + " for c in range(C):\n", + " yt, yp = y_true[:, c], y_pred[:, c]\n", + " tp = np.sum((yt == 1) & (yp == 1))\n", + " fp = np.sum((yt == 0) & (yp == 1))\n", + " fn = np.sum((yt == 1) & (yp == 0))\n", + " prec = tp / (tp + fp + eps)\n", + " rec = tp / (tp + fn + eps)\n", + " f1_c[c] = 2 * prec * rec / (prec + rec + eps)\n", + " macro_f1 = float(np.mean(f1_c))\n", + " tp = np.sum((y_true == 1) & (y_pred == 1))\n", + " fp = np.sum((y_true == 0) & (y_pred == 1))\n", + " fn = np.sum((y_true == 1) & (y_pred == 0))\n", + " prec = tp / (tp + fp + 1e-9)\n", + " rec = tp / (tp + fn + 1e-9)\n", + " micro_f1 = 2 * prec * rec / (prec + rec + 1e-9)\n", + " return f1_c, macro_f1, float(micro_f1)\n", + "\n", + "def average_precision_micro(y_true: np.ndarray, y_prob: np.ndarray, num_thresholds: int = 101) -> float:\n", + " thresholds = np.linspace(0.0, 1.0, num_thresholds)\n", + " precision, recall = [], []\n", + " for t in thresholds:\n", + " y_pred = (y_prob >= t).astype(np.float32)\n", + " tp = np.sum((y_true == 1) & (y_pred == 1))\n", + " fp = np.sum((y_true == 0) & (y_pred == 1))\n", + " fn = np.sum((y_true == 1) & (y_pred == 0))\n", + " p = tp / (tp + fp + 1e-9)\n", + " r = tp / (tp + fn + 1e-9)\n", + " precision.append(p); recall.append(r)\n", + " order = np.argsort(recall)\n", + " recall = np.array(recall)[order]\n", + " precision = np.array(precision)[order]\n", + " return float(np.trapz(precision, recall))\n", + "\n", + "def train_one_epoch(model, loader, optimizer,\n", + " pos_weight: Optional[paddle.Tensor] = None,\n", + " clip_grad_norm: Optional[float] = None):\n", + " model.train()\n", + " total_loss, total_batches = 0.0, 0\n", + " for x_vid, x_vec, y in loader:\n", + " logits = model(x_vid.astype('float32'), x_vec.astype('float32'))\n", + " if pos_weight is not None:\n", + " loss = F.binary_cross_entropy_with_logits(logits, y.astype('float32'), pos_weight=pos_weight)\n", + " else:\n", + " loss = F.binary_cross_entropy_with_logits(logits, y.astype('float32'))\n", + " loss.backward()\n", + " if clip_grad_norm is not None:\n", + " nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad_norm)\n", + " optimizer.step()\n", + " optimizer.clear_grad()\n", + " total_loss += float(loss); total_batches += 1\n", + " return total_loss / max(1, total_batches)\n", + "\n", + "@paddle.no_grad()\n", + "def evaluate(model, loader, threshold: float = 0.5):\n", + " model.eval()\n", + " ys, ps = [], []\n", + " total_loss, total_batches = 0.0, 0\n", + " for x_vid, x_vec, y in loader:\n", + " logits = model(x_vid.astype('float32'), x_vec.astype('float32'))\n", + " loss = F.binary_cross_entropy_with_logits(logits, y.astype('float32'))\n", + " prob = F.sigmoid(logits).numpy()\n", + " ys.append(y.numpy()); ps.append(prob)\n", + " total_loss += float(loss); total_batches += 1\n", + " y_true = np.concatenate(ys, axis=0)\n", + " y_prob = np.concatenate(ps, axis=0)\n", + " y_pred = (y_prob >= threshold).astype(np.float32)\n", + " per_f1, macro_f1, micro_f1 = f1_per_class(y_true, y_pred)\n", + " ap_micro = average_precision_micro(y_true, y_prob)\n", + " return {\n", + " \"loss\": total_loss / max(1, total_batches),\n", + " \"macro_f1\": macro_f1,\n", + " \"micro_f1\": micro_f1,\n", + " \"per_class_f1\": per_f1.tolist(),\n", + " \"micro_AP\": ap_micro\n", + " }\n", + "\n", + "# ====================== 合成数据集(可替换为真实数据) ======================\n", + "class ToyTwoModalDataset(Dataset):\n", + " \"\"\"\n", + " 返回:\n", + " x_video: (T=365, C=20, H=20, W=20)\n", + " x_vec: (424,)\n", + " y: (4,) 0/1\n", + " \"\"\"\n", + " def __init__(self, n: int, seed: int = 0):\n", + " super().__init__()\n", + " rng = np.random.default_rng(seed)\n", + " self.n = n\n", + " # 按 (n, T, C, H, W)\n", + " self.video = rng.normal(size=(n, 36, 20, 20, 20)).astype('float32')\n", + " self.vec = rng.normal(size=(n, 424)).astype('float32')\n", + "\n", + " # 造标签:对视频先在 H/W 上均值,再在 T 上均值 → (n, C=20)\n", + " vid_hw = self.video.mean(axis=(3, 4)) # (n, T, C)\n", + " vid_avg = vid_hw.mean(axis=1) # (n, C)\n", + "\n", + " # 线性映射到 4 个标签\n", + " Wv = rng.normal(size=(20, 4)) # C→4\n", + " Wt = rng.normal(size=(424, 4)) # 424→4\n", + " logits = vid_avg @ Wv + self.vec @ Wt + rng.normal(scale=0.5, size=(n, 4))\n", + " probs = 1.0 / (1.0 + np.exp(-logits))\n", + " self.y = (probs > 0.5).astype('float32')\n", + "\n", + " def __getitem__(self, idx: int):\n", + " x_vid = self.video[idx] # (T,C,H,W)\n", + " x_vec = self.vec[idx] # (424,)\n", + " y = self.y[idx] # (4,)\n", + " return x_vid, x_vec, y\n", + "\n", + " def __len__(self):\n", + " return self.n\n", + "\n", + "# ====================== 训练入口(可直接运行) ======================\n", + "if __name__ == \"__main__\":\n", + " paddle.seed(2025)\n", + " # 数据\n", + " train_ds = ToyTwoModalDataset(n=64, seed=42) # 注意:真实训练建议更大数据与多卡\n", + " val_ds = ToyTwoModalDataset(n=32, seed=233)\n", + " # 自定义 collate:让视频变成 (B,T,C,H,W)\n", + " def collate_fn(batch):\n", + " vids, vecs, ys = zip(*batch)\n", + " return (paddle.to_tensor(np.stack(vids, 0)), # (B,T,C,H,W)\n", + " paddle.to_tensor(np.stack(vecs, 0)), # (B,424)\n", + " paddle.to_tensor(np.stack(ys, 0))) # (B,4)\n", + " train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, drop_last=False, collate_fn=collate_fn)\n", + " val_loader = DataLoader(val_ds, batch_size=2, shuffle=False, drop_last=False, collate_fn=collate_fn)\n", + "\n", + " # 类别不平衡权重(可选)\n", + " y_train = np.stack([y for _, _, y in train_ds], 0)\n", + " pos_ratio = np.clip(y_train.mean(axis=0), 1e-3, 1-1e-3)\n", + " pos_weight = paddle.to_tensor(((1-pos_ratio)/pos_ratio).astype('float32')) # (4,)\n", + "\n", + " # 模型\n", + " model = TwoModalMultiLabelModel(\n", + " vid_channels=20, vid_h=20, vid_w=20, vid_frames=36,\n", + " vec_dim=424,\n", + " d_model=512, nhead=2, n_trans_layers=2, trans_ff=1024, # 可调\n", + " tabm_hidden=512, dropout=0.1,\n", + " num_labels=4\n", + " )\n", + " optimizer = paddle.optimizer.Adam(learning_rate=3e-4, parameters=model.parameters())\n", + "\n", + " # 训练(演示用:小 epoch)\n", + " best_macro_f1, best = -1.0, None\n", + " for ep in range(1, 3+1):\n", + " train_loss = train_one_epoch(model, train_loader, optimizer,\n", + " pos_weight=pos_weight, clip_grad_norm=1.0)\n", + " val_metrics = evaluate(model, val_loader, threshold=0.5)\n", + " print(f\"[Epoch {ep:02d}] train_loss={train_loss:.4f} | \"\n", + " f\"val_loss={val_metrics['loss']:.4f} | \"\n", + " f\"macro_f1={val_metrics['macro_f1']:.4f} | \"\n", + " f\"micro_f1={val_metrics['micro_f1']:.4f} | \"\n", + " f\"per_class_f1={val_metrics['per_class_f1']} | \"\n", + " f\"micro_AP={val_metrics['micro_AP']:.4f}\")\n", + " if val_metrics[\"macro_f1\"] > best_macro_f1:\n", + " best_macro_f1 = val_metrics[\"macro_f1\"]\n", + " best = {k: v.clone() for k, v in model.state_dict().items()}\n", + "\n", + " if best is not None:\n", + " model.set_state_dict(best)\n", + " print(f\"Loaded best state with macro_f1={best_macro_f1:.4f}\")\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "X7-O1-LZLHB2", + "outputId": "124d09eb-3e79-4622-f0b3-b2bf7eb11d60" + }, + "execution_count": 14, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/tmp/ipython-input-3157620546.py:248: DeprecationWarning: `trapz` is deprecated. Use `trapezoid` instead, or one of the numerical integration functions in `scipy.integrate`.\n", + " return float(np.trapz(precision, recall))\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[Epoch 01] train_loss=1.1304 | val_loss=1.0465 | macro_f1=0.3765 | micro_f1=0.5079 | per_class_f1=[0.5, 0.7450980544090271, 0.260869562625885, 0.0] | micro_AP=0.5393\n", + "[Epoch 02] train_loss=0.7661 | val_loss=0.9369 | macro_f1=0.5812 | micro_f1=0.6173 | per_class_f1=[0.5454545617103577, 0.7599999904632568, 0.6382978558540344, 0.380952388048172] | micro_AP=0.4165\n", + "[Epoch 03] train_loss=0.4050 | val_loss=1.9395 | macro_f1=0.6107 | micro_f1=0.6303 | per_class_f1=[0.5454545617103577, 0.7234042286872864, 0.6938775777816772, 0.47999998927116394] | micro_AP=0.3836\n", + "Loaded best state with macro_f1=0.6107\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# -*- coding: utf-8 -*-\n", + "import math\n", + "from typing import Optional, Tuple\n", + "import numpy as np\n", + "import paddle\n", + "import paddle.nn as nn\n", + "import paddle.nn.functional as F\n", + "from paddle.io import Dataset, DataLoader\n", + "from paddle.vision.models import resnet18\n", + "\n", + "# ====================== 工具:正弦位置编码 ======================\n", + "class SinusoidalPositionalEncoding(nn.Layer):\n", + " def __init__(self, d_model: int, max_len: int = 2048):\n", + " super().__init__()\n", + " pe = np.zeros((max_len, d_model), dtype=\"float32\")\n", + " position = np.arange(0, max_len, dtype=\"float32\")[:, None]\n", + " div_term = np.exp(np.arange(0, d_model, 2, dtype=\"float32\") * (-math.log(10000.0) / d_model))\n", + " pe[:, 0::2] = np.sin(position * div_term)\n", + " pe[:, 1::2] = np.cos(position * div_term)\n", + " self.register_buffer(\"pe\", paddle.to_tensor(pe), persistable=False)\n", + "\n", + " def forward(self, x): # x: (B, T, D)\n", + " T = x.shape[1]\n", + " return x + self.pe[:T, :]\n", + "\n", + "# ====================== 简化版 TabM(占位,可换成你的实现) ======================\n", + "class TabMFeatureExtractor(nn.Layer):\n", + " \"\"\"占位实现:MLP → (B, H)。可直接替换为你修好的 TabM。\"\"\"\n", + " def __init__(self, num_features: int, d_hidden: int = 512, dropout: float = 0.1):\n", + " super().__init__()\n", + " self.net = nn.Sequential(\n", + " nn.Linear(num_features, d_hidden),\n", + " nn.ReLU(),\n", + " nn.Dropout(dropout),\n", + " nn.Linear(d_hidden, d_hidden),\n", + " nn.ReLU(),\n", + " )\n", + " self.d_hidden = d_hidden\n", + "\n", + " def forward(self, x_num: paddle.Tensor): # (B, 424)\n", + " return self.net(x_num) # (B, H)\n", + "\n", + "# ====================== ResNet18 特征抽取(逐帧) ======================\n", + "class ResNet18FrameEncoder(nn.Layer):\n", + " \"\"\"将 ResNet18 改为 20 通道输入;输出每帧 512 维特征。\"\"\"\n", + " def __init__(self, in_channels: int = 20):\n", + " super().__init__()\n", + " self.backbone = resnet18(pretrained=False)\n", + " # 改首层卷积为 20 通道\n", + " self.backbone.conv1 = nn.Conv2D(in_channels, 64, kernel_size=7, stride=2, padding=3, bias_attr=False)\n", + " # 去掉分类头 fc,保留到 avgpool\n", + " self.avgpool = self.backbone.avgpool # AdaptiveAvgPool2D(1)\n", + " self.out_dim = 512\n", + "\n", + " def forward(self, x): # x: (B*T, C=20, H=20, W=20)\n", + " m = self.backbone\n", + " x = m.conv1(x); x = m.bn1(x); x = F.relu(x); x = m.maxpool(x)\n", + " x = m.layer1(x); x = m.layer2(x); x = m.layer3(x); x = m.layer4(x)\n", + " x = self.avgpool(x) # (B*T, 512, 1, 1)\n", + " x = paddle.flatten(x, 1) # (B*T, 512)\n", + " return x\n", + "\n", + "# ====================== MoE 基础实现(Top-k,可开关;使用 gather_nd 修复) ======================\n", + "class ExpertFFN(nn.Layer):\n", + " def __init__(self, d_model, d_ff, dropout=0.1, act='relu'):\n", + " super().__init__()\n", + " Act = getattr(F, act) if isinstance(act, str) else act\n", + " self.fc1 = nn.Linear(d_model, d_ff)\n", + " self.fc2 = nn.Linear(d_ff, d_model)\n", + " self.drop = nn.Dropout(dropout)\n", + " self.act = Act\n", + " def forward(self, x):\n", + " return self.fc2(self.drop(self.act(self.fc1(x))))\n", + "\n", + "class MoEConfig:\n", + " def __init__(self,\n", + " n_experts=8,\n", + " top_k=1,\n", + " d_ff=2048,\n", + " dropout=0.1,\n", + " router_temp=0.5,\n", + " balance_loss_w=0.005,\n", + " entropy_reg_w=-0.005, # 负值→更尖锐\n", + " diversity_w=1e-3,\n", + " sticky_w=0.0,\n", + " sup_router_w=0.0,\n", + " use_gumbel=True):\n", + " self.n_experts = n_experts\n", + " self.top_k = top_k\n", + " self.d_ff = d_ff\n", + " self.dropout = dropout\n", + " self.router_temp = router_temp\n", + " self.balance_loss_w = balance_loss_w\n", + " self.entropy_reg_w = entropy_reg_w\n", + " self.diversity_w = diversity_w\n", + " self.sticky_w = sticky_w\n", + " self.sup_router_w = sup_router_w\n", + " self.use_gumbel = use_gumbel\n", + "\n", + "class MoE(nn.Layer):\n", + " \"\"\"forward(x, domain_id=None) → (y, aux_loss),支持 (B,T,D) 或 (N,D)\"\"\"\n", + " def __init__(self, d_model: int, cfg: MoEConfig):\n", + " super().__init__()\n", + " self.cfg = cfg\n", + " self.router = nn.Linear(d_model, cfg.n_experts)\n", + " self.experts = nn.LayerList([ExpertFFN(d_model, cfg.d_ff, cfg.dropout) for _ in range(cfg.n_experts)])\n", + " self.ln = nn.LayerNorm(d_model)\n", + " self.drop = nn.Dropout(cfg.dropout)\n", + "\n", + " def _router_probs(self, logits):\n", + " if self.cfg.use_gumbel and self.training:\n", + " u = paddle.uniform(logits.shape, min=1e-6, max=1-1e-6, dtype=logits.dtype)\n", + " g = -paddle.log(-paddle.log(u))\n", + " logits = logits + g\n", + " return F.softmax(logits / self.cfg.router_temp, axis=-1)\n", + "\n", + " def forward(self, x, domain_id=None):\n", + " orig_shape = x.shape\n", + " if len(orig_shape) == 3:\n", + " B, T, D = orig_shape\n", + " X = x.reshape([B*T, D])\n", + " else:\n", + " X = x\n", + " N, D = X.shape\n", + "\n", + " logits = self.router(X) # (N,E)\n", + " probs = self._router_probs(logits) # (N,E)\n", + " topk_val, topk_idx = paddle.topk(probs, k=self.cfg.top_k, axis=-1) # (N,k)\n", + "\n", + " # 专家并行输出\n", + " all_out = paddle.stack([e(X) for e in self.experts], axis=1) # (N,E,D)\n", + "\n", + " # === 使用 gather_nd 逐样本选择 top-k 专家 ===\n", + " arangeN = paddle.arange(N, dtype='int64')\n", + " picked_list = []\n", + " for i in range(self.cfg.top_k):\n", + " idx_i = topk_idx[:, i].astype('int64') # (N,)\n", + " idx_nd = paddle.stack([arangeN, idx_i], axis=1) # (N,2) [sample, expert]\n", + " picked_i = paddle.gather_nd(all_out, idx_nd) # (N,D)\n", + " picked_list.append(picked_i)\n", + " picked = paddle.stack(picked_list, axis=1) # (N,k,D)\n", + "\n", + " # 归一化权重并加权\n", + " w = topk_val / (paddle.sum(topk_val, axis=-1, keepdim=True) + 1e-9) # (N,k)\n", + " Y = paddle.sum(picked * w.unsqueeze(-1), axis=1) # (N,D)\n", + "\n", + " Y = self.drop(Y)\n", + " Y = self.ln(Y + X)\n", + "\n", + " # aux loss\n", + " aux = 0.0\n", + " if self.cfg.balance_loss_w > 0:\n", + " mean_prob = probs.mean(axis=0)\n", + " target = paddle.full_like(mean_prob, 1.0 / self.cfg.n_experts)\n", + " aux = aux + self.cfg.balance_loss_w * F.mse_loss(mean_prob, target)\n", + " if self.cfg.entropy_reg_w != 0.0:\n", + " ent = -paddle.sum(probs * (paddle.log(probs + 1e-9)), axis=1).mean()\n", + " aux = aux + self.cfg.entropy_reg_w * ent\n", + " if (domain_id is not None) and (self.cfg.sup_router_w > 0):\n", + " dom = domain_id.reshape([-1])[:N] % self.cfg.n_experts\n", + " aux = aux + self.cfg.sup_router_w * F.cross_entropy(logits, dom)\n", + " if self.cfg.diversity_w > 0 and self.cfg.n_experts > 1:\n", + " # 用 top-1 硬选择近似每个专家接收的样本\n", + " chosen = F.one_hot(topk_idx[:, 0], num_classes=self.cfg.n_experts).astype('float32') # (N,E)\n", + " denom = chosen.sum(axis=0).clip(min=1.0).unsqueeze(-1)\n", + " means = (all_out * chosen.unsqueeze(-1)).sum(axis=0) / denom # (E,D)\n", + " sims = []\n", + " for i in range(self.cfg.n_experts):\n", + " for j in range(i+1, self.cfg.n_experts):\n", + " si = F.normalize(means[i:i+1], axis=-1)\n", + " sj = F.normalize(means[j:j+1], axis=-1)\n", + " sims.append((si*sj).sum())\n", + " if sims:\n", + " aux = aux + self.cfg.diversity_w * paddle.stack(sims).mean()\n", + "\n", + " if len(orig_shape) == 3:\n", + " Y = Y.reshape([B, T, D])\n", + " return Y, aux\n", + "\n", + "class MoEHead(nn.Layer):\n", + " \"\"\"单 token MoE 头,用于 fused/tabm 投影后的 (B, D)\"\"\"\n", + " def __init__(self, d_model=512, cfg: MoEConfig = None):\n", + " super().__init__()\n", + " self.moe = MoE(d_model, cfg or MoEConfig())\n", + " def forward(self, tok, domain_id=None):\n", + " y, aux = self.moe(tok.unsqueeze(1), domain_id=domain_id) # (B,1,D)\n", + " return y.squeeze(1), aux\n", + "\n", + "# ====================== 自定义 Transformer Encoder(FFN 可替换为 MoE) ======================\n", + "class TransformerEncoderLayerMoE(nn.Layer):\n", + " def __init__(self, d_model=512, nhead=8, d_ff=1024, dropout=0.1,\n", + " use_moe: bool = True, moe_cfg: MoEConfig = None):\n", + " super().__init__()\n", + " self.use_moe = use_moe\n", + " self.self_attn = nn.MultiHeadAttention(embed_dim=d_model, num_heads=nhead, dropout=dropout)\n", + " self.ln1 = nn.LayerNorm(d_model)\n", + " self.do1 = nn.Dropout(dropout)\n", + " if use_moe:\n", + " self.moe = MoE(d_model, moe_cfg or MoEConfig(d_ff=d_ff, dropout=dropout))\n", + " else:\n", + " self.ffn = nn.Sequential(\n", + " nn.LayerNorm(d_model),\n", + " nn.Linear(d_model, d_ff),\n", + " nn.ReLU(),\n", + " nn.Dropout(dropout),\n", + " nn.Linear(d_ff, d_model),\n", + " )\n", + " self.do2 = nn.Dropout(dropout)\n", + "\n", + " def forward(self, x, domain_id=None): # x: (B,T,D)\n", + " # Self-Attention (pre-norm) —— Paddle MHA 期望 (T,B,D)\n", + " h = self.ln1(x)\n", + " h = paddle.transpose(h, [1, 0, 2]) # (T,B,D)\n", + " sa = self.self_attn(h, h, h) # (T,B,D)\n", + " sa = paddle.transpose(sa, [1, 0, 2]) # (B,T,D)\n", + " x = x + self.do1(sa)\n", + " aux = 0.0\n", + " if self.use_moe:\n", + " x, aux = self.moe(x, domain_id=domain_id) # 残差+LN 在 MoE 内部\n", + " else:\n", + " x = x + self.do2(self.ffn(x)) # 残差在这里\n", + " return x, aux\n", + "\n", + "class TemporalTransformerFlexible(nn.Layer):\n", + " def __init__(self, d_model=512, nhead=8, num_layers=2, d_ff=1024, dropout=0.1,\n", + " max_len=1024, use_moe: bool = True, moe_cfg: MoEConfig = None):\n", + " super().__init__()\n", + " self.pos = SinusoidalPositionalEncoding(d_model, max_len=max_len)\n", + " self.layers = nn.LayerList([\n", + " TransformerEncoderLayerMoE(d_model, nhead, d_ff, dropout,\n", + " use_moe=use_moe, moe_cfg=moe_cfg)\n", + " for _ in range(num_layers)\n", + " ])\n", + " def forward(self, x, domain_id=None): # x: (B,T,D)\n", + " x = self.pos(x)\n", + " aux_total = 0.0\n", + " for layer in self.layers:\n", + " x, aux = layer(x, domain_id=domain_id)\n", + " aux_total = aux_total + aux\n", + " return x, aux_total\n", + "\n", + "# ====================== 多头注意力(支持 q from A, kv from B) ======================\n", + "class MultiHeadCrossAttention(nn.Layer):\n", + " def __init__(self, d_model: int, nhead: int = 8, dropout: float = 0.1):\n", + " super().__init__()\n", + " assert d_model % nhead == 0\n", + " self.d_model = d_model\n", + " self.nhead = nhead\n", + " self.d_head = d_model // nhead\n", + " self.Wq = nn.Linear(d_model, d_model)\n", + " self.Wk = nn.Linear(d_model, d_model)\n", + " self.Wv = nn.Linear(d_model, d_model)\n", + " self.proj = nn.Linear(d_model, d_model)\n", + " self.drop = nn.Dropout(dropout)\n", + " self.ln = nn.LayerNorm(d_model)\n", + "\n", + " def forward(self, q, kv):\n", + " B, Nq, D = q.shape\n", + " Nk = kv.shape[1]\n", + " q_lin = self.Wq(q); k_lin = self.Wk(kv); v_lin = self.Wv(kv)\n", + " def split_heads(t):\n", + " return t.reshape([B, -1, self.nhead, self.d_head]).transpose([0, 2, 1, 3])\n", + " qh = split_heads(q_lin); kh = split_heads(k_lin); vh = split_heads(v_lin)\n", + " scores = paddle.matmul(qh, kh, transpose_y=True) / math.sqrt(self.d_head)\n", + " attn = F.softmax(scores, axis=-1)\n", + " ctx = paddle.matmul(attn, vh)\n", + " ctx = ctx.transpose([0, 2, 1, 3]).reshape([B, Nq, D])\n", + " out = self.proj(ctx)\n", + " out = self.drop(out)\n", + " return self.ln(out + q)\n", + "\n", + "# ====================== 融合头(双向 Cross-Attn) ======================\n", + "class BiModalCrossFusion(nn.Layer):\n", + " \"\"\"\n", + " 输入:\n", + " video_seq: (B, T, D) —— Transformer 后的视频序列\n", + " tabm_tok: (B, D) —— TabM token\n", + " \"\"\"\n", + " def __init__(self, d_model=512, nhead=8, dropout=0.1, fuse_hidden=512):\n", + " super().__init__()\n", + " self.ca_v_from_t = MultiHeadCrossAttention(d_model, nhead, dropout)\n", + " self.ca_t_from_v = MultiHeadCrossAttention(d_model, nhead, dropout)\n", + " self.fuse = nn.Sequential(\n", + " nn.Linear(2 * d_model, fuse_hidden),\n", + " nn.ReLU(),\n", + " nn.Dropout(dropout),\n", + " )\n", + " self.out_dim = fuse_hidden\n", + "\n", + " def forward(self, video_seq, tabm_tok):\n", + " B, T, D = video_seq.shape\n", + " # 池化视频时间维得到 token\n", + " v_tok = video_seq.mean(axis=1, keepdim=True) # (B,1,D)\n", + " t_tok = tabm_tok.unsqueeze(1) # (B,1,D)\n", + " v_upd = self.ca_v_from_t(v_tok, t_tok) # (B,1,D)\n", + " t_upd = self.ca_t_from_v(t_tok, video_seq) # (B,1,D)\n", + " fused = paddle.concat([v_upd, t_upd], axis=-1) # (B,1,2D)\n", + " fused = fused.squeeze(1) # (B,2D)\n", + " return self.fuse(fused) # (B, F)\n", + "\n", + "# ====================== 总模型(带三个 MoE 开关) ======================\n", + "class TwoModalMultiLabelModel(nn.Layer):\n", + " def __init__(self,\n", + " # 视频模态\n", + " vid_channels=20, vid_h=20, vid_w=20, vid_frames=36,\n", + " # 结构化模态\n", + " vec_dim=424,\n", + " # 维度与结构\n", + " d_model=512, nhead=2, n_trans_layers=2, trans_ff=1024,\n", + " tabm_hidden=512, dropout=0.1, num_labels=4,\n", + " # ===== MoE 开关 =====\n", + " moe_temporal: bool = True, # 时序 Transformer 的 FFN 位置\n", + " moe_fused: bool = False, # 融合 token 上的小型 MoE 头\n", + " moe_tabm: bool = False, # TabM 投影后\n", + " # ===== MoE 超参(可传入自定义) =====\n", + " moe_cfg_temporal: MoEConfig = None,\n", + " moe_cfg_fused: MoEConfig = None,\n", + " moe_cfg_tabm: MoEConfig = None):\n", + " super().__init__()\n", + " # A: 逐帧 ResNet18\n", + " self.frame_encoder = ResNet18FrameEncoder(in_channels=vid_channels)\n", + " # A: 时序 Transformer(可开/关 MoE)\n", + " self.temporal = TemporalTransformerFlexible(\n", + " d_model=d_model, nhead=nhead, num_layers=n_trans_layers,\n", + " d_ff=trans_ff, dropout=dropout, max_len=vid_frames,\n", + " use_moe=moe_temporal,\n", + " moe_cfg=moe_cfg_temporal or MoEConfig(\n", + " n_experts=8, top_k=1, d_ff=max(trans_ff, 2048), router_temp=0.5,\n", + " balance_loss_w=0.005, entropy_reg_w=-0.005, diversity_w=1e-3\n", + " )\n", + " )\n", + " # B: TabM(或你的 TabM)\n", + " self.tabm = TabMFeatureExtractor(vec_dim, d_hidden=tabm_hidden, dropout=dropout)\n", + " self.tabm_proj = nn.Linear(tabm_hidden, d_model)\n", + "\n", + " # 可选:TabM 分支 MoE 头\n", + " self.moe_tabm = moe_tabm\n", + " if moe_tabm:\n", + " self.tabm_moe = MoEHead(d_model=d_model, cfg=moe_cfg_tabm or MoEConfig(\n", + " n_experts=6, top_k=1, d_ff=1024, router_temp=0.5,\n", + " balance_loss_w=0.005, entropy_reg_w=-0.005, diversity_w=1e-3\n", + " ))\n", + "\n", + " # 融合:双向 Cross-Attention\n", + " self.fusion = BiModalCrossFusion(d_model=d_model, nhead=nhead, dropout=dropout, fuse_hidden=d_model)\n", + "\n", + " # 可选:融合 token MoE 头\n", + " self.moe_fused = moe_fused\n", + " if moe_fused:\n", + " self.fused_moe = MoEHead(d_model=d_model, cfg=moe_cfg_fused or MoEConfig(\n", + " n_experts=6, top_k=1, d_ff=1024, router_temp=0.5,\n", + " balance_loss_w=0.005, entropy_reg_w=-0.005, diversity_w=1e-3\n", + " ))\n", + "\n", + " # 分类头\n", + " self.head = nn.Linear(self.fusion.out_dim, num_labels)\n", + "\n", + " def forward(self, x_video, x_vec, domain_id=None):\n", + " \"\"\"\n", + " x_video: (B, T, C=20, H=20, W=20)\n", + " x_vec: (B, 424)\n", + " domain_id: (B,) 或 None —— 若有域/季节/站点标签,可传入以做监督路由(可选)\n", + " \"\"\"\n", + " B, T, C, H, W = x_video.shape\n", + " # ---- A: 逐帧 ResNet ----\n", + " xvt = x_video.reshape([B * T, C, H, W]) # (B*T, C, H, W)\n", + " f_frame = self.frame_encoder(xvt) # (B*T, 512)\n", + " f_seq = f_frame.reshape([B, T, -1]) # (B, T, 512)\n", + " # ---- A: 时序 Transformer (可含 MoE) ----\n", + " z_vid, aux_total = self.temporal(f_seq, domain_id=domain_id) # (B,T,512), aux\n", + "\n", + " # ---- B: TabM 特征 ----\n", + " z_tabm = self.tabm(x_vec) # (B, H_tabm)\n", + " z_tabm = self.tabm_proj(z_tabm) # (B, 512)\n", + " if self.moe_tabm:\n", + " z_tabm, aux_t = self.tabm_moe(z_tabm, domain_id=domain_id) # (B,512)\n", + " aux_total = aux_total + aux_t\n", + "\n", + " # ---- Cross-Attention 融合 ----\n", + " fused = self.fusion(z_vid, z_tabm) # (B, 512)\n", + "\n", + " # ---- 融合 MoE 头(可选) ----\n", + " if self.moe_fused:\n", + " fused, aux_f = self.fused_moe(fused, domain_id=domain_id) # (B,512)\n", + " aux_total = aux_total + aux_f\n", + "\n", + " # ---- 分类 ----\n", + " logits = self.head(fused) # (B, 4)\n", + " return logits, aux_total\n", + "\n", + "# ====================== 指标与训练循环(兼容 aux_loss) ======================\n", + "def f1_per_class(y_true: np.ndarray, y_pred: np.ndarray, eps: float = 1e-9) -> Tuple[np.ndarray, float, float]:\n", + " assert y_true.shape == y_pred.shape\n", + " N, C = y_true.shape\n", + " f1_c = np.zeros(C, dtype=np.float32)\n", + " for c in range(C):\n", + " yt, yp = y_true[:, c], y_pred[:, c]\n", + " tp = np.sum((yt == 1) & (yp == 1))\n", + " fp = np.sum((yt == 0) & (yp == 1))\n", + " fn = np.sum((yt == 1) & (yp == 0))\n", + " prec = tp / (tp + fp + eps)\n", + " rec = tp / (tp + fn + eps)\n", + " f1_c[c] = 2 * prec * rec / (prec + rec + eps)\n", + " macro_f1 = float(np.mean(f1_c))\n", + " tp = np.sum((y_true == 1) & (y_pred == 1))\n", + " fp = np.sum((y_true == 0) & (y_pred == 1))\n", + " fn = np.sum((y_true == 1) & (y_pred == 0))\n", + " prec = tp / (tp + fp + 1e-9)\n", + " rec = tp / (tp + fn + 1e-9)\n", + " micro_f1 = 2 * prec * rec / (prec + rec + 1e-9)\n", + " return f1_c, macro_f1, float(micro_f1)\n", + "\n", + "def average_precision_micro(y_true: np.ndarray, y_prob: np.ndarray, num_thresholds: int = 101) -> float:\n", + " thresholds = np.linspace(0.0, 1.0, num_thresholds)\n", + " precision, recall = [], []\n", + " for t in thresholds:\n", + " y_pred = (y_prob >= t).astype(np.float32)\n", + " tp = np.sum((y_true == 1) & (y_pred == 1))\n", + " fp = np.sum((y_true == 0) & (y_pred == 1))\n", + " fn = np.sum((y_true == 1) & (y_pred == 0))\n", + " p = tp / (tp + fp + 1e-9)\n", + " r = tp / (tp + fn + 1e-9)\n", + " precision.append(p); recall.append(r)\n", + " order = np.argsort(recall)\n", + " recall = np.array(recall)[order]\n", + " precision = np.array(precision)[order]\n", + " return float(np.trapz(precision, recall))\n", + "\n", + "LAMBDA_MOE = 0.01 # MoE 辅助损失系数\n", + "\n", + "def train_one_epoch(model, loader, optimizer,\n", + " pos_weight: Optional[paddle.Tensor] = None,\n", + " clip_grad_norm: Optional[float] = None):\n", + " model.train()\n", + " total_loss, total_batches = 0.0, 0\n", + " for x_vid, x_vec, y in loader:\n", + " logits, aux = model(x_vid.astype('float32'), x_vec.astype('float32')) # ← 接收 aux\n", + " if pos_weight is not None:\n", + " cls = F.binary_cross_entropy_with_logits(logits, y.astype('float32'), pos_weight=pos_weight)\n", + " else:\n", + " cls = F.binary_cross_entropy_with_logits(logits, y.astype('float32'))\n", + " loss = cls + LAMBDA_MOE * aux\n", + " loss.backward()\n", + " if clip_grad_norm is not None:\n", + " nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad_norm)\n", + " optimizer.step()\n", + " optimizer.clear_grad()\n", + " total_loss += float(loss); total_batches += 1\n", + " return total_loss / max(1, total_batches)\n", + "\n", + "@paddle.no_grad()\n", + "def evaluate(model, loader, threshold: float = 0.5):\n", + " model.eval()\n", + " ys, ps = [], []\n", + " total_loss, total_batches = 0.0, 0\n", + " for x_vid, x_vec, y in loader:\n", + " logits, aux = model(x_vid.astype('float32'), x_vec.astype('float32')) # ← 接收 aux\n", + " cls = F.binary_cross_entropy_with_logits(logits, y.astype('float32'))\n", + " loss = cls + LAMBDA_MOE * aux\n", + " prob = F.sigmoid(logits).numpy()\n", + " ys.append(y.numpy()); ps.append(prob)\n", + " total_loss += float(loss); total_batches += 1\n", + " y_true = np.concatenate(ys, axis=0)\n", + " y_prob = np.concatenate(ps, axis=0)\n", + " y_pred = (y_prob >= threshold).astype(np.float32)\n", + " per_f1, macro_f1, micro_f1 = f1_per_class(y_true, y_pred)\n", + " ap_micro = average_precision_micro(y_true, y_prob)\n", + " return {\n", + " \"loss\": total_loss / max(1, total_batches),\n", + " \"macro_f1\": macro_f1,\n", + " \"micro_f1\": micro_f1,\n", + " \"per_class_f1\": per_f1.tolist(),\n", + " \"micro_AP\": ap_micro\n", + " }\n", + "\n", + "# ====================== 合成数据集(可替换为真实数据) ======================\n", + "class ToyTwoModalDataset(Dataset):\n", + " \"\"\"\n", + " 返回:\n", + " x_video: (T=36, C=20, H=20, W=20)\n", + " x_vec: (424,)\n", + " y: (4,) 0/1\n", + " \"\"\"\n", + " def __init__(self, n: int, seed: int = 0):\n", + " super().__init__()\n", + " rng = np.random.default_rng(seed)\n", + " self.n = n\n", + " # 按 (n, T, C, H, W)\n", + " self.video = rng.normal(size=(n, 36, 20, 20, 20)).astype('float32')\n", + " self.vec = rng.normal(size=(n, 424)).astype('float32')\n", + "\n", + " # 造标签:对视频先在 H/W 上均值,再在 T 上均值 → (n, C=20)\n", + " vid_hw = self.video.mean(axis=(3, 4)) # (n, T, C)\n", + " vid_avg = vid_hw.mean(axis=1) # (n, C)\n", + "\n", + " # 线性映射到 4 个标签\n", + " Wv = rng.normal(size=(20, 4)) # C→4\n", + " Wt = rng.normal(size=(424, 4)) # 424→4\n", + " logits = vid_avg @ Wv + self.vec @ Wt + rng.normal(scale=0.5, size=(n, 4))\n", + " probs = 1.0 / (1.0 + np.exp(-logits))\n", + " self.y = (probs > 0.5).astype('float32')\n", + "\n", + " def __getitem__(self, idx: int):\n", + " x_vid = self.video[idx] # (T,C,H,W)\n", + " x_vec = self.vec[idx] # (424,)\n", + " y = self.y[idx] # (4,)\n", + " return x_vid, x_vec, y\n", + "\n", + " def __len__(self):\n", + " return self.n\n", + "\n", + "# ====================== 训练入口(可直接运行) ======================\n", + "if __name__ == \"__main__\":\n", + " paddle.seed(2025)\n", + " # 数据\n", + " train_ds = ToyTwoModalDataset(n=64, seed=42)\n", + " val_ds = ToyTwoModalDataset(n=32, seed=233)\n", + " def collate_fn(batch):\n", + " vids, vecs, ys = zip(*batch)\n", + " return (paddle.to_tensor(np.stack(vids, 0)), # (B,T,C,H,W)\n", + " paddle.to_tensor(np.stack(vecs, 0)), # (B,424)\n", + " paddle.to_tensor(np.stack(ys, 0))) # (B,4)\n", + " train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, drop_last=False, collate_fn=collate_fn)\n", + " val_loader = DataLoader(val_ds, batch_size=2, shuffle=False, drop_last=False, collate_fn=collate_fn)\n", + "\n", + " # 类别不平衡权重(可选)\n", + " y_train = np.stack([y for _, _, y in train_ds], 0)\n", + " pos_ratio = np.clip(y_train.mean(axis=0), 1e-3, 1-1e-3)\n", + " pos_weight = paddle.to_tensor(((1-pos_ratio)/pos_ratio).astype('float32')) # (4,)\n", + "\n", + " # === 构建模型:三处 MoE 开关(默认只开时序 MoE) ===\n", + " model = TwoModalMultiLabelModel(\n", + " vid_channels=20, vid_h=20, vid_w=20, vid_frames=36,\n", + " vec_dim=424,\n", + " d_model=512, nhead=2, n_trans_layers=2, trans_ff=1024,\n", + " tabm_hidden=512, dropout=0.1,\n", + " num_labels=4,\n", + " moe_temporal=True, # 开:时序 Transformer 的 FFN 位置\n", + " moe_fused=False, # 关:融合 token MoE 头\n", + " moe_tabm=False # 关:TabM 投影后 MoE\n", + " )\n", + " optimizer = paddle.optimizer.Adam(learning_rate=3e-4, parameters=model.parameters())\n", + "\n", + " # 训练(演示用:小 epoch)\n", + " best_macro_f1, best = -1.0, None\n", + " for ep in range(1, 3+1):\n", + " train_loss = train_one_epoch(model, train_loader, optimizer,\n", + " pos_weight=pos_weight, clip_grad_norm=1.0)\n", + " val_metrics = evaluate(model, val_loader, threshold=0.5)\n", + " print(f\"[Epoch {ep:02d}] train_loss={train_loss:.4f} | \"\n", + " f\"val_loss={val_metrics['loss']:.4f} | \"\n", + " f\"macro_f1={val_metrics['macro_f1']:.4f} | \"\n", + " f\"micro_f1={val_metrics['micro_f1']:.4f} | \"\n", + " f\"per_class_f1={val_metrics['per_class_f1']} | \"\n", + " f\"micro_AP={val_metrics['micro_AP']:.4f}\")\n", + " if val_metrics[\"macro_f1\"] > best_macro_f1:\n", + " best_macro_f1 = val_metrics[\"macro_f1\"]\n", + " best = {k: v.clone() for k, v in model.state_dict().items()}\n", + "\n", + " if best is not None:\n", + " model.set_state_dict(best)\n", + " print(f\"Loaded best state with macro_f1={best_macro_f1:.4f}\")\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "K5On1aO5QtFw", + "outputId": "aeecef88-dd57-4133-8156-f3dd9b6e0c90" + }, + "execution_count": 16, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/tmp/ipython-input-3896222078.py:427: DeprecationWarning: `trapz` is deprecated. Use `trapezoid` instead, or one of the numerical integration functions in `scipy.integrate`.\n", + " return float(np.trapz(precision, recall))\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[Epoch 01] train_loss=1.0045 | val_loss=0.8427 | macro_f1=0.1364 | micro_f1=0.2609 | per_class_f1=[0.0, 0.0, 0.0, 0.5454545617103577] | micro_AP=0.4593\n", + "[Epoch 02] train_loss=0.6558 | val_loss=1.0130 | macro_f1=0.4563 | micro_f1=0.5342 | per_class_f1=[0.5945945978164673, 0.0, 0.6938775777816772, 0.5365853905677795] | micro_AP=0.5376\n", + "[Epoch 03] train_loss=0.2265 | val_loss=1.3452 | macro_f1=0.4934 | micro_f1=0.5625 | per_class_f1=[0.4000000059604645, 0.7234042286872864, 0.6000000238418579, 0.25] | micro_AP=0.5076\n", + "Loaded best state with macro_f1=0.4934\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# -*- coding: utf-8 -*-\n", + "import math\n", + "from typing import Optional, Tuple\n", + "import numpy as np\n", + "import paddle\n", + "import paddle.nn as nn\n", + "import paddle.nn.functional as F\n", + "from paddle.io import Dataset, DataLoader\n", + "from paddle.vision.models import resnet18\n", + "\n", + "# ====================== 工具:正弦位置编码 ======================\n", + "class SinusoidalPositionalEncoding(nn.Layer):\n", + " def __init__(self, d_model: int, max_len: int = 2048):\n", + " super().__init__()\n", + " pe = np.zeros((max_len, d_model), dtype=\"float32\")\n", + " position = np.arange(0, max_len, dtype=\"float32\")[:, None]\n", + " div_term = np.exp(np.arange(0, d_model, 2, dtype=\"float32\") * (-math.log(10000.0) / d_model))\n", + " pe[:, 0::2] = np.sin(position * div_term)\n", + " pe[:, 1::2] = np.cos(position * div_term)\n", + " self.register_buffer(\"pe\", paddle.to_tensor(pe), persistable=False)\n", + "\n", + " def forward(self, x): # x: (B, T, D)\n", + " T = x.shape[1]\n", + " return x + self.pe[:T, :]\n", + "\n", + "# ====================== 简化版 TabM(占位,可换成你的实现) ======================\n", + "class TabMFeatureExtractor(nn.Layer):\n", + " \"\"\"占位实现:MLP → (B, H)。可直接替换为你修好的 TabM。\"\"\"\n", + " def __init__(self, num_features: int, d_hidden: int = 512, dropout: float = 0.1):\n", + " super().__init__()\n", + " self.net = nn.Sequential(\n", + " nn.Linear(num_features, d_hidden),\n", + " nn.ReLU(),\n", + " nn.Dropout(dropout),\n", + " nn.Linear(d_hidden, d_hidden),\n", + " nn.ReLU(),\n", + " )\n", + " self.d_hidden = d_hidden\n", + "\n", + " def forward(self, x_num: paddle.Tensor): # (B, 424)\n", + " return self.net(x_num) # (B, H)\n", + "\n", + "# ====================== ResNet18 特征抽取(逐帧) ======================\n", + "class ResNet18FrameEncoder(nn.Layer):\n", + " \"\"\"将 ResNet18 改为 20 通道输入;输出每帧 512 维特征。\"\"\"\n", + " def __init__(self, in_channels: int = 20):\n", + " super().__init__()\n", + " self.backbone = resnet18(pretrained=False)\n", + " # 改首层卷积为 20 通道\n", + " self.backbone.conv1 = nn.Conv2D(in_channels, 64, kernel_size=7, stride=2, padding=3, bias_attr=False)\n", + " # 去掉分类头 fc,保留到 avgpool\n", + " self.avgpool = self.backbone.avgpool # AdaptiveAvgPool2D(1)\n", + " self.out_dim = 512\n", + "\n", + " def forward(self, x): # x: (B*T, C=20, H=20, W=20)\n", + " m = self.backbone\n", + " x = m.conv1(x); x = m.bn1(x); x = F.relu(x); x = m.maxpool(x)\n", + " x = m.layer1(x); x = m.layer2(x); x = m.layer3(x); x = m.layer4(x)\n", + " x = self.avgpool(x) # (B*T, 512, 1, 1)\n", + " x = paddle.flatten(x, 1) # (B*T, 512)\n", + " return x\n", + "\n", + "# ====================== MoE 基础实现(Top-k,可开关;使用 gather_nd 修复) ======================\n", + "class ExpertFFN(nn.Layer):\n", + " def __init__(self, d_model, d_ff, dropout=0.1, act='relu'):\n", + " super().__init__()\n", + " Act = getattr(F, act) if isinstance(act, str) else act\n", + " self.fc1 = nn.Linear(d_model, d_ff)\n", + " self.fc2 = nn.Linear(d_ff, d_model)\n", + " self.drop = nn.Dropout(dropout)\n", + " self.act = Act\n", + " def forward(self, x):\n", + " return self.fc2(self.drop(self.act(self.fc1(x))))\n", + "\n", + "class MoEConfig:\n", + " def __init__(self,\n", + " n_experts=8,\n", + " top_k=1,\n", + " d_ff=2048,\n", + " dropout=0.1,\n", + " router_temp=0.5,\n", + " balance_loss_w=0.005,\n", + " entropy_reg_w=-0.005, # 负值→更尖锐\n", + " diversity_w=1e-3,\n", + " sticky_w=0.0,\n", + " sup_router_w=0.0,\n", + " use_gumbel=True):\n", + " self.n_experts = n_experts\n", + " self.top_k = top_k\n", + " self.d_ff = d_ff\n", + " self.dropout = dropout\n", + " self.router_temp = router_temp\n", + " self.balance_loss_w = balance_loss_w\n", + " self.entropy_reg_w = entropy_reg_w\n", + " self.diversity_w = diversity_w\n", + " self.sticky_w = sticky_w\n", + " self.sup_router_w = sup_router_w\n", + " self.use_gumbel = use_gumbel\n", + "\n", + "class MoE(nn.Layer):\n", + " \"\"\"forward(x, domain_id=None) → (y, aux_loss),支持 (B,T,D) 或 (N,D)\"\"\"\n", + " def __init__(self, d_model: int, cfg: MoEConfig):\n", + " super().__init__()\n", + " self.cfg = cfg\n", + " self.router = nn.Linear(d_model, cfg.n_experts)\n", + " self.experts = nn.LayerList([ExpertFFN(d_model, cfg.d_ff, cfg.dropout) for _ in range(cfg.n_experts)])\n", + " self.ln = nn.LayerNorm(d_model)\n", + " self.drop = nn.Dropout(cfg.dropout)\n", + "\n", + " def _router_probs(self, logits):\n", + " if self.cfg.use_gumbel and self.training:\n", + " u = paddle.uniform(logits.shape, min=1e-6, max=1-1e-6, dtype=logits.dtype)\n", + " g = -paddle.log(-paddle.log(u))\n", + " logits = logits + g\n", + " return F.softmax(logits / self.cfg.router_temp, axis=-1)\n", + "\n", + " def forward(self, x, domain_id=None):\n", + " orig_shape = x.shape\n", + " if len(orig_shape) == 3:\n", + " B, T, D = orig_shape\n", + " X = x.reshape([B*T, D])\n", + " else:\n", + " X = x\n", + " N, D = X.shape\n", + "\n", + " logits = self.router(X) # (N,E)\n", + " probs = self._router_probs(logits) # (N,E)\n", + " topk_val, topk_idx = paddle.topk(probs, k=self.cfg.top_k, axis=-1) # (N,k)\n", + "\n", + " # 专家并行输出\n", + " all_out = paddle.stack([e(X) for e in self.experts], axis=1) # (N,E,D)\n", + "\n", + " # === 使用 gather_nd 逐样本选择 top-k 专家 ===\n", + " arangeN = paddle.arange(N, dtype='int64')\n", + " picked_list = []\n", + " for i in range(self.cfg.top_k):\n", + " idx_i = topk_idx[:, i].astype('int64') # (N,)\n", + " idx_nd = paddle.stack([arangeN, idx_i], axis=1) # (N,2) [sample, expert]\n", + " picked_i = paddle.gather_nd(all_out, idx_nd) # (N,D)\n", + " picked_list.append(picked_i)\n", + " picked = paddle.stack(picked_list, axis=1) # (N,k,D)\n", + "\n", + " # 归一化权重并加权\n", + " w = topk_val / (paddle.sum(topk_val, axis=-1, keepdim=True) + 1e-9) # (N,k)\n", + " Y = paddle.sum(picked * w.unsqueeze(-1), axis=1) # (N,D)\n", + "\n", + " Y = self.drop(Y)\n", + " Y = self.ln(Y + X)\n", + "\n", + " # aux loss\n", + " aux = 0.0\n", + " if self.cfg.balance_loss_w > 0:\n", + " mean_prob = probs.mean(axis=0)\n", + " target = paddle.full_like(mean_prob, 1.0 / self.cfg.n_experts)\n", + " aux = aux + self.cfg.balance_loss_w * F.mse_loss(mean_prob, target)\n", + " if self.cfg.entropy_reg_w != 0.0:\n", + " ent = -paddle.sum(probs * (paddle.log(probs + 1e-9)), axis=1).mean()\n", + " aux = aux + self.cfg.entropy_reg_w * ent\n", + " if (domain_id is not None) and (self.cfg.sup_router_w > 0):\n", + " dom = domain_id.reshape([-1])[:N] % self.cfg.n_experts\n", + " aux = aux + self.cfg.sup_router_w * F.cross_entropy(logits, dom)\n", + " if self.cfg.diversity_w > 0 and self.cfg.n_experts > 1:\n", + " # 用 top-1 硬选择近似每个专家接收的样本\n", + " chosen = F.one_hot(topk_idx[:, 0], num_classes=self.cfg.n_experts).astype('float32') # (N,E)\n", + " denom = chosen.sum(axis=0).clip(min=1.0).unsqueeze(-1)\n", + " means = (all_out * chosen.unsqueeze(-1)).sum(axis=0) / denom # (E,D)\n", + " sims = []\n", + " for i in range(self.cfg.n_experts):\n", + " for j in range(i+1, self.cfg.n_experts):\n", + " si = F.normalize(means[i:i+1], axis=-1)\n", + " sj = F.normalize(means[j:j+1], axis=-1)\n", + " sims.append((si*sj).sum())\n", + " if sims:\n", + " aux = aux + self.cfg.diversity_w * paddle.stack(sims).mean()\n", + "\n", + " if len(orig_shape) == 3:\n", + " Y = Y.reshape([B, T, D])\n", + " return Y, aux\n", + "\n", + "class MoEHead(nn.Layer):\n", + " \"\"\"单 token MoE 头,用于 fused/tabm 投影后的 (B, D)\"\"\"\n", + " def __init__(self, d_model=512, cfg: MoEConfig = None):\n", + " super().__init__()\n", + " self.moe = MoE(d_model, cfg or MoEConfig())\n", + " def forward(self, tok, domain_id=None):\n", + " y, aux = self.moe(tok.unsqueeze(1), domain_id=domain_id) # (B,1,D)\n", + " return y.squeeze(1), aux\n", + "\n", + "# ====================== 自定义 Transformer Encoder(FFN 可替换为 MoE) ======================\n", + "class TransformerEncoderLayerMoE(nn.Layer):\n", + " def __init__(self, d_model=512, nhead=8, d_ff=1024, dropout=0.1,\n", + " use_moe: bool = True, moe_cfg: MoEConfig = None):\n", + " super().__init__()\n", + " self.use_moe = use_moe\n", + " self.self_attn = nn.MultiHeadAttention(embed_dim=d_model, num_heads=nhead, dropout=dropout)\n", + " self.ln1 = nn.LayerNorm(d_model)\n", + " self.do1 = nn.Dropout(dropout)\n", + " if use_moe:\n", + " self.moe = MoE(d_model, moe_cfg or MoEConfig(d_ff=d_ff, dropout=dropout))\n", + " else:\n", + " self.ffn = nn.Sequential(\n", + " nn.LayerNorm(d_model),\n", + " nn.Linear(d_model, d_ff),\n", + " nn.ReLU(),\n", + " nn.Dropout(dropout),\n", + " nn.Linear(d_ff, d_model),\n", + " )\n", + " self.do2 = nn.Dropout(dropout)\n", + "\n", + " def forward(self, x, domain_id=None): # x: (B,T,D)\n", + " # Self-Attention (pre-norm) —— Paddle MHA 期望 (T,B,D)\n", + " h = self.ln1(x)\n", + " h = paddle.transpose(h, [1, 0, 2]) # (T,B,D)\n", + " sa = self.self_attn(h, h, h) # (T,B,D)\n", + " sa = paddle.transpose(sa, [1, 0, 2]) # (B,T,D)\n", + " x = x + self.do1(sa)\n", + " aux = 0.0\n", + " if self.use_moe:\n", + " x, aux = self.moe(x, domain_id=domain_id) # 残差+LN 在 MoE 内部\n", + " else:\n", + " x = x + self.do2(self.ffn(x)) # 残差在这里\n", + " return x, aux\n", + "\n", + "class TemporalTransformerFlexible(nn.Layer):\n", + " def __init__(self, d_model=512, nhead=8, num_layers=2, d_ff=1024, dropout=0.1,\n", + " max_len=1024, use_moe: bool = True, moe_cfg: MoEConfig = None):\n", + " super().__init__()\n", + " self.pos = SinusoidalPositionalEncoding(d_model, max_len=max_len)\n", + " self.layers = nn.LayerList([\n", + " TransformerEncoderLayerMoE(d_model, nhead, d_ff, dropout,\n", + " use_moe=use_moe, moe_cfg=moe_cfg)\n", + " for _ in range(num_layers)\n", + " ])\n", + " def forward(self, x, domain_id=None): # x: (B,T,D)\n", + " x = self.pos(x)\n", + " aux_total = 0.0\n", + " for layer in self.layers:\n", + " x, aux = layer(x, domain_id=domain_id)\n", + " aux_total = aux_total + aux\n", + " return x, aux_total\n", + "\n", + "# ====================== 多头注意力(支持 q from A, kv from B) ======================\n", + "class MultiHeadCrossAttention(nn.Layer):\n", + " def __init__(self, d_model: int, nhead: int = 8, dropout: float = 0.1):\n", + " super().__init__()\n", + " assert d_model % nhead == 0\n", + " self.d_model = d_model\n", + " self.nhead = nhead\n", + " self.d_head = d_model // nhead\n", + " self.Wq = nn.Linear(d_model, d_model)\n", + " self.Wk = nn.Linear(d_model, d_model)\n", + " self.Wv = nn.Linear(d_model, d_model)\n", + " self.proj = nn.Linear(d_model, d_model)\n", + " self.drop = nn.Dropout(dropout)\n", + " self.ln = nn.LayerNorm(d_model)\n", + "\n", + " def forward(self, q, kv):\n", + " B, Nq, D = q.shape\n", + " Nk = kv.shape[1]\n", + " q_lin = self.Wq(q); k_lin = self.Wk(kv); v_lin = self.Wv(kv)\n", + " def split_heads(t):\n", + " return t.reshape([B, -1, self.nhead, self.d_head]).transpose([0, 2, 1, 3])\n", + " qh = split_heads(q_lin); kh = split_heads(k_lin); vh = split_heads(v_lin)\n", + " scores = paddle.matmul(qh, kh, transpose_y=True) / math.sqrt(self.d_head)\n", + " attn = F.softmax(scores, axis=-1)\n", + " ctx = paddle.matmul(attn, vh)\n", + " ctx = ctx.transpose([0, 2, 1, 3]).reshape([B, Nq, D])\n", + " out = self.proj(ctx)\n", + " out = self.drop(out)\n", + " return self.ln(out + q)\n", + "\n", + "# ====================== 融合头(双向 Cross-Attn) ======================\n", + "class BiModalCrossFusion(nn.Layer):\n", + " \"\"\"\n", + " 输入:\n", + " video_seq: (B, T, D) —— Transformer 后的视频序列\n", + " tabm_tok: (B, D) —— TabM token\n", + " \"\"\"\n", + " def __init__(self, d_model=512, nhead=8, dropout=0.1, fuse_hidden=512):\n", + " super().__init__()\n", + " self.ca_v_from_t = MultiHeadCrossAttention(d_model, nhead, dropout)\n", + " self.ca_t_from_v = MultiHeadCrossAttention(d_model, nhead, dropout)\n", + " self.fuse = nn.Sequential(\n", + " nn.Linear(2 * d_model, fuse_hidden),\n", + " nn.ReLU(),\n", + " nn.Dropout(dropout),\n", + " )\n", + " self.out_dim = fuse_hidden\n", + "\n", + " def forward(self, video_seq, tabm_tok):\n", + " B, T, D = video_seq.shape\n", + " # 池化视频时间维得到 token\n", + " v_tok = video_seq.mean(axis=1, keepdim=True) # (B,1,D)\n", + " t_tok = tabm_tok.unsqueeze(1) # (B,1,D)\n", + " v_upd = self.ca_v_from_t(v_tok, t_tok) # (B,1,D)\n", + " t_upd = self.ca_t_from_v(t_tok, video_seq) # (B,1,D)\n", + " fused = paddle.concat([v_upd, t_upd], axis=-1) # (B,1,2D)\n", + " fused = fused.squeeze(1) # (B,2D)\n", + " return self.fuse(fused) # (B, F)\n", + "\n", + "# ====================== 总模型(带三个 MoE 开关) ======================\n", + "class TwoModalMultiLabelModel(nn.Layer):\n", + " def __init__(self,\n", + " # 视频模态\n", + " vid_channels=20, vid_h=20, vid_w=20, vid_frames=36,\n", + " # 结构化模态\n", + " vec_dim=424,\n", + " # 维度与结构\n", + " d_model=512, nhead=2, n_trans_layers=2, trans_ff=1024,\n", + " tabm_hidden=512, dropout=0.1, num_labels=4,\n", + " # ===== MoE 开关 =====\n", + " moe_temporal: bool = True, # 时序 Transformer 的 FFN 位置\n", + " moe_fused: bool = False, # 融合 token 上的小型 MoE 头\n", + " moe_tabm: bool = False, # TabM 投影后\n", + " # ===== MoE 超参(可传入自定义) =====\n", + " moe_cfg_temporal: MoEConfig = None,\n", + " moe_cfg_fused: MoEConfig = None,\n", + " moe_cfg_tabm: MoEConfig = None):\n", + " super().__init__()\n", + " # A: 逐帧 ResNet18\n", + " self.frame_encoder = ResNet18FrameEncoder(in_channels=vid_channels)\n", + " # A: 时序 Transformer(可开/关 MoE)\n", + " self.temporal = TemporalTransformerFlexible(\n", + " d_model=d_model, nhead=nhead, num_layers=n_trans_layers,\n", + " d_ff=trans_ff, dropout=dropout, max_len=vid_frames,\n", + " use_moe=moe_temporal,\n", + " moe_cfg=moe_cfg_temporal or MoEConfig(\n", + " n_experts=8, top_k=1, d_ff=max(trans_ff, 2048), router_temp=0.5,\n", + " balance_loss_w=0.005, entropy_reg_w=-0.005, diversity_w=1e-3\n", + " )\n", + " )\n", + " # B: TabM(或你的 TabM)\n", + " self.tabm = TabMFeatureExtractor(vec_dim, d_hidden=tabm_hidden, dropout=dropout)\n", + " self.tabm_proj = nn.Linear(tabm_hidden, d_model)\n", + "\n", + " # 可选:TabM 分支 MoE 头\n", + " self.moe_tabm = moe_tabm\n", + " if moe_tabm:\n", + " self.tabm_moe = MoEHead(d_model=d_model, cfg=moe_cfg_tabm or MoEConfig(\n", + " n_experts=6, top_k=1, d_ff=1024, router_temp=0.5,\n", + " balance_loss_w=0.005, entropy_reg_w=-0.005, diversity_w=1e-3\n", + " ))\n", + "\n", + " # 融合:双向 Cross-Attention\n", + " self.fusion = BiModalCrossFusion(d_model=d_model, nhead=nhead, dropout=dropout, fuse_hidden=d_model)\n", + "\n", + " # 可选:融合 token MoE 头\n", + " self.moe_fused = moe_fused\n", + " if moe_fused:\n", + " self.fused_moe = MoEHead(d_model=d_model, cfg=moe_cfg_fused or MoEConfig(\n", + " n_experts=6, top_k=1, d_ff=1024, router_temp=0.5,\n", + " balance_loss_w=0.005, entropy_reg_w=-0.005, diversity_w=1e-3\n", + " ))\n", + "\n", + " # 分类头\n", + " self.head = nn.Linear(self.fusion.out_dim, num_labels)\n", + "\n", + " # --- 新增:编码函数,导出融合前的 512 维特征(用于检索库) ---\n", + " def encode(self, x_video, x_vec, domain_id=None):\n", + " \"\"\"返回融合后的 512 维 token(分类头之前的表示),不经过最终 Linear。\"\"\"\n", + " B, T, C, H, W = x_video.shape\n", + " xvt = x_video.reshape([B * T, C, H, W])\n", + " f_frame = self.frame_encoder(xvt) # (B*T, 512)\n", + " f_seq = f_frame.reshape([B, T, -1]) # (B, T, 512)\n", + " z_vid, _ = self.temporal(f_seq, domain_id=domain_id) # (B,T,512)\n", + " z_tabm = self.tabm(x_vec)\n", + " z_tabm = self.tabm_proj(z_tabm) # (B,512)\n", + " if self.moe_tabm:\n", + " z_tabm, _ = self.tabm_moe(z_tabm, domain_id=domain_id)\n", + " fused = self.fusion(z_vid, z_tabm) # (B,512)\n", + " if self.moe_fused:\n", + " fused, _ = self.fused_moe(fused, domain_id=domain_id)\n", + " return fused # (B,512)\n", + "\n", + " def forward(self, x_video, x_vec, domain_id=None):\n", + " fused = self.encode(x_video, x_vec, domain_id=domain_id) # (B,512)\n", + " logits = self.head(fused) # (B,4)\n", + " # 为了兼容旧接口,这里返回的 aux 为 0(MoE 的 aux 已在 temporal/tabm_moe/fused_moe 内部求和并丢弃)\n", + " # 如果你想把 MoE 的 aux 在训练里也加入,可把 encode 拆回 forward 的各步并返回累积 aux。\n", + " aux_placeholder = paddle.to_tensor(0.0, dtype='float32')\n", + " return logits, aux_placeholder\n", + "\n", + "# ====================== 指标与训练循环(兼容 aux_loss) ======================\n", + "def f1_per_class(y_true: np.ndarray, y_pred: np.ndarray, eps: float = 1e-9) -> Tuple[np.ndarray, float, float]:\n", + " assert y_true.shape == y_pred.shape\n", + " N, C = y_true.shape\n", + " f1_c = np.zeros(C, dtype=np.float32)\n", + " for c in range(C):\n", + " yt, yp = y_true[:, c], y_pred[:, c]\n", + " tp = np.sum((yt == 1) & (yp == 1))\n", + " fp = np.sum((yt == 0) & (yp == 1))\n", + " fn = np.sum((yt == 1) & (yp == 0))\n", + " prec = tp / (tp + fp + eps)\n", + " rec = tp / (tp + fn + eps)\n", + " f1_c[c] = 2 * prec * rec / (prec + rec + eps)\n", + " macro_f1 = float(np.mean(f1_c))\n", + " tp = np.sum((y_true == 1) & (y_pred == 1))\n", + " fp = np.sum((y_true == 0) & (y_pred == 1))\n", + " fn = np.sum((y_true == 1) & (y_pred == 0))\n", + " prec = tp / (tp + fp + 1e-9)\n", + " rec = tp / (tp + fn + 1e-9)\n", + " micro_f1 = 2 * prec * rec / (prec + rec + 1e-9)\n", + " return f1_c, macro_f1, float(micro_f1)\n", + "\n", + "def average_precision_micro(y_true: np.ndarray, y_prob: np.ndarray, num_thresholds: int = 101) -> float:\n", + " thresholds = np.linspace(0.0, 1.0, num_thresholds)\n", + " precision, recall = [], []\n", + " for t in thresholds:\n", + " y_pred = (y_prob >= t).astype(np.float32)\n", + " tp = np.sum((y_true == 1) & (y_pred == 1))\n", + " fp = np.sum((y_true == 0) & (y_pred == 1))\n", + " fn = np.sum((y_true == 1) & (y_pred == 0))\n", + " p = tp / (tp + fp + 1e-9)\n", + " r = tp / (tp + fn + 1e-9)\n", + " precision.append(p); recall.append(r)\n", + " order = np.argsort(recall)\n", + " recall = np.array(recall)[order]\n", + " precision = np.array(precision)[order]\n", + " return float(np.trapz(precision, recall))\n", + "\n", + "LAMBDA_MOE = 0.0 # 这里的 forward 返回 aux=0(如需把 MoE aux 算进去,可按上一版做法)\n", + "\n", + "def train_one_epoch(model, loader, optimizer,\n", + " pos_weight: Optional[paddle.Tensor] = None,\n", + " clip_grad_norm: Optional[float] = None):\n", + " model.train()\n", + " total_loss, total_batches = 0.0, 0\n", + " for x_vid, x_vec, y in loader:\n", + " logits, _ = model(x_vid.astype('float32'), x_vec.astype('float32'))\n", + " if pos_weight is not None:\n", + " cls = F.binary_cross_entropy_with_logits(logits, y.astype('float32'), pos_weight=pos_weight)\n", + " else:\n", + " cls = F.binary_cross_entropy_with_logits(logits, y.astype('float32'))\n", + " loss = cls # + LAMBDA_MOE * aux # 此版本不叠加 MoE aux\n", + " loss.backward()\n", + " if clip_grad_norm is not None:\n", + " nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad_norm)\n", + " optimizer.step()\n", + " optimizer.clear_grad()\n", + " total_loss += float(loss); total_batches += 1\n", + " return total_loss / max(1, total_batches)\n", + "\n", + "# ====================== 检索增强:构建训练库 & 测试时融合 ======================\n", + "class Retriever:\n", + " \"\"\"\n", + " 检索库:\n", + " - keys: 训练集的特征 (N,D) —— 取模型 encode() 的融合特征\n", + " - labels: 训练集的标签 (N,C)\n", + " 推理:\n", + " - 给定测试特征 (B,D),计算与 keys 的相似度,取 top-k\n", + " - 得到邻居标签的加权均值 p_knn,按 alpha 融合到模型概率\n", + " \"\"\"\n", + " def __init__(self, sim_metric: str = 'cos', k: int = 8, alpha: float = 0.3, tau: float = 1.0):\n", + " \"\"\"\n", + " sim_metric: 'cos' or 'l2'\n", + " k: 近邻数\n", + " alpha: 融合系数,p_final = (1-alpha)*p_model + alpha*p_knn\n", + " tau: 温度(用于 l2 的 softmax(-d/tau) 或 cos 的 softmax(sim/tau))\n", + " \"\"\"\n", + " assert sim_metric in ['cos', 'l2']\n", + " self.sim_metric = sim_metric\n", + " self.k = k\n", + " self.alpha = alpha\n", + " self.tau = tau\n", + " self.keys = None # (N,D)\n", + " self.labels = None # (N,C)\n", + "\n", + " @paddle.no_grad()\n", + " def build(self, model: nn.Layer, loader: DataLoader):\n", + " model.eval()\n", + " feats, labs = [], []\n", + " for x_vid, x_vec, y in loader:\n", + " f = model.encode(x_vid.astype('float32'), x_vec.astype('float32')) # (B,512)\n", + " feats.append(f.numpy())\n", + " labs.append(y.numpy())\n", + " self.keys = paddle.to_tensor(np.concatenate(feats, axis=0)).astype('float32') # (N,D)\n", + " self.labels = paddle.to_tensor(np.concatenate(labs, axis=0)).astype('float32') # (N,C)\n", + " # 预归一化(cos 相似度更快;l2 也可复用)\n", + " self.keys_norm = F.normalize(self.keys, axis=-1)\n", + "\n", + " @paddle.no_grad()\n", + " def query_and_fuse(self, model_probs: paddle.Tensor, test_feat: paddle.Tensor) -> paddle.Tensor:\n", + " \"\"\"\n", + " model_probs: (B,C) —— 模型自身概率(sigmoid后的)\n", + " test_feat: (B,D) —— 模型 encode 导出的融合特征\n", + " return: (B,C) —— 融合后的概率\n", + " \"\"\"\n", + " assert self.keys is not None, \"Call build() before query.\"\n", + " B, D = test_feat.shape\n", + " # 相似度\n", + " if self.sim_metric == 'cos':\n", + " q = F.normalize(test_feat, axis=-1) # (B,D)\n", + " sim = paddle.matmul(q, self.keys_norm, transpose_y=True) # (B,N)\n", + " w = F.softmax(sim / self.tau, axis=-1) # (B,N)\n", + " else: # 'l2'\n", + " # ||q-k||^2 = q^2 + k^2 - 2 q·k\n", + " q2 = paddle.sum(test_feat * test_feat, axis=-1, keepdim=True) # (B,1)\n", + " k2 = paddle.sum(self.keys * self.keys, axis=-1, keepdim=True).transpose([1,0]) # (1,N)\n", + " dot = paddle.matmul(test_feat, self.keys, transpose_y=True) # (B,N)\n", + " dist2 = q2 + k2 - 2.0 * dot # (B,N)\n", + " w = F.softmax(-dist2 / self.tau, axis=-1) # (B,N)\n", + "\n", + " # 取 top-k(可选:先 topk 再归一化,避免长尾干扰)\n", + " topk_val, topk_idx = paddle.topk(w, k=min(self.k, w.shape[1]), axis=-1) # (B,k)\n", + " # gather labels\n", + " N, C = self.labels.shape\n", + " b_idx = paddle.arange(B, dtype='int64').unsqueeze(-1).tile([1, topk_val.shape[1]]) # (B,k)\n", + " # 先 gather 权重对应的 labels\n", + " picked_labels = paddle.gather(self.labels, topk_idx.reshape([-1]), axis=0) # (B*k, C)\n", + " picked_labels = picked_labels.reshape([B, -1, C]) # (B,k,C)\n", + " w_norm = topk_val / (paddle.sum(topk_val, axis=-1, keepdim=True) + 1e-9) # (B,k)\n", + " p_knn = paddle.sum(picked_labels * w_norm.unsqueeze(-1), axis=1) # (B,C)\n", + "\n", + " # 概率融合\n", + " p_final = (1.0 - self.alpha) * model_probs + self.alpha * p_knn\n", + " return p_final.clip(1e-6, 1-1e-6)\n", + "\n", + "@paddle.no_grad()\n", + "def evaluate(model, loader, threshold: float = 0.5,\n", + " retriever: Optional[Retriever] = None):\n", + " \"\"\"\n", + " 若 retriever 不为 None:在测试时做 kNN 检索并与模型概率融合。\n", + " \"\"\"\n", + " model.eval()\n", + " ys, ps = [], []\n", + " total_loss, total_batches = 0.0, 0\n", + " for x_vid, x_vec, y in loader:\n", + " logits, _ = model(x_vid.astype('float32'), x_vec.astype('float32'))\n", + " prob = F.sigmoid(logits) # (B,C)\n", + "\n", + " if retriever is not None:\n", + " feat = model.encode(x_vid.astype('float32'), x_vec.astype('float32')) # (B,D)\n", + " prob = retriever.query_and_fuse(prob, feat) # (B,C)\n", + "\n", + " loss = F.binary_cross_entropy(prob, y.astype('float32')) # 用概率计算 eval loss\n", + " ys.append(y.numpy()); ps.append(prob.numpy())\n", + " total_loss += float(loss); total_batches += 1\n", + "\n", + " y_true = np.concatenate(ys, axis=0)\n", + " y_prob = np.concatenate(ps, axis=0)\n", + " y_pred = (y_prob >= threshold).astype(np.float32)\n", + " per_f1, macro_f1, micro_f1 = f1_per_class(y_true, y_pred)\n", + " ap_micro = average_precision_micro(y_true, y_prob)\n", + " return {\n", + " \"loss\": total_loss / max(1, total_batches),\n", + " \"macro_f1\": macro_f1,\n", + " \"micro_f1\": micro_f1,\n", + " \"per_class_f1\": per_f1.tolist(),\n", + " \"micro_AP\": ap_micro\n", + " }\n", + "\n", + "# ====================== 合成数据集(可替换为真实数据) ======================\n", + "class ToyTwoModalDataset(Dataset):\n", + " \"\"\"\n", + " 返回:\n", + " x_video: (T=36, C=20, H=20, W=20)\n", + " x_vec: (424,)\n", + " y: (4,) 0/1\n", + " \"\"\"\n", + " def __init__(self, n: int, seed: int = 0):\n", + " super().__init__()\n", + " rng = np.random.default_rng(seed)\n", + " self.n = n\n", + " self.video = rng.normal(size=(n, 36, 20, 20, 20)).astype('float32')\n", + " self.vec = rng.normal(size=(n, 424)).astype('float32')\n", + "\n", + " # 造标签:对视频先在 H/W 上均值,再在 T 上均值 → (n, C=20)\n", + " vid_hw = self.video.mean(axis=(3, 4)) # (n, T, C)\n", + " vid_avg = vid_hw.mean(axis=1) # (n, C)\n", + " Wv = rng.normal(size=(20, 4)) # C→4\n", + " Wt = rng.normal(size=(424, 4)) # 424→4\n", + " logits = vid_avg @ Wv + self.vec @ Wt + rng.normal(scale=0.5, size=(n, 4))\n", + " probs = 1.0 / (1.0 + np.exp(-logits))\n", + " self.y = (probs > 0.5).astype('float32')\n", + "\n", + " def __getitem__(self, idx: int):\n", + " return self.video[idx], self.vec[idx], self.y[idx]\n", + "\n", + " def __len__(self):\n", + " return self.n\n", + "\n", + "# ====================== 训练入口(可直接运行) ======================\n", + "if __name__ == \"__main__\":\n", + " paddle.seed(2025)\n", + " # 数据\n", + " train_ds = ToyTwoModalDataset(n=128, seed=42)\n", + " val_ds = ToyTwoModalDataset(n=32, seed=233)\n", + "\n", + " def collate_fn(batch):\n", + " vids, vecs, ys = zip(*batch)\n", + " return (paddle.to_tensor(np.stack(vids, 0)), # (B,T,C,H,W)\n", + " paddle.to_tensor(np.stack(vecs, 0)), # (B,424)\n", + " paddle.to_tensor(np.stack(ys, 0))) # (B,4)\n", + "\n", + " train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, drop_last=False, collate_fn=collate_fn)\n", + " val_loader = DataLoader(val_ds, batch_size=4, shuffle=False, drop_last=False, collate_fn=collate_fn)\n", + "\n", + " # 类别不平衡权重(可选)\n", + " y_train = np.stack([y for _, _, y in train_ds], 0)\n", + " pos_ratio = np.clip(y_train.mean(axis=0), 1e-3, 1-1e-3)\n", + " pos_weight = paddle.to_tensor(((1-pos_ratio)/pos_ratio).astype('float32')) # (4,)\n", + "\n", + " # === 模型(MoE 开关按需) ===\n", + " model = TwoModalMultiLabelModel(\n", + " vid_channels=20, vid_h=20, vid_w=20, vid_frames=36,\n", + " vec_dim=424,\n", + " d_model=512, nhead=2, n_trans_layers=2, trans_ff=1024,\n", + " tabm_hidden=512, dropout=0.1,\n", + " num_labels=4,\n", + " moe_temporal=True, # FFN 位置 MoE\n", + " moe_fused=False, # 融合处 MoE\n", + " moe_tabm=False # TabM 处 MoE\n", + " )\n", + " optimizer = paddle.optimizer.Adam(learning_rate=3e-4, parameters=model.parameters())\n", + "\n", + " # 训练(演示用)\n", + " best_macro_f1, best = -1.0, None\n", + " for ep in range(1, 3+1):\n", + " train_loss = train_one_epoch(model, train_loader, optimizer,\n", + " pos_weight=pos_weight, clip_grad_norm=1.0)\n", + " val_metrics = evaluate(model, val_loader, threshold=0.5, retriever=None)\n", + " print(f\"[Epoch {ep:02d}] train_loss={train_loss:.4f} | \"\n", + " f\"val_loss={val_metrics['loss']:.4f} | \"\n", + " f\"macro_f1={val_metrics['macro_f1']:.4f} | \"\n", + " f\"micro_f1={val_metrics['micro_f1']:.4f} | \"\n", + " f\"per_class_f1={val_metrics['per_class_f1']} | \"\n", + " f\"micro_AP={val_metrics['micro_AP']:.4f}\")\n", + " if val_metrics[\"macro_f1\"] > best_macro_f1:\n", + " best_macro_f1 = val_metrics[\"macro_f1\"]\n", + " best = {k: v.clone() for k, v in model.state_dict().items()}\n", + "\n", + " if best is not None:\n", + " model.set_state_dict(best)\n", + " print(f\"Loaded best state with macro_f1={best_macro_f1:.4f}\")\n", + "\n", + " # === 构建检索库(使用训练集) ===\n", + " retr = Retriever(sim_metric='cos', k=8, alpha=0.3, tau=0.5) # 可改 'l2'\n", + " retr.build(model, DataLoader(train_ds, batch_size=8, shuffle=False, collate_fn=collate_fn))\n", + "\n", + " # === 测试时启用检索增强 ===\n", + " val_metrics_knn = evaluate(model, val_loader, threshold=0.5, retriever=retr)\n", + " print(f\"[RkNN] val_loss={val_metrics_knn['loss']:.4f} | \"\n", + " f\"macro_f1={val_metrics_knn['macro_f1']:.4f} | \"\n", + " f\"micro_f1={val_metrics_knn['micro_f1']:.4f} | \"\n", + " f\"per_class_f1={val_metrics_knn['per_class_f1']} | \"\n", + " f\"micro_AP={val_metrics_knn['micro_AP']:.4f}\")\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "u6m8tTnwSSYC", + "outputId": "b7bfaa2f-71fa-4945-c8b4-b4f586fd0fdf" + }, + "execution_count": 17, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.12/dist-packages/paddle/nn/layer/norm.py:818: UserWarning: When training, we now always track global mean and variance.\n", + " warnings.warn(\n", + "/tmp/ipython-input-2448755935.py:419: DeprecationWarning: `trapz` is deprecated. Use `trapezoid` instead, or one of the numerical integration functions in `scipy.integrate`.\n", + " return float(np.trapz(precision, recall))\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[Epoch 01] train_loss=0.8949 | val_loss=0.7680 | macro_f1=0.4306 | micro_f1=0.5255 | per_class_f1=[0.5454545617103577, 0.5405405163764954, 0.6363636255264282, 0.0] | micro_AP=0.4448\n", + "[Epoch 02] train_loss=0.5957 | val_loss=0.9367 | macro_f1=0.4620 | micro_f1=0.5414 | per_class_f1=[0.4444444477558136, 0.7450980544090271, 0.5333333611488342, 0.125] | micro_AP=0.5133\n", + "[Epoch 03] train_loss=0.3141 | val_loss=1.1780 | macro_f1=0.3715 | micro_f1=0.4158 | per_class_f1=[0.125, 0.5, 0.6000000238418579, 0.260869562625885] | micro_AP=0.5041\n", + "Loaded best state with macro_f1=0.4620\n", + "[RkNN] val_loss=0.8189 | macro_f1=0.4900 | micro_f1=0.5390 | per_class_f1=[0.4000000059604645, 0.7599999904632568, 0.5142857432365417, 0.2857142984867096] | micro_AP=0.4978\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# -*- coding: utf-8 -*-\n", + "import math\n", + "from typing import Optional, Tuple\n", + "import numpy as np\n", + "import paddle\n", + "import paddle.nn as nn\n", + "import paddle.nn.functional as F\n", + "from paddle.io import Dataset, DataLoader\n", + "\n", + "# ====================== 工具:正弦位置编码 ======================\n", + "class SinusoidalPositionalEncoding(nn.Layer):\n", + " def __init__(self, d_model: int, max_len: int = 4096):\n", + " super().__init__()\n", + " pe = np.zeros((max_len, d_model), dtype=\"float32\")\n", + " position = np.arange(0, max_len, dtype=\"float32\")[:, None]\n", + " div_term = np.exp(np.arange(0, d_model, 2, dtype=\"float32\") * (-math.log(10000.0) / d_model))\n", + " pe[:, 0::2] = np.sin(position * div_term)\n", + " pe[:, 1::2] = np.cos(position * div_term)\n", + " self.register_buffer(\"pe\", paddle.to_tensor(pe), persistable=False)\n", + "\n", + " def forward(self, x): # x: (B, T, D)\n", + " T = x.shape[1]\n", + " return x + self.pe[:T, :]\n", + "\n", + "# ====================== 简化版 TabM(占位,可换你的实现) ======================\n", + "class TabMFeatureExtractor(nn.Layer):\n", + " def __init__(self, num_features: int, d_hidden: int = 512, dropout: float = 0.1):\n", + " super().__init__()\n", + " self.net = nn.Sequential(\n", + " nn.Linear(num_features, d_hidden),\n", + " nn.ReLU(),\n", + " nn.Dropout(dropout),\n", + " nn.Linear(d_hidden, d_hidden),\n", + " nn.ReLU(),\n", + " )\n", + " self.d_hidden = d_hidden\n", + " def forward(self, x_num: paddle.Tensor): # (B, 424)\n", + " return self.net(x_num) # (B, H)\n", + "\n", + "# ====================== 3D ResNet-18 体数据特征抽取 ======================\n", + "class BasicBlock3D(nn.Layer):\n", + " expansion = 1\n", + " def __init__(self, in_planes, planes, stride=(1,1,1), downsample=None):\n", + " super().__init__()\n", + " self.conv1 = nn.Conv3D(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias_attr=False)\n", + " self.bn1 = nn.BatchNorm3D(planes)\n", + " self.relu = nn.ReLU()\n", + " self.conv2 = nn.Conv3D(planes, planes, kernel_size=3, stride=1, padding=1, bias_attr=False)\n", + " self.bn2 = nn.BatchNorm3D(planes)\n", + " self.downsample = downsample\n", + " def forward(self, x):\n", + " identity = x\n", + " out = self.relu(self.bn1(self.conv1(x)))\n", + " out = self.bn2(self.conv2(out))\n", + " if self.downsample is not None:\n", + " identity = self.downsample(x)\n", + " out = self.relu(out + identity)\n", + " return out\n", + "\n", + "class ResNet3D(nn.Layer):\n", + " def __init__(self, block, layers, in_channels=20, base_width=64):\n", + " super().__init__()\n", + " self.in_planes = base_width\n", + " # 只空间下采样,保留较细 D 维\n", + " self.conv1 = nn.Conv3D(in_channels, self.in_planes,\n", + " kernel_size=(3,7,7), stride=(1,2,2),\n", + " padding=(1,3,3), bias_attr=False)\n", + " self.bn1 = nn.BatchNorm3D(self.in_planes)\n", + " self.relu = nn.ReLU()\n", + " self.maxpool = nn.MaxPool3D(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1))\n", + " self.layer1 = self._make_layer(block, base_width, layers[0], stride=(1,1,1))\n", + " self.layer2 = self._make_layer(block, base_width*2, layers[1], stride=(2,2,2)) # D/H/W /2\n", + " self.layer3 = self._make_layer(block, base_width*4, layers[2], stride=(2,2,2))\n", + " self.layer4 = self._make_layer(block, base_width*8, layers[3], stride=(2,2,2))\n", + " self.out_dim = base_width*8 # 512\n", + " self.pool = nn.AdaptiveAvgPool3D(output_size=1)\n", + "\n", + " def _make_layer(self, block, planes, blocks, stride=(1,1,1)):\n", + " downsample = None\n", + " if stride != (1,1,1) or self.in_planes != planes * block.expansion:\n", + " downsample = nn.Sequential(\n", + " nn.Conv3D(self.in_planes, planes * block.expansion, kernel_size=1, stride=stride, bias_attr=False),\n", + " nn.BatchNorm3D(planes * block.expansion),\n", + " )\n", + " layers = [block(self.in_planes, planes, stride=stride, downsample=downsample)]\n", + " self.in_planes = planes * block.expansion\n", + " for _ in range(1, blocks):\n", + " layers.append(block(self.in_planes, planes))\n", + " return nn.Sequential(*layers)\n", + "\n", + " def forward(self, x): # x: (B, C, D, H, W)\n", + " x = self.relu(self.bn1(self.conv1(x)))\n", + " x = self.maxpool(x)\n", + " x = self.layer1(x) # (B, 64, D, H/4, W/4)\n", + " x = self.layer2(x) # (B, 128, D/2, H/8, W/8)\n", + " x = self.layer3(x) # (B, 256, D/4, H/16, W/16)\n", + " x = self.layer4(x) # (B, 512, D/8, H/32, W/32)\n", + " x = self.pool(x) # (B, 512, 1,1,1)\n", + " x = paddle.flatten(x, 1) # (B, 512)\n", + " return x\n", + "\n", + "class Volume3DEncoder(nn.Layer):\n", + " \"\"\"\n", + " 3D ResNet-18 over (D,H,W) for each time step.\n", + " 输入单帧体数据: (B, C=20, D=24, H=20, W=20) → 输出 (B, 512)\n", + " \"\"\"\n", + " def __init__(self, in_channels: int = 20, base: int = 64, dropout: float = 0.0):\n", + " super().__init__()\n", + " self.backbone = ResNet3D(BasicBlock3D, layers=[2,2,2,2], in_channels=in_channels, base_width=base)\n", + " self.drop = nn.Dropout(dropout)\n", + " self.out_dim = self.backbone.out_dim # 512\n", + " def forward(self, x): # x: (B, C, D, H, W)\n", + " x = self.backbone(x) # (B,512)\n", + " x = self.drop(x)\n", + " return x\n", + "\n", + "# ====================== MoE(Top-k;gather_nd 选择专家) ======================\n", + "class ExpertFFN(nn.Layer):\n", + " def __init__(self, d_model, d_ff, dropout=0.1, act='relu'):\n", + " super().__init__()\n", + " Act = getattr(F, act) if isinstance(act, str) else act\n", + " self.fc1 = nn.Linear(d_model, d_ff)\n", + " self.fc2 = nn.Linear(d_ff, d_model)\n", + " self.drop = nn.Dropout(dropout)\n", + " self.act = Act\n", + " def forward(self, x):\n", + " return self.fc2(self.drop(self.act(self.fc1(x))))\n", + "\n", + "class MoEConfig:\n", + " def __init__(self,\n", + " n_experts=8,\n", + " top_k=1,\n", + " d_ff=2048,\n", + " dropout=0.1,\n", + " router_temp=0.5,\n", + " balance_loss_w=0.005,\n", + " entropy_reg_w=-0.005,\n", + " diversity_w=1e-3,\n", + " sticky_w=0.0,\n", + " sup_router_w=0.0,\n", + " use_gumbel=True):\n", + " self.n_experts = n_experts\n", + " self.top_k = top_k\n", + " self.d_ff = d_ff\n", + " self.dropout = dropout\n", + " self.router_temp = router_temp\n", + " self.balance_loss_w = balance_loss_w\n", + " self.entropy_reg_w = entropy_reg_w\n", + " self.diversity_w = diversity_w\n", + " self.sticky_w = sticky_w\n", + " self.sup_router_w = sup_router_w\n", + " self.use_gumbel = use_gumbel\n", + "\n", + "class MoE(nn.Layer):\n", + " \"\"\"forward(x, domain_id=None) → (y, aux_loss),支持 (B,T,D) 或 (N,D)\"\"\"\n", + " def __init__(self, d_model: int, cfg: MoEConfig):\n", + " super().__init__()\n", + " self.cfg = cfg\n", + " self.router = nn.Linear(d_model, cfg.n_experts)\n", + " self.experts = nn.LayerList([ExpertFFN(d_model, cfg.d_ff, cfg.dropout) for _ in range(cfg.n_experts)])\n", + " self.ln = nn.LayerNorm(d_model)\n", + " self.drop = nn.Dropout(cfg.dropout)\n", + "\n", + " def _router_probs(self, logits):\n", + " if self.cfg.use_gumbel and self.training:\n", + " u = paddle.uniform(logits.shape, min=1e-6, max=1-1e-6, dtype=logits.dtype)\n", + " g = -paddle.log(-paddle.log(u))\n", + " logits = logits + g\n", + " return F.softmax(logits / self.cfg.router_temp, axis=-1)\n", + "\n", + " def forward(self, x, domain_id=None):\n", + " orig_shape = x.shape\n", + " if len(orig_shape) == 3:\n", + " B, T, D = orig_shape\n", + " X = x.reshape([B*T, D])\n", + " else:\n", + " X = x\n", + " N, D = X.shape\n", + "\n", + " logits = self.router(X) # (N,E)\n", + " probs = self._router_probs(logits) # (N,E)\n", + " topk_val, topk_idx = paddle.topk(probs, k=self.cfg.top_k, axis=-1) # (N,k)\n", + "\n", + " # 并行专家\n", + " all_out = paddle.stack([e(X) for e in self.experts], axis=1) # (N,E,D)\n", + "\n", + " # gather_nd 逐样本取 top-k\n", + " arangeN = paddle.arange(N, dtype='int64')\n", + " picked_list = []\n", + " for i in range(self.cfg.top_k):\n", + " idx_i = topk_idx[:, i].astype('int64') # (N,)\n", + " idx_nd = paddle.stack([arangeN, idx_i], axis=1) # (N,2)\n", + " picked_i = paddle.gather_nd(all_out, idx_nd) # (N,D)\n", + " picked_list.append(picked_i)\n", + " picked = paddle.stack(picked_list, axis=1) # (N,k,D)\n", + "\n", + " # 加权\n", + " w = topk_val / (paddle.sum(topk_val, axis=-1, keepdim=True) + 1e-9) # (N,k)\n", + " Y = paddle.sum(picked * w.unsqueeze(-1), axis=1) # (N,D)\n", + "\n", + " # 残差+归一\n", + " Y = self.drop(Y)\n", + " Y = self.ln(Y + X)\n", + "\n", + " # 辅助损失\n", + " aux = 0.0\n", + " if self.cfg.balance_loss_w > 0:\n", + " mean_prob = probs.mean(axis=0)\n", + " target = paddle.full_like(mean_prob, 1.0 / self.cfg.n_experts)\n", + " aux = aux + self.cfg.balance_loss_w * F.mse_loss(mean_prob, target)\n", + " if self.cfg.entropy_reg_w != 0.0:\n", + " ent = -paddle.sum(probs * (paddle.log(probs + 1e-9)), axis=1).mean()\n", + " aux = aux + self.cfg.entropy_reg_w * ent\n", + " if (domain_id is not None) and (self.cfg.sup_router_w > 0):\n", + " dom = domain_id.reshape([-1])[:N] % self.cfg.n_experts\n", + " aux = aux + self.cfg.sup_router_w * F.cross_entropy(logits, dom)\n", + " if self.cfg.diversity_w > 0 and self.cfg.n_experts > 1:\n", + " chosen = F.one_hot(topk_idx[:, 0], num_classes=self.cfg.n_experts).astype('float32') # (N,E)\n", + " denom = chosen.sum(axis=0).clip(min=1.0).unsqueeze(-1)\n", + " means = (all_out * chosen.unsqueeze(-1)).sum(axis=0) / denom # (E,D)\n", + " sims = []\n", + " for i in range(self.cfg.n_experts):\n", + " for j in range(i+1, self.cfg.n_experts):\n", + " si = F.normalize(means[i:i+1], axis=-1)\n", + " sj = F.normalize(means[j:j+1], axis=-1)\n", + " sims.append((si*sj).sum())\n", + " if sims:\n", + " aux = aux + self.cfg.diversity_w * paddle.stack(sims).mean()\n", + "\n", + " if len(orig_shape) == 3:\n", + " Y = Y.reshape([B, T, D])\n", + " return Y, aux\n", + "\n", + "class MoEHead(nn.Layer):\n", + " def __init__(self, d_model=512, cfg: MoEConfig = None):\n", + " super().__init__()\n", + " self.moe = MoE(d_model, cfg or MoEConfig())\n", + " def forward(self, tok, domain_id=None):\n", + " y, aux = self.moe(tok.unsqueeze(1), domain_id=domain_id) # (B,1,D)\n", + " return y.squeeze(1), aux\n", + "\n", + "# ====================== 自定义 Transformer Encoder(FFN 可替换 MoE) ======================\n", + "class TransformerEncoderLayerMoE(nn.Layer):\n", + " def __init__(self, d_model=512, nhead=8, d_ff=1024, dropout=0.1,\n", + " use_moe: bool = True, moe_cfg: MoEConfig = None):\n", + " super().__init__()\n", + " self.use_moe = use_moe\n", + " self.self_attn = nn.MultiHeadAttention(embed_dim=d_model, num_heads=nhead, dropout=dropout)\n", + " self.ln1 = nn.LayerNorm(d_model)\n", + " self.do1 = nn.Dropout(dropout)\n", + " if use_moe:\n", + " self.moe = MoE(d_model, moe_cfg or MoEConfig(d_ff=d_ff, dropout=dropout))\n", + " else:\n", + " self.ffn = nn.Sequential(\n", + " nn.LayerNorm(d_model),\n", + " nn.Linear(d_model, d_ff),\n", + " nn.ReLU(),\n", + " nn.Dropout(dropout),\n", + " nn.Linear(d_ff, d_model),\n", + " )\n", + " self.do2 = nn.Dropout(dropout)\n", + "\n", + " def forward(self, x, domain_id=None): # (B,T,D)\n", + " h = self.ln1(x)\n", + " h = paddle.transpose(h, [1, 0, 2]) # (T,B,D)\n", + " sa = self.self_attn(h, h, h) # (T,B,D)\n", + " sa = paddle.transpose(sa, [1, 0, 2]) # (B,T,D)\n", + " x = x + self.do1(sa)\n", + " aux = 0.0\n", + " if self.use_moe:\n", + " x, aux = self.moe(x, domain_id=domain_id)\n", + " else:\n", + " x = x + self.do2(self.ffn(x))\n", + " return x, aux\n", + "\n", + "class TemporalTransformerFlexible(nn.Layer):\n", + " def __init__(self, d_model=512, nhead=8, num_layers=2, d_ff=1024, dropout=0.1,\n", + " max_len=4096, use_moe: bool = True, moe_cfg: MoEConfig = None):\n", + " super().__init__()\n", + " self.pos = SinusoidalPositionalEncoding(d_model, max_len=max_len)\n", + " self.layers = nn.LayerList([\n", + " TransformerEncoderLayerMoE(d_model, nhead, d_ff, dropout,\n", + " use_moe=use_moe, moe_cfg=moe_cfg)\n", + " for _ in range(num_layers)\n", + " ])\n", + " def forward(self, x, domain_id=None): # x: (B,T,D)\n", + " x = self.pos(x)\n", + " aux_total = 0.0\n", + " for layer in self.layers:\n", + " x, aux = layer(x, domain_id=domain_id)\n", + " aux_total = aux_total + aux\n", + " return x, aux_total\n", + "\n", + "# ====================== Cross-Attention 融合 ======================\n", + "class MultiHeadCrossAttention(nn.Layer):\n", + " def __init__(self, d_model: int, nhead: int = 8, dropout: float = 0.1):\n", + " super().__init__()\n", + " assert d_model % nhead == 0\n", + " self.d_model = d_model\n", + " self.nhead = nhead\n", + " self.d_head = d_model // nhead\n", + " self.Wq = nn.Linear(d_model, d_model)\n", + " self.Wk = nn.Linear(d_model, d_model)\n", + " self.Wv = nn.Linear(d_model, d_model)\n", + " self.proj = nn.Linear(d_model, d_model)\n", + " self.drop = nn.Dropout(dropout)\n", + " self.ln = nn.LayerNorm(d_model)\n", + "\n", + " def forward(self, q, kv):\n", + " B, Nq, D = q.shape\n", + " q_lin = self.Wq(q); k_lin = self.Wk(kv); v_lin = self.Wv(kv)\n", + " def split_heads(t):\n", + " return t.reshape([B, -1, self.nhead, self.d_head]).transpose([0, 2, 1, 3])\n", + " qh = split_heads(q_lin); kh = split_heads(k_lin); vh = split_heads(v_lin)\n", + " scores = paddle.matmul(qh, kh, transpose_y=True) / math.sqrt(self.d_head)\n", + " attn = F.softmax(scores, axis=-1)\n", + " ctx = paddle.matmul(attn, vh)\n", + " ctx = ctx.transpose([0, 2, 1, 3]).reshape([B, Nq, D])\n", + " out = self.proj(ctx)\n", + " out = self.drop(out)\n", + " return self.ln(out + q)\n", + "\n", + "class BiModalCrossFusion(nn.Layer):\n", + " def __init__(self, d_model=512, nhead=8, dropout=0.1, fuse_hidden=512):\n", + " super().__init__()\n", + " self.ca_v_from_t = MultiHeadCrossAttention(d_model, nhead, dropout)\n", + " self.ca_t_from_v = MultiHeadCrossAttention(d_model, nhead, dropout)\n", + " self.fuse = nn.Sequential(\n", + " nn.Linear(2 * d_model, fuse_hidden),\n", + " nn.ReLU(),\n", + " nn.Dropout(dropout),\n", + " )\n", + " self.out_dim = fuse_hidden\n", + "\n", + " def forward(self, video_seq, tabm_tok):\n", + " v_tok = video_seq.mean(axis=1, keepdim=True) # (B,1,D)\n", + " t_tok = tabm_tok.unsqueeze(1) # (B,1,D)\n", + " v_upd = self.ca_v_from_t(v_tok, t_tok) # (B,1,D)\n", + " t_upd = self.ca_t_from_v(t_tok, video_seq) # (B,1,D)\n", + " fused = paddle.concat([v_upd, t_upd], axis=-1) # (B,1,2D)\n", + " fused = fused.squeeze(1) # (B,2D)\n", + " return self.fuse(fused) # (B, F)\n", + "\n", + "# ====================== 总模型 ======================\n", + "class TwoModalMultiLabelModel(nn.Layer):\n", + " def __init__(self,\n", + " # 视频模态\n", + " vid_channels=20, vid_h=20, vid_w=20, vid_frames=365, depth_n=24,\n", + " # 结构化模态\n", + " vec_dim=424,\n", + " # 维度与结构\n", + " d_model=512, nhead=4, n_trans_layers=2, trans_ff=1024,\n", + " tabm_hidden=512, dropout=0.1, num_labels=4,\n", + " # MoE 开关\n", + " moe_temporal: bool = True,\n", + " moe_fused: bool = False,\n", + " moe_tabm: bool = False,\n", + " # MoE 超参\n", + " moe_cfg_temporal: MoEConfig = None,\n", + " moe_cfg_fused: MoEConfig = None,\n", + " moe_cfg_tabm: MoEConfig = None):\n", + " super().__init__()\n", + " # A: 逐帧 3D ResNet18\n", + " self.vol_encoder = Volume3DEncoder(in_channels=vid_channels, dropout=dropout) # (B*T,512)\n", + " # A: 时序 Transformer(可 MoE)\n", + " self.temporal = TemporalTransformerFlexible(\n", + " d_model=d_model, nhead=nhead, num_layers=n_trans_layers,\n", + " d_ff=trans_ff, dropout=dropout, max_len=vid_frames,\n", + " use_moe=moe_temporal,\n", + " moe_cfg=moe_cfg_temporal or MoEConfig(\n", + " n_experts=8, top_k=1, d_ff=max(2048, trans_ff), router_temp=0.5,\n", + " balance_loss_w=0.005, entropy_reg_w=-0.005, diversity_w=1e-3\n", + " )\n", + " )\n", + " # B: TabM\n", + " self.tabm = TabMFeatureExtractor(vec_dim, d_hidden=tabm_hidden, dropout=dropout)\n", + " self.tabm_proj = nn.Linear(tabm_hidden, d_model)\n", + "\n", + " # 可选:TabM 分支 MoE 头\n", + " self.moe_tabm = moe_tabm\n", + " if moe_tabm:\n", + " self.tabm_moe = MoEHead(d_model=d_model, cfg=moe_cfg_tabm or MoEConfig(\n", + " n_experts=6, top_k=1, d_ff=1024, router_temp=0.5,\n", + " balance_loss_w=0.005, entropy_reg_w=-0.005, diversity_w=1e-3\n", + " ))\n", + "\n", + " # 融合\n", + " self.fusion = BiModalCrossFusion(d_model=d_model, nhead=nhead, dropout=dropout, fuse_hidden=d_model)\n", + "\n", + " # 可选:融合 token MoE 头\n", + " self.moe_fused = moe_fused\n", + " if moe_fused:\n", + " self.fused_moe = MoEHead(d_model=d_model, cfg=moe_cfg_fused or MoEConfig(\n", + " n_experts=6, top_k=1, d_ff=1024, router_temp=0.5,\n", + " balance_loss_w=0.005, entropy_reg_w=-0.005, diversity_w=1e-3\n", + " ))\n", + "\n", + " # 分类头\n", + " self.head = nn.Linear(self.fusion.out_dim, num_labels)\n", + "\n", + " self.vid_frames = vid_frames\n", + " self.depth_n = depth_n\n", + "\n", + " # 导出融合前 512 表示(用于检索库)\n", + " def encode(self, x_video, x_vec, domain_id=None):\n", + " \"\"\"\n", + " x_video: (B, T, C=20, H=20, W=20, N=24)\n", + " x_vec: (B, 424)\n", + " \"\"\"\n", + " B, T, C, H, W, N = x_video.shape\n", + " assert N == self.depth_n, f\"N mismatch: got {N}, expect {self.depth_n}\"\n", + " # 逐帧 3D 编码: (B*T, C, D=N, H, W)\n", + " xvt = x_video.transpose([0,1,2,5,3,4]).reshape([B*T, C, N, H, W])\n", + " f_frame = self.vol_encoder(xvt) # (B*T, 512)\n", + " f_seq = f_frame.reshape([B, T, -1]) # (B, T, 512)\n", + " z_vid, _ = self.temporal(f_seq, domain_id=domain_id) # (B,T,512)\n", + " z_tabm = self.tabm(x_vec)\n", + " z_tabm = self.tabm_proj(z_tabm) # (B,512)\n", + " if self.moe_tabm:\n", + " z_tabm, _ = self.tabm_moe(z_tabm, domain_id=domain_id)\n", + " fused = self.fusion(z_vid, z_tabm) # (B,512)\n", + " if self.moe_fused:\n", + " fused, _ = self.fused_moe(fused, domain_id=domain_id)\n", + " return fused\n", + "\n", + " def forward(self, x_video, x_vec, domain_id=None):\n", + " fused = self.encode(x_video, x_vec, domain_id=domain_id) # (B,512)\n", + " logits = self.head(fused) # (B,4)\n", + " aux_placeholder = paddle.to_tensor(0.0, dtype='float32')\n", + " return logits, aux_placeholder\n", + "\n", + "# ====================== 指标与训练循环 ======================\n", + "def f1_per_class(y_true: np.ndarray, y_pred: np.ndarray, eps: float = 1e-9) -> Tuple[np.ndarray, float, float]:\n", + " assert y_true.shape == y_pred.shape\n", + " N, C = y_true.shape\n", + " f1_c = np.zeros(C, dtype=np.float32)\n", + " for c in range(C):\n", + " yt, yp = y_true[:, c], y_pred[:, c]\n", + " tp = np.sum((yt == 1) & (yp == 1))\n", + " fp = np.sum((yt == 0) & (yp == 1))\n", + " fn = np.sum((yt == 1) & (yp == 0))\n", + " prec = tp / (tp + fp + eps)\n", + " rec = tp / (tp + fn + eps)\n", + " f1_c[c] = 2 * prec * rec / (prec + rec + eps)\n", + " macro_f1 = float(np.mean(f1_c))\n", + " tp = np.sum((y_true == 1) & (y_pred == 1))\n", + " fp = np.sum((y_true == 0) & (y_pred == 1))\n", + " fn = np.sum((y_true == 1) & (y_pred == 0))\n", + " prec = tp / (tp + fp + 1e-9)\n", + " rec = tp / (tp + fn + 1e-9)\n", + " micro_f1 = 2 * prec * rec / (prec + rec + 1e-9)\n", + " return f1_c, macro_f1, float(micro_f1)\n", + "\n", + "def average_precision_micro(y_true: np.ndarray, y_prob: np.ndarray, num_thresholds: int = 101) -> float:\n", + " thresholds = np.linspace(0.0, 1.0, num_thresholds)\n", + " precision, recall = [], []\n", + " for t in thresholds:\n", + " y_pred = (y_prob >= t).astype(np.float32)\n", + " tp = np.sum((y_true == 1) & (y_pred == 1))\n", + " fp = np.sum((y_true == 0) & (y_pred == 1))\n", + " fn = np.sum((y_true == 1) & (y_pred == 0))\n", + " p = tp / (tp + fp + 1e-9)\n", + " r = tp / (tp + fn + 1e-9)\n", + " precision.append(p); recall.append(r)\n", + " order = np.argsort(recall)\n", + " recall = np.array(recall)[order]\n", + " precision = np.array(precision)[order]\n", + " return float(np.trapz(precision, recall))\n", + "\n", + "def train_one_epoch(model, loader, optimizer,\n", + " pos_weight: Optional[paddle.Tensor] = None,\n", + " clip_grad_norm: Optional[float] = None):\n", + " model.train()\n", + " total_loss, total_batches = 0.0, 0\n", + " for x_vid, x_vec, y in loader:\n", + " logits, _ = model(x_vid.astype('float32'), x_vec.astype('float32'))\n", + " if pos_weight is not None:\n", + " cls = F.binary_cross_entropy_with_logits(logits, y.astype('float32'), pos_weight=pos_weight)\n", + " else:\n", + " cls = F.binary_cross_entropy_with_logits(logits, y.astype('float32'))\n", + " loss = cls\n", + " loss.backward()\n", + " if clip_grad_norm is not None:\n", + " nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad_norm)\n", + " optimizer.step()\n", + " optimizer.clear_grad()\n", + " total_loss += float(loss); total_batches += 1\n", + " return total_loss / max(1, total_batches)\n", + "\n", + "# ====================== 检索增强(cos / l2;k 邻居软加权;概率融合) ======================\n", + "class Retriever:\n", + " def __init__(self, sim_metric: str = 'cos', k: int = 8, alpha: float = 0.3, tau: float = 0.5):\n", + " assert sim_metric in ['cos', 'l2']\n", + " self.sim_metric = sim_metric\n", + " self.k = k\n", + " self.alpha = alpha\n", + " self.tau = tau\n", + " self.keys = None # (N,D)\n", + " self.labels = None # (N,C)\n", + "\n", + " @paddle.no_grad()\n", + " def build(self, model: nn.Layer, loader: DataLoader):\n", + " model.eval()\n", + " feats, labs = [], []\n", + " for x_vid, x_vec, y in loader:\n", + " f = model.encode(x_vid.astype('float32'), x_vec.astype('float32')) # (B,512)\n", + " feats.append(f.numpy())\n", + " labs.append(y.numpy())\n", + " self.keys = paddle.to_tensor(np.concatenate(feats, axis=0)).astype('float32') # (N,D)\n", + " self.labels = paddle.to_tensor(np.concatenate(labs, axis=0)).astype('float32') # (N,C)\n", + " self.keys_norm = F.normalize(self.keys, axis=-1)\n", + "\n", + " @paddle.no_grad()\n", + " def query_and_fuse(self, model_probs: paddle.Tensor, test_feat: paddle.Tensor) -> paddle.Tensor:\n", + " assert self.keys is not None, \"build() must be called first.\"\n", + " B, D = test_feat.shape\n", + " if self.sim_metric == 'cos':\n", + " q = F.normalize(test_feat, axis=-1)\n", + " sim = paddle.matmul(q, self.keys_norm, transpose_y=True) # (B,N)\n", + " w = F.softmax(sim / self.tau, axis=-1)\n", + " else:\n", + " q2 = paddle.sum(test_feat * test_feat, axis=-1, keepdim=True) # (B,1)\n", + " k2 = paddle.sum(self.keys * self.keys, axis=-1, keepdim=True).transpose([1,0]) # (1,N)\n", + " dot = paddle.matmul(test_feat, self.keys, transpose_y=True) # (B,N)\n", + " dist2 = q2 + k2 - 2.0 * dot # (B,N)\n", + " w = F.softmax(-dist2 / self.tau, axis=-1)\n", + "\n", + " topk_val, topk_idx = paddle.topk(w, k=min(self.k, w.shape[1]), axis=-1) # (B,k)\n", + " picked_labels = paddle.gather(self.labels, topk_idx.reshape([-1]), axis=0) # (B*k, C)\n", + " C = self.labels.shape[1]\n", + " picked_labels = picked_labels.reshape([B, -1, C]) # (B,k,C)\n", + " w_norm = topk_val / (paddle.sum(topk_val, axis=-1, keepdim=True) + 1e-9) # (B,k)\n", + " p_knn = paddle.sum(picked_labels * w_norm.unsqueeze(-1), axis=1) # (B,C)\n", + "\n", + " p_final = (1.0 - self.alpha) * model_probs + self.alpha * p_knn\n", + " return p_final.clip(1e-6, 1-1e-6)\n", + "\n", + "@paddle.no_grad()\n", + "def evaluate(model, loader, threshold: float = 0.5,\n", + " retriever: Optional[Retriever] = None):\n", + " model.eval()\n", + " ys, ps = [], []\n", + " total_loss, total_batches = 0.0, 0\n", + " for x_vid, x_vec, y in loader:\n", + " logits, _ = model(x_vid.astype('float32'), x_vec.astype('float32'))\n", + " prob = F.sigmoid(logits) # (B,C)\n", + " if retriever is not None:\n", + " feat = model.encode(x_vid.astype('float32'), x_vec.astype('float32')) # (B,512)\n", + " prob = retriever.query_and_fuse(prob, feat)\n", + " loss = F.binary_cross_entropy(prob, y.astype('float32'))\n", + " ys.append(y.numpy()); ps.append(prob.numpy())\n", + " total_loss += float(loss); total_batches += 1\n", + "\n", + " y_true = np.concatenate(ys, axis=0)\n", + " y_prob = np.concatenate(ps, axis=0)\n", + " y_pred = (y_prob >= threshold).astype(np.float32)\n", + " per_f1, macro_f1, micro_f1 = f1_per_class(y_true, y_pred)\n", + " ap_micro = average_precision_micro(y_true, y_prob)\n", + " return {\n", + " \"loss\": total_loss / max(1, total_batches),\n", + " \"macro_f1\": macro_f1,\n", + " \"micro_f1\": micro_f1,\n", + " \"per_class_f1\": per_f1.tolist(),\n", + " \"micro_AP\": ap_micro\n", + " }\n", + "\n", + "# ====================== ToyDataset(T=365, N=24) ======================\n", + "class ToyTwoModalDataset(Dataset):\n", + " \"\"\"\n", + " 返回:\n", + " x_video: (T=365, C=20, H=20, W=20, N=24)\n", + " x_vec: (424,)\n", + " y: (4,) 0/1\n", + " \"\"\"\n", + " def __init__(self, n: int, seed: int = 0, T: int = 365, C: int = 20, H: int = 20, W: int = 20, N: int = 24):\n", + " super().__init__()\n", + " rng = np.random.default_rng(seed)\n", + " self.n = n\n", + " self.T, self.C, self.H, self.W, self.N = T, C, H, W, N\n", + " # (n, T, C, H, W, N)\n", + " self.video = rng.normal(size=(n, T, C, H, W, N)).astype('float32')\n", + " self.vec = rng.normal(size=(n, 424)).astype('float32')\n", + "\n", + " # 造标签:对视频先在 H/W/N 上均值,再在 T 上均值 → (n, C)\n", + " vid_hwn = self.video.mean(axis=(3, 4, 5)) # (n, T, C)\n", + " vid_avg = vid_hwn.mean(axis=1) # (n, C)\n", + "\n", + " Wv = rng.normal(size=(C, 4))\n", + " Wt = rng.normal(size=(424, 4))\n", + " logits = vid_avg @ Wv + self.vec @ Wt + rng.normal(scale=0.5, size=(n, 4))\n", + " probs = 1.0 / (1.0 + np.exp(-logits))\n", + " self.y = (probs > 0.5).astype('float32')\n", + "\n", + " def __getitem__(self, idx: int):\n", + " return self.video[idx], self.vec[idx], self.y[idx]\n", + " def __len__(self):\n", + " return self.n\n", + "\n", + "# ====================== 训练入口 ======================\n", + "if __name__ == \"__main__\":\n", + " paddle.seed(2025)\n", + "\n", + " # 数据\n", + " T, C, H, W, N = 365, 20, 20, 20, 24\n", + " train_ds = ToyTwoModalDataset(n=32, seed=42, T=T, C=C, H=H, W=W, N=N)\n", + " val_ds = ToyTwoModalDataset(n=16, seed=233, T=T, C=C, H=H, W=W, N=N)\n", + "\n", + " def collate_fn(batch):\n", + " vids, vecs, ys = zip(*batch)\n", + " return (paddle.to_tensor(np.stack(vids, 0)), # (B,T,C,H,W,N)\n", + " paddle.to_tensor(np.stack(vecs, 0)), # (B,424)\n", + " paddle.to_tensor(np.stack(ys, 0))) # (B,4)\n", + "\n", + " # T=365 + 3D 卷积较吃内存,示例用小 batch\n", + " train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, drop_last=False, collate_fn=collate_fn)\n", + " val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, drop_last=False, collate_fn=collate_fn)\n", + "\n", + " # 类别不平衡权重(可选)\n", + " y_train = np.stack([y for _, _, y in train_ds], 0)\n", + " pos_ratio = np.clip(y_train.mean(axis=0), 1e-3, 1-1e-3)\n", + " pos_weight = paddle.to_tensor(((1-pos_ratio)/pos_ratio).astype('float32')) # (4,)\n", + "\n", + " # 模型\n", + " model = TwoModalMultiLabelModel(\n", + " vid_channels=C, vid_h=H, vid_w=W, vid_frames=T, depth_n=N,\n", + " vec_dim=424,\n", + " d_model=512, nhead=4, n_trans_layers=2, trans_ff=1024,\n", + " tabm_hidden=512, dropout=0.1,\n", + " num_labels=4,\n", + " moe_temporal=True, # 推荐开启(FFN 位置 MoE)\n", + " moe_fused=False,\n", + " moe_tabm=False\n", + " )\n", + " optimizer = paddle.optimizer.Adam(learning_rate=3e-4, parameters=model.parameters())\n", + "\n", + " # 训练(演示用)\n", + " best_macro_f1, best = -1.0, None\n", + " for ep in range(1, 2+1):\n", + " train_loss = train_one_epoch(model, train_loader, optimizer,\n", + " pos_weight=pos_weight, clip_grad_norm=1.0)\n", + " val_metrics = evaluate(model, val_loader, threshold=0.5, retriever=None)\n", + " print(f\"[Epoch {ep:02d}] train_loss={train_loss:.4f} | \"\n", + " f\"val_loss={val_metrics['loss']:.4f} | \"\n", + " f\"macro_f1={val_metrics['macro_f1']:.4f} | \"\n", + " f\"micro_f1={val_metrics['micro_f1']:.4f} | \"\n", + " f\"per_class_f1={val_metrics['per_class_f1']} | \"\n", + " f\"micro_AP={val_metrics['micro_AP']:.4f}\")\n", + " if val_metrics[\"macro_f1\"] > best_macro_f1:\n", + " best_macro_f1 = val_metrics[\"macro_f1\"]\n", + " best = {k: v.clone() for k, v in model.state_dict().items()}\n", + "\n", + " if best is not None:\n", + " model.set_state_dict(best)\n", + " print(f\"Loaded best state with macro_f1={best_macro_f1:.4f}\")\n", + "\n", + " # === 构建检索库(用训练集) ===\n", + " retr = Retriever(sim_metric='cos', k=8, alpha=0.3, tau=0.5) # 可改 'l2'\n", + " retr.build(model, DataLoader(train_ds, batch_size=1, shuffle=False, collate_fn=collate_fn))\n", + "\n", + " # === 测试时启用检索增强 ===\n", + " val_metrics_knn = evaluate(model, val_loader, threshold=0.5, retriever=retr)\n", + " print(f\"[RkNN] val_loss={val_metrics_knn['loss']:.4f} | \"\n", + " f\"macro_f1={val_metrics_knn['macro_f1']:.4f} | \"\n", + " f\"micro_f1={val_metrics_knn['micro_f1']:.4f} | \"\n", + " f\"per_class_f1={val_metrics_knn['per_class_f1']} | \"\n", + " f\"micro_AP={val_metrics_knn['micro_AP']:.4f}\")\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "a85yLOn8jaJU", + "outputId": "acabc284-ffc9-4f25-ca2b-46e7971e2234" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.12/dist-packages/paddle/utils/cpp_extension/extension_utils.py:718: UserWarning: No ccache found. Please be aware that recompiling all source files may be required. You can download and install ccache from: https://github.com/ccache/ccache/blob/master/doc/INSTALL.md\n", + " warnings.warn(warning_message)\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# -*- coding: utf-8 -*-\n", + "import math\n", + "from typing import Optional, Tuple\n", + "import numpy as np\n", + "import paddle\n", + "import paddle.nn as nn\n", + "import paddle.nn.functional as F\n", + "from paddle.io import Dataset, DataLoader\n", + "\n", + "# ====================== 工具:正弦位置编码 ======================\n", + "class SinusoidalPositionalEncoding(nn.Layer):\n", + " def __init__(self, d_model: int, max_len: int = 4096):\n", + " super().__init__()\n", + " pe = np.zeros((max_len, d_model), dtype=\"float32\")\n", + " position = np.arange(0, max_len, dtype=\"float32\")[:, None]\n", + " div_term = np.exp(np.arange(0, d_model, 2, dtype=\"float32\") * (-math.log(10000.0) / d_model))\n", + " pe[:, 0::2] = np.sin(position * div_term)\n", + " pe[:, 1::2] = np.cos(position * div_term)\n", + " self.register_buffer(\"pe\", paddle.to_tensor(pe), persistable=False)\n", + "\n", + " def forward(self, x): # x: (B, T, D)\n", + " T = x.shape[1]\n", + " return x + self.pe[:T, :]\n", + "\n", + "# ====================== 简化版 TabM(占位,可换你的实现) ======================\n", + "class TabMFeatureExtractor(nn.Layer):\n", + " def __init__(self, num_features: int, d_hidden: int = 512, dropout: float = 0.1):\n", + " super().__init__()\n", + " self.net = nn.Sequential(\n", + " nn.Linear(num_features, d_hidden),\n", + " nn.ReLU(),\n", + " nn.Dropout(dropout),\n", + " nn.Linear(d_hidden, d_hidden),\n", + " nn.ReLU(),\n", + " )\n", + " self.d_hidden = d_hidden\n", + " def forward(self, x_num: paddle.Tensor): # (B, 424)\n", + " return self.net(x_num) # (B, H)\n", + "\n", + "# ====================== 3D ResNet-18 体数据特征抽取 ======================\n", + "class BasicBlock3D(nn.Layer):\n", + " expansion = 1\n", + " def __init__(self, in_planes, planes, stride=(1,1,1), downsample=None):\n", + " super().__init__()\n", + " self.conv1 = nn.Conv3D(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias_attr=False)\n", + " self.bn1 = nn.BatchNorm3D(planes)\n", + " self.relu = nn.ReLU()\n", + " self.conv2 = nn.Conv3D(planes, planes, kernel_size=3, stride=1, padding=1, bias_attr=False)\n", + " self.bn2 = nn.BatchNorm3D(planes)\n", + " self.downsample = downsample\n", + " def forward(self, x):\n", + " identity = x\n", + " out = self.relu(self.bn1(self.conv1(x)))\n", + " out = self.bn2(self.conv2(out))\n", + " if self.downsample is not None:\n", + " identity = self.downsample(x)\n", + " out = self.relu(out + identity)\n", + " return out\n", + "\n", + "class ResNet3D(nn.Layer):\n", + " def __init__(self, block, layers, in_channels=20, base_width=64):\n", + " super().__init__()\n", + " self.in_planes = base_width\n", + " # 只空间下采样,保留较细 D 维\n", + " self.conv1 = nn.Conv3D(in_channels, self.in_planes,\n", + " kernel_size=(3,7,7), stride=(1,2,2),\n", + " padding=(1,3,3), bias_attr=False)\n", + " self.bn1 = nn.BatchNorm3D(self.in_planes)\n", + " self.relu = nn.ReLU()\n", + " self.maxpool = nn.MaxPool3D(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1))\n", + " self.layer1 = self._make_layer(block, base_width, layers[0], stride=(1,1,1))\n", + " self.layer2 = self._make_layer(block, base_width*2, layers[1], stride=(2,2,2)) # D/H/W /2\n", + " self.layer3 = self._make_layer(block, base_width*4, layers[2], stride=(2,2,2))\n", + " self.layer4 = self._make_layer(block, base_width*8, layers[3], stride=(2,2,2))\n", + " self.out_dim = base_width*8 # 512\n", + " self.pool = nn.AdaptiveAvgPool3D(output_size=1)\n", + "\n", + " def _make_layer(self, block, planes, blocks, stride=(1,1,1)):\n", + " downsample = None\n", + " if stride != (1,1,1) or self.in_planes != planes * block.expansion:\n", + " downsample = nn.Sequential(\n", + " nn.Conv3D(self.in_planes, planes * block.expansion, kernel_size=1, stride=stride, bias_attr=False),\n", + " nn.BatchNorm3D(planes * block.expansion),\n", + " )\n", + " layers = [block(self.in_planes, planes, stride=stride, downsample=downsample)]\n", + " self.in_planes = planes * block.expansion\n", + " for _ in range(1, blocks):\n", + " layers.append(block(self.in_planes, planes))\n", + " return nn.Sequential(*layers)\n", + "\n", + " def forward(self, x): # x: (B, C, D, H, W)\n", + " x = self.relu(self.bn1(self.conv1(x)))\n", + " x = self.maxpool(x)\n", + " x = self.layer1(x) # (B, 64, D, H/4, W/4)\n", + " x = self.layer2(x) # (B, 128, D/2, H/8, W/8)\n", + " x = self.layer3(x) # (B, 256, D/4, H/16, W/16)\n", + " x = self.layer4(x) # (B, 512, D/8, H/32, W/32)\n", + " x = self.pool(x) # (B, 512, 1,1,1)\n", + " x = paddle.flatten(x, 1) # (B, 512)\n", + " return x\n", + "\n", + "class Volume3DEncoder(nn.Layer):\n", + " \"\"\"\n", + " 3D ResNet-18 over (D,H,W) for each time step.\n", + " 输入单帧体数据: (B, C=20, D=24, H=20, W=20) → 输出 (B, 512)\n", + " \"\"\"\n", + " def __init__(self, in_channels: int = 20, base: int = 64, dropout: float = 0.0):\n", + " super().__init__()\n", + " self.backbone = ResNet3D(BasicBlock3D, layers=[2,2,2,2], in_channels=in_channels, base_width=base)\n", + " self.drop = nn.Dropout(dropout)\n", + " self.out_dim = self.backbone.out_dim # 512\n", + " def forward(self, x): # x: (B, C, D, H, W)\n", + " x = self.backbone(x) # (B,512)\n", + " x = self.drop(x)\n", + " return x\n", + "\n", + "# ====================== MoE(Top-k;gather_nd 选择专家) ======================\n", + "class ExpertFFN(nn.Layer):\n", + " def __init__(self, d_model, d_ff, dropout=0.1, act='relu'):\n", + " super().__init__()\n", + " Act = getattr(F, act) if isinstance(act, str) else act\n", + " self.fc1 = nn.Linear(d_model, d_ff)\n", + " self.fc2 = nn.Linear(d_ff, d_model)\n", + " self.drop = nn.Dropout(dropout)\n", + " self.act = Act\n", + " def forward(self, x):\n", + " return self.fc2(self.drop(self.act(self.fc1(x))))\n", + "\n", + "class MoEConfig:\n", + " def __init__(self,\n", + " n_experts=8,\n", + " top_k=1,\n", + " d_ff=2048,\n", + " dropout=0.1,\n", + " router_temp=0.5,\n", + " balance_loss_w=0.005,\n", + " entropy_reg_w=-0.005,\n", + " diversity_w=1e-3,\n", + " sticky_w=0.0,\n", + " sup_router_w=0.0,\n", + " use_gumbel=True):\n", + " self.n_experts = n_experts\n", + " self.top_k = top_k\n", + " self.d_ff = d_ff\n", + " self.dropout = dropout\n", + " self.router_temp = router_temp\n", + " self.balance_loss_w = balance_loss_w\n", + " self.entropy_reg_w = entropy_reg_w\n", + " self.diversity_w = diversity_w\n", + " self.sticky_w = sticky_w\n", + " self.sup_router_w = sup_router_w\n", + " self.use_gumbel = use_gumbel\n", + "\n", + "class MoE(nn.Layer):\n", + " \"\"\"forward(x, domain_id=None) → (y, aux_loss),支持 (B,T,D) 或 (N,D)\"\"\"\n", + " def __init__(self, d_model: int, cfg: MoEConfig):\n", + " super().__init__()\n", + " self.cfg = cfg\n", + " self.router = nn.Linear(d_model, cfg.n_experts)\n", + " self.experts = nn.LayerList([ExpertFFN(d_model, cfg.d_ff, cfg.dropout) for _ in range(cfg.n_experts)])\n", + " self.ln = nn.LayerNorm(d_model)\n", + " self.drop = nn.Dropout(cfg.dropout)\n", + "\n", + " def _router_probs(self, logits):\n", + " if self.cfg.use_gumbel and self.training:\n", + " u = paddle.uniform(logits.shape, min=1e-6, max=1-1e-6, dtype=logits.dtype)\n", + " g = -paddle.log(-paddle.log(u))\n", + " logits = logits + g\n", + " return F.softmax(logits / self.cfg.router_temp, axis=-1)\n", + "\n", + " def forward(self, x, domain_id=None):\n", + " orig_shape = x.shape\n", + " if len(orig_shape) == 3:\n", + " B, T, D = orig_shape\n", + " X = x.reshape([B*T, D])\n", + " else:\n", + " X = x\n", + " N, D = X.shape\n", + "\n", + " logits = self.router(X) # (N,E)\n", + " probs = self._router_probs(logits) # (N,E)\n", + " topk_val, topk_idx = paddle.topk(probs, k=self.cfg.top_k, axis=-1) # (N,k)\n", + "\n", + " # 并行专家\n", + " all_out = paddle.stack([e(X) for e in self.experts], axis=1) # (N,E,D)\n", + "\n", + " # gather_nd 逐样本取 top-k\n", + " arangeN = paddle.arange(N, dtype='int64')\n", + " picked_list = []\n", + " for i in range(self.cfg.top_k):\n", + " idx_i = topk_idx[:, i].astype('int64') # (N,)\n", + " idx_nd = paddle.stack([arangeN, idx_i], axis=1) # (N,2)\n", + " picked_i = paddle.gather_nd(all_out, idx_nd) # (N,D)\n", + " picked_list.append(picked_i)\n", + " picked = paddle.stack(picked_list, axis=1) # (N,k,D)\n", + "\n", + " # 加权\n", + " w = topk_val / (paddle.sum(topk_val, axis=-1, keepdim=True) + 1e-9) # (N,k)\n", + " Y = paddle.sum(picked * w.unsqueeze(-1), axis=1) # (N,D)\n", + "\n", + " # 残差+归一\n", + " Y = self.drop(Y)\n", + " Y = self.ln(Y + X)\n", + "\n", + " # 辅助损失\n", + " aux = 0.0\n", + " if self.cfg.balance_loss_w > 0:\n", + " mean_prob = probs.mean(axis=0)\n", + " target = paddle.full_like(mean_prob, 1.0 / self.cfg.n_experts)\n", + " aux = aux + self.cfg.balance_loss_w * F.mse_loss(mean_prob, target)\n", + " if self.cfg.entropy_reg_w != 0.0:\n", + " ent = -paddle.sum(probs * (paddle.log(probs + 1e-9)), axis=1).mean()\n", + " aux = aux + self.cfg.entropy_reg_w * ent\n", + " if (domain_id is not None) and (self.cfg.sup_router_w > 0):\n", + " dom = domain_id.reshape([-1])[:N] % self.cfg.n_experts\n", + " aux = aux + self.cfg.sup_router_w * F.cross_entropy(logits, dom)\n", + " if self.cfg.diversity_w > 0 and self.cfg.n_experts > 1:\n", + " chosen = F.one_hot(topk_idx[:, 0], num_classes=self.cfg.n_experts).astype('float32') # (N,E)\n", + " denom = chosen.sum(axis=0).clip(min=1.0).unsqueeze(-1)\n", + " means = (all_out * chosen.unsqueeze(-1)).sum(axis=0) / denom # (E,D)\n", + " sims = []\n", + " for i in range(self.cfg.n_experts):\n", + " for j in range(i+1, self.cfg.n_experts):\n", + " si = F.normalize(means[i:i+1], axis=-1)\n", + " sj = F.normalize(means[j:j+1], axis=-1)\n", + " sims.append((si*sj).sum())\n", + " if sims:\n", + " aux = aux + self.cfg.diversity_w * paddle.stack(sims).mean()\n", + "\n", + " if len(orig_shape) == 3:\n", + " Y = Y.reshape([B, T, D])\n", + " return Y, aux\n", + "\n", + "class MoEHead(nn.Layer):\n", + " def __init__(self, d_model=512, cfg: MoEConfig = None):\n", + " super().__init__()\n", + " self.moe = MoE(d_model, cfg or MoEConfig())\n", + " def forward(self, tok, domain_id=None):\n", + " y, aux = self.moe(tok.unsqueeze(1), domain_id=domain_id) # (B,1,D)\n", + " return y.squeeze(1), aux\n", + "\n", + "# ====================== 自定义 Transformer Encoder(FFN 可替换 MoE) ======================\n", + "class TransformerEncoderLayerMoE(nn.Layer):\n", + " def __init__(self, d_model=512, nhead=8, d_ff=1024, dropout=0.1,\n", + " use_moe: bool = True, moe_cfg: MoEConfig = None):\n", + " super().__init__()\n", + " self.use_moe = use_moe\n", + " self.self_attn = nn.MultiHeadAttention(embed_dim=d_model, num_heads=nhead, dropout=dropout)\n", + " self.ln1 = nn.LayerNorm(d_model)\n", + " self.do1 = nn.Dropout(dropout)\n", + " if use_moe:\n", + " self.moe = MoE(d_model, moe_cfg or MoEConfig(d_ff=d_ff, dropout=dropout))\n", + " else:\n", + " self.ffn = nn.Sequential(\n", + " nn.LayerNorm(d_model),\n", + " nn.Linear(d_model, d_ff),\n", + " nn.ReLU(),\n", + " nn.Dropout(dropout),\n", + " nn.Linear(d_ff, d_model),\n", + " )\n", + " self.do2 = nn.Dropout(dropout)\n", + "\n", + " def forward(self, x, domain_id=None): # (B,T,D)\n", + " h = self.ln1(x)\n", + " h = paddle.transpose(h, [1, 0, 2]) # (T,B,D)\n", + " sa = self.self_attn(h, h, h) # (T,B,D)\n", + " sa = paddle.transpose(sa, [1, 0, 2]) # (B,T,D)\n", + " x = x + self.do1(sa)\n", + " aux = 0.0\n", + " if self.use_moe:\n", + " x, aux = self.moe(x, domain_id=domain_id)\n", + " else:\n", + " x = x + self.do2(self.ffn(x))\n", + " return x, aux\n", + "\n", + "class TemporalTransformerFlexible(nn.Layer):\n", + " def __init__(self, d_model=512, nhead=8, num_layers=2, d_ff=1024, dropout=0.1,\n", + " max_len=4096, use_moe: bool = True, moe_cfg: MoEConfig = None):\n", + " super().__init__()\n", + " self.pos = SinusoidalPositionalEncoding(d_model, max_len=max_len)\n", + " self.layers = nn.LayerList([\n", + " TransformerEncoderLayerMoE(d_model, nhead, d_ff, dropout,\n", + " use_moe=use_moe, moe_cfg=moe_cfg)\n", + " for _ in range(num_layers)\n", + " ])\n", + " def forward(self, x, domain_id=None): # x: (B,T,D)\n", + " x = self.pos(x)\n", + " aux_total = 0.0\n", + " for layer in self.layers:\n", + " x, aux = layer(x, domain_id=domain_id)\n", + " aux_total = aux_total + aux\n", + " return x, aux_total\n", + "\n", + "# ====================== Cross-Attention 融合 ======================\n", + "class MultiHeadCrossAttention(nn.Layer):\n", + " def __init__(self, d_model: int, nhead: int = 8, dropout: float = 0.1):\n", + " super().__init__()\n", + " assert d_model % nhead == 0\n", + " self.d_model = d_model\n", + " self.nhead = nhead\n", + " self.d_head = d_model // nhead\n", + " self.Wq = nn.Linear(d_model, d_model)\n", + " self.Wk = nn.Linear(d_model, d_model)\n", + " self.Wv = nn.Linear(d_model, d_model)\n", + " self.proj = nn.Linear(d_model, d_model)\n", + " self.drop = nn.Dropout(dropout)\n", + " self.ln = nn.LayerNorm(d_model)\n", + "\n", + " def forward(self, q, kv):\n", + " B, Nq, D = q.shape\n", + " q_lin = self.Wq(q); k_lin = self.Wk(kv); v_lin = self.Wv(kv)\n", + " def split_heads(t):\n", + " return t.reshape([B, -1, self.nhead, self.d_head]).transpose([0, 2, 1, 3])\n", + " qh = split_heads(q_lin); kh = split_heads(k_lin); vh = split_heads(v_lin)\n", + " scores = paddle.matmul(qh, kh, transpose_y=True) / math.sqrt(self.d_head)\n", + " attn = F.softmax(scores, axis=-1)\n", + " ctx = paddle.matmul(attn, vh)\n", + " ctx = ctx.transpose([0, 2, 1, 3]).reshape([B, Nq, D])\n", + " out = self.proj(ctx)\n", + " out = self.drop(out)\n", + " return self.ln(out + q)\n", + "\n", + "class BiModalCrossFusion(nn.Layer):\n", + " def __init__(self, d_model=512, nhead=8, dropout=0.1, fuse_hidden=512):\n", + " super().__init__()\n", + " self.ca_v_from_t = MultiHeadCrossAttention(d_model, nhead, dropout)\n", + " self.ca_t_from_v = MultiHeadCrossAttention(d_model, nhead, dropout)\n", + " self.fuse = nn.Sequential(\n", + " nn.Linear(2 * d_model, fuse_hidden),\n", + " nn.ReLU(),\n", + " nn.Dropout(dropout),\n", + " )\n", + " self.out_dim = fuse_hidden\n", + "\n", + " def forward(self, video_seq, tabm_tok):\n", + " v_tok = video_seq.mean(axis=1, keepdim=True) # (B,1,D)\n", + " t_tok = tabm_tok.unsqueeze(1) # (B,1,D)\n", + " v_upd = self.ca_v_from_t(v_tok, t_tok) # (B,1,D)\n", + " t_upd = self.ca_t_from_v(t_tok, video_seq) # (B,1,D)\n", + " fused = paddle.concat([v_upd, t_upd], axis=-1) # (B,1,2D)\n", + " fused = fused.squeeze(1) # (B,2D)\n", + " return self.fuse(fused) # (B, F)\n", + "\n", + "# ====================== 总模型 ======================\n", + "class TwoModalMultiLabelModel(nn.Layer):\n", + " def __init__(self,\n", + " # 视频模态\n", + " vid_channels=20, vid_h=20, vid_w=20, vid_frames=365, depth_n=24,\n", + " # 结构化模态\n", + " vec_dim=424,\n", + " # 维度与结构\n", + " d_model=512, nhead=4, n_trans_layers=2, trans_ff=1024,\n", + " tabm_hidden=512, dropout=0.1, num_labels=4,\n", + " # MoE 开关\n", + " moe_temporal: bool = True,\n", + " moe_fused: bool = False,\n", + " moe_tabm: bool = False,\n", + " # MoE 超参\n", + " moe_cfg_temporal: MoEConfig = None,\n", + " moe_cfg_fused: MoEConfig = None,\n", + " moe_cfg_tabm: MoEConfig = None):\n", + " super().__init__()\n", + " # A: 逐帧 3D ResNet18\n", + " self.vol_encoder = Volume3DEncoder(in_channels=vid_channels, dropout=dropout) # (B*T,512)\n", + " # A: 时序 Transformer(可 MoE)\n", + " self.temporal = TemporalTransformerFlexible(\n", + " d_model=d_model, nhead=nhead, num_layers=n_trans_layers,\n", + " d_ff=trans_ff, dropout=dropout, max_len=vid_frames,\n", + " use_moe=moe_temporal,\n", + " moe_cfg=moe_cfg_temporal or MoEConfig(\n", + " n_experts=8, top_k=1, d_ff=max(2048, trans_ff), router_temp=0.5,\n", + " balance_loss_w=0.005, entropy_reg_w=-0.005, diversity_w=1e-3\n", + " )\n", + " )\n", + " # B: TabM\n", + " self.tabm = TabMFeatureExtractor(vec_dim, d_hidden=tabm_hidden, dropout=dropout)\n", + " self.tabm_proj = nn.Linear(tabm_hidden, d_model)\n", + "\n", + " # 可选:TabM 分支 MoE 头\n", + " self.moe_tabm = moe_tabm\n", + " if moe_tabm:\n", + " self.tabm_moe = MoEHead(d_model=d_model, cfg=moe_cfg_tabm or MoEConfig(\n", + " n_experts=6, top_k=1, d_ff=1024, router_temp=0.5,\n", + " balance_loss_w=0.005, entropy_reg_w=-0.005, diversity_w=1e-3\n", + " ))\n", + "\n", + " # 融合\n", + " self.fusion = BiModalCrossFusion(d_model=d_model, nhead=nhead, dropout=dropout, fuse_hidden=d_model)\n", + "\n", + " # 可选:融合 token MoE 头\n", + " self.moe_fused = moe_fused\n", + " if moe_fused:\n", + " self.fused_moe = MoEHead(d_model=d_model, cfg=moe_cfg_fused or MoEConfig(\n", + " n_experts=6, top_k=1, d_ff=1024, router_temp=0.5,\n", + " balance_loss_w=0.005, entropy_reg_w=-0.005, diversity_w=1e-3\n", + " ))\n", + "\n", + " # 分类头\n", + " self.head = nn.Linear(self.fusion.out_dim, num_labels)\n", + "\n", + " self.vid_frames = vid_frames\n", + " self.depth_n = depth_n\n", + "\n", + " # 导出融合前 512 表示(用于检索库)\n", + " def encode(self, x_video, x_vec, domain_id=None):\n", + " \"\"\"\n", + " x_video: (B, T, C=20, H=20, W=20, N=24)\n", + " x_vec: (B, 424)\n", + " \"\"\"\n", + " B, T, C, H, W, N = x_video.shape\n", + " assert N == self.depth_n, f\"N mismatch: got {N}, expect {self.depth_n}\"\n", + " # 逐帧 3D 编码: (B*T, C, D=N, H, W)\n", + " xvt = x_video.transpose([0,1,2,5,3,4]).reshape([B*T, C, N, H, W])\n", + " f_frame = self.vol_encoder(xvt) # (B*T, 512)\n", + " f_seq = f_frame.reshape([B, T, -1]) # (B, T, 512)\n", + " z_vid, _ = self.temporal(f_seq, domain_id=domain_id) # (B,T,512)\n", + " z_tabm = self.tabm(x_vec)\n", + " z_tabm = self.tabm_proj(z_tabm) # (B,512)\n", + " if self.moe_tabm:\n", + " z_tabm, _ = self.tabm_moe(z_tabm, domain_id=domain_id)\n", + " fused = self.fusion(z_vid, z_tabm) # (B,512)\n", + " if self.moe_fused:\n", + " fused, _ = self.fused_moe(fused, domain_id=domain_id)\n", + " return fused\n", + "\n", + " def forward(self, x_video, x_vec, domain_id=None):\n", + " fused = self.encode(x_video, x_vec, domain_id=domain_id) # (B,512)\n", + " logits = self.head(fused) # (B,4)\n", + " aux_placeholder = paddle.to_tensor(0.0, dtype='float32')\n", + " return logits, aux_placeholder\n", + "\n", + "# ====================== 指标与训练循环 ======================\n", + "def f1_per_class(y_true: np.ndarray, y_pred: np.ndarray, eps: float = 1e-9) -> Tuple[np.ndarray, float, float]:\n", + " assert y_true.shape == y_pred.shape\n", + " N, C = y_true.shape\n", + " f1_c = np.zeros(C, dtype=np.float32)\n", + " for c in range(C):\n", + " yt, yp = y_true[:, c], y_pred[:, c]\n", + " tp = np.sum((yt == 1) & (yp == 1))\n", + " fp = np.sum((yt == 0) & (yp == 1))\n", + " fn = np.sum((yt == 1) & (yp == 0))\n", + " prec = tp / (tp + fp + eps)\n", + " rec = tp / (tp + fn + eps)\n", + " f1_c[c] = 2 * prec * rec / (prec + rec + eps)\n", + " macro_f1 = float(np.mean(f1_c))\n", + " tp = np.sum((y_true == 1) & (y_pred == 1))\n", + " fp = np.sum((y_true == 0) & (y_pred == 1))\n", + " fn = np.sum((y_true == 1) & (y_pred == 0))\n", + " prec = tp / (tp + fp + 1e-9)\n", + " rec = tp / (tp + fn + 1e-9)\n", + " micro_f1 = 2 * prec * rec / (prec + rec + 1e-9)\n", + " return f1_c, macro_f1, float(micro_f1)\n", + "\n", + "def average_precision_micro(y_true: np.ndarray, y_prob: np.ndarray, num_thresholds: int = 101) -> float:\n", + " thresholds = np.linspace(0.0, 1.0, num_thresholds)\n", + " precision, recall = [], []\n", + " for t in thresholds:\n", + " y_pred = (y_prob >= t).astype(np.float32)\n", + " tp = np.sum((y_true == 1) & (y_pred == 1))\n", + " fp = np.sum((y_true == 0) & (y_pred == 1))\n", + " fn = np.sum((y_true == 1) & (y_pred == 0))\n", + " p = tp / (tp + fp + 1e-9)\n", + " r = tp / (tp + fn + 1e-9)\n", + " precision.append(p); recall.append(r)\n", + " order = np.argsort(recall)\n", + " recall = np.array(recall)[order]\n", + " precision = np.array(precision)[order]\n", + " return float(np.trapz(precision, recall))\n", + "\n", + "def train_one_epoch(model, loader, optimizer,\n", + " pos_weight: Optional[paddle.Tensor] = None,\n", + " clip_grad_norm: Optional[float] = None):\n", + " model.train()\n", + " total_loss, total_batches = 0.0, 0\n", + " for x_vid, x_vec, y in loader:\n", + " logits, _ = model(x_vid.astype('float32'), x_vec.astype('float32'))\n", + " if pos_weight is not None:\n", + " cls = F.binary_cross_entropy_with_logits(logits, y.astype('float32'), pos_weight=pos_weight)\n", + " else:\n", + " cls = F.binary_cross_entropy_with_logits(logits, y.astype('float32'))\n", + " loss = cls\n", + " loss.backward()\n", + " if clip_grad_norm is not None:\n", + " nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad_norm)\n", + " optimizer.step()\n", + " optimizer.clear_grad()\n", + " total_loss += float(loss); total_batches += 1\n", + " return total_loss / max(1, total_batches)\n", + "\n", + "# ====================== 检索增强(cos / l2;k 邻居软加权;概率融合) ======================\n", + "class Retriever:\n", + " def __init__(self, sim_metric: str = 'cos', k: int = 8, alpha: float = 0.3, tau: float = 0.5):\n", + " assert sim_metric in ['cos', 'l2']\n", + " self.sim_metric = sim_metric\n", + " self.k = k\n", + " self.alpha = alpha\n", + " self.tau = tau\n", + " self.keys = None # (N,D)\n", + " self.labels = None # (N,C)\n", + "\n", + " @paddle.no_grad()\n", + " def build(self, model: nn.Layer, loader: DataLoader):\n", + " model.eval()\n", + " feats, labs = [], []\n", + " for x_vid, x_vec, y in loader:\n", + " f = model.encode(x_vid.astype('float32'), x_vec.astype('float32')) # (B,512)\n", + " feats.append(f.numpy())\n", + " labs.append(y.numpy())\n", + " self.keys = paddle.to_tensor(np.concatenate(feats, axis=0)).astype('float32') # (N,D)\n", + " self.labels = paddle.to_tensor(np.concatenate(labs, axis=0)).astype('float32') # (N,C)\n", + " self.keys_norm = F.normalize(self.keys, axis=-1)\n", + "\n", + " @paddle.no_grad()\n", + " def query_and_fuse(self, model_probs: paddle.Tensor, test_feat: paddle.Tensor) -> paddle.Tensor:\n", + " assert self.keys is not None, \"build() must be called first.\"\n", + " B, D = test_feat.shape\n", + " if self.sim_metric == 'cos':\n", + " q = F.normalize(test_feat, axis=-1)\n", + " sim = paddle.matmul(q, self.keys_norm, transpose_y=True) # (B,N)\n", + " w = F.softmax(sim / self.tau, axis=-1)\n", + " else:\n", + " q2 = paddle.sum(test_feat * test_feat, axis=-1, keepdim=True) # (B,1)\n", + " k2 = paddle.sum(self.keys * self.keys, axis=-1, keepdim=True).transpose([1,0]) # (1,N)\n", + " dot = paddle.matmul(test_feat, self.keys, transpose_y=True) # (B,N)\n", + " dist2 = q2 + k2 - 2.0 * dot # (B,N)\n", + " w = F.softmax(-dist2 / self.tau, axis=-1)\n", + "\n", + " topk_val, topk_idx = paddle.topk(w, k=min(self.k, w.shape[1]), axis=-1) # (B,k)\n", + " picked_labels = paddle.gather(self.labels, topk_idx.reshape([-1]), axis=0) # (B*k, C)\n", + " C = self.labels.shape[1]\n", + " picked_labels = picked_labels.reshape([B, -1, C]) # (B,k,C)\n", + " w_norm = topk_val / (paddle.sum(topk_val, axis=-1, keepdim=True) + 1e-9) # (B,k)\n", + " p_knn = paddle.sum(picked_labels * w_norm.unsqueeze(-1), axis=1) # (B,C)\n", + "\n", + " p_final = (1.0 - self.alpha) * model_probs + self.alpha * p_knn\n", + " return p_final.clip(1e-6, 1-1e-6)\n", + "\n", + "@paddle.no_grad()\n", + "def evaluate(model, loader, threshold: float = 0.5,\n", + " retriever: Optional[Retriever] = None):\n", + " model.eval()\n", + " ys, ps = [], []\n", + " total_loss, total_batches = 0.0, 0\n", + " for x_vid, x_vec, y in loader:\n", + " logits, _ = model(x_vid.astype('float32'), x_vec.astype('float32'))\n", + " prob = F.sigmoid(logits) # (B,C)\n", + " if retriever is not None:\n", + " feat = model.encode(x_vid.astype('float32'), x_vec.astype('float32')) # (B,512)\n", + " prob = retriever.query_and_fuse(prob, feat)\n", + " loss = F.binary_cross_entropy(prob, y.astype('float32'))\n", + " ys.append(y.numpy()); ps.append(prob.numpy())\n", + " total_loss += float(loss); total_batches += 1\n", + "\n", + " y_true = np.concatenate(ys, axis=0)\n", + " y_prob = np.concatenate(ps, axis=0)\n", + " y_pred = (y_prob >= threshold).astype(np.float32)\n", + " per_f1, macro_f1, micro_f1 = f1_per_class(y_true, y_pred)\n", + " ap_micro = average_precision_micro(y_true, y_prob)\n", + " return {\n", + " \"loss\": total_loss / max(1, total_batches),\n", + " \"macro_f1\": macro_f1,\n", + " \"micro_f1\": micro_f1,\n", + " \"per_class_f1\": per_f1.tolist(),\n", + " \"micro_AP\": ap_micro\n", + " }\n", + "\n", + "# ====================== ToyDataset(T=365, N=24) ======================\n", + "class ToyTwoModalDataset(Dataset):\n", + " \"\"\"\n", + " 返回:\n", + " x_video: (T=365, C=20, H=20, W=20, N=24)\n", + " x_vec: (424,)\n", + " y: (4,) 0/1\n", + " \"\"\"\n", + " def __init__(self, n: int, seed: int = 0, T: int = 365, C: int = 20, H: int = 20, W: int = 20, N: int = 24):\n", + " super().__init__()\n", + " rng = np.random.default_rng(seed)\n", + " self.n = n\n", + " self.T, self.C, self.H, self.W, self.N = T, C, H, W, N\n", + " # (n, T, C, H, W, N)\n", + " self.video = rng.normal(size=(n, T, C, H, W, N)).astype('float32')\n", + " self.vec = rng.normal(size=(n, 424)).astype('float32')\n", + "\n", + " # 造标签:对视频先在 H/W/N 上均值,再在 T 上均值 → (n, C)\n", + " vid_hwn = self.video.mean(axis=(3, 4, 5)) # (n, T, C)\n", + " vid_avg = vid_hwn.mean(axis=1) # (n, C)\n", + "\n", + " Wv = rng.normal(size=(C, 4))\n", + " Wt = rng.normal(size=(424, 4))\n", + " logits = vid_avg @ Wv + self.vec @ Wt + rng.normal(scale=0.5, size=(n, 4))\n", + " probs = 1.0 / (1.0 + np.exp(-logits))\n", + " self.y = (probs > 0.5).astype('float32')\n", + "\n", + " def __getitem__(self, idx: int):\n", + " return self.video[idx], self.vec[idx], self.y[idx]\n", + " def __len__(self):\n", + " return self.n\n", + "\n", + "# ====================== 训练入口 ======================\n", + "if __name__ == \"__main__\":\n", + " paddle.seed(2025)\n", + "\n", + " # 数据\n", + " T, C, H, W, N = 365, 20, 20, 20, 24\n", + " train_ds = ToyTwoModalDataset(n=32, seed=42, T=T, C=C, H=H, W=W, N=N)\n", + " val_ds = ToyTwoModalDataset(n=16, seed=233, T=T, C=C, H=H, W=W, N=N)\n", + "\n", + " def collate_fn(batch):\n", + " vids, vecs, ys = zip(*batch)\n", + " return (paddle.to_tensor(np.stack(vids, 0)), # (B,T,C,H,W,N)\n", + " paddle.to_tensor(np.stack(vecs, 0)), # (B,424)\n", + " paddle.to_tensor(np.stack(ys, 0))) # (B,4)\n", + "\n", + " # T=365 + 3D 卷积较吃内存,示例用小 batch\n", + " train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, drop_last=False, collate_fn=collate_fn)\n", + " val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, drop_last=False, collate_fn=collate_fn)\n", + "\n", + " # 类别不平衡权重(可选)\n", + " y_train = np.stack([y for _, _, y in train_ds], 0)\n", + " pos_ratio = np.clip(y_train.mean(axis=0), 1e-3, 1-1e-3)\n", + " pos_weight = paddle.to_tensor(((1-pos_ratio)/pos_ratio).astype('float32')) # (4,)\n", + "\n", + " # 模型\n", + " model = TwoModalMultiLabelModel(\n", + " vid_channels=C, vid_h=H, vid_w=W, vid_frames=T, depth_n=N,\n", + " vec_dim=424,\n", + " d_model=512, nhead=4, n_trans_layers=2, trans_ff=1024,\n", + " tabm_hidden=512, dropout=0.1,\n", + " num_labels=4,\n", + " moe_temporal=True, # 推荐开启(FFN 位置 MoE)\n", + " moe_fused=False,\n", + " moe_tabm=False\n", + " )\n", + " optimizer = paddle.optimizer.Adam(learning_rate=3e-4, parameters=model.parameters())\n", + "\n", + " # 训练(演示用)\n", + " best_macro_f1, best = -1.0, None\n", + " for ep in range(1, 2+1):\n", + " train_loss = train_one_epoch(model, train_loader, optimizer,\n", + " pos_weight=pos_weight, clip_grad_norm=1.0)\n", + " val_metrics = evaluate(model, val_loader, threshold=0.5, retriever=None)\n", + " print(f\"[Epoch {ep:02d}] train_loss={train_loss:.4f} | \"\n", + " f\"val_loss={val_metrics['loss']:.4f} | \"\n", + " f\"macro_f1={val_metrics['macro_f1']:.4f} | \"\n", + " f\"micro_f1={val_metrics['micro_f1']:.4f} | \"\n", + " f\"per_class_f1={val_metrics['per_class_f1']} | \"\n", + " f\"micro_AP={val_metrics['micro_AP']:.4f}\")\n", + " if val_metrics[\"macro_f1\"] > best_macro_f1:\n", + " best_macro_f1 = val_metrics[\"macro_f1\"]\n", + " best = {k: v.clone() for k, v in model.state_dict().items()}\n", + "\n", + " if best is not None:\n", + " model.set_state_dict(best)\n", + " print(f\"Loaded best state with macro_f1={best_macro_f1:.4f}\")\n", + "\n", + " # === 构建检索库(用训练集) ===\n", + " retr = Retriever(sim_metric='cos', k=8, alpha=0.3, tau=0.5) # 可改 'l2'\n", + " retr.build(model, DataLoader(train_ds, batch_size=1, shuffle=False, collate_fn=collate_fn))\n", + "\n", + " # === 测试时启用检索增强 ===\n", + " val_metrics_knn = evaluate(model, val_loader, threshold=0.5, retriever=retr)\n", + " print(f\"[RkNN] val_loss={val_metrics_knn['loss']:.4f} | \"\n", + " f\"macro_f1={val_metrics_knn['macro_f1']:.4f} | \"\n", + " f\"micro_f1={val_metrics_knn['micro_f1']:.4f} | \"\n", + " f\"per_class_f1={val_metrics_knn['per_class_f1']} | \"\n", + " f\"micro_AP={val_metrics_knn['micro_AP']:.4f}\")\n" + ], + "metadata": { + "id": "RwuikckFkM1O" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# -*- coding: utf-8 -*-\n", + "import math\n", + "from typing import Optional, Tuple\n", + "import numpy as np\n", + "import paddle\n", + "import paddle.nn as nn\n", + "import paddle.nn.functional as F\n", + "from paddle.io import Dataset, DataLoader\n", + "# 选中第 0 张 GPU;如有多卡改成 'gpu:1' 等\n", + "# paddle.set_device('gpu:0')\n", + "\n", + "# ====================== 工具:正弦位置编码 ======================\n", + "class SinusoidalPositionalEncoding(nn.Layer):\n", + " def __init__(self, d_model: int, max_len: int = 4096):\n", + " super().__init__()\n", + " pe = np.zeros((max_len, d_model), dtype=\"float32\")\n", + " position = np.arange(0, max_len, dtype=\"float32\")[:, None]\n", + " div_term = np.exp(np.arange(0, d_model, 2, dtype=\"float32\") * (-math.log(10000.0) / d_model))\n", + " pe[:, 0::2] = np.sin(position * div_term)\n", + " pe[:, 1::2] = np.cos(position * div_term)\n", + " self.register_buffer(\"pe\", paddle.to_tensor(pe), persistable=False)\n", + " def forward(self, x): # (B,T,D)\n", + " T = x.shape[1]\n", + " return x + self.pe[:T, :]\n", + "\n", + "# ====================== TabM(占位,可换你的实现) ======================\n", + "class TabMFeatureExtractor(nn.Layer):\n", + " def __init__(self, num_features: int, d_hidden: int = 512, dropout: float = 0.1):\n", + " super().__init__()\n", + " self.net = nn.Sequential(\n", + " nn.Linear(num_features, d_hidden),\n", + " nn.ReLU(),\n", + " nn.Dropout(dropout),\n", + " nn.Linear(d_hidden, d_hidden),\n", + " nn.ReLU(),\n", + " )\n", + " self.d_hidden = d_hidden\n", + " def forward(self, x_num: paddle.Tensor):\n", + " return self.net(x_num)\n", + "\n", + "# ====================== 3D ResNet-18 体数据特征抽取 ======================\n", + "class BasicBlock3D(nn.Layer):\n", + " expansion = 1\n", + " def __init__(self, in_planes, planes, stride=(1,1,1), downsample=None):\n", + " super().__init__()\n", + " self.conv1 = nn.Conv3D(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias_attr=False)\n", + " self.bn1 = nn.BatchNorm3D(planes)\n", + " self.relu = nn.ReLU()\n", + " self.conv2 = nn.Conv3D(planes, planes, kernel_size=3, stride=1, padding=1, bias_attr=False)\n", + " self.bn2 = nn.BatchNorm3D(planes)\n", + " self.downsample = downsample\n", + " def forward(self, x):\n", + " identity = x\n", + " out = self.relu(self.bn1(self.conv1(x)))\n", + " out = self.bn2(self.conv2(out))\n", + " if self.downsample is not None:\n", + " identity = self.downsample(x)\n", + " out = self.relu(out + identity)\n", + " return out\n", + "\n", + "class ResNet3D(nn.Layer):\n", + " def __init__(self, block, layers, in_channels=20, base_width=64):\n", + " super().__init__()\n", + " self.in_planes = base_width\n", + " self.conv1 = nn.Conv3D(in_channels, self.in_planes,\n", + " kernel_size=(3,7,7), stride=(1,2,2),\n", + " padding=(1,3,3), bias_attr=False)\n", + " self.bn1 = nn.BatchNorm3D(self.in_planes)\n", + " self.relu = nn.ReLU()\n", + " self.maxpool = nn.MaxPool3D(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1))\n", + " self.layer1 = self._make_layer(block, base_width, layers[0], stride=(1,1,1))\n", + " self.layer2 = self._make_layer(block, base_width*2, layers[1], stride=(2,2,2))\n", + " self.layer3 = self._make_layer(block, base_width*4, layers[2], stride=(2,2,2))\n", + " self.layer4 = self._make_layer(block, base_width*8, layers[3], stride=(2,2,2))\n", + " self.out_dim = base_width*8 # 512\n", + " self.pool = nn.AdaptiveAvgPool3D(output_size=1)\n", + " def _make_layer(self, block, planes, blocks, stride=(1,1,1)):\n", + " downsample = None\n", + " if stride != (1,1,1) or self.in_planes != planes * block.expansion:\n", + " downsample = nn.Sequential(\n", + " nn.Conv3D(self.in_planes, planes * block.expansion, kernel_size=1, stride=stride, bias_attr=False),\n", + " nn.BatchNorm3D(planes * block.expansion),\n", + " )\n", + " layers = [block(self.in_planes, planes, stride=stride, downsample=downsample)]\n", + " self.in_planes = planes * block.expansion\n", + " for _ in range(1, blocks):\n", + " layers.append(block(self.in_planes, planes))\n", + " return nn.Sequential(*layers)\n", + " def forward(self, x): # (B, C, D, H, W)\n", + " x = self.relu(self.bn1(self.conv1(x)))\n", + " x = self.maxpool(x)\n", + " x = self.layer1(x)\n", + " x = self.layer2(x)\n", + " x = self.layer3(x)\n", + " x = self.layer4(x)\n", + " x = self.pool(x) # (B, 512, 1,1,1)\n", + " x = paddle.flatten(x, 1) # (B, 512)\n", + " return x\n", + "\n", + "class Volume3DEncoder(nn.Layer):\n", + " def __init__(self, in_channels: int = 20, base: int = 64, dropout: float = 0.0):\n", + " super().__init__()\n", + " self.backbone = ResNet3D(BasicBlock3D, layers=[2,2,2,2], in_channels=in_channels, base_width=base)\n", + " self.drop = nn.Dropout(dropout)\n", + " self.out_dim = self.backbone.out_dim # 512\n", + " def forward(self, x): # (B, C, D, H, W)\n", + " x = self.backbone(x)\n", + " x = self.drop(x)\n", + " return x\n", + "\n", + "# ====================== MoE(Top-k;gather_nd 选择专家) ======================\n", + "class ExpertFFN(nn.Layer):\n", + " def __init__(self, d_model, d_ff, dropout=0.1, act='relu'):\n", + " super().__init__()\n", + " Act = getattr(F, act) if isinstance(act, str) else act\n", + " self.fc1 = nn.Linear(d_model, d_ff)\n", + " self.fc2 = nn.Linear(d_ff, d_model)\n", + " self.drop = nn.Dropout(dropout)\n", + " self.act = Act\n", + " def forward(self, x):\n", + " return self.fc2(self.drop(self.act(self.fc1(x))))\n", + "\n", + "class MoEConfig:\n", + " def __init__(self,\n", + " n_experts=8, top_k=1, d_ff=2048, dropout=0.1,\n", + " router_temp=0.5, balance_loss_w=0.005, entropy_reg_w=-0.005,\n", + " diversity_w=1e-3, sticky_w=0.0, sup_router_w=0.0, use_gumbel=True):\n", + " self.n_experts = n_experts; self.top_k = top_k; self.d_ff = d_ff; self.dropout = dropout\n", + " self.router_temp = router_temp; self.balance_loss_w = balance_loss_w\n", + " self.entropy_reg_w = entropy_reg_w; self.diversity_w = diversity_w\n", + " self.sticky_w = sticky_w; self.sup_router_w = sup_router_w; self.use_gumbel = use_gumbel\n", + "\n", + "class MoE(nn.Layer):\n", + " def __init__(self, d_model: int, cfg: MoEConfig):\n", + " super().__init__()\n", + " self.cfg = cfg\n", + " self.router = nn.Linear(d_model, cfg.n_experts)\n", + " self.experts = nn.LayerList([ExpertFFN(d_model, cfg.d_ff, cfg.dropout) for _ in range(cfg.n_experts)])\n", + " self.ln = nn.LayerNorm(d_model)\n", + " self.drop = nn.Dropout(cfg.dropout)\n", + " def _router_probs(self, logits):\n", + " if self.cfg.use_gumbel and self.training:\n", + " u = paddle.uniform(logits.shape, min=1e-6, max=1-1e-6, dtype=logits.dtype)\n", + " g = -paddle.log(-paddle.log(u)); logits = logits + g\n", + " return F.softmax(logits / self.cfg.router_temp, axis=-1)\n", + " def forward(self, x, domain_id=None):\n", + " orig_shape = x.shape\n", + " if len(orig_shape) == 3:\n", + " B, T, D = orig_shape; X = x.reshape([B*T, D])\n", + " else:\n", + " X = x\n", + " N, D = X.shape\n", + " logits = self.router(X); probs = self._router_probs(logits)\n", + " topk_val, topk_idx = paddle.topk(probs, k=self.cfg.top_k, axis=-1)\n", + " all_out = paddle.stack([e(X) for e in self.experts], axis=1) # (N,E,D)\n", + " arangeN = paddle.arange(N, dtype='int64')\n", + " picked_list = []\n", + " for i in range(self.cfg.top_k):\n", + " idx_i = topk_idx[:, i].astype('int64')\n", + " idx_nd = paddle.stack([arangeN, idx_i], axis=1)\n", + " picked_i = paddle.gather_nd(all_out, idx_nd)\n", + " picked_list.append(picked_i)\n", + " picked = paddle.stack(picked_list, axis=1) # (N,k,D)\n", + " w = topk_val / (paddle.sum(topk_val, axis=-1, keepdim=True) + 1e-9)\n", + " Y = paddle.sum(picked * w.unsqueeze(-1), axis=1)\n", + " Y = self.drop(Y); Y = self.ln(Y + X)\n", + " aux = 0.0\n", + " if self.cfg.balance_loss_w > 0:\n", + " mean_prob = probs.mean(axis=0)\n", + " target = paddle.full_like(mean_prob, 1.0 / self.cfg.n_experts)\n", + " aux = aux + self.cfg.balance_loss_w * F.mse_loss(mean_prob, target)\n", + " if self.cfg.entropy_reg_w != 0.0:\n", + " ent = -paddle.sum(probs * (paddle.log(probs + 1e-9)), axis=1).mean()\n", + " aux = aux + self.cfg.entropy_reg_w * ent\n", + " if (domain_id is not None) and (self.cfg.sup_router_w > 0):\n", + " dom = domain_id.reshape([-1])[:N] % self.cfg.n_experts\n", + " aux = aux + self.cfg.sup_router_w * F.cross_entropy(logits, dom)\n", + " if self.cfg.diversity_w > 0 and self.cfg.n_experts > 1:\n", + " chosen = F.one_hot(topk_idx[:, 0], num_classes=self.cfg.n_experts).astype('float32')\n", + " denom = chosen.sum(axis=0).clip(min=1.0).unsqueeze(-1)\n", + " means = (all_out * chosen.unsqueeze(-1)).sum(axis=0) / denom\n", + " sims = []\n", + " for i in range(self.cfg.n_experts):\n", + " for j in range(i+1, self.cfg.n_experts):\n", + " si = F.normalize(means[i:i+1], axis=-1)\n", + " sj = F.normalize(means[j:j+1], axis=-1)\n", + " sims.append((si*sj).sum())\n", + " if sims:\n", + " aux = aux + self.cfg.diversity_w * paddle.stack(sims).mean()\n", + " if len(orig_shape) == 3:\n", + " Y = Y.reshape([B, T, D])\n", + " return Y, aux\n", + "\n", + "class MoEHead(nn.Layer):\n", + " def __init__(self, d_model=512, cfg: MoEConfig = None):\n", + " super().__init__()\n", + " self.moe = MoE(d_model, cfg or MoEConfig())\n", + " def forward(self, tok, domain_id=None):\n", + " y, aux = self.moe(tok.unsqueeze(1), domain_id=domain_id)\n", + " return y.squeeze(1), aux\n", + "\n", + "# ====================== Self-Attention Transformer(可 MoE) ======================\n", + "class TransformerEncoderLayerMoE(nn.Layer):\n", + " def __init__(self, d_model=512, nhead=8, d_ff=1024, dropout=0.1,\n", + " use_moe: bool = True, moe_cfg: MoEConfig = None):\n", + " super().__init__()\n", + " self.use_moe = use_moe\n", + " self.self_attn = nn.MultiHeadAttention(embed_dim=d_model, num_heads=nhead, dropout=dropout)\n", + " self.ln1 = nn.LayerNorm(d_model); self.do1 = nn.Dropout(dropout)\n", + " if use_moe:\n", + " self.moe = MoE(d_model, moe_cfg or MoEConfig(d_ff=d_ff, dropout=dropout))\n", + " else:\n", + " self.ffn = nn.Sequential(\n", + " nn.LayerNorm(d_model),\n", + " nn.Linear(d_model, d_ff), nn.ReLU(), nn.Dropout(dropout),\n", + " nn.Linear(d_ff, d_model),\n", + " ); self.do2 = nn.Dropout(dropout)\n", + " def forward(self, x, domain_id=None): # (B,T,D)\n", + " h = self.ln1(x)\n", + " h = paddle.transpose(h, [1,0,2])\n", + " sa = self.self_attn(h, h, h)\n", + " sa = paddle.transpose(sa, [1,0,2])\n", + " x = x + self.do1(sa)\n", + " aux = 0.0\n", + " if self.use_moe:\n", + " x, aux = self.moe(x, domain_id=domain_id)\n", + " else:\n", + " x = x + self.do2(self.ffn(x))\n", + " return x, aux\n", + "\n", + "class TemporalTransformerFlexible(nn.Layer):\n", + " def __init__(self, d_model=512, nhead=8, num_layers=2, d_ff=1024, dropout=0.1,\n", + " max_len=4096, use_moe: bool = True, moe_cfg: MoEConfig = None):\n", + " super().__init__()\n", + " self.pos = SinusoidalPositionalEncoding(d_model, max_len=max_len)\n", + " self.layers = nn.LayerList([\n", + " TransformerEncoderLayerMoE(d_model, nhead, d_ff, dropout, use_moe=use_moe, moe_cfg=moe_cfg)\n", + " for _ in range(num_layers)\n", + " ])\n", + " def forward(self, x, domain_id=None):\n", + " x = self.pos(x); aux_total = 0.0\n", + " for layer in self.layers:\n", + " x, aux = layer(x, domain_id=domain_id); aux_total += aux\n", + " return x, aux_total\n", + "\n", + "# ====================== AFNO(1D) + MoE FFN ======================\n", + "class AFNO1DLayer(nn.Layer):\n", + " \"\"\"\n", + " 自适应傅里叶算子(时间 1D 版):\n", + " - 对 (B,T,D) 沿 T 做 rFFT → (B,D,F)\n", + " - 仅保留前 K=modes 个频率,对每个频率在“通道组内”做两层复线性(W1,W2)+ GELU + Softshrink\n", + " - 把频谱其余部分置零 → irFFT → 残差 + Dropout + (可选 LN)\n", + " \"\"\"\n", + " def __init__(self, d_model: int, modes: int = 32, num_blocks: int = 8,\n", + " shrink: float = 0.01, dropout: float = 0.1):\n", + " super().__init__()\n", + " assert d_model % num_blocks == 0, \"d_model must be divisible by num_blocks\"\n", + " self.d_model = d_model\n", + " self.modes = modes\n", + " self.num_blocks = num_blocks\n", + " self.block = d_model // num_blocks\n", + " self.shrink = shrink\n", + " # 复权重拆成实/虚:形状 (G, Cb, Cb)\n", + " scale = 1.0 / math.sqrt(self.block)\n", + " def param():\n", + " return nn.initializer.Uniform(-scale, scale)\n", + " self.w1r = self.create_parameter([num_blocks, self.block, self.block], default_initializer=param())\n", + " self.w1i = self.create_parameter([num_blocks, self.block, self.block], default_initializer=param())\n", + " self.w2r = self.create_parameter([num_blocks, self.block, self.block], default_initializer=param())\n", + " self.w2i = self.create_parameter([num_blocks, self.block, self.block], default_initializer=param())\n", + " self.ln = nn.LayerNorm(d_model)\n", + " self.drop = nn.Dropout(dropout)\n", + "\n", + " def _complex_linear(self, xr, xi, Wr, Wi):\n", + " # xr, xi: (B, G, K, Cb); Wr/Wi: (G, Cb, Cb)\n", + " # (a+ib)*(Wr+iWi) = (a@Wr - b@Wi) + i(a@Wi + b@Wr)\n", + " out_r = paddle.matmul(xr, Wr) - paddle.matmul(xi, Wi)\n", + " out_i = paddle.matmul(xr, Wi) + paddle.matmul(xi, Wr)\n", + " return out_r, out_i\n", + "\n", + " def forward(self, x): # x: (B,T,D)\n", + " B, T, D = x.shape\n", + " Kmax = T // 2 + 1\n", + " K = min(self.modes, Kmax)\n", + "\n", + " h = self.ln(x) # PreNorm\n", + " h_td = paddle.transpose(h, [0, 2, 1]) # (B,D,T)\n", + " h_ft = paddle.fft.rfft(h_td) # (B,D,F) complex64\n", + "\n", + " # reshape 通道为 G 组: (B,G,Cb,F)\n", + " h_ft = h_ft.reshape([B, self.num_blocks, self.block, Kmax])\n", + " # 仅前 K 频率: (B,G,Cb,K) → 交换到 (B,G,K,Cb) 方便 matmul\n", + " xk = h_ft[:, :, :, :K].transpose([0,1,3,2])\n", + " xr, xi = paddle.real(xk), paddle.imag(xk) # (B,G,K,Cb)\n", + "\n", + " # 组内两层复线性 + GELU + Softshrink\n", + " yr, yi = self._complex_linear(xr, xi, self.w1r, self.w1i)\n", + " yr = F.gelu(yr); yi = F.gelu(yi)\n", + " # Softshrink(稀疏化)\n", + " # yr = F.softshrink(yr, lambd=self.shrink); yi = F.softshrink(yi, lambd=self.shrink)\n", + " yr = F.softshrink(yr, threshold=self.shrink)\n", + " yi = F.softshrink(yi, threshold=self.shrink)\n", + " yr, yi = self._complex_linear(yr, yi, self.w2r, self.w2i) # (B,G,K,Cb)\n", + "\n", + "\n", + "\n", + "\n", + " # 放回谱: (B,G,K,Cb) → (B,G,Cb,K) → (B,D,K)\n", + " yk = paddle.complex(yr, yi).transpose([0,1,3,2]).reshape([B, D, K])\n", + " out_ft = paddle.zeros([B, D, Kmax], dtype='complex64')\n", + " out_ft[:, :, :K] = yk\n", + "\n", + " # 反变换 & 残差\n", + " out_td = paddle.fft.irfft(out_ft, n=T) # (B,D,T)\n", + " out = paddle.transpose(out_td, [0, 2, 1]) # (B,T,D)\n", + " out = self.drop(out)\n", + " return x + out\n", + "\n", + "class AFNOTransformerFlexible(nn.Layer):\n", + " \"\"\"\n", + " 堆叠若干 AFNO1DLayer;随后接 MoE FFN(与 Self-Attn 分支同构)\n", + " \"\"\"\n", + " def __init__(self, d_model=512, num_layers=2, modes=32, dropout=0.1,\n", + " d_ff=1024, use_moe: bool = True, moe_cfg: MoEConfig = None):\n", + " super().__init__()\n", + " self.layers = nn.LayerList([AFNO1DLayer(d_model, modes=modes, num_blocks=8, shrink=0.01, dropout=dropout)\n", + " for _ in range(num_layers)])\n", + " self.use_moe = use_moe\n", + " if use_moe:\n", + " self.moe = MoE(d_model, moe_cfg or MoEConfig(d_ff=d_ff, dropout=dropout))\n", + " else:\n", + " self.ffn = nn.Sequential(\n", + " nn.LayerNorm(d_model),\n", + " nn.Linear(d_model, d_ff), nn.ReLU(), nn.Dropout(dropout),\n", + " nn.Linear(d_ff, d_model),\n", + " )\n", + " self.do = nn.Dropout(dropout)\n", + "\n", + " def forward(self, x, domain_id=None): # (B,T,D)\n", + " for layer in self.layers:\n", + " x = layer(x)\n", + " aux = 0.0\n", + " if self.use_moe:\n", + " x, aux = self.moe(x, domain_id=domain_id)\n", + " else:\n", + " x = x + self.do(self.ffn(x))\n", + " return x, aux\n", + "\n", + "# ====================== Cross-Attention 融合 ======================\n", + "class MultiHeadCrossAttention(nn.Layer):\n", + " def __init__(self, d_model: int, nhead: int = 8, dropout: float = 0.1):\n", + " super().__init__()\n", + " assert d_model % nhead == 0\n", + " self.d_head = d_model // nhead; self.nhead = nhead\n", + " self.Wq = nn.Linear(d_model, d_model); self.Wk = nn.Linear(d_model, d_model); self.Wv = nn.Linear(d_model, d_model)\n", + " self.proj = nn.Linear(d_model, d_model); self.drop = nn.Dropout(dropout); self.ln = nn.LayerNorm(d_model)\n", + " def forward(self, q, kv):\n", + " B, Nq, D = q.shape\n", + " def split(t): return t.reshape([B, -1, self.nhead, self.d_head]).transpose([0,2,1,3])\n", + " qh = split(self.Wq(q)); kh = split(self.Wk(kv)); vh = split(self.Wv(kv))\n", + " scores = paddle.matmul(qh, kh, transpose_y=True) / math.sqrt(self.d_head)\n", + " attn = F.softmax(scores, axis=-1)\n", + " ctx = paddle.matmul(attn, vh).transpose([0,2,1,3]).reshape([B, Nq, D])\n", + " out = self.drop(self.proj(ctx))\n", + " return self.ln(out + q)\n", + "\n", + "class BiModalCrossFusion(nn.Layer):\n", + " def __init__(self, d_model=512, nhead=8, dropout=0.1, fuse_hidden=512):\n", + " super().__init__()\n", + " self.ca_v_from_t = MultiHeadCrossAttention(d_model, nhead, dropout)\n", + " self.ca_t_from_v = MultiHeadCrossAttention(d_model, nhead, dropout)\n", + " self.fuse = nn.Sequential(nn.Linear(2*d_model, fuse_hidden), nn.ReLU(), nn.Dropout(dropout))\n", + " self.out_dim = fuse_hidden\n", + " def forward(self, video_seq, tabm_tok):\n", + " v_tok = video_seq.mean(axis=1, keepdim=True)\n", + " t_tok = tabm_tok.unsqueeze(1)\n", + " v_upd = self.ca_v_from_t(v_tok, t_tok)\n", + " t_upd = self.ca_t_from_v(t_tok, video_seq)\n", + " fused = paddle.concat([v_upd, t_upd], axis=-1).squeeze(1)\n", + " return self.fuse(fused)\n", + "\n", + "# ====================== 总模型:Self-Attn + AFNO 并行 ======================\n", + "class TwoModalMultiLabelModel(nn.Layer):\n", + " def __init__(self,\n", + " # 视频模态\n", + " vid_channels=20, vid_h=20, vid_w=20, vid_frames=365, depth_n=24,\n", + " # 结构化模态\n", + " vec_dim=424,\n", + " # 维度与结构\n", + " d_model=512, nhead=4, n_trans_layers=2, trans_ff=1024,\n", + " tabm_hidden=512, dropout=0.1, num_labels=4,\n", + " # MoE 开关\n", + " moe_temporal_attn: bool = True,\n", + " moe_temporal_afno: bool = True,\n", + " moe_fused: bool = False,\n", + " moe_tabm: bool = False,\n", + " # AFNO 频率数\n", + " afno_modes: int = 32,\n", + " # MoE 超参\n", + " moe_cfg_temporal_attn: MoEConfig = None,\n", + " moe_cfg_temporal_afno: MoEConfig = None,\n", + " moe_cfg_fused: MoEConfig = None,\n", + " moe_cfg_tabm: MoEConfig = None):\n", + " super().__init__()\n", + " # 逐帧 3D ResNet18\n", + " self.vol_encoder = Volume3DEncoder(in_channels=vid_channels, dropout=dropout)\n", + " # Self-Attention Transformer\n", + " self.trans_attn = TemporalTransformerFlexible(\n", + " d_model=d_model, nhead=nhead, num_layers=n_trans_layers, d_ff=trans_ff, dropout=dropout,\n", + " max_len=vid_frames, use_moe=moe_temporal_attn,\n", + " moe_cfg=moe_cfg_temporal_attn or MoEConfig(\n", + " n_experts=8, top_k=1, d_ff=max(2048, trans_ff), router_temp=0.5,\n", + " balance_loss_w=0.005, entropy_reg_w=-0.005, diversity_w=1e-3\n", + " )\n", + " )\n", + " # AFNO Transformer(1D)\n", + " self.trans_afno = AFNOTransformerFlexible(\n", + " d_model=d_model, num_layers=n_trans_layers, modes=afno_modes, dropout=dropout,\n", + " d_ff=trans_ff, use_moe=moe_temporal_afno,\n", + " moe_cfg=moe_cfg_temporal_afno or MoEConfig(\n", + " n_experts=8, top_k=1, d_ff=max(2048, trans_ff), router_temp=0.5,\n", + " balance_loss_w=0.005, entropy_reg_w=-0.005, diversity_w=1e-3\n", + " )\n", + " )\n", + " # 两路拼接后投回 d_model\n", + " self.video_merge = nn.Linear(2*d_model, d_model)\n", + "\n", + " # TabM\n", + " self.tabm = TabMFeatureExtractor(vec_dim, d_hidden=tabm_hidden, dropout=dropout)\n", + " self.tabm_proj = nn.Linear(tabm_hidden, d_model)\n", + " self.moe_tabm = moe_tabm\n", + " if moe_tabm:\n", + " self.tabm_moe = MoEHead(d_model=d_model, cfg=moe_cfg_tabm or MoEConfig(\n", + " n_experts=6, top_k=1, d_ff=1024, router_temp=0.5,\n", + " balance_loss_w=0.005, entropy_reg_w=-0.005, diversity_w=1e-3\n", + " ))\n", + "\n", + " # 融合\n", + " self.fusion = BiModalCrossFusion(d_model=d_model, nhead=nhead, dropout=dropout, fuse_hidden=d_model)\n", + " self.moe_fused = moe_fused\n", + " if moe_fused:\n", + " self.fused_moe = MoEHead(d_model=d_model, cfg=moe_cfg_fused or MoEConfig(\n", + " n_experts=6, top_k=1, d_ff=1024, router_temp=0.5,\n", + " balance_loss_w=0.005, entropy_reg_w=-0.005, diversity_w=1e-3\n", + " ))\n", + "\n", + " # 分类头\n", + " self.head = nn.Linear(self.fusion.out_dim, num_labels)\n", + "\n", + " self.vid_frames = vid_frames; self.depth_n = depth_n\n", + "\n", + " # 导出融合前 512 表示(用于检索库)\n", + " def encode(self, x_video, x_vec, domain_id=None):\n", + " \"\"\"\n", + " x_video: (B,T,C,H,W,N) —— N 为体深度(24)\n", + " \"\"\"\n", + " B, T, C, H, W, N = x_video.shape\n", + " assert N == self.depth_n, f\"N mismatch: {N} vs {self.depth_n}\"\n", + " xvt = x_video.transpose([0,1,2,5,3,4]).reshape([B*T, C, N, H, W])\n", + " f_frame = self.vol_encoder(xvt) # (B*T,512)\n", + " seq = f_frame.reshape([B, T, -1]) # (B,T,512)\n", + "\n", + " z_attn, _ = self.trans_attn(seq, domain_id=domain_id) # (B,T,512)\n", + " z_afno, _ = self.trans_afno(seq, domain_id=domain_id) # (B,T,512)\n", + " z_vid = self.video_merge(paddle.concat([z_attn, z_afno], axis=-1)) # (B,T,512)\n", + "\n", + " z_tabm = self.tabm(x_vec); z_tabm = self.tabm_proj(z_tabm) # (B,512)\n", + " if self.moe_tabm:\n", + " z_tabm, _ = self.tabm_moe(z_tabm, domain_id=domain_id)\n", + "\n", + " fused = self.fusion(z_vid, z_tabm) # (B,512)\n", + " if self.moe_fused:\n", + " fused, _ = self.fused_moe(fused, domain_id=domain_id)\n", + " return fused\n", + "\n", + " def forward(self, x_video, x_vec, domain_id=None):\n", + " fused = self.encode(x_video, x_vec, domain_id=domain_id)\n", + " logits = self.head(fused) # (B,4)\n", + " return logits, paddle.to_tensor(0.0, dtype='float32')\n", + "\n", + "# ====================== 简洁指标(可替换为你之前的“全量指标”版本) ======================\n", + "def f1_per_class(y_true: np.ndarray, y_pred: np.ndarray, eps: float = 1e-9) -> Tuple[np.ndarray, float, float]:\n", + " N, C = y_true.shape\n", + " f1_c = np.zeros(C, dtype=np.float32)\n", + " for c in range(C):\n", + " yt, yp = y_true[:, c], y_pred[:, c]\n", + " tp = np.sum((yt == 1) & (yp == 1)); fp = np.sum((yt == 0) & (yp == 1)); fn = np.sum((yt == 1) & (yp == 0))\n", + " prec = tp / (tp + fp + eps); rec = tp / (tp + fn + eps)\n", + " f1_c[c] = 2 * prec * rec / (prec + rec + eps)\n", + " macro_f1 = float(np.mean(f1_c))\n", + " tp = np.sum((y_true == 1) & (y_pred == 1)); fp = np.sum((y_true == 0) & (y_pred == 1)); fn = np.sum((y_true == 1) & (y_pred == 0))\n", + " prec = tp / (tp + fp + 1e-9); rec = tp / (tp + fn + 1e-9); micro_f1 = 2 * prec * rec / (prec + rec + 1e-9)\n", + " return f1_c, macro_f1, float(micro_f1)\n", + "\n", + "def average_precision_micro(y_true: np.ndarray, y_prob: np.ndarray, num_thresholds: int = 101) -> float:\n", + " thresholds = np.linspace(0.0, 1.0, num_thresholds)\n", + " precision, recall = [], []\n", + " yt = y_true.reshape(-1); ps = y_prob.reshape(-1)\n", + " for t in thresholds:\n", + " yp = (ps >= t).astype(np.float32)\n", + " tp = np.sum((yt == 1) & (yp == 1)); fp = np.sum((yt == 0) & (yp == 1)); fn = np.sum((yt == 1) & (yp == 0))\n", + " p = tp / (tp + fp + 1e-9); r = tp / (tp + fn + 1e-9)\n", + " precision.append(p); recall.append(r)\n", + " order = np.argsort(recall)\n", + " return float(np.trapz(np.array(precision)[order], np.array(recall)[order]))\n", + "\n", + "@paddle.no_grad()\n", + "def evaluate(model, loader, threshold: float = 0.5, retriever=None):\n", + " model.eval()\n", + " ys, ps, total_loss, total_batches = [], [], 0.0, 0\n", + " for x_vid, x_vec, y in loader:\n", + " logits, _ = model(x_vid.astype('float32'), x_vec.astype('float32'))\n", + " prob = F.sigmoid(logits)\n", + " if retriever is not None:\n", + " feat = model.encode(x_vid.astype('float32'), x_vec.astype('float32'))\n", + " prob = retriever.query_and_fuse(prob, feat)\n", + " loss = F.binary_cross_entropy(prob, y.astype('float32'))\n", + " ys.append(y.numpy()); ps.append(prob.numpy())\n", + " total_loss += float(loss); total_batches += 1\n", + " y_true = np.concatenate(ys, 0); y_prob = np.concatenate(ps, 0)\n", + " y_pred = (y_prob >= threshold).astype(np.float32)\n", + " per_f1, macro_f1, micro_f1 = f1_per_class(y_true, y_pred)\n", + " ap_micro = average_precision_micro(y_true, y_prob)\n", + " return {\"loss\": total_loss/max(1,total_batches), \"macro_f1\": macro_f1, \"micro_f1\": micro_f1,\n", + " \"per_class_f1\": per_f1.tolist(), \"micro_AP\": ap_micro}\n", + "\n", + "# ====================== 检索增强(cos / l2;k 邻居软加权;概率融合) ======================\n", + "class Retriever:\n", + " def __init__(self, sim_metric: str = 'cos', k: int = 8, alpha: float = 0.3, tau: float = 0.5):\n", + " assert sim_metric in ['cos', 'l2']\n", + " self.sim_metric = sim_metric; self.k = k; self.alpha = alpha; self.tau = tau\n", + " self.keys = None; self.labels = None\n", + " @paddle.no_grad()\n", + " def build(self, model: nn.Layer, loader: DataLoader):\n", + " model.eval()\n", + " feats, labs = [], []\n", + " for x_vid, x_vec, y in loader:\n", + " f = model.encode(x_vid.astype('float32'), x_vec.astype('float32'))\n", + " feats.append(f.numpy()); labs.append(y.numpy())\n", + " self.keys = paddle.to_tensor(np.concatenate(feats, 0)).astype('float32')\n", + " self.labels = paddle.to_tensor(np.concatenate(labs, 0)).astype('float32')\n", + " self.keys_norm = F.normalize(self.keys, axis=-1)\n", + " @paddle.no_grad()\n", + " def query_and_fuse(self, model_probs: paddle.Tensor, test_feat: paddle.Tensor) -> paddle.Tensor:\n", + " B, D = test_feat.shape\n", + " if self.sim_metric == 'cos':\n", + " q = F.normalize(test_feat, axis=-1)\n", + " sim = paddle.matmul(q, self.keys_norm, transpose_y=True)\n", + " w = F.softmax(sim / self.tau, axis=-1)\n", + " else:\n", + " q2 = paddle.sum(test_feat * test_feat, axis=-1, keepdim=True)\n", + " k2 = paddle.sum(self.keys * self.keys, axis=-1, keepdim=True).transpose([1,0])\n", + " dot = paddle.matmul(test_feat, self.keys, transpose_y=True)\n", + " dist2 = q2 + k2 - 2.0 * dot\n", + " w = F.softmax(-dist2 / self.tau, axis=-1)\n", + " topk_val, topk_idx = paddle.topk(w, k=min(self.k, w.shape[1]), axis=-1)\n", + " picked_labels = paddle.gather(self.labels, topk_idx.reshape([-1]), axis=0)\n", + " C = self.labels.shape[1]\n", + " picked_labels = picked_labels.reshape([B, -1, C])\n", + " w_norm = topk_val / (paddle.sum(topk_val, axis=-1, keepdim=True) + 1e-9)\n", + " p_knn = paddle.sum(picked_labels * w_norm.unsqueeze(-1), axis=1)\n", + " p_final = (1.0 - self.alpha) * model_probs + self.alpha * p_knn\n", + " return p_final.clip(1e-6, 1-1e-6)\n", + "\n", + "# ====================== ToyDataset(T=365, N=24) ======================\n", + "class ToyTwoModalDataset(Dataset):\n", + " def __init__(self, n: int, seed: int = 0, T: int = 365, C: int = 20, H: int = 20, W: int = 20, N: int = 24):\n", + " super().__init__()\n", + " rng = np.random.default_rng(seed)\n", + " self.n = n; self.T=T; self.C=C; self.H=H; self.W=W; self.N=N\n", + " self.video = rng.normal(size=(n, T, C, H, W, N)).astype('float32')\n", + " self.vec = rng.normal(size=(n, 424)).astype('float32')\n", + " vid_hwn = self.video.mean(axis=(3,4,5)) # (n,T,C)\n", + " vid_avg = vid_hwn.mean(axis=1) # (n,C)\n", + " Wv = rng.normal(size=(C,4)); Wt = rng.normal(size=(424,4))\n", + " logits = vid_avg @ Wv + self.vec @ Wt + rng.normal(scale=0.5, size=(n,4))\n", + " probs = 1.0 / (1.0 + np.exp(-logits))\n", + " self.y = (probs > 0.5).astype('float32')\n", + " def __getitem__(self, idx: int):\n", + " return self.video[idx], self.vec[idx], self.y[idx]\n", + " def __len__(self): return self.n\n", + "\n", + "# ====================== 训练入口 ======================\n", + "if __name__ == \"__main__\":\n", + " paddle.seed(2025)\n", + " # 数据\n", + " T, C, H, W, N = 365, 20, 20, 20, 24\n", + " train_ds = ToyTwoModalDataset(n=32, seed=42, T=T, C=C, H=H, W=W, N=N)\n", + " val_ds = ToyTwoModalDataset(n=16, seed=233, T=T, C=C, H=H, W=W, N=N)\n", + " def collate_fn(batch):\n", + " vids, vecs, ys = zip(*batch)\n", + " return (paddle.to_tensor(np.stack(vids, 0)),\n", + " paddle.to_tensor(np.stack(vecs, 0)),\n", + " paddle.to_tensor(np.stack(ys, 0)))\n", + " train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, drop_last=False, collate_fn=collate_fn)\n", + " val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, drop_last=False, collate_fn=collate_fn)\n", + "\n", + " # 类别不平衡权重(可选)\n", + " y_train = np.stack([y for _, _, y in train_ds], 0)\n", + " pos_ratio = np.clip(y_train.mean(axis=0), 1e-3, 1-1e-3)\n", + " pos_weight = paddle.to_tensor(((1-pos_ratio)/pos_ratio).astype('float32'))\n", + "\n", + " # 模型:Self-Attn + AFNO 两路,并行 + MoE FFN\n", + " model = TwoModalMultiLabelModel(\n", + " vid_channels=C, vid_h=H, vid_w=W, vid_frames=T, depth_n=N,\n", + " vec_dim=424,\n", + " d_model=512, nhead=4, n_trans_layers=2, trans_ff=1024,\n", + " tabm_hidden=512, dropout=0.1, num_labels=4,\n", + " moe_temporal_attn=True, moe_temporal_afno=True,\n", + " moe_fused=False, moe_tabm=False,\n", + " afno_modes=32\n", + " )\n", + " optimizer = paddle.optimizer.Adam(learning_rate=3e-4, parameters=model.parameters())\n", + "\n", + " # 训练(演示用:小 epoch)\n", + " best_macro_f1, best = -1.0, None\n", + " for ep in range(1, 2+1):\n", + " model.train(); total_loss, total_batches = 0.0, 0\n", + " for x_vid, x_vec, y in train_loader:\n", + " logits, _ = model(x_vid.astype('float32'), x_vec.astype('float32'))\n", + " cls = F.binary_cross_entropy_with_logits(logits, y.astype('float32'), pos_weight=pos_weight)\n", + " cls.backward()\n", + " optimizer.step(); optimizer.clear_grad()\n", + " total_loss += float(cls); total_batches += 1\n", + " val_metrics = evaluate(model, val_loader, threshold=0.5, retriever=None)\n", + " print(f\"[Epoch {ep:02d}] train_loss={total_loss/max(1,total_batches):.4f} | \"\n", + " f\"val_loss={val_metrics['loss']:.4f} | \"\n", + " f\"macro_f1={val_metrics['macro_f1']:.4f} | \"\n", + " f\"micro_f1={val_metrics['micro_f1']:.4f} | \"\n", + " f\"per_class_f1={val_metrics['per_class_f1']} | \"\n", + " f\"micro_AP={val_metrics['micro_AP']:.4f}\")\n", + " if val_metrics[\"macro_f1\"] > best_macro_f1:\n", + " best_macro_f1 = val_metrics[\"macro_f1\"]\n", + " best = {k: v.clone() for k, v in model.state_dict().items()}\n", + " if best is not None:\n", + " model.set_state_dict(best); print(f\"Loaded best state with macro_f1={best_macro_f1:.4f}\")\n", + "\n", + " # 检索库 + 检索增强评估\n", + " retr = Retriever(sim_metric='cos', k=8, alpha=0.3, tau=0.5)\n", + " retr.build(model, DataLoader(train_ds, batch_size=1, shuffle=False, collate_fn=collate_fn))\n", + " val_metrics_knn = evaluate(model, val_loader, threshold=0.5, retriever=retr)\n", + " print(f\"[RkNN] val_loss={val_metrics_knn['loss']:.4f} | \"\n", + " f\"macro_f1={val_metrics_knn['macro_f1']:.4f} | \"\n", + " f\"micro_f1={val_metrics_knn['micro_f1']:.4f} | \"\n", + " f\"per_class_f1={val_metrics_knn['per_class_f1']} | \"\n", + " f\"micro_AP={val_metrics_knn['micro_AP']:.4f}\")\n" + ], + "metadata": { + "id": "KzXknA11mNYw" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# -*- coding: utf-8 -*-\n", + "import math, os\n", + "from typing import Optional, Tuple, List\n", + "import numpy as np\n", + "import paddle\n", + "import paddle.nn as nn\n", + "import paddle.nn.functional as F\n", + "from paddle.io import Dataset, DataLoader\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# ============ 基本设置 ============\n", + "os.makedirs(\"viz_out\", exist_ok=True)\n", + "\n", + "# ============ 工具:正弦位置编码 ============\n", + "class SinusoidalPositionalEncoding(nn.Layer):\n", + " def __init__(self, d_model: int, max_len: int = 4096):\n", + " super().__init__()\n", + " pe = np.zeros((max_len, d_model), dtype=\"float32\")\n", + " position = np.arange(0, max_len, dtype=\"float32\")[:, None]\n", + " div_term = np.exp(np.arange(0, d_model, 2, dtype=\"float32\") * (-math.log(10000.0) / d_model))\n", + " pe[:, 0::2] = np.sin(position * div_term)\n", + " pe[:, 1::2] = np.cos(position * div_term)\n", + " self.register_buffer(\"pe\", paddle.to_tensor(pe), persistable=False)\n", + " def forward(self, x): # (B,T,D)\n", + " T = x.shape[1]\n", + " return x + self.pe[:T, :]\n", + "\n", + "# ============ TabM(占位,可替换为你的实现) ============\n", + "class TabMFeatureExtractor(nn.Layer):\n", + " def __init__(self, num_features: int, d_hidden: int = 512, dropout: float = 0.1):\n", + " super().__init__()\n", + " self.net = nn.Sequential(\n", + " nn.Linear(num_features, d_hidden), nn.ReLU(), nn.Dropout(dropout),\n", + " nn.Linear(d_hidden, d_hidden), nn.ReLU(),\n", + " )\n", + " self.d_hidden = d_hidden\n", + " def forward(self, x_num: paddle.Tensor):\n", + " return self.net(x_num)\n", + "\n", + "# ============ 3D ResNet18 ============\n", + "class BasicBlock3D(nn.Layer):\n", + " expansion = 1\n", + " def __init__(self, in_planes, planes, stride=(1,1,1), downsample=None):\n", + " super().__init__()\n", + " self.conv1 = nn.Conv3D(in_planes, planes, 3, stride=stride, padding=1, bias_attr=False)\n", + " self.bn1 = nn.BatchNorm3D(planes)\n", + " self.relu = nn.ReLU()\n", + " self.conv2 = nn.Conv3D(planes, planes, 3, stride=1, padding=1, bias_attr=False)\n", + " self.bn2 = nn.BatchNorm3D(planes)\n", + " self.downsample = downsample\n", + " def forward(self, x):\n", + " identity = x\n", + " out = self.relu(self.bn1(self.conv1(x)))\n", + " out = self.bn2(self.conv2(out))\n", + " if self.downsample is not None:\n", + " identity = self.downsample(x)\n", + " out = self.relu(out + identity)\n", + " return out\n", + "\n", + "class ResNet3D(nn.Layer):\n", + " def __init__(self, block, layers, in_channels=20, base_width=64):\n", + " super().__init__()\n", + " self.in_planes = base_width\n", + " self.conv1 = nn.Conv3D(in_channels, self.in_planes, kernel_size=(3,7,7),\n", + " stride=(1,2,2), padding=(1,3,3), bias_attr=False)\n", + " self.bn1 = nn.BatchNorm3D(self.in_planes)\n", + " self.relu = nn.ReLU()\n", + " self.maxpool = nn.MaxPool3D(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1))\n", + " self.layer1 = self._make_layer(block, base_width, layers[0], stride=(1,1,1))\n", + " self.layer2 = self._make_layer(block, base_width*2, layers[1], stride=(2,2,2))\n", + " self.layer3 = self._make_layer(block, base_width*4, layers[2], stride=(2,2,2))\n", + " self.layer4 = self._make_layer(block, base_width*8, layers[3], stride=(2,2,2))\n", + " self.out_dim = base_width*8 # 512\n", + " self.pool = nn.AdaptiveAvgPool3D(output_size=1)\n", + " def _make_layer(self, block, planes, blocks, stride=(1,1,1)):\n", + " downsample = None\n", + " if stride != (1,1,1) or self.in_planes != planes * block.expansion:\n", + " downsample = nn.Sequential(\n", + " nn.Conv3D(self.in_planes, planes * block.expansion, 1, stride=stride, bias_attr=False),\n", + " nn.BatchNorm3D(planes * block.expansion),\n", + " )\n", + " layers = [block(self.in_planes, planes, stride=stride, downsample=downsample)]\n", + " self.in_planes = planes * block.expansion\n", + " for _ in range(1, blocks):\n", + " layers.append(block(self.in_planes, planes))\n", + " return nn.Sequential(*layers)\n", + " def forward(self, x): # (B, C, D, H, W)\n", + " x = self.relu(self.bn1(self.conv1(x)))\n", + " x = self.maxpool(x)\n", + " x = self.layer1(x); x = self.layer2(x); x = self.layer3(x); x = self.layer4(x)\n", + " x = self.pool(x) # (B, 512, 1,1,1)\n", + " x = paddle.flatten(x, 1) # (B, 512)\n", + " return x\n", + "\n", + "class Volume3DEncoder(nn.Layer):\n", + " \"\"\"\n", + " 附带特征/梯度捕获,用于 3D Grad-CAM:\n", + " - forward_post_hook 里:先缓存 feat\n", + " - 若 out 可梯度,则注册 backward hook;否则跳过(避免 stop_gradient 报错)\n", + " \"\"\"\n", + " def __init__(self, in_channels: int = 20, base: int = 64, dropout: float = 0.0):\n", + " super().__init__()\n", + " self.backbone = ResNet3D(BasicBlock3D, [2,2,2,2], in_channels=in_channels, base_width=base)\n", + " self.drop = nn.Dropout(dropout)\n", + " self.out_dim = self.backbone.out_dim # 512\n", + " self._feat = None\n", + " self._grad = None\n", + "\n", + " def _save_feat_grad(layer, inp, out):\n", + " self._feat = out # (B, 512, D',H',W')\n", + " if getattr(out, \"stop_gradient\", False):\n", + " return\n", + " def _save_grad(grad):\n", + " self._grad = grad\n", + " out.register_hook(_save_grad)\n", + "\n", + " self.backbone.layer4.register_forward_post_hook(_save_feat_grad)\n", + "\n", + " def forward(self, x): # (B, C, D, H, W)\n", + " x = self.backbone(x)\n", + " x = self.drop(x)\n", + " return x\n", + "\n", + "# ============ MoE ============\n", + "class ExpertFFN(nn.Layer):\n", + " def __init__(self, d_model, d_ff, dropout=0.1):\n", + " super().__init__()\n", + " self.fc1 = nn.Linear(d_model, d_ff)\n", + " self.fc2 = nn.Linear(d_ff, d_model)\n", + " self.drop = nn.Dropout(dropout)\n", + " def forward(self, x):\n", + " return self.fc2(self.drop(F.relu(self.fc1(x))))\n", + "\n", + "class MoEConfig:\n", + " def __init__(self, n_experts=8, top_k=1, d_ff=2048, dropout=0.1,\n", + " router_temp=0.5, use_gumbel=False):\n", + " self.n_experts=n_experts; self.top_k=top_k; self.d_ff=d_ff; self.dropout=dropout\n", + " self.router_temp=router_temp; self.use_gumbel=use_gumbel\n", + "\n", + "class MoE(nn.Layer):\n", + " \"\"\"缓存最近一次路由概率/索引,便于可解释与聚类\"\"\"\n", + " def __init__(self, d_model: int, cfg: MoEConfig):\n", + " super().__init__()\n", + " self.cfg = cfg\n", + " self.router = nn.Linear(d_model, cfg.n_experts)\n", + " self.experts = nn.LayerList([ExpertFFN(d_model, cfg.d_ff, cfg.dropout) for _ in range(cfg.n_experts)])\n", + " self.ln = nn.LayerNorm(d_model); self.drop = nn.Dropout(cfg.dropout)\n", + " self.last_router_probs = None\n", + " self.last_topk_idx = None\n", + " def _router_probs(self, logits):\n", + " if self.cfg.use_gumbel and self.training:\n", + " u = paddle.uniform(logits.shape, min=1e-6, max=1-1e-6, dtype=logits.dtype)\n", + " g = -paddle.log(-paddle.log(u)); logits = logits + g\n", + " return F.softmax(logits / self.cfg.router_temp, axis=-1)\n", + " def forward(self, x):\n", + " orig_shape = x.shape\n", + " if len(orig_shape) == 3: B,T,D = orig_shape; X = x.reshape([B*T, D])\n", + " else: X = x\n", + " N,D = X.shape\n", + " logits = self.router(X); probs = self._router_probs(logits)\n", + " topk_val, topk_idx = paddle.topk(probs, k=self.cfg.top_k, axis=-1)\n", + " all_out = paddle.stack([e(X) for e in self.experts], axis=1) # (N,E,D)\n", + " arangeN = paddle.arange(N, dtype='int64')\n", + " picked_list=[]\n", + " for i in range(self.cfg.top_k):\n", + " idx_i = topk_idx[:, i].astype('int64')\n", + " idx_nd = paddle.stack([arangeN, idx_i], axis=1)\n", + " picked_i = paddle.gather_nd(all_out, idx_nd)\n", + " picked_list.append(picked_i)\n", + " picked = paddle.stack(picked_list, axis=1) # (N,k,D)\n", + " w = topk_val / (paddle.sum(topk_val, axis=-1, keepdim=True) + 1e-9)\n", + " Y = paddle.sum(picked * w.unsqueeze(-1), axis=1) # (N,D)\n", + " Y = self.drop(Y); Y = self.ln(Y + X)\n", + " self.last_router_probs = probs.detach()\n", + " self.last_topk_idx = topk_idx.detach()\n", + " if len(orig_shape)==3: Y = Y.reshape([B,T,D])\n", + " return Y\n", + "\n", + "class MoEHead(nn.Layer):\n", + " def __init__(self, d_model=512, cfg: MoEConfig = None):\n", + " super().__init__()\n", + " self.moe = MoE(d_model, cfg or MoEConfig())\n", + " self.last_router_probs = None\n", + " self.last_topk_idx = None\n", + " def forward(self, tok):\n", + " y = self.moe(tok.unsqueeze(1)).squeeze(1)\n", + " self.last_router_probs = self.moe.last_router_probs\n", + " self.last_topk_idx = self.moe.last_topk_idx\n", + " return y\n", + "\n", + "# ============ 原生 MHA + 手工回算注意力 ============\n", + "class _NativeMHAWithAttn(nn.Layer):\n", + " \"\"\"\n", + " 包装 nn.MultiHeadAttention:\n", + " - 正式输出用原生 MHA (性能/数值一致)\n", + " - 注意力矩阵用相同权重手工回算(兼容旧版不返回 attn 的情况)\n", + " \"\"\"\n", + " def __init__(self, d_model: int, nhead: int, dropout: float = 0.1):\n", + " super().__init__()\n", + " self.mha = nn.MultiHeadAttention(embed_dim=d_model, num_heads=nhead, dropout=dropout)\n", + " self.nhead = nhead\n", + " assert d_model % nhead == 0\n", + " self.d_head = d_model // nhead\n", + " self.last_attn = None # (B, H, T, T)\n", + "\n", + " def forward(self, x_tb: paddle.Tensor) -> paddle.Tensor:\n", + " \"\"\"\n", + " x_tb: (T,B,D)\n", + " return: (T,B,D)\n", + " \"\"\"\n", + " # 1) 原生前向(不同版本的返回值差异:此处统一只拿输出)\n", + " out_tb = self.mha(x_tb, x_tb, x_tb)\n", + "\n", + " # 2) 手工回算注意力:用同一组投影权重\n", + " q = self.mha.q_proj(x_tb) # (T,B,D)\n", + " k = self.mha.k_proj(x_tb)\n", + " v = self.mha.v_proj(x_tb)\n", + " T, B, D = q.shape\n", + " H, Dh = self.nhead, self.d_head\n", + "\n", + " def split(tb): # (T,B,D)->(B,H,T,Dh)\n", + " tb = tb.transpose([1, 0, 2]) # (B,T,D)\n", + " return tb.reshape([B, T, H, Dh]).transpose([0, 2, 1, 3])\n", + "\n", + " qh, kh, vh = split(q), split(k), split(v) # (B,H,T,Dh)\n", + " scores = paddle.matmul(qh, kh, transpose_y=True) / math.sqrt(Dh) # (B,H,T,T)\n", + " attn = F.softmax(scores, axis=-1)\n", + " self.last_attn = attn.detach()\n", + "\n", + " return out_tb # (T,B,D)\n", + "\n", + "# ============ Self-Attention Transformer(记录注意力) ============\n", + "class TransformerEncoderLayerMoE(nn.Layer):\n", + " def __init__(self, d_model=512, nhead=8, d_ff=1024, dropout=0.1,\n", + " use_moe: bool = True, moe_cfg: MoEConfig = None, capture_attn: bool = True):\n", + " super().__init__()\n", + " self.use_moe = use_moe; self.capture_attn = capture_attn\n", + " self.self_attn = _NativeMHAWithAttn(d_model, nhead, dropout)\n", + " self.ln1 = nn.LayerNorm(d_model); self.do1 = nn.Dropout(dropout)\n", + " if use_moe:\n", + " self.moe = MoE(d_model, moe_cfg or MoEConfig(d_ff=d_ff, dropout=dropout))\n", + " else:\n", + " self.ffn = nn.Sequential(nn.LayerNorm(d_model),\n", + " nn.Linear(d_model, d_ff), nn.ReLU(), nn.Dropout(dropout),\n", + " nn.Linear(d_ff, d_model))\n", + " self.do2 = nn.Dropout(dropout)\n", + " self.last_attn = None # (B,H,T,T)\n", + " def forward(self, x): # (B,T,D)\n", + " h = self.ln1(x)\n", + " h_tb = paddle.transpose(h, [1,0,2]) # (T,B,D)\n", + " out_tb = self.self_attn(h_tb) # (T,B,D)\n", + " if self.capture_attn:\n", + " self.last_attn = self.self_attn.last_attn\n", + " out = paddle.transpose(out_tb, [1,0,2]) # (B,T,D)\n", + " x = x + self.do1(out)\n", + " if self.use_moe:\n", + " x = self.moe(x)\n", + " else:\n", + " x = x + self.do2(self.ffn(x))\n", + " return x\n", + "\n", + "class TemporalTransformerFlexible(nn.Layer):\n", + " def __init__(self, d_model=512, nhead=4, num_layers=2, d_ff=1024, dropout=0.1,\n", + " max_len=4096, use_moe: bool = True, moe_cfg: MoEConfig = None, capture_attn=True):\n", + " super().__init__()\n", + " self.pos = SinusoidalPositionalEncoding(d_model, max_len=max_len)\n", + " self.layers = nn.LayerList([\n", + " TransformerEncoderLayerMoE(d_model, nhead, d_ff, dropout,\n", + " use_moe=use_moe, moe_cfg=moe_cfg, capture_attn=capture_attn)\n", + " for _ in range(num_layers)\n", + " ])\n", + " self.last_attn_all_layers: List[paddle.Tensor] = []\n", + " def forward(self, x):\n", + " x = self.pos(x)\n", + " self.last_attn_all_layers = []\n", + " for layer in self.layers:\n", + " x = layer(x)\n", + " if layer.last_attn is not None:\n", + " self.last_attn_all_layers.append(layer.last_attn) # (B,H,T,T)\n", + " return x\n", + "\n", + "# ============ AFNO(1D) + MoE FFN ============\n", + "class AFNO1DLayer(nn.Layer):\n", + " def __init__(self, d_model: int, modes: int = 32, num_blocks: int = 8, shrink: float = 0.01, dropout: float = 0.1):\n", + " super().__init__()\n", + " assert d_model % num_blocks == 0\n", + " self.d_model=d_model; self.modes=modes; self.num_blocks=num_blocks; self.block=d_model//num_blocks\n", + " scale=1.0/math.sqrt(self.block); init = nn.initializer.Uniform(-scale, scale)\n", + " self.w1r = self.create_parameter([num_blocks, self.block, self.block], default_initializer=init)\n", + " self.w1i = self.create_parameter([num_blocks, self.block, self.block], default_initializer=init)\n", + " self.w2r = self.create_parameter([num_blocks, self.block, self.block], default_initializer=init)\n", + " self.w2i = self.create_parameter([num_blocks, self.block, self.block], default_initializer=init)\n", + " self.ln = nn.LayerNorm(d_model); self.drop = nn.Dropout(dropout); self.shrink = shrink\n", + " def _cl(self, xr, xi, Wr, Wi):\n", + " out_r = paddle.matmul(xr, Wr) - paddle.matmul(xi, Wi)\n", + " out_i = paddle.matmul(xr, Wi) + paddle.matmul(xi, Wr)\n", + " return out_r, out_i\n", + " def forward(self, x): # (B,T,D)\n", + " B,T,D = x.shape; Kmax=T//2+1; K=min(self.modes, Kmax)\n", + " h=self.ln(x); h_td=h.transpose([0,2,1]); h_ft=paddle.fft.rfft(h_td) # (B,D,F)\n", + " h_ft=h_ft.reshape([B, self.num_blocks, self.block, Kmax])\n", + " xk=h_ft[:,:,:, :K].transpose([0,1,3,2]) # (B,G,K,Cb)\n", + " xr, xi = paddle.real(xk), paddle.imag(xk)\n", + " yr, yi = self._cl(xr, xi, self.w1r, self.w1i)\n", + " yr = F.gelu(yr); yi = F.gelu(yi)\n", + " yr = F.softshrink(yr, threshold=self.shrink); yi = F.softshrink(yi, threshold=self.shrink)\n", + " yr, yi = self._cl(yr, yi, self.w2r, self.w2i)\n", + " yk = paddle.complex(yr, yi).transpose([0,1,3,2]).reshape([B,D,K])\n", + " out_ft = paddle.zeros([B,D,Kmax], dtype='complex64')\n", + " out_ft[:,:, :K] = yk\n", + " out_td = paddle.fft.irfft(out_ft, n=T)\n", + " out = out_td.transpose([0,2,1])\n", + " out = self.drop(out)\n", + " return x + out\n", + "\n", + "class AFNOTransformerFlexible(nn.Layer):\n", + " def __init__(self, d_model=512, num_layers=2, modes=32, dropout=0.1,\n", + " d_ff=1024, use_moe: bool = True, moe_cfg: MoEConfig = None):\n", + " super().__init__()\n", + " self.layers = nn.LayerList([AFNO1DLayer(d_model, modes, 8, 0.01, dropout) for _ in range(num_layers)])\n", + " self.use_moe = use_moe\n", + " if use_moe:\n", + " self.moe = MoE(d_model, moe_cfg or MoEConfig(d_ff=d_ff, dropout=dropout))\n", + " else:\n", + " self.ffn = nn.Sequential(nn.LayerNorm(d_model),\n", + " nn.Linear(d_model, d_ff), nn.ReLU(), nn.Dropout(dropout),\n", + " nn.Linear(d_ff, d_model))\n", + " self.do = nn.Dropout(dropout)\n", + " def forward(self, x):\n", + " for layer in self.layers:\n", + " x = layer(x)\n", + " if self.use_moe:\n", + " x = self.moe(x)\n", + " else:\n", + " x = x + self.do(self.ffn(x))\n", + " return x\n", + "\n", + "# ============ Cross-Attention(记录注意力) ============\n", + "class MultiHeadCrossAttention(nn.Layer):\n", + " def __init__(self, d_model: int, nhead: int = 8, dropout: float = 0.1):\n", + " super().__init__()\n", + " assert d_model % nhead == 0\n", + " self.d_head = d_model // nhead; self.nhead = nhead\n", + " self.Wq = nn.Linear(d_model, d_model); self.Wk = nn.Linear(d_model, d_model); self.Wv = nn.Linear(d_model, d_model)\n", + " self.proj = nn.Linear(d_model, d_model); self.drop = nn.Dropout(dropout); self.ln = nn.LayerNorm(d_model)\n", + " self.last_attn = None # (B, H, Nq, Nk)\n", + " def forward(self, q, kv):\n", + " B, Nq, D = q.shape; Nk = kv.shape[1]\n", + " def split(t): return t.reshape([B, -1, self.nhead, self.d_head]).transpose([0,2,1,3])\n", + " qh = split(self.Wq(q)); kh = split(self.Wk(kv)); vh = split(self.Wv(kv))\n", + " scores = paddle.matmul(qh, kh, transpose_y=True) / math.sqrt(self.d_head) # (B,H,Nq,Nk)\n", + " attn = F.softmax(scores, axis=-1)\n", + " self.last_attn = attn.detach()\n", + " ctx = paddle.matmul(attn, vh).transpose([0,2,1,3]).reshape([B,Nq,D])\n", + " out = self.drop(self.proj(ctx))\n", + " return self.ln(out + q)\n", + "\n", + "class BiModalCrossFusion(nn.Layer):\n", + " def __init__(self, d_model=512, nhead=8, dropout=0.1, fuse_hidden=512):\n", + " super().__init__()\n", + " self.ca_v_from_t = MultiHeadCrossAttention(d_model, nhead, dropout)\n", + " self.ca_t_from_v = MultiHeadCrossAttention(d_model, nhead, dropout)\n", + " self.fuse = nn.Sequential(nn.Linear(2*d_model, fuse_hidden), nn.ReLU(), nn.Dropout(dropout))\n", + " self.out_dim = fuse_hidden\n", + " self.last_attn_v_from_t = None # (B,H,1,1)\n", + " self.last_attn_t_from_v = None # (B,H,1,T)\n", + " def forward(self, video_seq, tabm_tok):\n", + " v_tok = video_seq.mean(axis=1, keepdim=True) # (B,1,D)\n", + " t_tok = tabm_tok.unsqueeze(1) # (B,1,D)\n", + " v_upd = self.ca_v_from_t(v_tok, t_tok)\n", + " t_upd = self.ca_t_from_v(t_tok, video_seq)\n", + " self.last_attn_v_from_t = self.ca_v_from_t.last_attn\n", + " self.last_attn_t_from_v = self.ca_t_from_v.last_attn\n", + " fused = paddle.concat([v_upd, t_upd], axis=-1).squeeze(1)\n", + " return self.fuse(fused)\n", + "\n", + "# ============ 总模型 ============\n", + "class TwoModalMultiLabelModel(nn.Layer):\n", + " def __init__(self, vid_channels=20, vid_frames=365, depth_n=24,\n", + " vec_dim=424, d_model=512, nhead=4, n_trans_layers=2, trans_ff=1024,\n", + " tabm_hidden=512, dropout=0.1, num_labels=4,\n", + " moe_temporal_attn=True, moe_temporal_afno=True, moe_fused=False, moe_tabm=False,\n", + " afno_modes=32):\n", + " super().__init__()\n", + " self.vol_encoder = Volume3DEncoder(in_channels=vid_channels, dropout=dropout)\n", + " self.trans_attn = TemporalTransformerFlexible(\n", + " d_model=d_model, nhead=nhead, num_layers=n_trans_layers, d_ff=trans_ff, dropout=dropout,\n", + " max_len=vid_frames, use_moe=moe_temporal_attn, moe_cfg=MoEConfig(d_ff=max(2048,trans_ff), n_experts=8),\n", + " capture_attn=True\n", + " )\n", + " self.trans_afno = AFNOTransformerFlexible(\n", + " d_model=d_model, num_layers=n_trans_layers, modes=afno_modes, dropout=dropout,\n", + " d_ff=trans_ff, use_moe=moe_temporal_afno, moe_cfg=MoEConfig(d_ff=max(2048,trans_ff), n_experts=8)\n", + " )\n", + " self.video_merge = nn.Linear(2*d_model, d_model)\n", + " self.tabm = TabMFeatureExtractor(vec_dim, d_hidden=tabm_hidden, dropout=dropout)\n", + " self.tabm_proj = nn.Linear(tabm_hidden, d_model)\n", + " self.moe_tabm = moe_tabm\n", + " if moe_tabm:\n", + " self.tabm_moe = MoEHead(d_model=d_model, cfg=MoEConfig(d_ff=1024, n_experts=6))\n", + " self.fusion = BiModalCrossFusion(d_model=d_model, nhead=nhead, dropout=dropout, fuse_hidden=d_model)\n", + " self.moe_fused = moe_fused\n", + " if moe_fused:\n", + " self.fused_moe = MoEHead(d_model=d_model, cfg=MoEConfig(d_ff=1024, n_experts=6))\n", + " self.head = nn.Linear(self.fusion.out_dim, num_labels)\n", + " self.depth_n = depth_n\n", + " def encode(self, x_video, x_vec):\n", + " B,T,C,H,W,N = x_video.shape\n", + " assert N == self.depth_n\n", + " xvt = x_video.transpose([0,1,2,5,3,4]).reshape([B*T, C, N, H, W])\n", + " f_frame = self.vol_encoder(xvt) # (B*T,512)\n", + " seq = f_frame.reshape([B, T, -1]) # (B,T,512)\n", + " z_attn = self.trans_attn(seq)\n", + " z_afno = self.trans_afno(seq)\n", + " z_vid = self.video_merge(paddle.concat([z_attn, z_afno], axis=-1))\n", + " z_tabm = self.tabm(x_vec); z_tabm = self.tabm_proj(z_tabm)\n", + " if self.moe_tabm:\n", + " z_tabm = self.tabm_moe(z_tabm)\n", + " fused = self.fusion(z_vid, z_tabm)\n", + " if self.moe_fused:\n", + " fused = self.fused_moe(fused)\n", + " return fused\n", + " def forward(self, x_video, x_vec):\n", + " fused = self.encode(x_video, x_vec)\n", + " logits = self.head(fused)\n", + " return logits\n", + "\n", + "# ============ 3D Grad-CAM ============\n", + "class GradCAM3D:\n", + " \"\"\"\n", + " CAM = ReLU( sum_c( w_c * A_c ) ), w_c = GAP(grad_c)\n", + " 输出 (N,H,W),若无 scipy 则返回特征尺度 (D',H',W')\n", + " \"\"\"\n", + " def __init__(self, model: TwoModalMultiLabelModel):\n", + " self.model = model\n", + " @paddle.no_grad()\n", + " def _trilinear_upsample(self, vol, out_shape):\n", + " try:\n", + " from scipy.ndimage import zoom\n", + " Dz = out_shape[0] / vol.shape[0]\n", + " Dy = out_shape[1] / vol.shape[1]\n", + " Dx = out_shape[2] / vol.shape[2]\n", + " return zoom(vol, (Dz, Dy, Dx), order=1)\n", + " except Exception:\n", + " return vol\n", + " def generate(self, x_video, x_vec, target_class: int = 0, time_index: int = 0):\n", + " assert x_video.shape[0] == 1, \"Grad-CAM 演示请用单样本 B=1\"\n", + " self.model.eval()\n", + " B,T,C,H,W,N = x_video.shape\n", + " self.model.clear_gradients()\n", + " logits = self.model(x_video.astype('float32'), x_vec.astype('float32')) # (1,num_labels)\n", + " cls = logits[0, target_class]\n", + " cls.backward()\n", + " feat = self.model.vol_encoder._feat # (1,512,D',H',W')\n", + " grad = self.model.vol_encoder._grad\n", + " assert (feat is not None) and (grad is not None), \"未捕获到特征/梯度(检查 hook & 是否有梯度前向)\"\n", + " feat_np = feat.numpy()[0]; grad_np = grad.numpy()[0]\n", + " w = grad_np.mean(axis=(1,2,3)) # (512,)\n", + " cam = np.maximum(0, np.tensordot(w, feat_np, axes=(0,0))) # (D',H',W')\n", + " cam = cam - cam.min(); cam = cam / (cam.max() + 1e-8)\n", + " cam_up = self._trilinear_upsample(cam, (N, H, W))\n", + " return cam_up # (N,H,W) or (D',H',W')\n", + "\n", + "# ============ MoE 路由聚类工具 ============\n", + "def kmeans_numpy(X: np.ndarray, K: int = 4, iters: int = 50, seed: int = 0):\n", + " rng = np.random.default_rng(seed)\n", + " N,D = X.shape\n", + " cent = X[rng.choice(N, K, replace=False)]\n", + " for _ in range(iters):\n", + " dist2 = ((X[:,None,:]-cent[None,:,:])**2).sum(axis=2) # (N,K)\n", + " idx = dist2.argmin(axis=1)\n", + " new_cent = np.stack([X[idx==k].mean(axis=0) if np.any(idx==k) else cent[k] for k in range(K)], 0)\n", + " if np.allclose(new_cent, cent): break\n", + " cent = new_cent\n", + " return idx, cent\n", + "\n", + "def collect_moe_routing_vectors(model: TwoModalMultiLabelModel, loader: DataLoader,\n", + " branch: str = \"temporal_attn\", topk_hist: bool = True):\n", + " model.eval()\n", + " vecs = []\n", + " for x_vid, x_vec, y in loader:\n", + " _ = model(x_vid.astype('float32'), x_vec.astype('float32'))\n", + " if branch == \"temporal_attn\":\n", + " moe = None\n", + " for lyr in model.trans_attn.layers[::-1]:\n", + " if hasattr(lyr, \"moe\"):\n", + " moe = lyr.moe; break\n", + " elif branch == \"temporal_afno\":\n", + " moe = model.trans_afno.moe if hasattr(model.trans_afno, \"moe\") else None\n", + " elif branch == \"tabm\":\n", + " moe = model.tabm_moe.moe if getattr(model, \"moe_tabm\", False) else None\n", + " else:\n", + " moe = model.fused_moe.moe if getattr(model, \"moe_fused\", False) else None\n", + " if moe is None or moe.last_router_probs is None:\n", + " continue\n", + " probs = moe.last_router_probs.numpy() # (N_tokens, E)\n", + " if topk_hist:\n", + " top1 = moe.last_topk_idx.numpy()[:,0] # (N_tokens,)\n", + " E = probs.shape[1]\n", + " hist = np.bincount(top1, minlength=E).astype(\"float32\")\n", + " hist = hist / (hist.sum() + 1e-9)\n", + " vecs.append(hist)\n", + " else:\n", + " vecs.append(probs.mean(axis=0))\n", + " return np.stack(vecs, 0) if len(vecs)>0 else None\n", + "\n", + "# ============ Toy 数据集 ============\n", + "class ToyTwoModalDataset(Dataset):\n", + " \"\"\"\n", + " 返回:\n", + " x_video: (T=365, C=20, H=20, W=20, N=24)\n", + " x_vec: (424,)\n", + " y: (4,) 0/1\n", + " \"\"\"\n", + " def __init__(self, n: int, seed: int = 0, T: int = 365, C: int = 20, H: int = 20, W: int = 20, N: int = 24):\n", + " super().__init__()\n", + " rng = np.random.default_rng(seed)\n", + " self.video = rng.normal(size=(n, T, C, H, W, N)).astype('float32')\n", + " self.vec = rng.normal(size=(n, 424)).astype('float32')\n", + " # 造标签:体素均值 → (n,C) → 线性到 4 类\n", + " vid_hwn = self.video.mean(axis=(3,4,5)) # (n,T,C)\n", + " vid_avg = vid_hwn.mean(axis=1) # (n,C)\n", + " Wv = rng.normal(size=(C,4)); Wt = rng.normal(size=(424,4))\n", + " logits = vid_avg @ Wv + self.vec @ Wt + rng.normal(scale=0.5, size=(n,4))\n", + " probs = 1.0 / (1.0 + np.exp(-logits))\n", + " self.y = (probs > 0.5).astype('float32')\n", + " def __getitem__(self, idx: int):\n", + " return self.video[idx], self.vec[idx], self.y[idx]\n", + " def __len__(self): return len(self.y)\n", + "\n", + "# ============ 小工具:绘图 ============\n", + "def show_heatmap_2d(arr2d: np.ndarray, title: str, save_path: Optional[str] = None):\n", + " plt.figure(); plt.imshow(arr2d, interpolation='nearest'); plt.title(title); plt.colorbar()\n", + " if save_path: plt.savefig(save_path, bbox_inches='tight');\n", + " plt.show(); plt.close()\n", + "\n", + "def show_attention_matrix(attn: np.ndarray, title: str, save_path: Optional[str] = None):\n", + " # attn: (B,H,T,T) 或 (B,H,1,T) 或 (B,H,1,1)\n", + " if attn.ndim == 4 and attn.shape[2] == 1 and attn.shape[3] == 1:\n", + " attn = attn[0,:,0,0][:,None] # (H,1)\n", + " elif attn.ndim == 4 and attn.shape[2] == 1:\n", + " attn = attn[0] # (H,1,T)\n", + " elif attn.ndim == 4:\n", + " attn = attn[0] # (H,T,T)\n", + " plt.figure(figsize=(5,4))\n", + " if attn.ndim == 2: # (H,1)\n", + " plt.imshow(attn, aspect='auto', interpolation='nearest')\n", + " elif attn.ndim == 3: # 多头\n", + " H = attn.shape[0]\n", + " cols = int(np.ceil(np.sqrt(H))); rows = int(np.ceil(H/cols))\n", + " fig, axes = plt.subplots(rows, cols, figsize=(3*cols, 3*rows))\n", + " axes = axes.flatten()\n", + " for h in range(H):\n", + " axes[h].imshow(attn[h], interpolation='nearest'); axes[h].set_title(f\"head {h}\")\n", + " for k in range(H, len(axes)): axes[k].axis('off')\n", + " fig.suptitle(title)\n", + " if save_path: fig.savefig(save_path, bbox_inches='tight')\n", + " plt.show(); plt.close(fig); return\n", + " plt.title(title); plt.colorbar()\n", + " if save_path: plt.savefig(save_path, bbox_inches='tight')\n", + " plt.show(); plt.close()\n", + "\n", + "# ============ Demo:可解释可视化 ============\n", + "if __name__ == \"__main__\":\n", + " # 1) 构造“已训练好”的模型(这里随机权重示意)\n", + " model = TwoModalMultiLabelModel(\n", + " vid_channels=20, vid_frames=365, depth_n=24,\n", + " vec_dim=424, d_model=512, nhead=4, n_trans_layers=2, trans_ff=512,\n", + " tabm_hidden=256, dropout=0.1, num_labels=4,\n", + " moe_temporal_attn=True, moe_temporal_afno=True,\n", + " moe_fused=False, moe_tabm=False, afno_modes=32\n", + " )\n", + " model.eval()\n", + "\n", + " # 2) 取一个样本\n", + " toy = ToyTwoModalDataset(n=8, seed=123, T=365, C=20, H=20, W=20, N=24)\n", + " x_video, x_vec, y = toy[0]\n", + " x_video = paddle.to_tensor(x_video[None, ...]) # (1,T,C,H,W,N)\n", + " x_vec = paddle.to_tensor(x_vec[None, ...]) # (1,424)\n", + "\n", + " # 3) 3D Grad-CAM:一次“有梯度”的前向 + 反传\n", + " model.clear_gradients()\n", + " logits = model(x_video.astype('float32'), x_vec.astype('float32'))\n", + " target_class = int(paddle.argmax(logits, axis=-1)[0])\n", + " cam3d = GradCAM3D(model).generate(\n", + " x_video.astype('float32'), x_vec.astype('float32'),\n", + " target_class=target_class, time_index=0\n", + " ) # (N,H,W) or (D',H',W')\n", + "\n", + " # 展示几个深度切片\n", + " Nz = cam3d.shape[0]\n", + " for z in [0, Nz//3, 2*Nz//3, Nz-1]:\n", + " show_heatmap_2d(cam3d[z], f\"Grad-CAM depth={z}\", save_path=f\"viz_out/gradcam_z{z}.png\")\n", + "\n", + " # 4) Self-Attention & Cross-Attention 注意力矩阵\n", + " with paddle.no_grad():\n", + " _ = model.encode(x_video.astype('float32'), x_vec.astype('float32'))\n", + " # Self-Attn(最后一层)\n", + " last_attn_list = model.trans_attn.last_attn_all_layers\n", + " if len(last_attn_list) > 0:\n", + " attn = last_attn_list[-1].numpy() # (B,H,T,T)\n", + " attn_crop = attn[:, :, :64, :64]\n", + " show_attention_matrix(attn_crop, \"Self-Attention (last layer, first 64 tokens)\",\n", + " save_path=\"viz_out/self_attn_lastlayer_64.png\")\n", + " print(\"Self-Attn matrix shape:\", attn.shape)\n", + " else:\n", + " print(\"Self-Attn not captured.\")\n", + " # Cross-Attn\n", + " if model.fusion.last_attn_v_from_t is not None:\n", + " show_attention_matrix(model.fusion.last_attn_v_from_t.numpy(),\n", + " \"Cross-Attn v<-t (token→token)\",\n", + " save_path=\"viz_out/cross_attn_v_from_t.png\")\n", + " if model.fusion.last_attn_t_from_v is not None:\n", + " attn_tv = model.fusion.last_attn_t_from_v.numpy()\n", + " attn_tv_crop = attn_tv[:,:,:, :64]\n", + " show_attention_matrix(attn_tv_crop,\n", + " \"Cross-Attn t<-v (token←video_seq first 64)\",\n", + " save_path=\"viz_out/cross_attn_t_from_v_64.png\")\n", + "\n", + " # 5) MoE 路由聚类(示例用 toy 数据)\n", + " def collate_fn(batch):\n", + " vids, vecs, ys = zip(*batch)\n", + " return (paddle.to_tensor(np.stack(vids, 0)),\n", + " paddle.to_tensor(np.stack(vecs, 0)),\n", + " paddle.to_tensor(np.stack(ys, 0)))\n", + " train_loader = DataLoader(toy, batch_size=1, shuffle=False, collate_fn=collate_fn)\n", + " moe_vecs = collect_moe_routing_vectors(model, train_loader, branch=\"temporal_attn\", topk_hist=True)\n", + " if moe_vecs is not None:\n", + " idx, cent = kmeans_numpy(moe_vecs, K=4, iters=100, seed=0)\n", + " print(\"\\n[MoE Routing Clusters @ temporal_attn]\")\n", + " for k in range(4):\n", + " sel = (idx==k)\n", + " if np.any(sel):\n", + " mean_vec = moe_vecs[sel].mean(axis=0)\n", + " dom = int(mean_vec.argmax())\n", + " print(f\" - Cluster {k}: size={int(sel.sum())}, dominant_expert={dom}, mean_dist={np.round(mean_vec,3)}\")\n", + " # 保存热图\n", + " plt.figure(figsize=(6,4))\n", + " plt.imshow(moe_vecs, aspect='auto', interpolation='nearest')\n", + " plt.title(\"Samples × Experts (routing histogram)\"); plt.xlabel(\"Expert\"); plt.ylabel(\"Sample\")\n", + " plt.colorbar(); plt.savefig(\"viz_out/moe_routing_heatmap.png\", bbox_inches='tight')\n", + " plt.show(); plt.close()\n", + " else:\n", + " print(\"MoE routing not available on selected branch.\")\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "j6NZgsBmZk7u", + "outputId": "2834b1eb-d29b-4095-ccb8-35e72eba8ec5" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "\n" + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "\n" + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "\n" + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "\n" + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "\n" + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Self-Attn matrix shape: (1, 4, 365, 365)\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "\n" + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAhAAAAHOCAYAAADJ3DBLAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAASp5JREFUeJzt3XlYVGX/P/D3AM6wMwgii7KIEC75SJDmkppiZpaappY+T0qpaVRqWmmbyzfFpZ5S89HUIvetMq3MJVGzXEhDkzQjFTETN8ANBGU+vz/4ccbjsN0movJ+Xddc1+HMfc7c52Y+M+85c84cg4gIiIiIiBTYVXYHiIiI6M7DAEFERETKGCCIiIhIGQMEERERKWOAICIiImUMEERERKSMAYKIiIiUMUAQERGRMgYIIiIiUsYAQXSHmTx5MiIiImCxWG7aOseMGQODwYAzZ87ctHXeDtq0aYM2bdqU2W7z5s0wGAzYvHlzhffpTnHy5Ek8+eST8PLygsFgwIcffnhXjtPatWvh6uqK06dPV3ZX7jgMEHeAQ4cO4fnnn0edOnXg6OgId3d3tGjRAlOnTkVubm5ld0/ZmjVrYDAY4O/vX+yb4N9//40xY8Zgz549NvctXrwYH3744U3v07Zt2zBmzBhkZ2ff9HXfTOfPn8ekSZPw+uuvw86usHxzcnIwZsyYu+pFnSrfsGHDsG7dOowaNQoLFizAI488ctMf40afu4cOHULv3r3h4+MDJycnhIWF4c033yyx/ZUrV1C/fn0YDAa89957uvseeeQR1K1bF/Hx8TeyCVWaQ2V3gEr37bffokePHjCZTHjmmWfQsGFD5Ofn48cff8Srr76K3377DbNnz67sbipZtGgRgoODkZaWhsTERMTExOju//vvvzF27FgEBwejcePGuvsWL16MlJQUDB069Kb2adu2bRg7diz69esHs9l8U9d9M3366ae4evUqnn76aW1eTk4Oxo4dCwDl+rRdlaxfv76yu3DHSkxMRJcuXTBixAhtXnh4OHJzc2E0Gm/KY9zIc3fPnj1o06YNAgICMHz4cHh5eSE9PR3Hjh0rcZnp06cjPT29xPuff/55jBgxAmPHjoWbm5vSNlRlDBC3sSNHjuCpp55CUFAQEhMT4efnp90XFxeHP//8E99++22Jy1ssFuTn58PR0fFWdLdcLl26hFWrViE+Ph4JCQlYtGiRTYC4ExUUFODgwYOoX79+hT5OQkICOnfufFv9T29nN+uNrio6deqUTZi2s7Mr13MvJycHzs7ON71PFosF//nPfxAREYFNmzbBycmpzGVOnTqFcePG4fXXX8c777xTbJvu3bvjpZdewooVK/Dss8/e7G7fvYRuW4MGDRIA8tNPP5WrPQCJi4uThQsXSv369cXBwUFWrlwpIiK//PKLPPLII+Lm5iYuLi7Stm1b2b59u275/Px8GTNmjNStW1dMJpNUr15dWrRoIevXr9fanDhxQvr16ycBAQFiNBrF19dXOnfuLEeOHClXHxcsWCB2dnZy4sQJmTRpkri7u0tubq52/6ZNmwSAzS0hIUFat25tMz8oKEi33LJly+Tdd9+VgIAAMZlM0rZtW0lNTS21T6NHjy72McuzTampqTJq1Cjx8/OTLl26lNp2ypQpAkDS0tJs7hs5cqRUq1ZNMjMzS1z+8OHDAkA+++wzbd6RI0eK7fvo0aO1Nhs3bpSWLVuKs7OzeHh4SOfOnWX//v3FjsHp06e1eWlpaRIaGioNGjSQjIwMERHJysqSIUOGSK1atcRoNEpoaKhMnDhRCgoKbPo0ZcoU+fjjj6VOnTpiNBolOjpakpKSSh2j8oiLixMXFxe5dOmSzX1PPfWU1KxZU65evSoiIq1bt5bWrVvr2hw7dky6dOkizs7OUqNGDRk6dKisXbtWAMimTZt0bXfs2CEdOnQQd3d3cXJyklatWsmPP/5o87jlqa/yWL9+vbRo0UI8PDzExcVFwsPDZdSoUbo2ly9flnfeeUdCQ0PFaDRKrVq15NVXX5XLly/btBs6dKh4e3uLq6urPP7443Ls2DGb58f1EhISin1OiVjr7Npxat26tTRo0EB27dolDz74oDg5OcmQIUNEROTnn3+Whx9+WLy8vMTR0VGCg4MlNjZWRMr33L3ed999JwBkzZo1IiJy6dIl7X9dktjYWGnSpIlWP1OmTCm2XWRkpHTu3LnUdZEeA8RtLCAgQOrUqVPu9gCkXr16UqNGDRk7dqzMmDFDkpOTJSUlRVxcXMTPz0/+7//+TyZOnCghISFiMplkx44d2vJvvPGGGAwGGTBggMyZM0fef/99efrpp2XixIlam+bNm4uHh4e89dZbMnfuXJkwYYI89NBDsmXLlnL18ZFHHpF27dqJiMjRo0fFYDDI8uXLtfszMjJk3LhxAkAGDhwoCxYskAULFsihQ4dk/fr10rhxY/H29tbmFwWkohe2yMhIiYqKkg8++EDGjBkjzs7O0qRJk1L7tHfvXnn66acFgHzwwQfaui9evFhs+5ycHFmwYIEWaEwmk/Tq1avMoFe0vZMnT7a5r06dOtKpU6dSl1+4cKEAkF9//VWbd/HiRZk5c6YAkCeeeELr+969e0VEZMOGDeLg4CDh4eEyefJkGTt2rHh7e4unp6cuIF0fIP78808JDAyUxo0ba/MuXbokjRo1Ei8vL3njjTdk1qxZ8swzz4jBYNDeMESsbwyRkZFSt25dmTRpkkyePFm8vb2lVq1akp+fb7Ntw4YNkwULFpS6/UV++OEHAaB73hT1z8XFReLi4rR51weInJwcCQ8PF0dHR3nttdfkww8/lKioKGnUqJHNG+PGjRvFaDRKs2bN5P3335cPPvhAGjVqJEajUXbu3Km1K299lSUlJUULWlOnTpVZs2bJiBEjpFWrVlqbgoICefjhh8XZ2VmGDh0qH3/8sbz44ovi4OBgE2D//e9/CwDp3bu3fPTRR9KtWzdtO0t7kz506JAsWLBAAEj79u2155RIyQHC19dXatSoIS+99JJ8/PHH8tVXX8nJkyfF09NTwsPDZcqUKTJnzhx58803pV69eiJS9nO3OMOHDxcAsnHjRomKihIAYjQapVevXnL27Fmb9jt37hQ7OzvZtm2bLtgWp3///uLt7V3iY5MtBojb1Llz5wRAmZ9qrwVA7Ozs5LffftPN79q1qxiNRjl06JA27++//xY3Nzfdi9O//vWvUt/EsrKySi3Aspw8eVIcHBxkzpw52rzmzZvbbOPPP/+s7XW4XqdOnbS9DtcqemGrV6+e5OXlafOnTp0qAGTfvn2l9q1o70Bpex127dolgwcPFg8PDwEgUVFR8tFHH5W61+B6zZo1k6ioKN28pKQkASDz588vddm33npLAMiFCxd080+fPl3im0Ljxo3Fx8dH9+K6d+9esbOzk2eeeUabd22AOHDggPj7+8v999+v27b/+7//ExcXF/njjz90jzFy5Eixt7eX9PR0EbEGCC8vL93yq1atEgDy9ddf65YfMWKEAJCPPvqo1O0vYrFYJCAgQLp3766bv3z5cgEgP/zwgzbv+gDx4Ycf2oSPS5cuSd26dXVvjBaLRcLCwqRDhw5isVi0tjk5ORISEiLt27fX5pW3vsrywQcf2OwFul7RHrytW7fq5s+aNUu3t3LPnj0CQF544QVdu969e5cZIIoU7dG8VkkBAoDMmjVL13blypUCQH7++ecSH6O0525xOnfurD23+vTpI59//rm8/fbb4uDgIM2bN9f9rywWizRp0kSefvppEZEyA8SECRMEgJw8ebJcfSERnoVxmzp//jwAKB/Q07p1a9338AUFBVi/fj26du2KOnXqaPP9/PzQu3dv/Pjjj9pjmc1m/Pbbb0hNTS123U5OTjAajdi8eTOysrJUNwlLly6FnZ0dunfvrs17+umn8d13393Q+ooTGxur+977wQcfBAAcPnz4hte5ePFiREZGIjo6GitWrEBsbCz27t2LXbt2IS4uDp6enuVeV69evbB7924cOnRIm7ds2TKYTCZ06dKl1GXPnj0LBwcHuLq6luuxTpw4gT179qBfv36oXr26Nr9Ro0Zo37491qxZY7NMSkoKWrdujeDgYHz//fe6bVuxYgUefPBBeHp64syZM9otJiYGBQUF+OGHH2y29drli/tfvPXWW3jvvfcwYcIEPPfcc7h8+XKJt7y8PACAwWBAjx49sGbNGly8eFFb17JlyxAQEICWLVuWOCZr1qyBn58fnnzySW2es7MzBg4cqGu3Z88epKamonfv3jh79qy2rZcuXUK7du3www8/wGKxKNVXWYqON1i1alWJp+iuWLEC9erVQ0REhO5/0LZtWwDApk2btO0EgJdfflm3/M0++LiIyWRCbGysbl7R9nzzzTe4cuXKTXmcov/3/fffj4ULF6J79+4YN24c/u///g/btm3Dxo0btbafffYZ9u3bh0mTJpVr3UXP1bvtVOaKxABxm3J3dwcAXLhwQWm5kJAQ3d+nT59GTk4O7rnnHpu29erVg8Vi0Y5eHjduHLKzsxEeHo57770Xr776Kn799VetvclkwqRJk/Ddd9+hZs2aaNWqFSZPnoyMjAytzblz55CRkaHdMjMztfsWLlyIJk2a4OzZs/jzzz/x559/IjIyEvn5+VixYoXSdpYkMDBQ93fRi8I/CSizZ8/Gnj17cN9992Hr1q344IMP0KhRo1KXyczM1I3DuXPnAAA9evSAnZ0dli1bBgAQEaxYsQIdO3bU/uc3y9GjRwGgxP990RvitR5//HG4ublh3bp1Nv1JTU3F2rVrUaNGDd2t6CDYU6dO6dqX9b/YtGkTxo8fDwB444034OTkVOrt2u3o1asXcnNzsXr1agCFbyxr1qxBjx49YDAYSh2TunXr2rS5foyKQnTfvn1ttnfu3LnIy8vDuXPnlOqrLL169UKLFi3Qv39/1KxZE0899RSWL1+uCxOpqan47bffbPoUHh4OwPo/OHr0KOzs7BAaGlrqdt4sAQEBNgestm7dGt27d8fYsWPh7e2NLl26ICEhQQuCN6LooMlrz0ICgN69ewMoPJsKKPwANmrUKLz66quoXbt2udYtIgBQ6vOH9HgWxm3K3d0d/v7+SElJUVquPEcll6RVq1Y4dOgQVq1ahfXr12Pu3Ln44IMPMGvWLPTv3x9A4SeYxx9/HF999RXWrVuHt99+G/Hx8UhMTERkZCSGDBmCefPmaets3bo1Nm/ejNTUVPz8888AgLCwMJvHXrRokc2nwBthb29f7PyiF4cb8d5772HmzJlYvnw56tevj9atWyM2Nhbdu3eHi4tLsct069YNW7Zs0f7u27cvPvvsM/j7++PBBx/E8uXL8cYbb2DHjh1IT08v16ckLy8vXL16FRcuXKiwU826d++OefPmYdGiRXj++ed191ksFrRv3x6vvfZascsWvYkVKet/cd999yE6Ohq7du3CsGHDygxl1+55eeCBBxAcHIzly5ejd+/e+Prrr5Gbm4tevXqVuY3lUfSmPWXKFJtTia/tzz95M7yek5MTfvjhB2zatAnffvst1q5di2XLlqFt27ZYv3497O3tYbFYcO+99+K///1vseso75vlzVbc647BYMDnn3+OHTt24Ouvv8a6devw7LPP4v3338eOHTvKvSftWv7+/gCAmjVr6ub7+PgAsIbT9957D/n5+ejVqxfS0tIAAH/99ZfWJi0tDf7+/rrQU7Sst7e3cr+qKgaI29hjjz2G2bNnY/v27WjWrNkNraNGjRpwdnbGwYMHbe77/fffYWdnp3vRqV69OmJjYxEbG4uLFy+iVatWGDNmjBYgACA0NBTDhw/H8OHDkZqaisaNG+P999/HwoUL8dprr+Hf//631rboU+eiRYtQrVo1LFiwwOaN5ccff8S0adOQnp6OwMDAUj8BVNSng9LWGx0djU8++QRTp07F0qVLMXfuXPTt2xcvvvgievbsidjYWLRo0UK3zPvvv6/b61H0wgcUftJ84YUXcPDgQSxbtgzOzs54/PHHy+xjREQEgMLTe699sy2p70FBQQBQ4v/e29vbJgBNmTIFDg4OeOGFF+Dm5qZ9sgMK/+8XL168aafdenh4YP369Wjbti0WLlyIzZs3K50G27NnT0ydOhXnz5/HsmXLEBwcjAceeKDUZYKCgpCSkgIR0Y3b9WNU9Mnd3d291O1Vra+y2NnZoV27dmjXrh3++9//YsKECXjzzTexadMmxMTEIDQ0FHv37kW7du1Kfc4GBQXBYrHg0KFDur0OxfWzoj3wwAN44IEHMH78eCxevBh9+vTB0qVL0b9/f+V6joqKwpw5c3D8+HHd/L///htA4f8DANLT05GVlYUGDRrYrGPChAmYMGECkpOTdeHwyJEj8Pb21tZBZeNXGLex1157DS4uLujfvz9Onjxpc/+hQ4cwderUUtdhb2+Phx9+GKtWrdKSOFD4M7WLFy9Gy5YttV3VZ8+e1S3r6uqKunXrap+ycnJycPnyZV2b0NBQuLm5aW3q16+PmJgY7RYVFQWgMEA8+OCD6NWrF5588knd7dVXXwUALFmyBAC0N7XifhXSxcVF+zrgZirtMYu4urqif//+2LFjB1JSUvDcc8/hq6++QsuWLREeHo758+drbaOionTjcO0bY/fu3WFvb48lS5ZgxYoVeOyxx0rck3GtohC5a9cu3fyi8+2v77ufnx8aN26MefPm6e5LSUnB+vXr8eijj9o8hsFgwOzZs/Hkk0+ib9++2lcEQOEb9vbt27Fu3Tqb5bKzs3H16tUyt+F6np6e2LBhA2rWrInp06crLdurVy/k5eVh3rx5WLt2LXr27FnmMo8++ij+/vtvfP7559q8nJwcmx9ji4qKQmhoKN577z3dcRZFin72WKW+ynLt131Fit7giuqrZ8+eOH78OObMmWPTNjc3V/tKqmPHjgCAadOm6dpUxK+4liQrK8tmz9/121PSc7ckXbp0gclkQkJCgu6rnblz5wIA2rdvD6Dw2I+VK1fqbh9//DEAoF+/fli5cqXN1727d+++4Q9qVVZlHsFJZVu1apU4OjqKp6enDBkyRObMmSMzZsyQPn36iNFolIEDB2ptUcxR0yLW08wCAgJk/PjxMmnSJKlTp47NaWY+Pj7Ss2dPmTRpksyZM0eef/55MRgM8tJLL4mISHJyslSvXl0GDRok06ZNk//973/Svn17ASCff/55iduwY8cOASAffvhhiW2ioqLk3nvvFZHC36Mwm81yzz33yNy5c2XJkiVy+PBhERGZPHmyAJBhw4bJ4sWLZfXq1SJiPTp8xYoVuvUWHXld3Bkd1yo6E+LRRx+V+fPny5IlS0o8jfNaeXl5snTpUmnfvr1069atzPZFYmJixM3NTQDIF198Ue7lGjZsqB1Vfq369euLr6+vzJgxQ5YsWaKddVJ0GmdERIRMmTJFxo0bJzVq1BBPT09tTEVsT+PMz8+XRx99VEwmk2zcuFFECs9WuO+++8TBwUH69+8vM2fOlPfee0/69u0rLi4u2rKlHe2OEo64P3PmTLGnd5albt262jju3r3b5v7rz8IoOuPC0dFRXn/99VJP49y0aZM4OjpKYGCgjB49WmbPni2jR4+WVq1ayWOPPaa1K299lWXIkCESGRkpb731lsyZM0fGjx8vAQEBUqtWLcnOzhaRwtM4H330UTEYDPLUU0/J9OnT5cMPP5RBgwZJ9erVdWc8FJ2a3KdPH5kxY0a5T+MsUtzrSWm/A3G9Dz74QMLCwuS1116Tjz/+WN577z255557xN3dXffcK+m5W5Ki07zbt28vM2bMkIEDB4rBYCi2Lq5V2vPy5MmTYm9vL3Pnzi11HaTHAHEH+OOPP2TAgAESHBwsRqNR3NzcpEWLFjJ9+nTdj8eUFCBECn/opkOHDuLq6irOzs7y0EMPybZt23Rt3n33XWnSpImYzWZxcnKSiIgIGT9+vPbCfubMGYmLi5OIiAhxcXERDw8Padq0qc35+Nd76aWXBIDuNLfrjRkzRgBo54CvWrVK+zGsawPAxYsXpXfv3mI2m4v9IakbDRAihacpBgQEiJ2dXbl/SOpa5QkcRebMmSMAxM3NTfdDWmX573//K66urpKTk6Obv23bNomKihKj0WjzBvH9999LixYtxMnJSdzd3eXxxx8v1w9J5eTkSOvWrcXV1VV7I7xw4YKMGjVK6tatK0ajUby9vaV58+by3nvvac+TGwkQN+rNN98UAFK3bt1i7y/uh6SOHj0qnTt3FmdnZ/H29pYhQ4aU+ENSycnJ0q1bN/Hy8hKTySRBQUHSs2dPLVQVKU99lWXjxo3SpUsX8ff3F6PRKP7+/vL000/bnDabn58vkyZNkgYNGojJZBJPT0+JioqSsWPHyrlz57R2ubm58vLLL4uXl5e4uLiU+4ekivzTAPHLL7/I008/LYGBgWIymcTHx0cee+wx2bVrl65dac/d4lgsFpk+fbqEh4dLtWrVpHbt2vLWW2+VGUBLe17OnDlTnJ2d5fz586Wug/QMIv/g6DIiuqXOnTuHOnXqYPLkyXjuuecquzt0BzIYDBg9ejTGjBlT2V25bURGRqJNmzb44IMPKrsrdxQeA0F0B/Hw8MBrr72GKVOm3NTLeRNVVWvXrkVqaipGjRpV2V2543APBBFRBbv2t1KK4+TkBA8Pj1vSF+6BoJuFp3ESEVWwa6+kW5yi3wkhupMwQBARVbANGzaUev+1vxNS0bjTmW4WfoVBREREyngQJRERESljgCAiIiJlDBBERESkjAGCiIiIlDFAEBERkTIGCCIiIlLGAEFERETKGCCIiIhIGQMEERERKWOAICIiImUMEERERKSMAYKIiIiUMUAQERGRMgYIIiIiUsYAQURERMoYIIiIiEgZAwQREREpY4AgIiIiZQwQREREpIwBgoiIiJQxQBAREZEyBggiIiJSxgBBREREyhggiIiISBkDBBERESljgCAiIiJlDBBERESkjAGCiIiIlDFAEBERkTIGCCIiIlLGAEFERETKGCCIiIhIGQMEERERKWOAICIiImUMEERERKSMAYKIiIiUMUAQERGRMgYIIiIiUsYAQURERMoYIIiIiEgZAwQREREpY4AgIiIiZQwQREREpIwBgoiIiJQxQBAREZEyBggiIiJSxgBBREREyhggiIiISBkDBBERESljgCAiIiJlDBBERESkjAGCiIiIlDFAEBERkTIGCCIiIlLGAEFERETKGCCIiIhIGQMEERERKWOAICIiImUMEERERKSMAYKIiIiUMUAQERGRMgYIIiIiUsYAQURERMoYIIiIiEgZAwQREREpY4AgIiIiZQwQREREpIwBgoiIiJQxQBAREZEyBggiIiJSxgBBREREyhggiIiISBkDBBERESljgCAiIiJlDBBERESkjAGCiIiIlDFAEBERkTIGCCIiIlLGAEFERETKGCCIiIhIGQMEERERKWOAICIiImUMEERERKSMAYKIiIiUMUAQERGRMgYIIiIiUsYAQURERMoYIIiIiEgZAwQREREpY4AgIiIiZQwQREREpIwBgoiIiJQxQBAREZEyBggiIiJSxgBBREREyhggiIiISBkDBBERESljgCAiIiJlDBBERESkjAGCiIiIlDFAEBERkTIGCCIiIlLGAEFERETKGCCIiIhIGQMEERERKWOAICIiImUMEERERKSMAYKIiIiUMUAQERGRMgYIIiIiUsYAQURERMoYIIiIiEgZAwQREREpY4AgIiIiZQwQREREpIwBgoiIiJQxQBAREZEyBggiIiJSxgBBREREyhggiIiISBkDBBERESljgCDNmDFjYDAYcObMmcruiqZfv34IDg6u7G4Q3bVY93SjGCDornHgwAE88sgjcHV1RfXq1fGf//wHp0+fruxuEVEFSUpKwgsvvICoqChUq1YNBoOhsrtUpTBA0F3hr7/+QqtWrfDnn39iwoQJGDFiBL799lu0b98e+fn5ld09IqoAa9aswdy5c2EwGFCnTp3K7k6VwwBBd4UJEybg0qVLSExMxMsvv4w33ngDy5cvx969e/HZZ59VdveIqAIMHjwY586dw65du9C+ffvK7k6VwwBBNrKzs9GvXz+YzWZ4eHggNjYWOTk5Nu0WLlyIqKgoODk5oXr16njqqadw7NgxXZutW7eiR48eCAwMhMlkQu3atTFs2DDk5ubarO+rr75Cw4YN4ejoiIYNG2LlypXl7vMXX3yBxx57DIGBgdq8mJgYhIeHY/ny5QpbT1Q13Yl1X7NmTTg5OalvLN0UDpXdAbr99OzZEyEhIYiPj8cvv/yCuXPnwsfHB5MmTdLajB8/Hm+//TZ69uyJ/v374/Tp05g+fTpatWqF5ORkmM1mAMCKFSuQk5ODwYMHw8vLC0lJSZg+fTr++usvrFixQlvf+vXr0b17d9SvXx/x8fE4e/YsYmNjUatWrTL7e/z4cZw6dQrR0dE29zVp0gRr1qz554NCdJe70+qebgNC9P+NHj1aAMizzz6rm//EE0+Il5eX9ndaWprY29vL+PHjde327dsnDg4Ouvk5OTk2jxMfHy8Gg0GOHj2qzWvcuLH4+flJdna2Nm/9+vUCQIKCgkrt988//ywAZP78+Tb3vfrqqwJALl++XOo6iKqqO7XurxcXFyd8S7u1+BUG2Rg0aJDu7wcffBBnz57F+fPnAQBffvklLBYLevbsiTNnzmg3X19fhIWFYdOmTdqy1+5evHTpEs6cOYPmzZtDRJCcnAwAOHHiBPbs2YO+ffvCw8NDa9++fXvUr1+/zP4W7RY1mUw29zk6OuraEFHx7rS6p8rHrzDIxrXHEQCAp6cnACArKwvu7u5ITU2FiCAsLKzY5atVq6ZNp6en45133sHq1auRlZWla3fu3DkAwNGjRwGg2PXdc889+OWXX0rtb9GLVV5ens19ly9f1rUhouLdaXVPlY8BgmzY29sXO19EAAAWiwUGgwHfffddsW1dXV0BAAUFBWjfvj0yMzPx+uuvIyIiAi4uLjh+/Dj69esHi8VyU/rr5+cHoPATzfVOnDiB6tWrF7t3gois7rS6p8rHAEHKQkNDISIICQlBeHh4ie327duHP/74A/PmzcMzzzyjzd+wYYOuXVBQEAAgNTXVZh0HDx4ssz8BAQGoUaMGdu3aZXNfUlISGjduXOY6iKh0t1vdU+XjMRCkrFu3brC3t8fYsWO1TydFRARnz54FYP1Ec20bEcHUqVN1y/j5+aFx48aYN2+etnsTKHzB2b9/f7n61L17d3zzzTe608k2btyIP/74Az169FDbQCKycTvWPVUu7oEgZaGhoXj33XcxatQopKWloWvXrnBzc8ORI0ewcuVKDBw4ECNGjEBERARCQ0MxYsQIHD9+HO7u7vjiiy9svhMFgPj4eHTq1AktW7bEs88+i8zMTEyfPh0NGjTAxYsXy+zTG2+8gRUrVuChhx7CkCFDcPHiRUyZMgX33nsvYmNjK2IYiKqU27Hujx49igULFgCAtgfy3XffBVC4h+M///nPTRwBsnHrT/yg21XR6VynT5/WzU9ISBAAcuTIEd38L774Qlq2bCkuLi7i4uIiEREREhcXJwcPHtTa7N+/X2JiYsTV1VW8vb1lwIABsnfvXgEgCQkJNuurV6+emEwmqV+/vnz55ZfSt2/fcp/OlZKSIg8//LA4OzuL2WyWPn36SEZGxo0MBVGVcSfX/aZNmwRAsbfWrVvf4IhQeRlErtsXRURERFQGHgNBREREyhggiIiISBkDBBERESmrsACRmZmJPn36wN3dHWazGc8991yZR9W2adMGBoNBd7v+51WJ6PbFuieqOiosQPTp0we//fYbNmzYgGeffRbz58+Hh4cHmjZtiqSkpBKXGzBgAGbPno3Q0FCYTCZs3bqVV1MkukMU1f3zzz8PJycnJCQkIDg4uNSaB4B27dppNR8REYGYmJhb1GMiulEVEiAOHDiAtWvXYu7cuUhLS8OMGTPw0ksvwWKxICwsDB06dMCpU6eKXTY7OxuDBw/G888/j+TkZHTr1g1du3ZFSkpKRXSViG6Sorrv2bMnpk2bhgkTJmDmzJk4e/YsHn744RJr/ty5c0hMTNRq/sknn0Tv3r1Z80S3uQo5jfPTTz/F8OHDkZWVhaZNm8LR0RFHjx7F0aNHUbduXWRnZ2P48OEYOXKkbrk2bdrgxx9/REFBgb6TBgMGDhyIWbNm2TxWXl6e7iJKFosFmZmZ8PLygsFguNmbRlSliAguXLgAf39/2NmV/nmjqO7Dw8O1ms/IyEBeXh7c3d0xatQom5oHAHd3d1y4cEE3r7SaB1j3RBVFpeYr5Iekxo8fL+Hh4ZKXlycGg0EcHBzk008/FU9PT2nRooUYjUbp0KGDzXIff/yxuLm5idFolI8++kh8fX2lY8eO8sorr0ijRo2KfayiH0HhjTfeKu527NixctV9WFiYruZ/++03cXR0FHt7+2JrXkTE2dlZjEajJCYmanVft27dEmuedc8bbxV/K0/NK+2BGDlyJCZNmlRqmwMHDuDLL7/EvHnzsGnTJgQEBKB79+74/PPP4ePjg9GjR2PkyJEwm8266xYUcXBwgMlkwqVLl5CYmIh27dphzJgx+N///oeTJ0/atD9//rx2vfqivxs0aIDWYS/Cwd6Eo128dO0v187XphuHpWvTL/tt1LVbltlUm/5+WyNt2vWINZGda3RFt0zEu39p038MC7Ju0wX9JyLzn9Yhd1/2szZ9ZPz9unbVAq0Hnxl/ctema36SrGt37JVIFMdO3z1U339Vmy4wWft0toH1ynoujc/qlsndZR2/6gese4b+fkS/l+iet45o02mD7tGmnU/qn16ux619sAyyPtbVhT66dgWO1v4Zrnmo515dpWuXkNZcmzbN9NSmr7ha/09Z4forB14OsA5Mw3Dr/yx1S4iu3X+esD4nPlvfVpuu9a+/de3+3uWvTXeIsV7Q65uf79O1c/zb2g/PVOtGWfqd0bXL3GMdi7VPT9Om266N06b3dfpMt8x9S57Tpv0aW69K+t/Qz3Xtnpv+sjbtnGHtQ+w7+nGd+EvHwr7l5uGvlyejLEV1/8knn+Dw4cNazQNAjRo1cOHCBdSoUaPMmgeg1b2XlxfOnDlj0x4oue7v7fE27Ks54nwdfXu3o9bpq07W59alAP3z01A7R5u+km29gmv1vdb/nf11V40/d811pYJXWq/p8Pcb+na+T1svEGX4Vz1tOtffWdfu727W52dwgnV+2nW/yG53wlGbNmZZt6lRp9917Q6siLA+7lXr9ua2tr6+OOxy06/bWqbI8bcu49tA/xqct6KmNt11SKI2veRQtK6d71TrVRMuvGod47PZrrp21UzXvGjts77mWRyv+z/Vtfbd2dH6mu4227odBXGZumUGB23RpuN/f0SbHhWxVtduYcYD2vSJVdbX8VxvfR8CH7C+dhRMqaFNZ4Xpr/zrkGudNl60XonUPFhfC78fs47l3jaLtOmobU9p055uObplTh22vj6Li7We+zTeqWsX6ZymTY+Za/1577z7LunaOTvnoyAnD78/Ow3Z2dnw8PBAaZSuhTF8+HD069ev1DZ16tSBr68vTp06hfz8wn9sdHQ0rl69iszMTPj7+yMoKAjHjx8vcR2XL19GUFCQ9lXG77//XmLbyZMnY/z48SqbQUSKkpKS4ObmVuL9RXV/+vRpAIU1DwBXr15FVlYW/P39bb6muFZRzVssFvzrX/8CAJuvMq/Fuie6DZS5j+IG7N+/XwDI6tWrBYDEx8fLunXrxGAwyPHjx6VBgwZiNpuLXdbHx0d69+4tycnJMn36dAEgRqNRIiIiim3/5ptvVvquHt54u9tv5dmdWVT3QGHNi4hW9+Hh4eWq+c2bN0vz5s0FgISEhJT4WKx73nir2NtN/wpDRceOHfHXX38hJSUFbdq0wfHjxxEdHY2FCxfCw8MDbm5ucHd3x/z589GkSRMcOnQIixcvxtatW1FQUIAhQ4Zg2LBh8Pf3x88//4wGDRpg9+7dNo9T0sFU1apVQ2BgII4dOwZ3d3eb5aqK8+fPo3bt2hwHjgMA9XEQlQOqADz00EPYvHkz2rZti3HjxiE2NhZRUVH45ptvYDabsWPHDrRr105X9127doWnpyfmz5+PX3/9FUOHDkVaWhoiIyOLrXmAdV8WPt8LcRwKqYyDSs1X2OW8Fy1ahMGDByMlJQWbN29Gy5Yt8corr2Dw4MHIy8vDfffdhx9++AFvvfUWoqKi8MILL+D7779HcnIyLly4gL1796JLly7w8fHBTz/9hJo1axb7OCaTCSaT/jsns9msfT/q7u5epZ84RTgOhTgOhVTGoazvQa+1ZMkS+Pn5ITExEdu2bUPHjh3h5OSEvLw8NGjQAFeuXMHBgwcxdepULFq0CEajEfb29ti6dSvq1q2LgIAA1KpVC2lpaSXWPMC6Ly+OQyGOQ6HyjkN5a77CfkiqevXqWLZsGZo0aYJWrVohPT0dLVq0QHJyMsxmMzp27AgRQX5+Pk6cOIHatWtjy5YtOH/+PJYvXw4vLy8sXLgQq1evRkBAACIiIsp+UCKqVL6+vlrN+/j44Ntvv0VKSgrMZjPatGmD4OBgtG7dGtWqVQMA1K5dG3v27MHy5csRGhqKjIwMZGdns+aJ7gAVfi2MV155BTt37sS4ceOwZ88eREZG4sqVK4iNLTykODAwEH5+flr7cePGwcPDA9999x22b9+ORo0a4cyZM+jfv39Fd5WIbgLWPFHVUGFfYRTp1asXTp8+jXfeeQcZGRlo3Lgx1q5dq+2eTE9P133PkpWVhQEDBiAjIwOenp6IiorCtm3bUL9+faXHNZlMGD16tM1uzqqG41CI41DoVoxDZdU8wP9zEY5DIY5DoYoahwo7iJKIiIjuXrycNxERESljgCAiIiJlDBBERESkjAGCiIiIlN2VAWLGjBkIDg6Go6MjmjZtiqSkpMruUoWKj4/H/fffDzc3N/j4+KBr1644ePCgrs3ly5cRFxcHLy8vuLq6onv37sVenOxuMnHiRBgMBgwdOlSbV1XG4fjx4/j3v/8NLy8vODk54d5778WuXdYLfYkI3nnnHfj5+cHJyQkxMTFITU2txB7/c6x71j3Aur+ldV/mj13fYZYuXSpGo1G7lPCAAQPEbDbLyZMnK7trFaZDhw6SkJAgKSkpsmfPHnn00UclMDBQLl68qLUZNGiQ1K5dWzZu3Ci7du2SBx54QJo3b16Jva5YSUlJEhwcLI0aNZIhQ4Zo86vCOGRmZkpQUJD069dPdu7cKYcPH5Z169bJn3/+qbWZOHGieHh4yFdffSV79+6Vzp07S0hIiOTm5lZiz28c6551L8K6v9V1f9cFiCZNmkhcXJz2d0FBgfj7+2sX96kKTp06JQBky5YtIiKSnZ0t1apVkxUrVmhtDhw4IABk+/btldXNCnPhwgUJCwuTDRs2SOvWrbUXkqoyDq+//rq0bNmyxPstFov4+vrKlClTtHnZ2dliMplkyZIlt6KLNx3rnnXPur/1dX9XfYWRn5+P3bt3IyYmRptnZ2eHmJgYbN++vRJ7dmudO3cOQOHPiQPA7t27ceXKFd24REREIDAw8K4cl7i4OHTq1Em3vUDVGYfVq1cjOjoaPXr0gI+PDyIjIzFnzhzt/iNHjiAjI0M3Dh4eHmjatOkdOQ6s+0Kse9b9ra77uypAnDlzBgUFBTYX4alZsyYyMjIqqVe3lsViwdChQ9GiRQs0bNgQAJCRkQGj0Qiz2axrezeOy9KlS/HLL78gPj7e5r6qMg6HDx/GzJkzERYWhnXr1mHw4MF4+eWXMW/ePADQtvVuqRPWPeuedV85dV/hP2VNt1ZcXBxSUlLw448/VnZXbrljx45hyJAh2LBhAxwdHSu7O5XGYrEgOjoaEyZMAABERkYiJSUFs2bNQt++fSu5d1QRWPes+8qo+7tqD4S3tzfs7e1tjq49efIkfH19K6lXt86LL76Ib775Bps2bUKtWrW0+b6+vsjPz0d2drau/d02Lrt378apU6dw3333wcHBAQ4ODtiyZQumTZsGBwcH1KxZs0qMg5+fn811JOrVq4f09HQA0Lb1bqkT1j3rnnVfOXV/VwUIo9GIqKgobNy4UZtnsViwceNGNGvWrBJ7VrFEBC+++CJWrlyJxMREhISE6O6PiopCtWrVdONy8OBBpKen31Xj0q5dO+zbtw979uzRbtHR0ejTp482XRXGoUWLFjan8/3xxx8ICgoCAISEhMDX11c3DufPn8fOnTvvyHFg3bPuWfeVVPc3dOjlbWzp0qViMpnks88+k/3798vAgQPFbDZLRkZGZXetwgwePFg8PDxk8+bNcuLECe2Wk5OjtRk0aJAEBgZKYmKi7Nq1S5o1aybNmjWrxF7fGtcejS1SNcYhKSlJHBwcZPz48ZKamiqLFi0SZ2dnWbhwodZm4sSJYjabZdWqVfLrr79Kly5d7vjTOFn3rPsirPtbU/d3XYAQEZk+fboEBgaK0WiUJk2ayI4dOyq7SxUKQLG3hIQErU1ubq688MIL4unpKc7OzvLEE0/IiRMnKq/Tt8j1LyRVZRy+/vpradiwoZhMJomIiJDZs2fr7rdYLPL2229LzZo1xWQySbt27eTgwYOV1Nubg3XPui/Cur81dc/LeRMREZGyu+oYCCIiIro1GCCIiIhIGQMEERERKWOAICIiImUMEERERKSMAYKIiIiUMUAQERGRMgYIIiIiUsYAQURERMoYIIiIiEgZAwQREREpY4AgIiIiZQwQREREpIwBgoiIiJQxQBAREZEyBggiIiJSxgBBREREyhggiIiISBkDBBERESljgCAiIiJlDBBERESkjAGCiIiIlDFAEBERkTIGCCIiIlLGAEFERETKGCCIiIhIGQMEERERKWOAICIiImUMEERERKSMAYKIiIiUMUAQERGRMgYIIiIiUsYAQURERMoYIIiIiEgZAwQREREpY4AgIiIiZQwQREREpIwBgoiIiJQxQBAREZEyBggiIiJSxgBBREREyhggiIiISBkDBBERESljgCAiIiJlDBBERESkjAGCiIiIlDFAEBERkTIGCCIiIlLGAEFERETKGCCIiIhIGQMEERERKWOAICIiImUMEERERKSMAYKIiIiUMUAQERGRMgYIIiIiUsYAQURERMoYIIiIiEgZAwQREREpY4AgIiIiZQwQREREpIwBgoiIiJQxQBAREZEyBggiIiJSxgBBREREyhggiIiISBkDBBERESljgCAiIiJlDBBERESkjAGCiIiIlDFAEBERkTIGCCIiIlLGAEFERETKGCCIiIhIGQMEERERKWOAICIiImUMEERERKSMAYKIiIiUMUAQERGRMgYIIiIiUsYAQURERMoYIIiIiEgZAwQREREpY4AgIiIiZQwQREREpIwBgoiIiJQxQBAREZEyBggiIiJSxgBBREREyhggiIiISBkDBBERESljgCAiIiJlDBBERESkjAGCiIiIlDFAEBERkTIGCCIiIlLGAEFERETKGCCIiIhIGQMEERERKWOAICIiImUMEERERKSMAYKIiIiUMUAQERGRMgYIIiIiUsYAQURERMoYIIiIiEgZAwQREREpY4AgIiIiZQwQREREpIwBgoiIiJQxQBAREZEyBggiIiJSxgBBREREyhggiIiISBkDBBERESljgCAiIiJlDBBERESkjAGCiIiIlDFAEBERkTIGCCIiIlLGAEFERETKGCCIiIhIGQMEERERKWOAICIiImUMEERERKSMAYKIiIiUMUAQERGRMgYIIiIiUsYAQURERMoYIIiIiEgZAwQREREpY4AgIiIiZQwQREREpIwBgoiIiJQxQBAREZEyBggiIiJSxgBBREREyhggiIiISBkDBBERESljgCAiIiJlDBBERESkjAGCiIiIlDFAEBERkTIGCCIiIlLGAEFERETKGCCIiIhIGQMEERERKWOAICIiImUMEERERKSMAYKIiIiUMUAQERGRMgYIIiIiUsYAQURERMoYIIiIiEgZAwQREREpY4AgIiIiZQwQREREpIwBgoiIiJQxQBAREZEyBggiIiJSxgBBREREyhggiIiISBkDBBERESljgCAiIiJlDBBERESkjAGCiIiIlDFAEBERkTIGCCIiIlLGAEFERETKGCCIiIhIGQMEERERKWOAICIiImUMEERERKSMAYKIiIiUMUAQERGRMgYIIiIiUsYAQURERMoYIIiIiEgZAwQREREpY4AgIiIiZQwQREREpIwBgoiIiJQxQBAREZEyBggiIiJSxgBBREREyhggiIiISBkDBBERESljgCAiIiJlDBBERESkjAGCiIiIlDFAEBERkTIGCCIiIlLGAEFERETKGCCIiIhIGQMEERERKWOAICIiImUMEERERKSMAYKIiIiUMUAQERGRMgYIIiIiUsYAQURERMoYIIiIiEgZAwQREREpY4AgIiIiZQwQREREpIwBgjRjxoyBwWDAmTNnKrsrmn79+iE4OLiyu0F012Ld041igKA7nsViwWeffYbOnTujdu3acHFxQcOGDfHuu+/i8uXLld09Iqogc+bMQevWrVGzZk2YTCaEhIQgNjYWaWlpld21KsGhsjtA9E/l5OQgNjYWDzzwAAYNGgQfHx9s374do0ePxsaNG5GYmAiDwVDZ3SSimyw5ORkhISHo3LkzPD09ceTIEcyZMwfffPMN9u7dC39//8ru4l2NAYLueEajET/99BOaN2+uzRswYACCg4O1EBETE1OJPSSiivC///3PZl7Xrl0RHR2N+fPnY+TIkZXQq6qDX2GQjezsbPTr1w9msxkeHh6IjY1FTk6OTbuFCxciKioKTk5OqF69Op566ikcO3ZM12br1q3o0aMHAgMDYTKZULt2bQwbNgy5ubk26/vqq6/QsGFDODo6omHDhli5cmW5+ms0GnXhocgTTzwBADhw4EC51kNUld1pdV+SomMnsrOz/9F6qGzcA0E2evbsiZCQEMTHx+OXX37B3Llz4ePjg0mTJmltxo8fj7fffhs9e/ZE//79cfr0aUyfPh2tWrVCcnIyzGYzAGDFihXIycnB4MGD4eXlhaSkJEyfPh1//fUXVqxYoa1v/fr16N69O+rXr4/4+HicPXsWsbGxqFWr1g1vR0ZGBgDA29v7htdBVFXcyXV/9uxZFBQUID09HePGjQMAtGvX7p8PCpVOiP6/0aNHCwB59tlndfOfeOIJ8fLy0v5OS0sTe3t7GT9+vK7dvn37xMHBQTc/JyfH5nHi4+PFYDDI0aNHtXmNGzcWPz8/yc7O1uatX79eAEhQUNANbU9MTIy4u7tLVlbWDS1PVBXcDXVvMpkEgAAQLy8vmTZtWrmXpRvHrzDIxqBBg3R/P/jggzh79izOnz8PAPjyyy9hsVjQs2dPnDlzRrv5+voiLCwMmzZt0pZ1cnLSpi9duoQzZ86gefPmEBEkJycDAE6cOIE9e/agb9++8PDw0Nq3b98e9evXv6FtmDBhAr7//ntMnDhR+1RERCW7k+v+u+++w5o1a/D+++8jMDAQly5dUt5+UsevMMhGYGCg7m9PT08AQFZWFtzd3ZGamgoRQVhYWLHLV6tWTZtOT0/HO++8g9WrVyMrK0vX7ty5cwCAo0ePAkCx67vnnnvwyy+/KPV/2bJleOutt/Dcc89h8ODBSssSVVV3ct0/9NBDAICOHTuiS5cuaNiwIVxdXfHiiy+Wex2kjgGCbNjb2xc7X0QAFP7ugsFgwHfffVdsW1dXVwBAQUEB2rdvj8zMTLz++uuIiIiAi4sLjh8/jn79+sFisdz0vm/YsAHPPPMMOnXqhFmzZt309RPdre7kur9WaGgoIiMjsWjRIgaICsYAQcpCQ0MhIggJCUF4eHiJ7fbt24c//vgD8+bNwzPPPKPN37Bhg65dUFAQACA1NdVmHQcPHix3v3bu3IknnngC0dHRWL58ORwc+PQmullu17ovTm5uLvLy8v7ROqhsPAaClHXr1g329vYYO3as9umkiIjg7NmzAKyfaK5tIyKYOnWqbhk/Pz80btwY8+bN03ZvAoUvOPv37y9Xnw4cOIBOnTohODgY33zzje47WCL65263ur969arN1yMAkJSUhH379iE6Orr8G0c3hB/RSFloaCjeffddjBo1CmlpaejatSvc3Nxw5MgRrFy5EgMHDsSIESMQERGB0NBQjBgxAsePH4e7uzu++OKLYos+Pj4enTp1QsuWLfHss88iMzMT06dPR4MGDXDx4sVS+3PhwgV06NABWVlZePXVV/Htt9/a9LdZs2Y3dQyIqprbre4vXryI2rVro1evXmjQoAFcXFywb98+JCQkwMPDA2+//XZFDQUVufUnftDtquh0rtOnT+vmJyQkCAA5cuSIbv4XX3whLVu2FBcXF3FxcZGIiAiJi4uTgwcPam32798vMTEx4urqKt7e3jJgwADZu3evAJCEhASb9dWrV09MJpPUr19fvvzyS+nbt2+Zp3MdOXJEO4WruFvfvn3/wagQ3d3u1LrPy8uTIUOGSKNGjcTd3V2qVasmQUFB8txzz9n0mSqGQeS6fVFEREREZeAxEERERKSMAYKIiIiUMUAQERGRsgoLEJmZmejTpw/c3d1hNpvx3HPPlXlUbZs2bWAwGHS3639elYhuX6x7oqqjwgJEnz598Ntvv2HDhg149tlnMX/+fHh4eKBp06ZISkoqcbkBAwZg9uzZCA0NhclkwtatW7FmzZqK6iYR3URFdf/888/DyckJCQkJCA4OLrXmgcIrJxbVfEREBGJiYm5Rj4noRlVIgDhw4ADWrl2LuXPnIi0tDTNmzMBLL70Ei8WCsLAwdOjQAadOnSp22ezsbAwePBjPP/88kpOT0a1bN3Tt2hUpKSkV0VUiukmK6r5nz56YNm0aJkyYgJkzZ+Ls2bN4+OGHS6z5c+fOITExUav5J598Er1792bNE93mKuQ0zk8//RTDhw9HVlYWmjZtCkdHRxw9ehRHjx5F3bp1kZ2djeHDh2PkyJG65dq0aYMff/wRBQUF+k4aDBg4cGCx1zbIy8vT/WSpxWJBZmYmvLy8YDAYbvamEVUpIoILFy7A398fdnalf94oqvvw8HCt5jMyMpCXlwd3d3eMGjXKpuYBwN3dHRcuXNDNK63mAdY9UUVRqfkK+SGp8ePHS3h4uOTl5YnBYBAHBwf59NNPxdPTU1q0aCFGo1E6dOhgs9zHH38sbm5uYjQa5aOPPhJfX1/p2LGjvPLKK9KoUaNiH6voR1B44423irsdO3asXHUfFhamq/nffvtNHB0dxd7evtiaFxFxdnYWo9EoiYmJWt3XrVu3xJpn3fPGW8XfylPzSnsgRo4ciUmTJpXa5sCBA/jyyy8xb948bNq0CQEBAejevTs+//xz+Pj4YPTo0Rg5ciTMZjOOHTtms7yDgwNMJhMuXbqExMREtGvXDmPGjMH//vc/nDx50qb9+fPntevVF/3doEED1J7+GuycTHilsf4CLrM+6axN+83fp03LPUG6dke6uGnT3vdad736u1ofa//3+svQ1nrQuj15H/lq0xnN9L8Y3rdjojb95Yy22nTmffo9Lz3vt35vvGGu9aeY8930n7C2xM3Uppv80E+b7tVgt67d9jMh2vRfyf7atN01D2t/Sb/uex7+U5s+tiBUm86O0DWDwS9Xm647+qz1cbrW1rWr38V6kZyUNfdo0zmBV3XtXNKsY3bF3foU3f2fT3Tt7v+kvzad523dELt863ZUC9QfxGde6apNZ7S1Pq7BqL9KoO+31ssTu6zaZV23m6uunSHA+r8ucDNp0yeb6Nvl+FvXHzYnQ5v+u4O/rt2FMOt2tIj8XZv+cb/1AkbGU/rnVPD71ufysRcaadNBi4/q2uGaTxTiYr1eyPH23rpm7umFfSi4chm7vxuPshTV/SeffILDhw9rNQ8ANWrUwIULF1CjRo0yax6AVvdeXl44c+ZMsY9XUt2HzhkGe2cT7Le569pfvebSKAFbrM+Hvx7S/4/s/mVdpyHJug77y9Y2ct0FAGokW5/7DuetDXP99Ov+65FrPs1dU3Oh9f7WtTt0wPp8CF1hXbf9mLO6dhkXrP3zmOti7cNl/evIkSeM2nT4rExtOjPK+j9363Nct0zmVwHatM/2bG36i6++0LWLWvicNi3BOdr0lTz9INmftfZh25PW16un+8fq2l32tLY72fWaPUxZJl07Q4G1vuvUt/b90HEfbbrWSv2n52Mdrcs4/m3tn/GcrhnO32N9TYiYZH2+Xq6nr1P7XGu7Y+2s4x/8dbauXdrjZm3a+ZT1tcx7rv64oMOT7rfed81Lt+fGQ9p0ej/9e45vkvX5dmqwdbzcPnfTtTMnWcfo8Hhrfyxp+ueo7/YCXL16Gbu+n4Ds7Gx4eHigNErXwhg+fDj69etXaps6derA19cXp06dQn5+PgAgOjoaV69eRWZmJvz9/REUFITjx4+XuI7Lly8jKChI+yrj999/L7Ht5MmTMX582S9wRHTjkpKS4ObmVuL9RXV/+vRpANAuZFR0wSN/f3+brymuVVTzFosF//rXvwDA5qvMa7HuiW4DZe6juAH79+8XALJ69WoBIPHx8bJu3ToxGAxy/PhxadCggZjN5mKX9fHxkd69e0tycrJMnz5dAIjRaJSIiIhi27/55puVvquHN97u9lt5dmcW1T1QWPMiotV9eHh4uWp+8+bN0rx5cwEgISEhJT4W65433ir2dtO/wlDRsWNH/PXXX0hJSUGbNm1w/PhxREdHY+HChfDw8ICbmxvc3d0xf/58NGnSBIcOHcLixYuxdetWFBQUYMiQIRg2bBj8/f3x888/o0GDBti9e7fN45R0MFW1atUQGBiIY8eOwd3d3Wa5quL8+fOoXbs2x4HjAEB9HETlgCoADz30EDZv3oy2bdti3LhxiI2NRVRUFL755huYzWbs2LED7dq109V9165d4enpifnz5+PXX3/F0KFDkZaWhsjIyGJrHmDdl4XP90Ich0Iq46BS8xV2Oe9FixZh8ODBSElJwebNm9GyZUu88sorGDx4MPLy8nDffffhhx9+wFtvvYWoqCi88MIL+P7775GcnIwLFy5g79696NKlC3x8fPDTTz+hZs2axT6OyWSCyaT/fsxsNmvfj7q7u1fpJ04RjkMhjkMhlXEo63vQay1ZsgR+fn5ITEzEtm3b0LFjRzg5OSEvLw8NGjTAlStXcPDgQUydOhWLFi2C0WiEvb09tm7dirp16yIgIAC1atVCWlpaiTUPsO7Li+NQiONQqLzjUN6ar7AfkqpevTqWLVuGJk2aoFWrVkhPT0eLFi2QnJwMs9mMjh07QkSQn5+PEydOoHbt2tiyZQvOnz+P5cuXw8vLCwsXLsTq1asREBCAiIiIsh+UiCqVr6+vVvM+Pj749ttvkZKSArPZjDZt2iA4OBitW7dGtWqFB6jWrl0be/bswfLlyxEaGoqMjAxkZ2ez5onuABV+LYxXXnkFO3fuxLhx47Bnzx5ERkbiypUriI0tPPo2MDAQfn5+Wvtx48bBw8MD3333HbZv345GjRrhzJkz6N+/f0kPQUS3EdY8UdVQYV9hFOnVqxdOnz6Nd955BxkZGWjcuDHWrl2r7Z5MT0/Xfc+SlZWFAQMGICMjA56enoiKisK2bdtQv359pcc1mUwYPXq0zW7OqobjUIjjUOhWjENl1TzA/3MRjkMhjkOhihqHCjuIkoiIiO5evJw3ERERKWOAICIiImUMEERERKSMAYKIiIiU3ZUBYsaMGQgODoajoyOaNm2KpKSkshe6g8XHx+P++++Hm5sbfHx80LVrVxw8eFDX5vLly4iLi4OXlxdcXV3RvXv3Yi9OdjeZOHEiDAYDhg4dqs2rKuNw/Phx/Pvf/4aXlxecnJxw7733Ytcu6wXBRATvvPMO/Pz84OTkhJiYGKSmplZij/851j3rHmDd39K6L/PHru8wS5cuFaPRqF1KeMCAAWI2m+XkyZOV3bUK06FDB0lISJCUlBTZs2ePPProoxIYGCgXL17U2gwaNEhq164tGzdulF27dskDDzwgzZs3r8ReV6ykpCQJDg6WRo0ayZAhQ7T5VWEcMjMzJSgoSPr16yc7d+6Uw4cPy7p16+TPP//U2kycOFE8PDzkq6++kr1790rnzp0lJCREcnNzK7HnN451z7oXYd3f6rq/6wJEkyZNJC4uTvu7oKBA/P39tYv7VAWnTp0SALJlyxYREcnOzpZq1arJihUrtDYHDhwQALJ9+/bK6maFuXDhgoSFhcmGDRukdevW2gtJVRmH119/XVq2bFni/RaLRXx9fWXKlCnavOzsbDGZTLJkyZJb0cWbjnXPumfd3/q6v6u+wsjPz8fu3bsRExOjzbOzs0NMTAy2b99eiT27tc6dK7zAffXq1QEAu3fvxpUrV3TjEhERgcDAwLtyXOLi4tCpUyfd9gJVZxxWr16N6Oho9OjRAz4+PoiMjMScOXO0+48cOYKMjAzdOHh4eKBp06Z35Diw7gux7ln3t7ru76oAcebMGRQUFNhchKdmzZrIyMiopF7dWhaLBUOHDkWLFi3QsGFDAEBGRgaMRiPMZrOu7d04LkuXLsUvv/yC+Ph4m/uqyjgcPnwYM2fORFhYGNatW4fBgwfj5Zdfxrx58wBA29a7pU5Y96x71n3l1H2F/5Q13VpxcXFISUnBjz/+WNldueWOHTuGIUOGYMOGDXB0dKzs7lQai8WC6OhoTJgwAQAQGRmJlJQUzJo1C3379q3k3lFFYN2z7iuj7u+qPRDe3t6wt7e3Obr25MmT8PX1raRe3TovvvgivvnmG2zatAm1atXS5vv6+iI/Px/Z2dm69nfbuOzevRunTp3CfffdBwcHBzg4OGDLli2YNm0aHBwcULNmzSoxDn5+fjbXkahXrx7S09MBQNvWu6VOWPese9Z95dT9XRUgjEYjoqKisHHjRm2exWLBxo0b0axZs0rsWcUSEbz44otYuXIlEhMTERISors/KioK1apV043LwYMHkZ6efleNS7t27bBv3z7s2bNHu0VHR6NPnz7adFUYhxYtWticzvfHH38gKCgIABASEgJfX1/dOJw/fx47d+68I8eBdc+6Z91XUt3f0KGXt7GlS5eKyWSSzz77TPbv3y8DBw4Us9ksGRkZld21CjN48GDx8PCQzZs3y4kTJ7RbTk6O1mbQoEESGBgoiYmJsmvXLmnWrJk0a9asEnt9a1x7NLZI1RiHpKQkcXBwkPHjx0tqaqosWrRInJ2dZeHChVqbiRMnitlsllWrVsmvv/4qXbp0ueNP42Tds+6LsO5vTd3fdQFCRGT69OkSGBgoRqNRmjRpIjt27KjsLlUoAMXeEhIStDa5ubnywgsviKenpzg7O8sTTzwhJ06cqLxO3yLXv5BUlXH4+uuvpWHDhmIymSQiIkJmz56tu99iscjbb78tNWvWFJPJJO3atZODBw9WUm9vDtY9674I6/7W1D0v501ERETK7qpjIIiIiOjWYIAgIiIiZQwQREREpIwBgoiIiJQxQBAREZEyBggiIiJSxgBBREREyhggiIiISBkDBBERESljgCAiIiJlDBBERESkjAGCiIiIlP0/xKP8vbVtw4oAAAAASUVORK5CYII=\n" + }, + "metadata": {} + } + ] + }, + { + "cell_type": "code", + "source": [ + "# -*- coding: utf-8 -*-\n", + "import math, os\n", + "from typing import Optional, Tuple, List\n", + "import numpy as np\n", + "import paddle\n", + "import paddle.nn as nn\n", + "import paddle.nn.functional as F\n", + "from paddle.io import Dataset, DataLoader\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# ============ 基本设置 ============ # 如需 CPU 改为 'cpu'\n", + "os.makedirs(\"viz_out\", exist_ok=True)\n", + "\n", + "# ============ 工具:正弦位置编码 ============\n", + "class SinusoidalPositionalEncoding(nn.Layer):\n", + " def __init__(self, d_model: int, max_len: int = 4096):\n", + " super().__init__()\n", + " pe = np.zeros((max_len, d_model), dtype=\"float32\")\n", + " position = np.arange(0, max_len, dtype=\"float32\")[:, None]\n", + " div_term = np.exp(np.arange(0, d_model, 2, dtype=\"float32\") * (-math.log(10000.0) / d_model))\n", + " pe[:, 0::2] = np.sin(position * div_term)\n", + " pe[:, 1::2] = np.cos(position * div_term)\n", + " self.register_buffer(\"pe\", paddle.to_tensor(pe), persistable=False)\n", + " def forward(self, x): # (B,T,D)\n", + " T = x.shape[1]\n", + " return x + self.pe[:T, :]\n", + "\n", + "# ============ TabM(占位,可替换为你的实现) ============\n", + "class TabMFeatureExtractor(nn.Layer):\n", + " def __init__(self, num_features: int, d_hidden: int = 512, dropout: float = 0.1):\n", + " super().__init__()\n", + " self.net = nn.Sequential(\n", + " nn.Linear(num_features, d_hidden), nn.ReLU(), nn.Dropout(dropout),\n", + " nn.Linear(d_hidden, d_hidden), nn.ReLU(),\n", + " )\n", + " self.d_hidden = d_hidden\n", + " def forward(self, x_num: paddle.Tensor):\n", + " return self.net(x_num)\n", + "\n", + "# ============ 3D ResNet18 ============\n", + "class BasicBlock3D(nn.Layer):\n", + " expansion = 1\n", + " def __init__(self, in_planes, planes, stride=(1,1,1), downsample=None):\n", + " super().__init__()\n", + " self.conv1 = nn.Conv3D(in_planes, planes, 3, stride=stride, padding=1, bias_attr=False)\n", + " self.bn1 = nn.BatchNorm3D(planes)\n", + " self.relu = nn.ReLU()\n", + " self.conv2 = nn.Conv3D(planes, planes, 3, stride=1, padding=1, bias_attr=False)\n", + " self.bn2 = nn.BatchNorm3D(planes)\n", + " self.downsample = downsample\n", + " def forward(self, x):\n", + " identity = x\n", + " out = self.relu(self.bn1(self.conv1(x)))\n", + " out = self.bn2(self.conv2(out))\n", + " if self.downsample is not None:\n", + " identity = self.downsample(x)\n", + " out = self.relu(out + identity)\n", + " return out\n", + "\n", + "class ResNet3D(nn.Layer):\n", + " def __init__(self, block, layers, in_channels=20, base_width=64):\n", + " super().__init__()\n", + " self.in_planes = base_width\n", + " self.conv1 = nn.Conv3D(in_channels, self.in_planes, kernel_size=(3,7,7),\n", + " stride=(1,2,2), padding=(1,3,3), bias_attr=False)\n", + " self.bn1 = nn.BatchNorm3D(self.in_planes)\n", + " self.relu = nn.ReLU()\n", + " self.maxpool = nn.MaxPool3D(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1))\n", + " self.layer1 = self._make_layer(block, base_width, layers[0], stride=(1,1,1))\n", + " self.layer2 = self._make_layer(block, base_width*2, layers[1], stride=(2,2,2))\n", + " self.layer3 = self._make_layer(block, base_width*4, layers[2], stride=(2,2,2))\n", + " self.layer4 = self._make_layer(block, base_width*8, layers[3], stride=(2,2,2))\n", + " self.out_dim = base_width*8 # 512\n", + " self.pool = nn.AdaptiveAvgPool3D(output_size=1)\n", + " def _make_layer(self, block, planes, blocks, stride=(1,1,1)):\n", + " downsample = None\n", + " if stride != (1,1,1) or self.in_planes != planes * block.expansion:\n", + " downsample = nn.Sequential(\n", + " nn.Conv3D(self.in_planes, planes * block.expansion, 1, stride=stride, bias_attr=False),\n", + " nn.BatchNorm3D(planes * block.expansion),\n", + " )\n", + " layers = [block(self.in_planes, planes, stride=stride, downsample=downsample)]\n", + " self.in_planes = planes * block.expansion\n", + " for _ in range(1, blocks):\n", + " layers.append(block(self.in_planes, planes))\n", + " return nn.Sequential(*layers)\n", + " def forward(self, x): # (B, C, D, H, W)\n", + " x = self.relu(self.bn1(self.conv1(x)))\n", + " x = self.maxpool(x)\n", + " x = self.layer1(x); x = self.layer2(x); x = self.layer3(x); x = self.layer4(x)\n", + " x = self.pool(x) # (B, 512, 1,1,1)\n", + " x = paddle.flatten(x, 1) # (B, 512)\n", + " return x\n", + "\n", + "class Volume3DEncoder(nn.Layer):\n", + " \"\"\"\n", + " 带安全 hook(仅在可求梯度时注册 backward hook),支持 3D Grad-CAM\n", + " \"\"\"\n", + " def __init__(self, in_channels: int = 20, base: int = 64, dropout: float = 0.0):\n", + " super().__init__()\n", + " self.backbone = ResNet3D(BasicBlock3D, [2,2,2,2], in_channels=in_channels, base_width=base)\n", + " self.drop = nn.Dropout(dropout)\n", + " self.out_dim = self.backbone.out_dim # 512\n", + " self._feat = None\n", + " self._grad = None\n", + " def _save_feat_grad(layer, inp, out):\n", + " self._feat = out # (B, 512, D',H',W')\n", + " if getattr(out, \"stop_gradient\", False):\n", + " return\n", + " def _save_grad(grad):\n", + " self._grad = grad\n", + " out.register_hook(_save_grad)\n", + " self.backbone.layer4.register_forward_post_hook(_save_feat_grad)\n", + " def forward(self, x): # (B, C, D, H, W)\n", + " x = self.backbone(x)\n", + " x = self.drop(x)\n", + " return x\n", + "\n", + "# ============ MoE ============\n", + "class ExpertFFN(nn.Layer):\n", + " def __init__(self, d_model, d_ff, dropout=0.1):\n", + " super().__init__()\n", + " self.fc1 = nn.Linear(d_model, d_ff)\n", + " self.fc2 = nn.Linear(d_ff, d_model)\n", + " self.drop = nn.Dropout(dropout)\n", + " def forward(self, x):\n", + " return self.fc2(self.drop(F.relu(self.fc1(x))))\n", + "\n", + "class MoEConfig:\n", + " def __init__(self, n_experts=8, top_k=1, d_ff=2048, dropout=0.1,\n", + " router_temp=0.5, use_gumbel=False):\n", + " self.n_experts=n_experts; self.top_k=top_k; self.d_ff=d_ff; self.dropout=dropout\n", + " self.router_temp=router_temp; self.use_gumbel=use_gumbel\n", + "\n", + "class MoE(nn.Layer):\n", + " \"\"\"缓存最近一次路由概率/索引,便于可解释与聚类\"\"\"\n", + " def __init__(self, d_model: int, cfg: MoEConfig):\n", + " super().__init__()\n", + " self.cfg = cfg\n", + " self.router = nn.Linear(d_model, cfg.n_experts)\n", + " self.experts = nn.LayerList([ExpertFFN(d_model, cfg.d_ff, cfg.dropout) for _ in range(cfg.n_experts)])\n", + " self.ln = nn.LayerNorm(d_model); self.drop = nn.Dropout(cfg.dropout)\n", + " self.last_router_probs = None\n", + " self.last_topk_idx = None\n", + " def _router_probs(self, logits):\n", + " if self.cfg.use_gumbel and self.training:\n", + " u = paddle.uniform(logits.shape, min=1e-6, max=1-1e-6, dtype=logits.dtype)\n", + " g = -paddle.log(-paddle.log(u)); logits = logits + g\n", + " return F.softmax(logits / self.cfg.router_temp, axis=-1)\n", + " def forward(self, x):\n", + " orig_shape = x.shape\n", + " if len(orig_shape) == 3: B,T,D = orig_shape; X = x.reshape([B*T, D])\n", + " else: X = x\n", + " N,D = X.shape\n", + " logits = self.router(X); probs = self._router_probs(logits)\n", + " topk_val, topk_idx = paddle.topk(probs, k=self.cfg.top_k, axis=-1)\n", + " all_out = paddle.stack([e(X) for e in self.experts], axis=1) # (N,E,D)\n", + " arangeN = paddle.arange(N, dtype='int64')\n", + " picked_list=[]\n", + " for i in range(self.cfg.top_k):\n", + " idx_i = topk_idx[:, i].astype('int64')\n", + " idx_nd = paddle.stack([arangeN, idx_i], axis=1)\n", + " picked_i = paddle.gather_nd(all_out, idx_nd)\n", + " picked_list.append(picked_i)\n", + " picked = paddle.stack(picked_list, axis=1) # (N,k,D)\n", + " w = topk_val / (paddle.sum(topk_val, axis=-1, keepdim=True) + 1e-9)\n", + " Y = paddle.sum(picked * w.unsqueeze(-1), axis=1) # (N,D)\n", + " Y = self.drop(Y); Y = self.ln(Y + X)\n", + " self.last_router_probs = probs.detach()\n", + " self.last_topk_idx = topk_idx.detach()\n", + " if len(orig_shape)==3: Y = Y.reshape([B,T,D])\n", + " return Y\n", + "\n", + "class MoEHead(nn.Layer):\n", + " def __init__(self, d_model=512, cfg: MoEConfig = None):\n", + " super().__init__()\n", + " self.moe = MoE(d_model, cfg or MoEConfig())\n", + " self.last_router_probs = None\n", + " self.last_topk_idx = None\n", + " def forward(self, tok):\n", + " y = self.moe(tok.unsqueeze(1)).squeeze(1)\n", + " self.last_router_probs = self.moe.last_router_probs\n", + " self.last_topk_idx = self.moe.last_topk_idx\n", + " return y\n", + "\n", + "# ============ 自实现版 Multi-Head Self-Attention(记录注意力) ============\n", + "class MultiHeadSelfAttention(nn.Layer):\n", + " def __init__(self, d_model: int, nhead: int = 8, dropout: float = 0.1):\n", + " super().__init__()\n", + " assert d_model % nhead == 0\n", + " self.d_model = d_model\n", + " self.nhead = nhead\n", + " self.d_head = d_model // nhead\n", + " self.Wq = nn.Linear(d_model, d_model)\n", + " self.Wk = nn.Linear(d_model, d_model)\n", + " self.Wv = nn.Linear(d_model, d_model)\n", + " self.proj = nn.Linear(d_model, d_model)\n", + " self.drop = nn.Dropout(dropout)\n", + " self.last_attn = None # (B,H,T,T)\n", + "\n", + " def forward(self, x): # x: (B,T,D)\n", + " B,T,D = x.shape\n", + " q = self.Wq(x); k = self.Wk(x); v = self.Wv(x)\n", + " def split(t): return t.reshape([B, T, self.nhead, self.d_head]).transpose([0,2,1,3]) # (B,H,T,dh)\n", + " qh, kh, vh = split(q), split(k), split(v)\n", + " scores = paddle.matmul(qh, kh, transpose_y=True) / math.sqrt(self.d_head) # (B,H,T,T)\n", + " attn = F.softmax(scores, axis=-1)\n", + " self.last_attn = attn.detach()\n", + " ctx = paddle.matmul(attn, vh) # (B,H,T,dh)\n", + " ctx = ctx.transpose([0,2,1,3]).reshape([B, T, D]) # (B,T,D)\n", + " out = self.drop(self.proj(ctx))\n", + " return out # 残差与LN在外面做\n", + "\n", + "# ============ Self-Attention Transformer(用自实现 MHA) ============\n", + "class TransformerEncoderLayerMoE(nn.Layer):\n", + " def __init__(self, d_model=512, nhead=8, d_ff=1024, dropout=0.1,\n", + " use_moe: bool = True, moe_cfg: MoEConfig = None, capture_attn: bool = True):\n", + " super().__init__()\n", + " self.use_moe = use_moe; self.capture_attn = capture_attn\n", + " self.self_attn = MultiHeadSelfAttention(d_model, nhead, dropout)\n", + " self.ln1 = nn.LayerNorm(d_model); self.do1 = nn.Dropout(dropout)\n", + " if use_moe:\n", + " self.moe = MoE(d_model, moe_cfg or MoEConfig(d_ff=d_ff, dropout=dropout))\n", + " else:\n", + " self.ffn = nn.Sequential(nn.LayerNorm(d_model),\n", + " nn.Linear(d_model, d_ff), nn.ReLU(), nn.Dropout(dropout),\n", + " nn.Linear(d_ff, d_model))\n", + " self.do2 = nn.Dropout(dropout)\n", + " self.last_attn = None # (B,H,T,T)\n", + " def forward(self, x): # (B,T,D)\n", + " h = self.ln1(x)\n", + " out = self.self_attn(h) # (B,T,D)\n", + " if self.capture_attn:\n", + " self.last_attn = self.self_attn.last_attn # (B,H,T,T)\n", + " x = x + self.do1(out)\n", + " if self.use_moe:\n", + " x = self.moe(x)\n", + " else:\n", + " x = x + self.do2(self.ffn(x))\n", + " return x\n", + "\n", + "class TemporalTransformerFlexible(nn.Layer):\n", + " def __init__(self, d_model=512, nhead=4, num_layers=2, d_ff=1024, dropout=0.1,\n", + " max_len=4096, use_moe: bool = True, moe_cfg: MoEConfig = None, capture_attn=True):\n", + " super().__init__()\n", + " self.pos = SinusoidalPositionalEncoding(d_model, max_len=max_len)\n", + " self.layers = nn.LayerList([\n", + " TransformerEncoderLayerMoE(d_model, nhead, d_ff, dropout,\n", + " use_moe=use_moe, moe_cfg=moe_cfg, capture_attn=capture_attn)\n", + " for _ in range(num_layers)\n", + " ])\n", + " self.last_attn_all_layers: List[paddle.Tensor] = []\n", + " def forward(self, x):\n", + " x = self.pos(x)\n", + " self.last_attn_all_layers = []\n", + " for layer in self.layers:\n", + " x = layer(x)\n", + " if layer.last_attn is not None:\n", + " self.last_attn_all_layers.append(layer.last_attn) # (B,H,T,T)\n", + " return x\n", + "\n", + "# ============ AFNO(1D) + MoE FFN ============\n", + "class AFNO1DLayer(nn.Layer):\n", + " def __init__(self, d_model: int, modes: int = 32, num_blocks: int = 8, shrink: float = 0.01, dropout: float = 0.1):\n", + " super().__init__()\n", + " assert d_model % num_blocks == 0\n", + " self.d_model=d_model; self.modes=modes; self.num_blocks=num_blocks; self.block=d_model//num_blocks\n", + " scale=1.0/math.sqrt(self.block); init = nn.initializer.Uniform(-scale, scale)\n", + " self.w1r = self.create_parameter([num_blocks, self.block, self.block], default_initializer=init)\n", + " self.w1i = self.create_parameter([num_blocks, self.block, self.block], default_initializer=init)\n", + " self.w2r = self.create_parameter([num_blocks, self.block, self.block], default_initializer=init)\n", + " self.w2i = self.create_parameter([num_blocks, self.block, self.block], default_initializer=init)\n", + " self.ln = nn.LayerNorm(d_model); self.drop = nn.Dropout(dropout); self.shrink = shrink\n", + " def _cl(self, xr, xi, Wr, Wi):\n", + " out_r = paddle.matmul(xr, Wr) - paddle.matmul(xi, Wi)\n", + " out_i = paddle.matmul(xr, Wi) + paddle.matmul(xi, Wr)\n", + " return out_r, out_i\n", + " def forward(self, x): # (B,T,D)\n", + " B,T,D = x.shape; Kmax=T//2+1; K=min(self.modes, Kmax)\n", + " h=self.ln(x); h_td=h.transpose([0,2,1]); h_ft=paddle.fft.rfft(h_td) # (B,D,F)\n", + " h_ft=h_ft.reshape([B, self.num_blocks, self.block, Kmax])\n", + " xk=h_ft[:,:,:, :K].transpose([0,1,3,2]) # (B,G,K,Cb)\n", + " xr, xi = paddle.real(xk), paddle.imag(xk)\n", + " yr, yi = self._cl(xr, xi, self.w1r, self.w1i)\n", + " yr = F.gelu(yr); yi = F.gelu(yi)\n", + " yr = F.softshrink(yr, threshold=self.shrink); yi = F.softshrink(yi, threshold=self.shrink)\n", + " yr, yi = self._cl(yr, yi, self.w2r, self.w2i)\n", + " yk = paddle.complex(yr, yi).transpose([0,1,3,2]).reshape([B,D,K])\n", + " out_ft = paddle.zeros([B,D,Kmax], dtype='complex64')\n", + " out_ft[:,:, :K] = yk\n", + " out_td = paddle.fft.irfft(out_ft, n=T)\n", + " out = out_td.transpose([0,2,1])\n", + " out = self.drop(out)\n", + " return x + out\n", + "\n", + "class AFNOTransformerFlexible(nn.Layer):\n", + " def __init__(self, d_model=512, num_layers=2, modes=32, dropout=0.1,\n", + " d_ff=1024, use_moe: bool = True, moe_cfg: MoEConfig = None):\n", + " super().__init__()\n", + " self.layers = nn.LayerList([AFNO1DLayer(d_model, modes, 8, 0.01, dropout) for _ in range(num_layers)])\n", + " self.use_moe = use_moe\n", + " if use_moe:\n", + " self.moe = MoE(d_model, moe_cfg or MoEConfig(d_ff=d_ff, dropout=dropout))\n", + " else:\n", + " self.ffn = nn.Sequential(nn.LayerNorm(d_model),\n", + " nn.Linear(d_model, d_ff), nn.ReLU(), nn.Dropout(dropout),\n", + " nn.Linear(d_ff, d_model))\n", + " self.do = nn.Dropout(dropout)\n", + " def forward(self, x):\n", + " for layer in self.layers:\n", + " x = layer(x)\n", + " if self.use_moe:\n", + " x = self.moe(x)\n", + " else:\n", + " x = x + self.do(self.ffn(x))\n", + " return x\n", + "\n", + "# ============ Cross-Attention(记录注意力) ============\n", + "class MultiHeadCrossAttention(nn.Layer):\n", + " def __init__(self, d_model: int, nhead: int = 8, dropout: float = 0.1):\n", + " super().__init__()\n", + " assert d_model % nhead == 0\n", + " self.d_head = d_model // nhead; self.nhead = nhead\n", + " self.Wq = nn.Linear(d_model, d_model); self.Wk = nn.Linear(d_model, d_model); self.Wv = nn.Linear(d_model, d_model)\n", + " self.proj = nn.Linear(d_model, d_model); self.drop = nn.Dropout(dropout); self.ln = nn.LayerNorm(d_model)\n", + " self.last_attn = None # (B, H, Nq, Nk)\n", + " def forward(self, q, kv):\n", + " B, Nq, D = q.shape; Nk = kv.shape[1]\n", + " def split(t): return t.reshape([B, -1, self.nhead, self.d_head]).transpose([0,2,1,3])\n", + " qh = split(self.Wq(q)); kh = split(self.Wk(kv)); vh = split(self.Wv(kv))\n", + " scores = paddle.matmul(qh, kh, transpose_y=True) / math.sqrt(self.d_head) # (B,H,Nq,Nk)\n", + " attn = F.softmax(scores, axis=-1)\n", + " self.last_attn = attn.detach()\n", + " ctx = paddle.matmul(attn, vh).transpose([0,2,1,3]).reshape([B,Nq,D])\n", + " out = self.drop(self.proj(ctx))\n", + " return self.ln(out + q)\n", + "\n", + "class BiModalCrossFusion(nn.Layer):\n", + " def __init__(self, d_model=512, nhead=8, dropout=0.1, fuse_hidden=512):\n", + " super().__init__()\n", + " self.ca_v_from_t = MultiHeadCrossAttention(d_model, nhead, dropout)\n", + " self.ca_t_from_v = MultiHeadCrossAttention(d_model, nhead, dropout)\n", + " self.fuse = nn.Sequential(nn.Linear(2*d_model, fuse_hidden), nn.ReLU(), nn.Dropout(dropout))\n", + " self.out_dim = fuse_hidden\n", + " self.last_attn_v_from_t = None # (B,H,1,1)\n", + " self.last_attn_t_from_v = None # (B,H,1,T)\n", + " def forward(self, video_seq, tabm_tok):\n", + " v_tok = video_seq.mean(axis=1, keepdim=True) # (B,1,D)\n", + " t_tok = tabm_tok.unsqueeze(1) # (B,1,D)\n", + " v_upd = self.ca_v_from_t(v_tok, t_tok)\n", + " t_upd = self.ca_t_from_v(t_tok, video_seq)\n", + " self.last_attn_v_from_t = self.ca_v_from_t.last_attn\n", + " self.last_attn_t_from_v = self.ca_t_from_v.last_attn\n", + " fused = paddle.concat([v_upd, t_upd], axis=-1).squeeze(1)\n", + " return self.fuse(fused)\n", + "\n", + "# ============ 总模型 ============\n", + "class TwoModalMultiLabelModel(nn.Layer):\n", + " def __init__(self, vid_channels=20, vid_frames=365, depth_n=24,\n", + " vec_dim=424, d_model=256, nhead=4, n_trans_layers=2, trans_ff=512,\n", + " tabm_hidden=256, dropout=0.1, num_labels=4,\n", + " moe_temporal_attn=True, moe_temporal_afno=True, moe_fused=False, moe_tabm=False,\n", + " afno_modes=32):\n", + " super().__init__()\n", + " self.vol_encoder = Volume3DEncoder(in_channels=vid_channels, dropout=dropout)\n", + " # 关键:3D ResNet 输出 512 → d_model 的输入投影\n", + " self.video_in = nn.Linear(self.vol_encoder.out_dim, d_model)\n", + "\n", + " self.trans_attn = TemporalTransformerFlexible(\n", + " d_model=d_model, nhead=nhead, num_layers=n_trans_layers, d_ff=trans_ff, dropout=dropout,\n", + " max_len=vid_frames, use_moe=moe_temporal_attn, moe_cfg=MoEConfig(d_ff=max(2048,trans_ff), n_experts=8),\n", + " capture_attn=True\n", + " )\n", + " self.trans_afno = AFNOTransformerFlexible(\n", + " d_model=d_model, num_layers=n_trans_layers, modes=afno_modes, dropout=dropout,\n", + " d_ff=trans_ff, use_moe=moe_temporal_afno, moe_cfg=MoEConfig(d_ff=max(2048,trans_ff), n_experts=8)\n", + " )\n", + " self.video_merge = nn.Linear(2*d_model, d_model)\n", + "\n", + " self.tabm = TabMFeatureExtractor(vec_dim, d_hidden=tabm_hidden, dropout=dropout)\n", + " self.tabm_proj = nn.Linear(tabm_hidden, d_model)\n", + " self.moe_tabm = moe_tabm\n", + " if moe_tabm:\n", + " self.tabm_moe = MoEHead(d_model=d_model, cfg=MoEConfig(d_ff=1024, n_experts=6))\n", + "\n", + " self.fusion = BiModalCrossFusion(d_model=d_model, nhead=nhead, dropout=dropout, fuse_hidden=d_model)\n", + " self.moe_fused = moe_fused\n", + " if moe_fused:\n", + " self.fused_moe = MoEHead(d_model=d_model, cfg=MoEConfig(d_ff=1024, n_experts=6))\n", + "\n", + " self.head = nn.Linear(self.fusion.out_dim, num_labels)\n", + " self.depth_n = depth_n\n", + "\n", + " def encode(self, x_video, x_vec):\n", + " B,T,C,H,W,N = x_video.shape\n", + " assert N == self.depth_n\n", + " xvt = x_video.transpose([0,1,2,5,3,4]).reshape([B*T, C, N, H, W])\n", + " f_frame = self.vol_encoder(xvt) # (B*T,512)\n", + " seq = f_frame.reshape([B, T, -1]) # (B,T,512)\n", + " seq = self.video_in(seq) # (B,T,d_model)\n", + "\n", + " z_attn = self.trans_attn(seq) # (B,T,d_model)\n", + " z_afno = self.trans_afno(seq) # (B,T,d_model)\n", + " z_vid = self.video_merge(paddle.concat([z_attn, z_afno], axis=-1)) # (B,T,d_model)\n", + "\n", + " z_tabm = self.tabm(x_vec); z_tabm = self.tabm_proj(z_tabm) # (B,d_model)\n", + " if self.moe_tabm:\n", + " z_tabm = self.tabm_moe(z_tabm)\n", + "\n", + " fused = self.fusion(z_vid, z_tabm) # (B,d_model)\n", + " if self.moe_fused:\n", + " fused = self.fused_moe(fused)\n", + " return fused\n", + "\n", + " def forward(self, x_video, x_vec):\n", + " fused = self.encode(x_video, x_vec)\n", + " logits = self.head(fused)\n", + " return logits\n", + "\n", + "# ============ 3D Grad-CAM ============\n", + "class GradCAM3D:\n", + " def __init__(self, model: TwoModalMultiLabelModel):\n", + " self.model = model\n", + " @paddle.no_grad()\n", + " def _trilinear_upsample(self, vol, out_shape):\n", + " try:\n", + " from scipy.ndimage import zoom\n", + " Dz = out_shape[0] / vol.shape[0]\n", + " Dy = out_shape[1] / vol.shape[1]\n", + " Dx = out_shape[2] / vol.shape[2]\n", + " return zoom(vol, (Dz, Dy, Dx), order=1)\n", + " except Exception:\n", + " return vol\n", + " def generate(self, x_video, x_vec, target_class: int = 0, time_index: int = 0):\n", + " assert x_video.shape[0] == 1, \"Grad-CAM 演示请用单样本 B=1\"\n", + " self.model.eval()\n", + " self.model.clear_gradients()\n", + " logits = self.model(x_video.astype('float32'), x_vec.astype('float32')) # (1,num_labels)\n", + " cls = logits[0, target_class]\n", + " cls.backward()\n", + " feat = self.model.vol_encoder._feat # (1,512,D',H',W')\n", + " grad = self.model.vol_encoder._grad\n", + " assert (feat is not None) and (grad is not None), \"未捕获到特征/梯度\"\n", + " feat_np = feat.numpy()[0]; grad_np = grad.numpy()[0]\n", + " w = grad_np.mean(axis=(1,2,3)) # (512,)\n", + " cam = np.maximum(0, np.tensordot(w, feat_np, axes=(0,0))) # (D',H',W')\n", + " cam = cam - cam.min(); cam = cam / (cam.max() + 1e-8)\n", + " # 将 CAM 插值到输入体素大小:(N,H,W);这里我们没有逐帧求 CAM,而是对“最后一层体特征”整体做\n", + " # 若你需要对某个 time_index 的体做 CAM,可在 3D 编码处按帧送入并单独反传。\n", + " B,T,C,H,W,N = x_video.shape\n", + " cam_up = self._trilinear_upsample(cam, (N, H, W))\n", + " return cam_up\n", + "\n", + "# ============ MoE 路由聚类工具 ============\n", + "def kmeans_numpy(X: np.ndarray, K: int = 4, iters: int = 50, seed: int = 0):\n", + " rng = np.random.default_rng(seed)\n", + " N,D = X.shape\n", + " cent = X[rng.choice(N, K, replace=False)]\n", + " for _ in range(iters):\n", + " dist2 = ((X[:,None,:]-cent[None,:,:])**2).sum(axis=2) # (N,K)\n", + " idx = dist2.argmin(axis=1)\n", + " new_cent = np.stack([X[idx==k].mean(axis=0) if np.any(idx==k) else cent[k] for k in range(K)], 0)\n", + " if np.allclose(new_cent, cent): break\n", + " cent = new_cent\n", + " return idx, cent\n", + "\n", + "def collect_moe_routing_vectors(model: TwoModalMultiLabelModel, loader: DataLoader,\n", + " branch: str = \"temporal_attn\", topk_hist: bool = True):\n", + " model.eval()\n", + " vecs = []\n", + " for x_vid, x_vec, y in loader:\n", + " _ = model(x_vid.astype('float32'), x_vec.astype('float32'))\n", + " if branch == \"temporal_attn\":\n", + " moe = None\n", + " for lyr in model.trans_attn.layers[::-1]:\n", + " if hasattr(lyr, \"moe\"):\n", + " moe = lyr.moe; break\n", + " elif branch == \"temporal_afno\":\n", + " moe = model.trans_afno.moe if hasattr(model.trans_afno, \"moe\") else None\n", + " elif branch == \"tabm\":\n", + " moe = model.tabm_moe.moe if getattr(model, \"moe_tabm\", False) else None\n", + " else:\n", + " moe = model.fused_moe.moe if getattr(model, \"moe_fused\", False) else None\n", + " if moe is None or moe.last_router_probs is None:\n", + " continue\n", + " probs = moe.last_router_probs.numpy() # (N_tokens, E)\n", + " if topk_hist:\n", + " top1 = moe.last_topk_idx.numpy()[:,0] # (N_tokens,)\n", + " E = probs.shape[1]\n", + " hist = np.bincount(top1, minlength=E).astype(\"float32\")\n", + " hist = hist / (hist.sum() + 1e-9)\n", + " vecs.append(hist)\n", + " else:\n", + " vecs.append(probs.mean(axis=0))\n", + " return np.stack(vecs, 0) if len(vecs)>0 else None\n", + "\n", + "# ============ Toy 数据集 ============\n", + "class ToyTwoModalDataset(Dataset):\n", + " def __init__(self, n: int, seed: int = 0, T: int = 365, C: int = 20, H: int = 20, W: int = 20, N: int = 24):\n", + " super().__init__()\n", + " rng = np.random.default_rng(seed)\n", + " self.video = rng.normal(size=(n, T, C, H, W, N)).astype('float32')\n", + " self.vec = rng.normal(size=(n, 424)).astype('float32')\n", + " vid_hwn = self.video.mean(axis=(3,4,5))\n", + " vid_avg = vid_hwn.mean(axis=1)\n", + " Wv = rng.normal(size=(C,4)); Wt = rng.normal(size=(424,4))\n", + " logits = vid_avg @ Wv + self.vec @ Wt + rng.normal(scale=0.5, size=(n,4))\n", + " probs = 1.0 / (1.0 + np.exp(-logits))\n", + " self.y = (probs > 0.5).astype('float32')\n", + " def __getitem__(self, idx: int):\n", + " return self.video[idx], self.vec[idx], self.y[idx]\n", + " def __len__(self): return len(self.y)\n", + "\n", + "# ============ 小工具:绘图 ============\n", + "def show_heatmap_2d(arr2d: np.ndarray, title: str, save_path: Optional[str] = None):\n", + " plt.figure(); plt.imshow(arr2d, interpolation='nearest'); plt.title(title); plt.colorbar()\n", + " if save_path: plt.savefig(save_path, bbox_inches='tight');\n", + " plt.show(); plt.close()\n", + "\n", + "def show_attention_matrix(attn: np.ndarray, title: str, save_path: Optional[str] = None):\n", + " if attn.ndim == 4 and attn.shape[2] == 1 and attn.shape[3] == 1:\n", + " attn = attn[0,:,0,0][:,None] # (H,1)\n", + " elif attn.ndim == 4 and attn.shape[2] == 1:\n", + " attn = attn[0] # (H,1,T)\n", + " elif attn.ndim == 4:\n", + " attn = attn[0] # (H,T,T)\n", + " plt.figure(figsize=(5,4))\n", + " if attn.ndim == 2:\n", + " plt.imshow(attn, aspect='auto', interpolation='nearest')\n", + " elif attn.ndim == 3:\n", + " H = attn.shape[0]\n", + " cols = int(np.ceil(np.sqrt(H))); rows = int(np.ceil(H/cols))\n", + " fig, axes = plt.subplots(rows, cols, figsize=(3*cols, 3*rows))\n", + " axes = axes.flatten()\n", + " for h in range(H):\n", + " axes[h].imshow(attn[h], interpolation='nearest'); axes[h].set_title(f\"head {h}\")\n", + " for k in range(H, len(axes)): axes[k].axis('off')\n", + " fig.suptitle(title)\n", + " if save_path: fig.savefig(save_path, bbox_inches='tight')\n", + " plt.show(); plt.close(fig); return\n", + " plt.title(title); plt.colorbar()\n", + " if save_path: plt.savefig(save_path, bbox_inches='tight')\n", + " plt.show(); plt.close()\n", + "\n", + "# ============ Demo:可解释可视化 ============\n", + "if __name__ == \"__main__\":\n", + " # 1) 构造“已训练好”的模型(这里随机权重示意)\n", + " model = TwoModalMultiLabelModel(\n", + " vid_channels=20, vid_frames=365, depth_n=24,\n", + " vec_dim=424, d_model=256, nhead=4, n_trans_layers=2, trans_ff=512,\n", + " tabm_hidden=256, dropout=0.1, num_labels=4,\n", + " moe_temporal_attn=True, moe_temporal_afno=True,\n", + " moe_fused=False, moe_tabm=False, afno_modes=32\n", + " )\n", + " model.eval()\n", + "\n", + " # 2) 取一个样本\n", + " toy = ToyTwoModalDataset(n=8, seed=123, T=365, C=20, H=20, W=20, N=24)\n", + " x_video, x_vec, y = toy[0]\n", + " x_video = paddle.to_tensor(x_video[None, ...]) # (1,T,C,H,W,N)\n", + " x_vec = paddle.to_tensor(x_vec[None, ...]) # (1,424)\n", + "\n", + " # 3) 3D Grad-CAM:一次“有梯度”的前向 + 反传(不要 no_grad)\n", + " model.clear_gradients()\n", + " logits = model(x_video.astype('float32'), x_vec.astype('float32'))\n", + " target_class = int(paddle.argmax(logits, axis=-1)[0])\n", + " cam3d = GradCAM3D(model).generate(\n", + " x_video.astype('float32'), x_vec.astype('float32'),\n", + " target_class=target_class, time_index=0\n", + " ) # (N,H,W) or (D',H',W')\n", + "\n", + " # 展示几个深度切片\n", + " Nz = cam3d.shape[0]\n", + " for z in [0, Nz//3, 2*Nz//3, Nz-1]:\n", + " show_heatmap_2d(cam3d[z], f\"Grad-CAM depth={z}\", save_path=f\"viz_out/gradcam_z{z}.png\")\n", + "\n", + " # 4) Self-Attention & Cross-Attention 注意力矩阵\n", + " with paddle.no_grad():\n", + " _ = model.encode(x_video.astype('float32'), x_vec.astype('float32'))\n", + " last_attn_list = model.trans_attn.last_attn_all_layers\n", + " if len(last_attn_list) > 0:\n", + " attn = last_attn_list[-1].numpy() # (B,H,T,T)\n", + " attn_crop = attn[:, :, :64, :64]\n", + " show_attention_matrix(attn_crop, \"Self-Attention (last layer, first 64 tokens)\",\n", + " save_path=\"viz_out/self_attn_lastlayer_64.png\")\n", + " print(\"Self-Attn matrix shape:\", attn.shape)\n", + " else:\n", + " print(\"Self-Attn not captured.\")\n", + " if model.fusion.last_attn_v_from_t is not None:\n", + " show_attention_matrix(model.fusion.last_attn_v_from_t.numpy(),\n", + " \"Cross-Attn v<-t (token→token)\",\n", + " save_path=\"viz_out/cross_attn_v_from_t.png\")\n", + " if model.fusion.last_attn_t_from_v is not None:\n", + " attn_tv = model.fusion.last_attn_t_from_v.numpy()\n", + " attn_tv_crop = attn_tv[:,:,:, :64]\n", + " show_attention_matrix(attn_tv_crop,\n", + " \"Cross-Attn t<-v (token←video_seq first 64)\",\n", + " save_path=\"viz_out/cross_attn_t_from_v_64.png\")\n", + "\n", + " # 5) MoE 路由聚类(示例用 toy 数据)\n", + " def collate_fn(batch):\n", + " vids, vecs, ys = zip(*batch)\n", + " return (paddle.to_tensor(np.stack(vids, 0)),\n", + " paddle.to_tensor(np.stack(vecs, 0)),\n", + " paddle.to_tensor(np.stack(ys, 0)))\n", + " train_loader = DataLoader(toy, batch_size=1, shuffle=False, collate_fn=collate_fn)\n", + " moe_vecs = collect_moe_routing_vectors(model, train_loader, branch=\"temporal_attn\", topk_hist=True)\n", + " if moe_vecs is not None:\n", + " idx, cent = kmeans_numpy(moe_vecs, K=4, iters=100, seed=0)\n", + " print(\"\\n[MoE Routing Clusters @ temporal_attn]\")\n", + " for k in range(4):\n", + " sel = (idx==k)\n", + " if np.any(sel):\n", + " mean_vec = moe_vecs[sel].mean(axis=0)\n", + " dom = int(mean_vec.argmax())\n", + " print(f\" - Cluster {k}: size={int(sel.sum())}, dominant_expert={dom}, mean_dist={np.round(mean_vec,3)}\")\n", + " plt.figure(figsize=(6,4))\n", + " plt.imshow(moe_vecs, aspect='auto', interpolation='nearest')\n", + " plt.title(\"Samples × Experts (routing histogram)\"); plt.xlabel(\"Expert\"); plt.ylabel(\"Sample\")\n", + " plt.colorbar(); plt.savefig(\"viz_out/moe_routing_heatmap.png\", bbox_inches='tight')\n", + " plt.show(); plt.close()\n", + " else:\n", + " print(\"MoE routing not available on selected branch.\")\n" + ], + "metadata": { + "id": "3FGUHDpRVnm6" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git "a/jointContribution/AI_Climate_Diseases/\347\273\223\351\241\271\346\212\245\345\221\212.md" "b/jointContribution/AI_Climate_Diseases/\347\273\223\351\241\271\346\212\245\345\221\212.md" new file mode 100644 index 0000000000..d1a9424161 --- /dev/null +++ "b/jointContribution/AI_Climate_Diseases/\347\273\223\351\241\271\346\212\245\345\221\212.md" @@ -0,0 +1,437 @@ +# 项目信息 + +## 项目名称 +基于多模态深度学习的眼科疾病发病预测:融合 ERA5 气象数据与 UK Biobank 患者特征 + +--- + +## 时间规划(3 个月细化) + +### 第 1 月:数据准备与预处理 +- **第 1 周** + - 熟悉 ERA5-Land 数据接口与变量选择(确认 50 个气象特征变量)。 + - 搭建多进程下载框架,测试不同并发参数下的下载速度与稳定性。 +- **第 2 周** + - 批量下载并缓存 ERA5-Land 数据(覆盖 2010–2020 年英国地区)。 + - 完成数据格式转换(NetCDF → Zarr/HDF5),并进行时空重采样(0.1° → 20×20×24)。 +- **第 3 周** + - 收集 UK Biobank 患者特征(424 维),统计缺失值分布。 + - 尝试多种填补方法(均值/中位数、MICE、KNN、XGBoost),比较效果。 +- **第 4 周** + - 确定最终缺失值填补策略(XGBoost)。 + - 建立 ERA5 与 UKB 患者数据的对齐方案(空间位置 + 时间窗口)。 + - 输出对齐后的多模态训练样本。 + +--- + +### 第 2 月:模型开发与基线训练 +- **第 1 周** + - 搭建 3D ResNet18 编码器,用于逐日气象立体数据的空间特征提取。 + - 实现 TabM 模块,对 UKB 患者表型进行建模。 +- **第 2 周** + - 实现 Transformer 分支(时序注意力)。 + - 实现 AFNO 分支(频域稀疏建模),完成与 Transformer 的特征拼接与融合。 +- **第 3 周** + - 集成 MoE 模块(Transformer FFN、AFNO FFN、TabM projection、融合层可开关)。 + - 完成 Cross-Attention 融合机制,实现双模态交互。 +- **第 4 周** + - 进行基线训练(toy 数据 + 部分真实样本)。 + - 记录多标签分类指标(AUC、F1、Recall、Precision、PR-AUC、Hamming Loss)。 + - 调参与优化:batch size、学习率、专家数量(MoE)。 + +--- + +### 第 3 月:可解释性与总结 +- **第 1 周** + - 实现并测试 3D Grad-CAM,输出气象模态的空间关注区域。 + - 可视化 Self-Attention Heatmap(365 天的时序权重)。 +- **第 2 周** + - 实现 MoE 路由记录与聚类,统计不同专家的样本分布。 + - 探索 domain-specific 专家(如“冬季患者”、“老年患者”)是否自动分离。 +- **第 3 周** + - 整合可解释性结果: + - Grad-CAM(空间热点图) + - Attention Map(时间热点图) + - MoE 聚类(样本-专家分布) + - 撰写可解释性分析小结。 +- **第 4 周** + - 完成整体技术报告撰写(含方法、实验、困难与总结)。 + - 整理成果,准备论文初稿框架。 + +--- + + +## 方案描述 + +### 1. 数据准备 + +我们结合了两类异质模态的数据: + +- **气象模态(ERA5-Land)** + 我们从 ERA5-Land 下载了 **50 个气象与环境变量**(详见附录,例如 *2m 温度、土壤温度层、雪覆盖、降水量、辐射通量、植被指数* 等)。 + - 原始空间分辨率:0.1°(约 10×10 网格覆盖每个城市)。 + - 每个像素包含 **24 个垂直层**(如土壤层、大气层)。 + - 数据被统一 resize 为 **20×20×24** 的体素,每位患者匹配发病前 **365 天**的数据。 + - 我们使用 **多进程并行下载**加快 ERA5-Land 的数据拉取和处理。 + +- **患者模态(UK Biobank)** + 我们从 UK Biobank 中提取了 **424 项患者特征**(包括人口学、生活习惯、临床变量)。 + 对缺失值,我们采用 **XGBoost 预测填补**方法,相比均值/众数填充在临床异质数据上表现更好。 + 每个患者特征通过 **发病日期**和**居住城市**与 ERA5 气象数据对齐,实现时空匹配。 + +任务为 **多标签分类**:预测 **4 种眼科疾病**(如青光眼、白内障、AMD、糖尿病视网膜病变)的发病风险。由于患者可能合并多种眼病,因此使用多标签框架。 + +--- + +### 2. 模型结构 + +整体结构为一个 **双模态神经网络**,用于捕捉 **时空气象模式**和 **个体患者属性**,并在融合时保持可解释性和灵活性。 + +#### 2.1 气象分支:3D ResNet + Transformer + AFNO + +- **3D ResNet18 主干** + 输入为 **365 天 × 50 通道 × 20×20×24 体素**。 + - 采用 **3D 卷积**在 (time × space × depth) 三个维度上同时建模,能够捕捉 **气候随时间的演变、地表到土壤的能量传输、雪深积累**等模式。 + - 输出为 **每日一个 512 维 embedding**。 + +- **时序 Transformer(含 MoE)** + 每日 embedding 输入到 **Transformer 编码器**: + - **自注意力**用于捕捉 365 天长时依赖。 + - **前馈层 (FFN) 替换为 MoE**:每个时间片的 token 被路由到不同专家,使得模型可以专门化处理 **不同气候类型**(如海洋性气候 vs. 大陆性气候)。 + +- **AFNO 分支(Adaptive Fourier Neural Operator)** + 并行使用 **AFNO** 捕捉气象的周期性: + - 将时间序列通过 **FFT** 转换到频率域。 + - 使用 **分块对角的复数线性变换**学习主要频率模式(如季节性波动、短期振荡)。 + - 通过 **soft-shrinkage 稀疏化**抑制非主要频率的噪声。 + - **逆 FFT**还原时域信号。 + 同时,AFNO 分支也加入了 **MoE**,使不同频率特征由不同专家处理,适应气候区域差异。 + +- **气象时序特征融合** + Transformer 与 AFNO 的输出拼接后,经线性层映射回 512 维,得到综合的气象时序表示。 + +--- + +#### 2.2 患者分支:TabM (Tabular Mixture) + +患者 424 维特征通过 **TabM**(Yandex Research 提出)处理。 + +- **核心原理** + TabM 提出了一种 **高效集成 (efficient ensemble)** 机制: + - 主权重矩阵 \(W\) 在所有子模型间共享。 + - 每个子模型(专家)通过一对 **低秩缩放向量** \((r_e, s_e)\) 调节输入/输出: + \[ + y_e = \big[ (x \odot r_e) W^\top \big] \odot s_e + b_e + \] + 其中 \(x\) 为输入特征,\(\odot\) 表示逐元素乘,\(b_e\) 为每个子模型的偏置。 + - 这样可以在几乎不增加参数量的情况下,构建一个内部的“打包集成模型”。 + +- **优点** + - 让模型在面对 **分布差异的亚群体**时更加鲁棒(例如不同生活方式的人群)。 + - 内部集成提高了 **泛化能力**和 **不确定性估计**。 + - 计算成本几乎与单模型相同。 + +我们将 TabM 输出投影到 512 维,与气象分支对齐。 + +--- + +#### 2.3 跨模态融合 + +采用 **双向交叉注意力 (bi-directional cross-attention)**: + +- **气象 → 患者**:患者 token 在气象序列上查询,关注与疾病最相关的时间片。 +- **患者 → 气象**:气象 summary token 在患者 embedding 上查询,将气候解释与个体特征结合。 + +两个更新后的 token 拼接,经 MLP 得到融合表示。 +可选地在这一层加入 **MoE 头**,让模型专门化处理不同疾病亚型。 + +--- + +#### 2.4 分类器 + +最终融合表示经线性分类头,预测 **4 个眼科疾病的风险概率**。 +损失函数使用 **binary cross-entropy with logits**。 + +--- + +### 3. 训练与推理 + +- **损失函数**:带类别权重的 BCE。 + +- **优化器**:Adam,带梯度裁剪。 + +- **推理时的检索增强 (Retrieval-Augmented Inference)** + 推理时,我们利用训练集构建一个特征库: + - 用模型的中间表示作为索引。 + - 在测试样本预测时,检索出 k 个最相似的训练样本(相似度可选 **余弦**或 **欧式距离**)。 + - 计算邻居的平均概率 \(p_{knn}\)。 + - 与模型预测概率 \(p_{model}\) 融合: + \[ + p_{final} = (1-\alpha)\,p_{model} + \alpha\,p_{knn} + \] + +这样能缓解训练/测试分布差异。 + +--- + +### 4. 可解释性 + +我们设计了多层次的可解释机制: + +- **3D Grad-CAM**:可视化气象体素中(纬度 × 经度 × 深度)最关键的区域。 +- **Transformer 注意力图**:显示模型在 365 天中关注的关键时段(如冬季骤降)。 +- **AFNO 频率分析**:指出模型利用的主导频率成分。 +- **MoE 路由可视化**:分析样本被分配到的专家,揭示潜在的病人亚群体或气候模式(例如“雪覆盖驱动的风险群体”)。 + + + +# 项目总结 + +## 已完成工作 +- 搭建多进程 ERA5-Land 数据下载框架,完成 50 个气象变量的收集。 +- 实现数据格式转换与重采样(0.1° → 20×20×24),构建日尺度气象数据立方体。 +- 收集并清洗 UK Biobank 患者特征(424 维),通过 XGBoost 完成缺失值填补。 +- 设计并实现多模态模型框架(3D ResNet + Transformer + AFNO + TabM + MoE + Cross-Attn)。 +- 完成 toy dataset 与真实数据子集的基线训练,验证模型结构可行性。 + +## 遇到的问题及解决方案 +- **数据下载效率低** → 使用多进程并行下载,大幅缩短获取 ERA5 数据的时间。 +- **数据申请受阻**→ 原计划中希望使用UKB基因数据库以及细粒度到经纬度的数据,但是申请流程太长采用城市地理中心,这样原先预设的时空细粒度难以对齐的问题反而不严重了。 +- **缺失值比例高** → 采用 XGBoost 学习型填补方法,相比均值/中位数填补更符合变量间分布关系。 +- **数据对齐复杂** → 按照患者发病时间窗口(365 天)和居住城市坐标匹配 ERA5 数据,构建个体化时空样本。 +- **PaddlePaddle 模块限制** → 例如 `nn.MultiHeadAttention` 不支持 `need_weights` 参数,导致 attention 可解释性实现受限;通过自定义 Cross-Attn 与保存注意力矩阵解决。 +- **气象数据噪声较多** → 采用了传统Transformer和时间序列维度上的AFNO双路transformer提取时域频域特征。 +- **模型需要可解释性方面的贡献** →Transformer 与 MoE 模块默认不输出可解释信息,容易形成“黑箱”;我们通过 自定义 Cross-Attention 权重输出、3D Grad-CAM 空间可视化、时序 Attention Map、AFNO 频域分解 与 MoE 路由可视化 等手段,显式揭示了模型在空间、时间、频率及亚群体层面的关注点。这些改进不仅解决了可解释性瓶颈,也使模型能够为临床专家与政策制定提供透明、可追溯的证据。 +- **计算负担大** → 模型包含 3D CNN + 双 Transformer 分支 + MoE,需依赖 多块A100 GPU 训练;经过分析主要时间复杂度集中在3D CNN上,我们通过气象特征筛选减少3D CNN通道数量,通过取对最后一个维度(小时)平均压缩维度,将3D CNN替换为2D CNN. +## 未来工作计划 + + +在已有的 **多模态深度学习框架** 基础上(融合 ERA5 气象数据与 UK Biobank 等多源患者数据),本研究聚焦于 **显式级联建模**(环境 → 系统 → 暴露 → 生物 → 疾病),构建跨模态、可解释的疾病预测与干预模拟平台。总体目标是: + +1. 在 *Nature Communications* / *Nature Medicine/AAAI/IJCAI* 发表 1–2 篇论文; +2. 提出适用于多模态级联预测的通用框架; +3. 验证模型在 UKB、CKB、FinnGen、BBJ 等多个 Biobank 上的可扩展性; +4. 提供气象与环境干预下的疾病风险模拟; +5. 为城市规划、空气质量控制、疾病防控政策提供量化证据。 + +## 测试样例 +见paddlepaddle文件,考虑到完整数据运行时间过长因此我们再里面提供了仿真数据构成的toydataloader替代。 + + +# Project Information + +## Project Title +Ophthalmic Disease Onset Prediction Based on Multimodal Deep Learning: Integrating ERA5 Meteorological Data and UK Biobank Patient Features + +--- + +## Timeline (Detailed for 3 Months) + +### Month 1: Data Preparation and Preprocessing +- **Week 1** + - Familiarize with ERA5-Land data interfaces and variable selection (confirm 50 meteorological feature variables). + - Build a multi-process downloading framework and test download speed/stability under different concurrency parameters. +- **Week 2** + - Batch download and cache ERA5-Land data (covering UK regions from 2010–2020). + - Complete data format conversion (NetCDF → Zarr/HDF5) and perform spatiotemporal resampling (0.1° → 20×20×24). +- **Week 3** + - Collect UK Biobank patient features (424 dimensions) and analyze missing value distributions. + - Test multiple imputation methods (mean/median, MICE, KNN, XGBoost) and compare performance. +- **Week 4** + - Finalize missing value imputation strategy (XGBoost). + - Establish alignment scheme between ERA5 and UKB patient data (spatial location + time window). + - Output aligned multimodal training samples. + +--- + +### Month 2: Model Development and Baseline Training +- **Week 1** + - Build a 3D ResNet18 encoder for extracting spatial features from daily meteorological volumetric data. + - Implement TabM module for modeling UKB patient phenotypes. +- **Week 2** + - Implement Transformer branch (temporal attention). + - Implement AFNO branch (frequency-domain sparse modeling) and complete feature concatenation and fusion with Transformer outputs. +- **Week 3** + - Integrate MoE modules (Transformer FFN, AFNO FFN, TabM projection, fusion layer with switchable experts). + - Complete Cross-Attention fusion for bimodal interaction. +- **Week 4** + - Conduct baseline training (toy dataset + partial real samples). + - Record multi-label classification metrics (AUC, F1, Recall, Precision, PR-AUC, Hamming Loss). + - Hyperparameter tuning: batch size, learning rate, number of experts (MoE). + +--- + +### Month 3: Explainability and Summary +- **Week 1** + - Implement and test 3D Grad-CAM to visualize spatial attention regions in meteorological modality. + - Visualize Self-Attention Heatmap (temporal weights over 365 days). +- **Week 2** + - Implement MoE routing recording and clustering; analyze sample distributions across experts. + - Explore whether domain-specific experts (e.g., "winter patients," "elderly patients") are automatically separated. +- **Week 3** + - Consolidate explainability results: + - Grad-CAM (spatial hotspots) + - Attention Map (temporal hotspots) + - MoE clustering (sample-expert distribution) + - Draft interpretability analysis summary. +- **Week 4** + - Complete technical report (methods, experiments, challenges, and conclusions). + - Organize results and prepare initial paper framework. + +--- + +## Project Design + +### 1. Data Preparation + +We integrate two heterogeneous modalities: + +- **Meteorological Modality (ERA5-Land)** + Downloaded **50 meteorological and environmental variables** (e.g., *2m temperature, soil temperature layers, snow cover, precipitation, radiation flux, vegetation index*). + - Original spatial resolution: 0.1° (~10×10 grid per city). + - Each pixel includes **24 vertical layers** (e.g., soil, atmosphere). + - Resized to **20×20×24** voxels, with each patient matched to **365 days** of data before disease onset. + - Used **multi-process parallel downloading** to accelerate ERA5-Land retrieval and preprocessing. + +- **Patient Modality (UK Biobank)** + Extracted **424 patient features** (demographics, lifestyle, clinical variables). + For missing values, applied **XGBoost-based predictive imputation**, which outperforms mean/median filling for heterogeneous clinical data. + Patient features are aligned with ERA5 meteorological data via **onset date** and **residential location**, enabling spatiotemporal matching. + +**Task**: Multi-label classification predicting **4 ophthalmic diseases** (e.g., glaucoma, cataract, AMD, diabetic retinopathy). As patients may develop multiple diseases, a multi-label framework is required. + +--- + +### 2. Model Architecture + +The overall design is a **bimodal neural network**, capturing both **spatiotemporal meteorological patterns** and **individual patient attributes**, while ensuring interpretability and flexibility. + +#### 2.1 Meteorological Branch: 3D ResNet + Transformer + AFNO + +- **3D ResNet18 Backbone** + Input: **365 days × 50 channels × 20×20×24 voxels**. + - **3D convolutions** jointly model (time × space × depth), capturing **seasonal changes, soil-atmosphere energy transfer, and snow accumulation**. + - Outputs **one 512-d embedding per day**. + +- **Temporal Transformer (with MoE)** + Daily embeddings are fed into a **Transformer encoder**: + - **Self-attention** captures long-term dependencies across 365 days. + - **FFN replaced by MoE**: each token is routed to different experts, allowing specialized handling of **climate types** (e.g., maritime vs. continental). + +- **AFNO (Adaptive Fourier Neural Operator) Branch** + Models periodicity in parallel: + - Convert sequence via **FFT** to frequency domain. + - Apply **block-diagonal complex linear transforms** to learn dominant frequencies (e.g., seasonal cycles, short-term oscillations). + - Use **soft-shrinkage sparsity** to suppress noise. + - Perform **inverse FFT** to reconstruct. + - Added MoE to allow frequency-specific specialization across regions. + +- **Meteorological Temporal Feature Fusion** + Concatenate Transformer and AFNO outputs, then project back to 512-d, forming integrated meteorological representations. + +--- + +#### 2.2 Patient Branch: TabM (Tabular Mixture) + +Patient 424-d features are modeled with **TabM** (proposed by Yandex Research). + +- **Core Idea** + TabM is an **efficient ensemble mechanism**: + - Weight matrix \(W\) is shared across sub-models. + - Each expert adjusts input/output via low-rank scaling vectors \((r_e, s_e)\): + \[ + y_e = \big[ (x \odot r_e) W^\top \big] \odot s_e + b_e + \] + where \(x\) is input, \(\odot\) is element-wise multiplication, and \(b_e\) is bias. + - Builds an ensemble internally with minimal additional parameters. + +- **Advantages** + - Robust to **population subgroup heterogeneity** (e.g., lifestyle differences). + - Improves **generalization** and **uncertainty estimation**. + - Computationally comparable to a single model. + +TabM outputs are projected to 512-d for alignment with meteorological features. + +--- + +#### 2.3 Cross-Modal Fusion + +Implemented **bi-directional cross-attention**: + +- **Meteorology → Patient**: patient tokens query meteorological sequences to attend to disease-relevant time slices. +- **Patient → Meteorology**: meteorological summary tokens query patient embeddings, linking climate patterns with individual traits. + +Fused tokens are concatenated and passed through MLP. +Optionally, a **MoE head** is added to specialize for disease subtypes. + +--- + +#### 2.4 Classifier + +The fused representation is passed to a linear classifier for **multi-label risk prediction of 4 ophthalmic diseases**. +Loss function: **binary cross-entropy with logits**. + +--- + +### 3. Training and Inference + +- **Loss Function**: BCE with class weights. +- **Optimizer**: Adam with gradient clipping. +- **Retrieval-Augmented Inference (RAI)**: + - Build feature index from training embeddings. + - At inference, retrieve *k* nearest neighbors. + - Compute average probability \(p_{knn}\). + - Combine with model prediction \(p_{model}\): + \[ + p_{final} = (1-\alpha)\,p_{model} + \alpha\,p_{knn} + \] + - Mitigates train-test distribution shift. + +--- + +### 4. Explainability + +Multi-level interpretability mechanisms: + +- **3D Grad-CAM**: highlights key spatial voxels (lat × lon × depth). +- **Transformer Attention Maps**: identify critical time windows (e.g., winter drops). +- **AFNO Frequency Analysis**: reveal dominant periodic components. +- **MoE Routing Visualization**: analyze expert assignments, revealing subgroups (e.g., "snow-driven high-risk patients"). + +--- + +# Project Summary + +## Completed Work +- Built multi-process ERA5-Land data downloading framework; collected 50 meteorological variables. +- Converted and resampled data (0.1° → 20×20×24) to daily meteorological cubes. +- Collected and cleaned UK Biobank patient features (424-d); imputed missing values using XGBoost. +- Designed and implemented multimodal model (3D ResNet + Transformer + AFNO + TabM + MoE + Cross-Attn). +- Conducted baseline training on toy and subset datasets, validating feasibility. + +## Challenges and Solutions +- **Low data download efficiency** → Solved with multi-process parallel downloading. +- **High missing rates** → Addressed via XGBoost predictive imputation, outperforming mean/median. +- **Complex data alignment** → Built spatiotemporal matching pipeline using onset date + residential location. +- **Framework limitation (PaddlePaddle)** → `nn.MultiHeadAttention` lacked `need_weights`; resolved by custom Cross-Attn with weight saving. +- **Need for interpretability contributions** → Tackled the “black box” issue by integrating: + - Cross-Attention weight outputs + - 3D Grad-CAM spatial visualization + - Temporal attention maps + - AFNO frequency decomposition + - MoE routing visualization + These enhancements provided transparency at spatial, temporal, frequency, and subgroup levels, supporting clinical and policy insights. +- **Heavy computational load** → Model combines 3D CNN + dual Transformers + MoE, requiring multi-GPU (A100); solved using gradient clipping, model quantization, and distributed parallel training. + +--- + +## Future Work Plan + +Building on the existing **multimodal deep learning framework** (ERA5 meteorology + UK Biobank features), the research will focus on **explicit cascade modeling** (Environment → System → Exposure → Biology → Disease) to construct an interpretable, multimodal disease prediction and intervention simulation platform. + +**Goals:** +1. Publish 1–2 papers in *Nature Communications*, *Nature Medicine*, AAAI, or IJCAI. +2. Propose a generalizable multimodal cascade prediction framework. +3. Validate scalability across multiple Biobanks (UKB, CKB, FinnGen, BBJ). +4. Provide disease risk simulations under environmental and climate interventions. +5. Deliver quantitative evidence for urban planning, air quality control, and public health policy.