Skip to content

Commit

Permalink
[Notebooks] improve clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
fedebotu committed Aug 22, 2023
1 parent a33c4fd commit 7d47178
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 56 deletions.
212 changes: 160 additions & 52 deletions notebooks/tutorials/3-change-encoder.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"source": [
"# Tutorial for changing the encoder and customizing the encoder\n",
"\n",
"In this notebook we will cover a tutorial for the flaxible encoders. Including a easy way to change the encoder of the model and customize the encoder to fit your experiment requirements. \n",
"In this notebook we will cover a tutorial for the flaxible encoders!\n",
"\n",
"<a href=\"https://colab.research.google.com/github/kaist-silab/rl4co/blob/main/notebooks/tutorials/3-change-encoder.ipynb\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"></a>"
]
Expand Down Expand Up @@ -50,8 +50,10 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/home/cbhua/.local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
"2023-08-22 18:18:55.097860: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
"2023-08-22 18:18:55.116842: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
"2023-08-22 18:18:55.438994: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
]
}
],
Expand All @@ -72,9 +74,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## A default minimal training scritp\n",
"## A default minimal training script\n",
"\n",
"Here we use the CVRP environment and AM model as a minimal example of training script. By default, the AM is initialized with a Graph Attention Encoder. "
"Here we use the CVRP environment and AM model as a minimal example of training script. By default, the AM is initialized with a Graph Attention Encoder, but we can change it to anything we want."
]
},
{
Expand All @@ -86,9 +88,9 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/home/cbhua/.local/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:196: UserWarning: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.\n",
"/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:196: UserWarning: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.\n",
" rank_zero_warn(\n",
"/home/cbhua/.local/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:196: UserWarning: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.\n",
"/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:196: UserWarning: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.\n",
" rank_zero_warn(\n",
"Using 16bit Automatic Mixed Precision (AMP)\n",
"GPU available: True (cuda), used: True\n",
Expand Down Expand Up @@ -152,41 +154,90 @@
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" "
]
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fe5417556e9242a28714f4e8a46d56aa",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Sanity Checking: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/cbhua/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:432: PossibleUserWarning: The dataloader, val_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
"/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:432: PossibleUserWarning: The dataloader, val_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
" rank_zero_warn(\n",
"/home/cbhua/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:432: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
"/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:432: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
" rank_zero_warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2: 100%|██████████| 196/196 [00:09<00:00, 21.26it/s, train/reward=-7.21, train/loss=-.501, val/reward=-7.05] "
]
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "68a8a713d72246a4ae5df669ea4b146b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"`Trainer.fit` stopped: `max_epochs=3` reached.\n"
]
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8b61b04b473c4096b08e6f7237cde214",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cdb5850788d44671ace75e5111f15e70",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b6209da9294147bcb83a8b99c012ceab",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 2: 100%|██████████| 196/196 [00:16<00:00, 12.17it/s, train/reward=-7.21, train/loss=-.501, val/reward=-7.05]\n"
"`Trainer.fit` stopped: `max_epochs=3` reached.\n"
]
}
],
Expand All @@ -201,7 +252,9 @@
"source": [
"## Change the Encoder\n",
"\n",
"In RL4CO, we provides two graph neural network encoders: *Graph Convolutionsal Network* (GCN) encoder and *Message Passing Neural Network* (MPNN) encoder. In this tutorial, we will show how to change the encoder. "
"In RL4CO, we provides two graph neural network encoders: *Graph Convolutionsal Network* (GCN) encoder and *Message Passing Neural Network* (MPNN) encoder. In this tutorial, we will show how to change the encoder. \n",
"\n",
"> Note: while we provide these examples, you can also implement your own encoder and use it in RL4CO!"
]
},
{
Expand All @@ -223,9 +276,9 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/home/cbhua/.local/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:196: UserWarning: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.\n",
"/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:196: UserWarning: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.\n",
" rank_zero_warn(\n",
"/home/cbhua/.local/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:196: UserWarning: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.\n",
"/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:196: UserWarning: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.\n",
" rank_zero_warn(\n",
"Using 16bit Automatic Mixed Precision (AMP)\n",
"GPU available: True (cuda), used: True\n",
Expand Down Expand Up @@ -280,8 +333,14 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/home/cbhua/.local/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:615: UserWarning: Checkpoint directory /home/cbhua/code/rl4co-rebuttal/notebooks/tutorials/checkpoints exists and is not empty.\n",
" rank_zero_warn(f\"Checkpoint directory {dirpath} exists and is not empty.\")\n",
"/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:615: UserWarning: Checkpoint directory /home/botu/Dev/rl4co/notebooks/tutorials/checkpoints exists and is not empty.\n",
" rank_zero_warn(f\"Checkpoint directory {dirpath} exists and is not empty.\")\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"val_file not set. Generating dataset instead\n",
"test_file not set. Generating dataset instead\n",
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
Expand All @@ -299,41 +358,90 @@
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" "
]
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a0b776cd4bbe4341bc058b4891c8f488",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Sanity Checking: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/cbhua/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:432: PossibleUserWarning: The dataloader, val_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
"/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:432: PossibleUserWarning: The dataloader, val_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
" rank_zero_warn(\n",
"/home/cbhua/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:432: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
"/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:432: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
" rank_zero_warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2: 100%|██████████| 196/196 [00:12<00:00, 15.63it/s, train/reward=-8.04, train/loss=2.110, val/reward=-7.82] "
]
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "907cfe2003c244dba7d45f7b3454a487",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"`Trainer.fit` stopped: `max_epochs=3` reached.\n"
]
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "40e5e900c09a4a9089d598e033cec768",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "82f258f1b8f34661843fe81ae9750e37",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "404cc0fc9c30441ab09199b84918ac86",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 2: 100%|██████████| 196/196 [00:22<00:00, 8.82it/s, train/reward=-8.04, train/loss=2.110, val/reward=-7.82]\n"
"`Trainer.fit` stopped: `max_epochs=3` reached.\n"
]
}
],
Expand All @@ -352,7 +460,7 @@
"\n",
"1. RL4CO provides the `env_init_embedding` method for each environment. You may want to use it to get the initial embedding of the environment.\n",
"2. `h` and `init_h` as return hidden features have the shape `([batch_size], num_node, hidden_size)`\n",
"3. In RL4CO, we put the graph neural network encoders in the `rl4co/models/nn/graph` folder. You may want to put your customized encoder to the same folder."
"3. In RL4CO, we put the graph neural network encoders in the `rl4co/models/nn/graph` folder. You may want to put your customized encoder to the same folder. Feel free to send a PR to add your encoder to RL4CO!"
]
},
{
Expand Down Expand Up @@ -420,7 +528,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.11"
},
"orig_nbformat": 4
},
Expand Down
14 changes: 10 additions & 4 deletions notebooks/tutorials/4-search-methods.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"source": [
"# EAS\n",
"\n",
"Efficient Active Search"
"In this notebook, we will showcase how to use the Efficient Active Search (EAS) algorithm to find better solutions to existing problems!"
]
},
{
Expand Down Expand Up @@ -145,7 +145,7 @@
"env.num_loc = 200\n",
"\n",
"dataset = env.dataset(batch_size=[2])\n",
"# eas_model = EASEmb(env, policy, dataset, batch_size=2, max_iters=20, save_path=\"eas_sols.pt\")\n",
"# eas_model = EASEmb(env, policy, dataset, batch_size=2, max_iters=20, save_path=\"eas_sols.pt\") # alternative\n",
"eas_model = EASLay(env, policy, dataset, batch_size=2, max_iters=20, save_path=\"eas_sols.pt\")\n",
"\n",
"eas_model.setup()"
Expand Down Expand Up @@ -368,8 +368,14 @@
"actions = actions[:torch.count_nonzero(actions, dim=-1)] # remove trailing zeros\n",
"state = td_dataset.cpu()[0]\n",
"\n",
"env.render(state, actions)\n",
"# small graphic bug"
"env.render(state, actions)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Even with few iterations, the search method can clearly find better solutions than the initial ones!"
]
}
],
Expand Down

0 comments on commit 7d47178

Please sign in to comment.