diff --git a/README.md b/README.md index 6b12827..847c0e6 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ Generative models for conditional molecular structure generation ============================== [//]: # (Badges) -[![GitHub Actions Build Status](https://github.com/REPLACE_WITH_OWNER_ACCOUNT/molgen/workflows/CI/badge.svg)](https://github.com/REPLACE_WITH_OWNER_ACCOUNT/molgen/actions?query=workflow%3ACI) +[![GitHub Actions Build Status](https://github.com/Ferg-Lab/molgen/workflows/CI/badge.svg)](https://github.com/Ferg-Lab/molgen/actions?query=workflow%3ACI) @@ -18,6 +18,7 @@ To use `molgen`, you will need an environment with the following packages: * Python 3.7+ * [PyTorch](https://pytorch.org/get-started/locally/) * [PyTorch Lightning](https://www.pytorchlightning.ai/) +* [Einops](https://einops.rocks/#Installation) For running and visualizing examples: * [NumPy](https://numpy.org/install/) @@ -80,6 +81,14 @@ model.save('ADP.ckpt') model = WGANGP.load_from_checkpoint('ADP.ckpt') ``` +Supports both generators based on both Generative Adversarial Networks (GANs) and Denoising Diffusion Probabilistic Models (DDPMs). The example above uses GANs, DDPMs support an equivalent API -- for example, + +```python +from molgen.models import DDPM + +model = DDPM(....) +``` + ### Copyright Copyright (c) 2023, Kirill Shmilovich diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index 35bc967..4e4ddb5 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -12,6 +12,8 @@ dependencies: # package dependencies - pytorch - pytorch-lightning + - einops + - tqdm # Testing - pytest diff --git a/examples/ADP_all_atom_DDPM.ipynb b/examples/ADP_all_atom_DDPM.ipynb new file mode 100644 index 0000000..943b715 --- /dev/null +++ b/examples/ADP_all_atom_DDPM.ipynb @@ -0,0 +1,1201 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import mdtraj as md\n", + "from pathlib import Path\n", + "import torch\n", + "import numpy as np\n", + "import sys\n", + "sys.path.append('../')\n", + "from molgen.models import DDPM" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "pdb_fname = '/project/andrewferguson/Kirill/CMSC-35450/data_mdshare/alanine-dipeptide-nowater.pdb'\n", + "trj_fnames = [str(i) for i in Path('/project/andrewferguson/Kirill/CMSC-35450/data_mdshare').glob('alanine-dipeptide-*-250ns-nowater.xtc')]\n", + "trjs = [md.load(t, top=pdb_fname).center_coordinates() for t in trj_fnames]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(torch.Size([250000, 66]), torch.Size([250000, 2]))" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "xyz = list()\n", + "phi_psi = list()\n", + "for trj in trjs:\n", + " \n", + " t_backbone = trj.center_coordinates()\n", + " \n", + " n = trj.xyz.shape[0]\n", + " \n", + " _, phi = md.compute_phi(trj)\n", + " _, psi = md.compute_psi(trj)\n", + " \n", + " xyz.append(torch.tensor(t_backbone.xyz.reshape(n, -1)).float())\n", + " phi_psi.append(torch.tensor(np.concatenate((phi, psi), -1)).float())\n", + " \n", + "xyz[0].shape, phi_psi[0].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "model = DDPM(xyz[0].shape[1], phi_psi[0].shape[1])" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/project/andrewferguson/Kirill/class_project_env/lib/python3.7/site-packages/lightning_lite/plugins/environments/slurm.py:170: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /project/andrewferguson/Kirill/class_project_env/lib ...\n", + " category=PossibleUserWarning,\n", + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "\n", + " | Name | Type | Params\n", + "--------------------------------------------------------\n", + "0 | model | GaussianDiffusion | 3.0 M \n", + "1 | ema_model | GaussianDiffusion | 3.0 M \n", + "2 | _feature_scaler | MinMaxScaler | 0 \n", + "3 | _condition_scaler | MinMaxScaler | 0 \n", + "--------------------------------------------------------\n", + "6.0 M Trainable params\n", + "0 Non-trainable params\n", + "6.0 M Total params\n", + "24.081 Total estimated model params size (MB)\n", + "/project/andrewferguson/Kirill/class_project_env/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:229: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 48 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", + " category=PossibleUserWarning,\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a7a0b14caaa745228133862cff25599c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=25` reached.\n" + ] + }, + { + "data": { + "text/plain": [ + "DDPM(\n", + " (model): GaussianDiffusion(\n", + " (denoise_fn): Unet(\n", + " (time_pos_emb): SinusoidalPosEmb()\n", + " (mlp): Sequential(\n", + " (0): Linear(in_features=32, out_features=128, bias=True)\n", + " (1): Mish()\n", + " (2): Linear(in_features=128, out_features=32, bias=True)\n", + " )\n", + " (downs): ModuleList(\n", + " (0): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=32, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(1, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Conv1d(1, 32, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=32, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(32, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 32, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (3): Downsample(\n", + " (conv): Conv1d(32, 32, kernel_size=(3,), stride=(2,), padding=(1,))\n", + " )\n", + " )\n", + " (1): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=64, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Conv1d(32, 64, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=64, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(64, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 64, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (3): Downsample(\n", + " (conv): Conv1d(64, 64, kernel_size=(3,), stride=(2,), padding=(1,))\n", + " )\n", + " )\n", + " (2): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=128, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Conv1d(64, 128, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=128, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(128, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 128, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (3): Downsample(\n", + " (conv): Conv1d(128, 128, kernel_size=(3,), stride=(2,), padding=(1,))\n", + " )\n", + " )\n", + " (3): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=256, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Conv1d(128, 256, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=256, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(256, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 256, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (3): Identity()\n", + " )\n", + " )\n", + " (ups): ModuleList(\n", + " (0): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=128, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(512, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Conv1d(512, 128, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=128, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(128, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 128, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (3): Upsample(\n", + " (conv): ConvTranspose1d(128, 128, kernel_size=(4,), stride=(2,), padding=(1,))\n", + " )\n", + " )\n", + " (1): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=64, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Conv1d(256, 64, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=64, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(64, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 64, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (3): Upsample(\n", + " (conv): ConvTranspose1d(64, 64, kernel_size=(4,), stride=(2,), padding=(1,))\n", + " )\n", + " )\n", + " (2): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=32, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Conv1d(128, 32, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=32, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(32, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 32, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (3): Upsample(\n", + " (conv): ConvTranspose1d(32, 32, kernel_size=(4,), stride=(2,), padding=(1,))\n", + " )\n", + " )\n", + " )\n", + " (mid_block1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=256, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (mid_attn): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(256, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 256, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (mid_block2): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=256, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (final_conv): Sequential(\n", + " (0): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (1): Conv1d(32, 1, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (ema_model): GaussianDiffusion(\n", + " (denoise_fn): Unet(\n", + " (time_pos_emb): SinusoidalPosEmb()\n", + " (mlp): Sequential(\n", + " (0): Linear(in_features=32, out_features=128, bias=True)\n", + " (1): Mish()\n", + " (2): Linear(in_features=128, out_features=32, bias=True)\n", + " )\n", + " (downs): ModuleList(\n", + " (0): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=32, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(1, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Conv1d(1, 32, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=32, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(32, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 32, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (3): Downsample(\n", + " (conv): Conv1d(32, 32, kernel_size=(3,), stride=(2,), padding=(1,))\n", + " )\n", + " )\n", + " (1): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=64, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Conv1d(32, 64, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=64, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(64, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 64, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (3): Downsample(\n", + " (conv): Conv1d(64, 64, kernel_size=(3,), stride=(2,), padding=(1,))\n", + " )\n", + " )\n", + " (2): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=128, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Conv1d(64, 128, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=128, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(128, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 128, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (3): Downsample(\n", + " (conv): Conv1d(128, 128, kernel_size=(3,), stride=(2,), padding=(1,))\n", + " )\n", + " )\n", + " (3): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=256, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Conv1d(128, 256, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=256, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(256, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 256, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (3): Identity()\n", + " )\n", + " )\n", + " (ups): ModuleList(\n", + " (0): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=128, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(512, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Conv1d(512, 128, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=128, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(128, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 128, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (3): Upsample(\n", + " (conv): ConvTranspose1d(128, 128, kernel_size=(4,), stride=(2,), padding=(1,))\n", + " )\n", + " )\n", + " (1): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=64, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Conv1d(256, 64, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=64, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(64, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 64, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (3): Upsample(\n", + " (conv): ConvTranspose1d(64, 64, kernel_size=(4,), stride=(2,), padding=(1,))\n", + " )\n", + " )\n", + " (2): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=32, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Conv1d(128, 32, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=32, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(32, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 32, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (3): Upsample(\n", + " (conv): ConvTranspose1d(32, 32, kernel_size=(4,), stride=(2,), padding=(1,))\n", + " )\n", + " )\n", + " )\n", + " (mid_block1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=256, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (mid_attn): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(256, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 256, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (mid_block2): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=256, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (final_conv): Sequential(\n", + " (0): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (1): Conv1d(32, 1, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (_feature_scaler): MinMaxScaler()\n", + " (_condition_scaler): MinMaxScaler()\n", + ")" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.fit(xyz, phi_psi, max_epochs=25)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "fa74bced2e2443b19daf5d4d8fa043b3", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1cd1d37733864f8c99e54762a0d8a086", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "NGLWidget(max_frame=749999)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import nglview as nv\n", + "trj_backbones = md.join(trjs)\n", + "v = nv.show_mdtraj(trj_backbones)\n", + "v" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "11716ea146624004a2d3d246e8f304b6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "sampling loop time step: 0%| | 0/1000 [00:00" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "xyz = xyz.reshape(xyz.size(0), -1, 3)\n", + "fake_trj = md.Trajectory(xyz = xyz.cpu().numpy(), topology = trj_backbones.top)\n", + "fake_trj" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c0fc18a2d4774565bdad5869a2993ecf", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "NGLWidget(max_frame=7499)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "v = nv.show_mdtraj(fake_trj)\n", + "v" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:class_project_env]", + "language": "python", + "name": "conda-env-class_project_env-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/ADP_all_atom.ipynb b/examples/ADP_all_atom_GAN.ipynb similarity index 95% rename from examples/ADP_all_atom.ipynb rename to examples/ADP_all_atom_GAN.ipynb index b183151..2ac20a5 100644 --- a/examples/ADP_all_atom.ipynb +++ b/examples/ADP_all_atom_GAN.ipynb @@ -78,14 +78,12 @@ "name": "stderr", "output_type": "stream", "text": [ - "Auto select gpus: [0]\n", "/project/andrewferguson/Kirill/class_project_env/lib/python3.7/site-packages/lightning_lite/plugins/environments/slurm.py:170: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /project/andrewferguson/Kirill/class_project_env/lib ...\n", " category=PossibleUserWarning,\n", - "GPU available: True (cuda), used: True\n", + "GPU available: False, used: False\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", "\n", " | Name | Type | Params\n", "----------------------------------------------------------\n", @@ -105,7 +103,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "88470123eb8446579873289fe3b8581c", + "model_id": "ca9208efd62741419140618908520c7b", "version_major": 2, "version_minor": 0 }, @@ -175,7 +173,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "5514bfe2a09346cc97f49e22ac8f5238", + "model_id": "7e9c321345df448eae87d4eb1d190600", "version_major": 2, "version_minor": 0 }, @@ -187,7 +185,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "82929b7fdad547eda4151cf6258a8246", + "model_id": "c72dc3d5e0484843a18317f2a215725d", "version_major": 2, "version_minor": 0 }, @@ -233,7 +231,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "b9b054b38526481681883f3f9696f8e5", + "model_id": "cff40ee0e673443180558cbc516abc77", "version_major": 2, "version_minor": 0 }, diff --git a/examples/ADP_backbone_DDPM.ipynb b/examples/ADP_backbone_DDPM.ipynb new file mode 100644 index 0000000..e876e15 --- /dev/null +++ b/examples/ADP_backbone_DDPM.ipynb @@ -0,0 +1,1206 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import mdtraj as md\n", + "from pathlib import Path\n", + "import torch\n", + "import numpy as np\n", + "import sys\n", + "sys.path.append('../')\n", + "from molgen.models import DDPM" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "pdb_fname = '/project/andrewferguson/Kirill/CMSC-35450/data_mdshare/alanine-dipeptide-nowater.pdb'\n", + "trj_fnames = [str(i) for i in Path('/project/andrewferguson/Kirill/CMSC-35450/data_mdshare').glob('alanine-dipeptide-*-250ns-nowater.xtc')]\n", + "trjs = [md.load(t, top=pdb_fname).center_coordinates() for t in trj_fnames]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(torch.Size([250000, 24]), torch.Size([250000, 2]))" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "xyz = list()\n", + "phi_psi = list()\n", + "for trj in trjs:\n", + " \n", + " t_backbone = trj.atom_slice(trj.top.select('backbone')).center_coordinates()\n", + " \n", + " n = trj.xyz.shape[0]\n", + " \n", + " _, phi = md.compute_phi(trj)\n", + " _, psi = md.compute_psi(trj)\n", + " \n", + " xyz.append(torch.tensor(t_backbone.xyz.reshape(n, -1)).float())\n", + " phi_psi.append(torch.tensor(np.concatenate((phi, psi), -1)).float())\n", + " \n", + "xyz[0].shape, phi_psi[0].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "model = DDPM(xyz[0].shape[1], phi_psi[0].shape[1])" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/project/andrewferguson/Kirill/class_project_env/lib/python3.7/site-packages/lightning_lite/plugins/environments/slurm.py:170: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /project/andrewferguson/Kirill/class_project_env/lib ...\n", + " category=PossibleUserWarning,\n", + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "\n", + " | Name | Type | Params\n", + "--------------------------------------------------------\n", + "0 | model | GaussianDiffusion | 3.0 M \n", + "1 | ema_model | GaussianDiffusion | 3.0 M \n", + "2 | _feature_scaler | MinMaxScaler | 0 \n", + "3 | _condition_scaler | MinMaxScaler | 0 \n", + "--------------------------------------------------------\n", + "6.0 M Trainable params\n", + "0 Non-trainable params\n", + "6.0 M Total params\n", + "24.081 Total estimated model params size (MB)\n", + "/project/andrewferguson/Kirill/class_project_env/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:229: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 48 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", + " category=PossibleUserWarning,\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "907a81432c734265b7ae947cea7a9f32", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/project/andrewferguson/Kirill/class_project_env/lib/python3.7/site-packages/pytorch_lightning/trainer/call.py:48: UserWarning: Detected KeyboardInterrupt, attempting graceful shutdown...\n", + " rank_zero_warn(\"Detected KeyboardInterrupt, attempting graceful shutdown...\")\n" + ] + }, + { + "data": { + "text/plain": [ + "DDPM(\n", + " (model): GaussianDiffusion(\n", + " (denoise_fn): Unet(\n", + " (time_pos_emb): SinusoidalPosEmb()\n", + " (mlp): Sequential(\n", + " (0): Linear(in_features=32, out_features=128, bias=True)\n", + " (1): Mish()\n", + " (2): Linear(in_features=128, out_features=32, bias=True)\n", + " )\n", + " (downs): ModuleList(\n", + " (0): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=32, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(1, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Conv1d(1, 32, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=32, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(32, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 32, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (3): Downsample(\n", + " (conv): Conv1d(32, 32, kernel_size=(3,), stride=(2,), padding=(1,))\n", + " )\n", + " )\n", + " (1): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=64, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Conv1d(32, 64, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=64, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(64, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 64, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (3): Downsample(\n", + " (conv): Conv1d(64, 64, kernel_size=(3,), stride=(2,), padding=(1,))\n", + " )\n", + " )\n", + " (2): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=128, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Conv1d(64, 128, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=128, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(128, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 128, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (3): Downsample(\n", + " (conv): Conv1d(128, 128, kernel_size=(3,), stride=(2,), padding=(1,))\n", + " )\n", + " )\n", + " (3): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=256, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Conv1d(128, 256, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=256, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(256, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 256, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (3): Identity()\n", + " )\n", + " )\n", + " (ups): ModuleList(\n", + " (0): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=128, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(512, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Conv1d(512, 128, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=128, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(128, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 128, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (3): Upsample(\n", + " (conv): ConvTranspose1d(128, 128, kernel_size=(4,), stride=(2,), padding=(1,))\n", + " )\n", + " )\n", + " (1): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=64, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Conv1d(256, 64, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=64, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(64, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 64, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (3): Upsample(\n", + " (conv): ConvTranspose1d(64, 64, kernel_size=(4,), stride=(2,), padding=(1,))\n", + " )\n", + " )\n", + " (2): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=32, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Conv1d(128, 32, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=32, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(32, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 32, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (3): Upsample(\n", + " (conv): ConvTranspose1d(32, 32, kernel_size=(4,), stride=(2,), padding=(1,))\n", + " )\n", + " )\n", + " )\n", + " (mid_block1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=256, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (mid_attn): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(256, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 256, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (mid_block2): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=256, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (final_conv): Sequential(\n", + " (0): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (1): Conv1d(32, 1, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (ema_model): GaussianDiffusion(\n", + " (denoise_fn): Unet(\n", + " (time_pos_emb): SinusoidalPosEmb()\n", + " (mlp): Sequential(\n", + " (0): Linear(in_features=32, out_features=128, bias=True)\n", + " (1): Mish()\n", + " (2): Linear(in_features=128, out_features=32, bias=True)\n", + " )\n", + " (downs): ModuleList(\n", + " (0): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=32, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(1, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Conv1d(1, 32, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=32, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(32, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 32, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (3): Downsample(\n", + " (conv): Conv1d(32, 32, kernel_size=(3,), stride=(2,), padding=(1,))\n", + " )\n", + " )\n", + " (1): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=64, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Conv1d(32, 64, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=64, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(64, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 64, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (3): Downsample(\n", + " (conv): Conv1d(64, 64, kernel_size=(3,), stride=(2,), padding=(1,))\n", + " )\n", + " )\n", + " (2): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=128, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Conv1d(64, 128, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=128, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(128, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 128, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (3): Downsample(\n", + " (conv): Conv1d(128, 128, kernel_size=(3,), stride=(2,), padding=(1,))\n", + " )\n", + " )\n", + " (3): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=256, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Conv1d(128, 256, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=256, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(256, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 256, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (3): Identity()\n", + " )\n", + " )\n", + " (ups): ModuleList(\n", + " (0): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=128, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(512, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Conv1d(512, 128, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=128, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 128, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(128, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 128, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (3): Upsample(\n", + " (conv): ConvTranspose1d(128, 128, kernel_size=(4,), stride=(2,), padding=(1,))\n", + " )\n", + " )\n", + " (1): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=64, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Conv1d(256, 64, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=64, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 64, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(64, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 64, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (3): Upsample(\n", + " (conv): ConvTranspose1d(64, 64, kernel_size=(4,), stride=(2,), padding=(1,))\n", + " )\n", + " )\n", + " (2): ModuleList(\n", + " (0): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=32, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(128, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Conv1d(128, 32, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=32, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (2): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(32, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 32, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (3): Upsample(\n", + " (conv): ConvTranspose1d(32, 32, kernel_size=(4,), stride=(2,), padding=(1,))\n", + " )\n", + " )\n", + " )\n", + " (mid_block1): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=256, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (mid_attn): Residual(\n", + " (fn): Rezero(\n", + " (fn): LinearAttention(\n", + " (to_qkv): Conv1d(256, 384, kernel_size=(1,), stride=(1,), bias=False)\n", + " (to_out): Conv1d(128, 256, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (mid_block2): ResnetBlock(\n", + " (mlp): Sequential(\n", + " (0): Mish()\n", + " (1): Linear(in_features=32, out_features=256, bias=True)\n", + " )\n", + " (block1): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (block2): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 256, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (res_conv): Identity()\n", + " )\n", + " (final_conv): Sequential(\n", + " (0): Block(\n", + " (block): Sequential(\n", + " (0): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n", + " (1): GroupNorm(8, 32, eps=1e-05, affine=True)\n", + " (2): Mish()\n", + " )\n", + " )\n", + " (1): Conv1d(32, 1, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (_feature_scaler): MinMaxScaler()\n", + " (_condition_scaler): MinMaxScaler()\n", + ")" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.fit(xyz, phi_psi, max_epochs=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "collapsed": true + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b93ca3ebd6bc4adfae56b842bd53fb74", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "aa8027fe9aed4e3e974194ffa166a5aa", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "NGLWidget(max_frame=749999)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import nglview as nv\n", + "trj_backbones = md.join([trj.atom_slice(trj.top.select('backbone')) for trj in trjs])\n", + "v = nv.show_mdtraj(trj_backbones)\n", + "v" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d2d8b2c75ee244a2b1806d5007064099", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "sampling loop time step: 0%| | 0/1000 [00:00" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "xyz = xyz.reshape(xyz.size(0), -1, 3)\n", + "fake_trj = md.Trajectory(xyz = xyz.cpu().numpy(), topology = trj_backbones.top)\n", + "fake_trj" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "collapsed": true + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "dc0625f6c7754d06b46be37f942b75ba", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "NGLWidget(max_frame=749)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "v = nv.show_mdtraj(fake_trj)\n", + "v" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:class_project_env]", + "language": "python", + "name": "conda-env-class_project_env-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/ADP_backbone.ipynb b/examples/ADP_backbone_GAN.ipynb similarity index 93% rename from examples/ADP_backbone.ipynb rename to examples/ADP_backbone_GAN.ipynb index d38c79d..ad2e019 100644 --- a/examples/ADP_backbone.ipynb +++ b/examples/ADP_backbone_GAN.ipynb @@ -80,14 +80,12 @@ "name": "stderr", "output_type": "stream", "text": [ - "Auto select gpus: [0]\n", "/project/andrewferguson/Kirill/class_project_env/lib/python3.7/site-packages/lightning_lite/plugins/environments/slurm.py:170: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /project/andrewferguson/Kirill/class_project_env/lib ...\n", " category=PossibleUserWarning,\n", - "GPU available: True (cuda), used: True\n", + "GPU available: False, used: False\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", "\n", " | Name | Type | Params\n", "----------------------------------------------------------\n", @@ -107,7 +105,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "dcc1f0c59c3142f3ae09c4404369484b", + "model_id": "91bc862da6ef40d5becb9456dc094947", "version_major": 2, "version_minor": 0 }, @@ -171,13 +169,25 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "d7cafc23b1eb465d83af7215296647e8", + "model_id": "4bb84c5b024640cd8d4c56c629df20c1", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a6494584dc984e51b6fcec03fb3be92a", "version_major": 2, "version_minor": 0 }, @@ -198,7 +208,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -207,7 +217,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -217,13 +227,13 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "ee7271deab364ea289f7c75f49bc75e3", + "model_id": "7d27f78241ed4d11a1a3ee8fafc50cdf", "version_major": 2, "version_minor": 0 }, diff --git a/molgen/data.py b/molgen/data.py index a32fbe5..526881e 100644 --- a/molgen/data.py +++ b/molgen/data.py @@ -11,18 +11,19 @@ class GANDataModule(LightningDataModule): """ GANDataModule is a Pytorch Lightning DataModule for training GANs. - It takes in feature_data and condition_data and creates a DataLoader for training. - The feature_data and condition_data should be of the same type (either a float tensor for single traj or list of float tensors for multiple trajs) - and must have the same number of data points. + It takes in feature_data and condition_data and creates a DataLoader for training. + The feature_data and condition_data should be of the same type (either a float tensor + for single traj or list of float tensors for multiple trajs) and must have the same + number of data points. Parameters ---------- feature_data : Union[torch.Tensor, List[torch.Tensor]] feature data for the GAN, either a float tensor for single traj or list of float tensors for multiple trajs - + condition_data : Union[torch.Tensor, List[torch.Tensor]] conditioning data for the GAN, either a float tensor for single traj or list of float tensors for multiple trajs - + batch_size : int, default = 10000 batch size for the DataLoader. Default is 1000. @@ -40,6 +41,7 @@ class GANDataModule(LightningDataModule): self.c_dim: int dimention of the conditioning """ + def __init__( self, feature_data: Union[torch.Tensor, List[torch.Tensor]], @@ -92,12 +94,12 @@ def __init__( def _get_scaler(self, data): """ Helper function to get the scaler for the data - + Parameters ---------- data : Union[torch.Tensor, List[torch.Tensor]] data to be scaled - + Returns ---------- MinMaxScaler : Scaler for the data @@ -120,10 +122,60 @@ def _get_scaler(self, data): def train_dataloader(self): """ Returns the DataLoader for training the GAN - + Returns ---------- DataLoader : Pytorch DataLoader for training the GAN """ dataset = TensorDataset(*self.train_data) return DataLoader(dataset, batch_size=self.batch_size, shuffle=True) + + +class DDPMDataModule(GANDataModule): + """ + DDPMDataModule is a Pytorch Lightning DataModule for training DDPMs. + It takes in feature_data and condition_data and creates a DataLoader for training. + The feature_data and condition_data should be of the same type (either a float tensor + for single traj or list of float tensors for multiple trajs) and must have the same + number of data points. + + Parameters + ---------- + feature_data : Union[torch.Tensor, List[torch.Tensor]] + feature data for the GAN, either a float tensor for single traj or list of float tensors for multiple trajs + + condition_data : Union[torch.Tensor, List[torch.Tensor]] + conditioning data for the GAN, either a float tensor for single traj or list of float tensors for multiple trajs + + batch_size : int, default = 10000 + batch size for the DataLoader. Default is 1000. + + Attributes + ---------- + self.feature_scaler: MinMaxScaler + scaler for scaling the feature data + + self.condition_scaler: MinMaxScaler + scaler for scaling the conditioning data + + self.x_dim: int + dimention of the features + + self.c_dim: int + dimention of the conditioning + """ + + def train_dataloader(self): + """ + Returns the DataLoader for training the DDPM + + Returns data as the conditions concatenated to the features + + Returns + ---------- + DataLoader : Pytorch DataLoader for training the GAN + """ + dataset = TensorDataset( + torch.cat((self.train_data[1], self.train_data[0]), dim=-1).unsqueeze(1) + ) + return DataLoader(dataset, batch_size=self.batch_size, shuffle=True) diff --git a/molgen/models.py b/molgen/models.py index 93881b8..1402707 100644 --- a/molgen/models.py +++ b/molgen/models.py @@ -3,11 +3,10 @@ import torch from pytorch_lightning import LightningModule, Trainer from typing import Union -from molgen.modules import ( - SimpleGenerator, - SimpleDiscriminator, -) -from molgen.data import GANDataModule +from molgen.modules import SimpleGenerator, SimpleDiscriminator, GaussianDiffusion, Unet +from molgen.data import GANDataModule, DDPMDataModule +from molgen.utils import EMA, MinMaxScaler +import copy class WGANGP(LightningModule): @@ -51,6 +50,7 @@ class WGANGP(LightningModule): **kwargs: Additional keyword arguments. """ + def __init__( self, feature_dim: int, @@ -77,6 +77,9 @@ def __init__( hidden_dim=self.hparams.dis_hidden_dim, ) + self._feature_scaler = MinMaxScaler(feature_dim) + self._condition_scaler = MinMaxScaler(condition_dim) + self.is_fit = False def forward(self, z, c): @@ -224,7 +227,8 @@ def fit( if not hasattr(self, "trainer_"): self.trainer_ = Trainer( - auto_select_gpus=True, + devices=1, + accelerator="gpu" if torch.cuda.is_available() else "cpu", max_epochs=max_epochs, logger=False, enable_checkpointing=False, @@ -255,8 +259,8 @@ def generate(self, c: torch.Tensor): assert self.is_fit, "model must be fit to data first using `fit`" assert ( - c.size(1) == self.trainer_.datamodule.c_dim - ), f"inconsistent dimensions, expecting {self.trainer_.datamodule.c_dim} dim" + c.size(1) == self.hparams.condition_dim + ), f"inconsistent dimensions, expecting {self.hparams.condition_dim} dim" self.eval() z = torch.randn(c.size(0), self.hparams.latent_dim, device=self.device) @@ -281,3 +285,246 @@ def save(self, fname: str): assert self.is_fit, "model must be fit to data first using `fit`" self.trainer_.save_checkpoint(fname) + + def on_load_checkpoint(self, checkpoint): + self.is_fit = True + return super().on_load_checkpoint(checkpoint) + + +class DDPM(LightningModule): + """ + A Denoising Diffusion Probabilistic Model (DDPM) implementation in Pytorch Lightning. + + Credit: https://github.com/lucidrains/denoising-diffusion-pytorch + + The DDPM class is a PyTorch implementation of the Diffusion Probabilistic Models (DDPM) algorithm for generative modeling. + It is built on top of the PyTorch Lightning framework and uses the Unet architecture for the generator and the + GaussianDiffusion class for the diffusion process. The DDPM class also includes an exponential moving average (EMA) + for stabilizing the training process. + + + Parameters + ---------- + feature_dim : int + The dimension of the feature space of the data. + + condition_dim : int + The dimension of the conditional input of the data. + + hidden_dim : int, default = 32 + Hidden dimention of the UNet model + + dis_hidden_dim : int, default = 256 + The dimension of the hidden layers in the discriminator network + + loss_type : str, default = 'huber' + The type of loss function used in the diffusion process. Acceptable options are + 'l1', 'l2' and 'huber' + + beta_schedule :str, default = 'linear' + The schedule for the beta parameter in the diffusion process. Acceptable options are + 'linear' and 'cosine' + + timesteps : int, default = 1000 + The number of timesteps in the diffusion process. + + lr : float, default = 2e-5 + The learning rate for the optimizer + + ema_decay : float, default = 0.995 + Decay rate of the EMA + + step_start_ema : int, default = 2000 + The number of steps after which the EMA will begin updating + + update_ema_every : int, default = 10 + The number of steps between updates to the EMA + + **kwargs: Additional keyword arguments. + + """ + + def __init__( + self, + feature_dim, + condition_dim, + hidden_dim=32, + loss_type="huber", + beta_schedule="linear", + timesteps=1000, + lr=2e-5, + ema_decay=0.995, + step_start_ema=2000, + update_ema_every=10, + ): + super().__init__() + self.save_hyperparameters() + + model = Unet(dim=hidden_dim, dim_mults=(1, 2, 4, 8), groups=8) + diffusion = GaussianDiffusion( + model, + timesteps=timesteps, + unmask_number=condition_dim, + loss_type=loss_type, + beta_schedule=beta_schedule, + ) + + self.model = diffusion + self.ema = EMA(ema_decay) + self.ema_model = copy.deepcopy(self.model) + + self._feature_scaler = MinMaxScaler(feature_dim) + self._condition_scaler = MinMaxScaler(condition_dim) + + self.reset_parameters() + + self.is_fit = False + + def configure_optimizers(self): + opt = torch.optim.Adam(self.model.parameters(), lr=self.hparams.lr) + return opt + + def reset_parameters(self): + self.ema_model.load_state_dict(self.model.state_dict()) + + def step_ema(self): + if self.global_step < self.hparams.step_start_ema: + self.reset_parameters() + return + self.ema.update_model_average(self.ema_model, self.model) + + def training_step(self, batch, batch_idx): + loss = self.model(batch[0]) + return loss + + def optimizer_step(self, *args, **kwargs): + super().optimizer_step(*args, **kwargs) + if self.global_step % self.hparams.update_ema_every == 0: + self.step_ema() + + def fit( + self, + feature_data: Union[torch.Tensor, list], + condition_data: Union[torch.Tensor, list], + batch_size: int = 1000, + max_epochs: int = 100, + **kwargs, + ): + """ + Fit the DDPM on provided data + + Parameters + ---------- + feature_data : torch.Tensor (single traj) or list[torch.Tensor] (multi traj) + tensor with dimentions dim 0 = steps, dim 1 = features of features representing the real data that + is strived to be recapitulated by the generative model + + condition_data : torch.Tensor (single traj) or list[torch.Tensor] (multi traj) + list of tensors with dimentions dim 0 = steps, dim 1 = features of features representing the conditioning + variables associated with each data point in feature space + + batch_size : int, default = 1000 + training batch size + + max_epochs : int, default = 100 + maximum number of epochs to train for + + **kwargs: + additional keyword arguments to be passed to the the Lightning `Trainer` + """ + datamodule = DDPMDataModule( + feature_data=feature_data, + condition_data=condition_data, + batch_size=batch_size, + **kwargs, + ) + if self.is_fit: + raise Warning( + """The `fit` method was called more than once on the same `DDPM` instance, + recreating data scaler on dataset from the most recent `fit` invocation. This warning + can be safely ignored if the `DDPM` is being fit on the same data""" + ) + self._feature_scaler = datamodule.feature_scaler + self._condition_scaler = datamodule.condition_scaler + + if not hasattr(self, "trainer_"): + self.trainer_ = Trainer( + devices=1, + accelerator="gpu" if torch.cuda.is_available() else "cpu", + max_epochs=max_epochs, + logger=False, + enable_checkpointing=False, + **kwargs, + ) + self.trainer_.fit(self, datamodule) + else: + self.trainer_.fit(self, datamodule) + + self.is_fit = True + return self + + def generate(self, c: torch.Tensor): + """ + Generate samples based on conditioning variables + + Parameters + ---------- + c : torch.Tensor + Conditioning variables, float tensor of shape (n_samples, conditioning_dim) + + Returns + ------- + gen: torch.Tensor + Generated samples, float tensor of shape (n_samples, feature_dim) + + """ + + assert self.is_fit, "model must be fit to data first using `fit`" + assert ( + c.size(1) == self.hparams.condition_dim + ), f"inconsistent dimensions, expecting {self.hparams.condition_dim} dim" + + self.eval() + c = self._condition_scaler.transform(c.to(self.device)).float().unsqueeze(1) + c = torch.cat( + ( + c, + torch.zeros( + c.shape[0], + c.shape[1], + self.hparams.feature_dim + self.hparams.condition_dim - c.shape[2], + dtype=float, + device=c.device, + ), + ), + -1, + ).float() + + gen = self.ema_model.sample( + self.hparams.feature_dim + self.hparams.condition_dim, + batch_size=c.shape[0], + samples=c, + ) + gen = gen[:, 0, self.hparams.condition_dim :] + + gen = self._feature_scaler.inverse_transform(gen) + + return gen + + def save(self, fname: str): + """ + Generates a synthetic trajectory from an initial starting point `x_0` + + Parameters + ---------- + fname : str + file name for saving a model checkpoint + """ + + assert self.is_fit, "model must be fit to data first using `fit`" + + self.trainer_.save_checkpoint(fname) + + def on_load_checkpoint(self, checkpoint): + self.is_fit = True + return super().on_load_checkpoint(checkpoint) diff --git a/molgen/modules.py b/molgen/modules.py index 0635e90..46154ba 100644 --- a/molgen/modules.py +++ b/molgen/modules.py @@ -1,4 +1,21 @@ +import torch from torch import nn +from einops import rearrange +from molgen.utils import ( + SinusoidalPosEmb, + Mish, + Residual, + default, + exists, + linear_schedule, + cosine_beta_schedule, + extract, + generate_inprint_mask, + noise_like, +) +import numpy as np +from tqdm.autonotebook import tqdm +from functools import partial class SimpleGenerator(nn.Module): @@ -7,7 +24,7 @@ class SimpleGenerator(nn.Module): It takes a latent dimension and an output dimension as input, and has a hidden dimension (default = 256) It is implemented as a sequential model with 3 linear layers, batch normalization and SiLU activation functions. It is a sub-class of nn.Module. - + Parameters ---------- latent_dim : int @@ -18,8 +35,9 @@ class SimpleGenerator(nn.Module): hidden_dim : int, default=256 dimension of the hidden layers - + """ + def __init__(self, latent_dim, output_dim, hidden_dim=256): super(SimpleGenerator, self).__init__() @@ -48,7 +66,7 @@ class SimpleDiscriminator(nn.Module): It takes an output dimension as input, and has a hidden dimension (default = 256) It is implemented as a sequential model with 3 linear layers and SiLU activation functions. It is a sub-class of nn.Module. - + Parameters ---------- output_dim : int @@ -56,8 +74,9 @@ class SimpleDiscriminator(nn.Module): hidden_dim : int, default=256 dimension of the hidden layers - + """ + def __init__(self, output_dim, hidden_dim=256): super(SimpleDiscriminator, self).__init__() @@ -74,3 +93,389 @@ def __init__(self, output_dim, hidden_dim=256): def forward(self, x): validity = self.model(x) return validity + + +class Block(nn.Module): + def __init__(self, dim, dim_out, groups=8): + super().__init__() + self.block = nn.Sequential( + nn.Conv1d(dim, dim_out, 3, padding=1), nn.GroupNorm(groups, dim_out), Mish() + ) + + def forward(self, x): + return self.block(x) + + +class ResnetBlock(nn.Module): + def __init__(self, dim, dim_out, *, time_emb_dim, groups=8): + super().__init__() + self.mlp = nn.Sequential(Mish(), nn.Linear(time_emb_dim, dim_out)) + + self.block1 = Block(dim, dim_out, groups=groups) + self.block2 = Block(dim_out, dim_out, groups=groups) + self.res_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else nn.Identity() + + def forward(self, x, time_emb): + h = self.block1(x) + h += self.mlp(time_emb)[:, :, None] + + h = self.block2(h) + return h + self.res_conv(x) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv1d(hidden_dim, dim, 1) + + def forward(self, x): + # b, c, l = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange( + qkv, "b (qkv heads c) l -> qkv b heads c l", heads=self.heads, qkv=3 + ) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange(out, "b heads c l -> b (heads c) l", heads=self.heads) + return self.to_out(out) + + +class Upsample(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class Downsample(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = nn.Conv1d(dim, dim, 3, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x): + return self.fn(x) * self.g + + +class Unet(nn.Module): + """From: https://github.com/lucidrains/denoising-diffusion-pytorch""" + + def __init__(self, dim, out_dim=None, dim_mults=(1, 2, 4, 8), groups=8): + super().__init__() + dims = [1, *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + + self.feature_dim = dim + self.dim_mults = dim_mults + self.time_pos_emb = SinusoidalPosEmb(dim) + self.mlp = nn.Sequential( + nn.Linear(dim, dim * 4), Mish(), nn.Linear(dim * 4, dim) + ) + + self.downs = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + num_resolutions = len(in_out) + + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (num_resolutions - 1) + + self.downs.append( + nn.ModuleList( + [ + ResnetBlock(dim_in, dim_out, time_emb_dim=dim, groups=groups), + ResnetBlock(dim_out, dim_out, time_emb_dim=dim, groups=groups), + Residual(Rezero(LinearAttention(dim_out))), + Downsample(dim_out) if not is_last else nn.Identity(), + ] + ) + ) + + mid_dim = dims[-1] + self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim, groups=groups) + self.mid_attn = Residual(Rezero(LinearAttention(mid_dim))) + self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim, groups=groups) + + for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): + is_last = ind >= (num_resolutions - 1) + + self.ups.append( + nn.ModuleList( + [ + ResnetBlock( + dim_out * 2, dim_in, time_emb_dim=dim, groups=groups + ), + ResnetBlock(dim_in, dim_in, time_emb_dim=dim, groups=groups), + Residual(Rezero(LinearAttention(dim_in))), + Upsample(dim_in) if not is_last else nn.Identity(), + ] + ) + ) + + out_dim = default(out_dim, 1) + self.final_conv = nn.Sequential( + Block(dim, dim, groups=groups), nn.Conv1d(dim, out_dim, 1) + ) + + def forward(self, x, time): + t = self.time_pos_emb(time) + t = self.mlp(t) + + h = [] + size_list = [] + + for resnet, resnet2, attn, downsample in self.downs: + x = resnet(x, t) + x = resnet2(x, t) + x = attn(x) + h.append(x) + size_list.append(x.shape[-1]) + x = downsample(x) + + x = self.mid_block1(x, t) + x = self.mid_attn(x) + x = self.mid_block2(x, t) + + for resnet, resnet2, attn, upsample in self.ups: + + x = torch.cat((x[:, :, : size_list.pop()], h.pop()), dim=1) + x = resnet(x, t) + x = resnet2(x, t) + x = attn(x) + x = upsample(x) + + return self.final_conv(x[:, :, : size_list.pop()]) + + +class GaussianDiffusion(nn.Module): + """From: https://github.com/lucidrains/denoising-diffusion-pytorch""" + + def __init__( + self, + denoise_fn, + timesteps=1000, + loss_type="l1", + betas=None, + beta_schedule="linear", + unmask_number=0, + ): + super().__init__() + self.denoise_fn = denoise_fn + + if exists(betas): + betas = ( + betas.detach().cpu().numpy() + if isinstance(betas, torch.Tensor) + else betas + ) + + # which beta scheduler to use + else: + if beta_schedule == "linear": + betas = linear_schedule(timesteps) + elif beta_schedule == "cosine": + betas = cosine_beta_schedule(timesteps) + + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.loss_type = loss_type + self.unmask_number = unmask_number + if unmask_number == 0: + self.unmask_index = None + else: + self.unmask_index = [*range(unmask_number)] + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) + ) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = ( + betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) + ) + + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer("posterior_variance", to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer( + "posterior_log_variance_clipped", + to_torch(np.log(np.maximum(posterior_variance, 1e-20))), + ) + self.register_buffer( + "posterior_mean_coef1", + to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), + ) + self.register_buffer( + "posterior_mean_coef2", + to_torch( + (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) + ), + ) + + def q_mean_variance(self, x_start, t): + mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract( + self.posterior_log_variance_clipped, t, x_t.shape + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + x_recon = self.predict_start_from_noise(x, t=t, noise=self.denoise_fn(x, t)) + + if clip_denoised: + x_recon.clamp_(-1.0, 1.0) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t + ) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, _, l, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance( + x=x, t=t, clip_denoised=clip_denoised + ) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + denosied_x = ( + model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + ) + inprint_mask = generate_inprint_mask(b, l, self.unmask_index).to(device) + + denosied_x[inprint_mask] = x[inprint_mask] + + return denosied_x + + @torch.no_grad() + def p_sample_loop(self, shape, samples=None): + device = self.betas.device + + b = shape[0] + state = torch.randn(shape, device=device) + + # if not samples == None: + if samples is not None: + assert shape == samples.shape + + inprint_mask = generate_inprint_mask(b, shape[2], self.unmask_index).to( + device + ) + state[inprint_mask] = samples[inprint_mask] + + for i in tqdm( + reversed(range(0, self.num_timesteps)), + desc="sampling loop time step", + total=self.num_timesteps, + ): + state = self.p_sample( + state, torch.full((b,), i, device=device, dtype=torch.long) + ) + + return state + + @torch.no_grad() + def sample(self, op_number, batch_size=16, samples=None): + return self.p_sample_loop((batch_size, 1, op_number), samples) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + + x_noisy = ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + # if not self.unmask_index == None: + if self.unmask_index is not None: + b, c, l = x_start.shape + inprint_mask = generate_inprint_mask(b, l, self.unmask_index).to( + x_start.device + ) + x_start[inprint_mask] + x_noisy[inprint_mask] = x_start[inprint_mask] + else: + inprint_mask = None + return x_noisy, inprint_mask + + def p_losses(self, x_start, t, noise=None): + b, c, l = x_start.shape + noise = default(noise, lambda: torch.randn_like(x_start)) + + x_noisy, inprint_mask = self.q_sample(x_start=x_start, t=t, noise=noise) + + x_recon = self.denoise_fn(x_noisy, t) + + # if not inprint_mask == None: + if inprint_mask is not None: + noise = torch.masked_select(noise, ~inprint_mask) + x_recon = torch.masked_select(x_recon, ~inprint_mask) + + if self.loss_type == "l1": + loss = torch.nn.functional.l1_loss(noise, x_recon) + elif self.loss_type == "l2": + loss = torch.nn.functional.mse_loss(noise, x_recon) + elif self.loss_type == "huber": + loss = torch.nn.functional.smooth_l1_loss(noise, x_recon) + else: + raise NotImplementedError() + + return loss + + def forward(self, x, *args, **kwargs): + b, *_, device = *x.shape, x.device + t = torch.randint(0, self.num_timesteps, (b,), device=device).long() + return self.p_losses(x, t, *args, **kwargs) diff --git a/molgen/tests/test_models.py b/molgen/tests/test_models.py index a34c094..b32f842 100644 --- a/molgen/tests/test_models.py +++ b/molgen/tests/test_models.py @@ -1,13 +1,37 @@ import torch import pytest -from molgen.models import WGANGP - -@pytest.mark.parametrize("feature_dim,condition_dim,gen_hidden_dim,dis_hidden_dim,lambda_gp,n_critic,latent_dim,lr,opt", [ - (100, 10, 256, 256, 10.0, 5, 128, 5e-5, "rmsprop"), - (100, 10, 256, 256, 10.0, 5, 128, 1e-4, "adam") -]) -def test_wgangp(feature_dim, condition_dim, gen_hidden_dim, dis_hidden_dim, lambda_gp, n_critic, latent_dim, lr, opt): - model = WGANGP(feature_dim, condition_dim, gen_hidden_dim, dis_hidden_dim, lambda_gp, n_critic, latent_dim, lr, opt) +from molgen.models import WGANGP, DDPM + + +@pytest.mark.parametrize( + "feature_dim,condition_dim,gen_hidden_dim,dis_hidden_dim,lambda_gp,n_critic,latent_dim,lr,opt", + [ + (100, 10, 256, 256, 10.0, 5, 128, 5e-5, "rmsprop"), + (100, 10, 256, 256, 10.0, 5, 128, 1e-4, "adam"), + ], +) +def test_wgangp( + feature_dim, + condition_dim, + gen_hidden_dim, + dis_hidden_dim, + lambda_gp, + n_critic, + latent_dim, + lr, + opt, +): + model = WGANGP( + feature_dim, + condition_dim, + gen_hidden_dim, + dis_hidden_dim, + lambda_gp, + n_critic, + latent_dim, + lr, + opt, + ) assert model.hparams.feature_dim == feature_dim assert model.hparams.condition_dim == condition_dim assert model.hparams.gen_hidden_dim == gen_hidden_dim @@ -23,9 +47,40 @@ def test_wgangp(feature_dim, condition_dim, gen_hidden_dim, dis_hidden_dim, lamb max_epochs = 2 # test fitting - model.fit(fake_feature_data, fake_condition_data, max_epochs) + model.fit(fake_feature_data, fake_condition_data, max_epochs=max_epochs) # test generation model.generate(fake_condition_data) - + +@pytest.mark.parametrize( + "feature_dim,condition_dim,hidden_dim,loss_type,beta_schedule,timesteps,lr", + [ + (100, 10, 32, "l1", "linear", 1000, 2e-5), + (100, 10, 32, "l2", "cosine", 1000, 2e-5), + (100, 10, 32, "huber", "cosine", 1000, 2e-5), + ], +) +def test_ddpm( + feature_dim, condition_dim, hidden_dim, loss_type, beta_schedule, timesteps, lr +): + model = DDPM( + feature_dim, condition_dim, hidden_dim, loss_type, beta_schedule, timesteps, lr + ) + assert model.hparams.feature_dim == feature_dim + assert model.hparams.condition_dim == condition_dim + assert model.hparams.hidden_dim == hidden_dim + assert model.hparams.loss_type == loss_type + assert model.hparams.beta_schedule == beta_schedule + assert model.hparams.timesteps == timesteps + assert model.hparams.lr == lr + + fake_feature_data = torch.randn(10, feature_dim) + fake_condition_data = torch.randn(10, condition_dim) + max_epochs = 2 + + # test fitting + model.fit(fake_feature_data, fake_condition_data, max_epochs=max_epochs) + + # test generation + model.generate(fake_condition_data) diff --git a/molgen/utils.py b/molgen/utils.py index a227e09..e43279b 100644 --- a/molgen/utils.py +++ b/molgen/utils.py @@ -1,6 +1,9 @@ """Some utility functions""" import torch +import math +from inspect import isfunction +import numpy as np class MinMaxScaler(torch.nn.Module): @@ -75,3 +78,110 @@ def inverse_transform(self, X): """ self._check_if_fit() return self.min + (X - self.feature_range[0]) / self.range + + +class SinusoidalPosEmb(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class Mish(torch.nn.Module): + def forward(self, x): + return x * torch.tanh(torch.nn.functional.softplus(x)) + + +class Residual(torch.nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, *args, **kwargs): + return self.fn(x, *args, **kwargs) + x + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def linear_schedule(timesteps, s=0.008): + """ + linear schedule + """ + betas = np.linspace(0.0001, 0.02, timesteps, dtype=np.float64) + return np.clip(betas, a_min=0, a_max=0.999) + + +def cosine_beta_schedule(timesteps, s=0.008): + """ + cosine schedule + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps = timesteps + 1 + x = torch.linspace(0, timesteps, steps, dtype=torch.float64) + alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + + betas = betas.numpy() + return np.clip(betas, a_min=0, a_max=0.999) + + +def extract(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( + shape[0], *((1,) * (len(shape) - 1)) + ) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + + +def generate_inprint_mask(n_batch, op_num, unmask_index=None): + """ + The mask will be True where we keep the true value and false where we want to infer the value + So far it only supporting masking the right side of images + """ + + mask = torch.zeros((n_batch, 1, op_num), dtype=bool) + # if not unmask_index == None: + if unmask_index is not None: + mask[:, :, unmask_index] = True + return mask + + +class EMA: + def __init__(self, beta): + super().__init__() + self.beta = beta + + def update_model_average(self, ma_model, current_model): + for current_params, ma_params in zip( + current_model.parameters(), ma_model.parameters() + ): + old_weight, up_weight = ma_params.data, current_params.data + ma_params.data = self.update_average(old_weight, up_weight) + + def update_average(self, old, new): + if old is None: + return new + return old * self.beta + (1 - self.beta) * new