Skip to content
This repository was archived by the owner on Jul 9, 2022. It is now read-only.

Commit 697d672

Browse files
committed
add ddp
1 parent ebf5573 commit 697d672

File tree

3 files changed

+47
-22
lines changed

3 files changed

+47
-22
lines changed

README.md

+1-13
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
![PyPI](https://img.shields.io/pypi/v/pytorch-pqrnn?style=plastic) ![Maintenance](https://img.shields.io/maintenance/yes/2021?style=plastic) ![PyPI - License](https://img.shields.io/pypi/l/pytorch-pqrnn?style=plastic)
44

5-
<<<<<<< HEAD
65
## Installation
76

87
```bash
@@ -15,17 +14,6 @@ poetry install
1514
## Environment
1615

1716
Because of [this issue](https://github.com/salesforce/pytorch-qrnn/issues/29), `pytorch-qrnn` is no longer compatible with pytorch and it is also not actively maintained. If you want to use a QRNN layer in this model, you have install `pytorch-qrnn` with `torch <= 1.4` first.
18-
=======
19-
## Note
20-
21-
Because of [this issue](https://github.com/salesforce/pytorch-qrnn/issues/29), [QRNN](https://github.com/salesforce/pytorch-qrnn) is not supported with `torch >= 1.7`. If you want to use a QRNN layer with this repo, please follow the instructions [here](https://github.com/salesforce/pytorch-qrnn) to install `python-qrnn` first with downgraded `torch <= 1.4`. Otherwise, you can directly run
22-
23-
```
24-
pip install -r requirements.txt
25-
```
26-
27-
to set up the env.
28-
>>>>>>> d83b7c7e27e32583a585d93e463d7f82192622c4
2917

3018
## Usage
3119

@@ -97,7 +85,7 @@ Datasets
9785
| ------------------------ | ---------- | -------------------------- | ----------------- | --------------------------- | ---------------------------------------------------------------- |
9886
| ~~PQRNN (this repo)~~<sup>0</sup> | ~~78K~~ | ~~6.3~~ | ~~70.4~~ | ~~TODO~~ | `--b 128 --d 64 --num_layers 4 --rnn_type QRNN` |
9987
| PRNN (this repo) | 90K | 5.5 | **70.7** | 95.57 | `--b 128 --d 64 --num_layers 1 --rnn_type GRU` |
100-
| PTransformer (this repo) | 617K | 10.8 | 68 | 86.5 | `--b 128 --d 64 --num_layers 1 --rnn_type Transformer --nhead 2` |
88+
| PTransformer (this repo) | 618K | 10.8 | 68 | 92.4 | `--b 128 --d 64 --num_layers 1 --rnn_type Transformer --nhead 8` |
10189
| PRADO<sup>1</sup> | 175K | | 65.9 | | |
10290
| BERT | 335M | **1.81** | 70.58 | **98.856**<sup>2</sup> | |
10391
0. Not supported with `torch >= 1.7`

pytorch_pqrnn/model.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,8 @@
1515
import pytorch_lightning as pl
1616
import torch
1717
import torch.nn as nn
18-
from pytorch_lightning.metrics.functional import f1_score
19-
from pytorch_lightning.metrics.functional.classification import (
20-
accuracy,
21-
auroc,
22-
)
18+
from pytorch_lightning.metrics.functional import accuracy, auroc
19+
from pytorch_lightning.metrics.functional import f1 as f1_score
2320
from torch.nn import TransformerEncoder, TransformerEncoderLayer
2421
from torch.optim.lr_scheduler import ReduceLROnPlateau
2522

@@ -184,7 +181,14 @@ def validation_epoch_end(self, outputs):
184181
"val_auroc",
185182
np.mean(
186183
[
187-
auroc(logits[:, i], labels[:, i]).detach().cpu().item()
184+
auroc(
185+
torch.sigmoid(logits[:, i]),
186+
labels[:, i],
187+
pos_label=1,
188+
)
189+
.detach()
190+
.cpu()
191+
.item()
188192
for i in range(logits.shape[1])
189193
]
190194
),

run.py

+36-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
from pytorch_lightning import loggers as pl_loggers
55
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
6+
from pytorch_lightning.plugins import DeepSpeedPlugin
67
from pytorch_pqrnn.dataset import create_dataloaders
78
from pytorch_pqrnn.model import PQRNN
89
from rich.console import Console
@@ -46,6 +47,36 @@ def train(
4647
data_path: str,
4748
):
4849

50+
deepspeed_config = {
51+
"zero_allow_untested_optimizer": True,
52+
"optimizer": {
53+
"type": "Adam",
54+
"params": {
55+
"lr": lr,
56+
"betas": [0.998, 0.999],
57+
"eps": 1e-5,
58+
"weight_decay": 1e-9,
59+
},
60+
},
61+
"scheduler": {
62+
"type": "WarmupLR",
63+
"params": {
64+
"last_batch_iteration": -1,
65+
"warmup_min_lr": 0,
66+
"warmup_max_lr": 3e-5,
67+
"warmup_num_steps": 100,
68+
},
69+
},
70+
"zero_optimization": {
71+
"stage": 2, # Enable Stage 2 ZeRO (Optimizer/Gradient state partitioning)
72+
"cpu_offload": True, # Enable Offloading optimizer state/calculation to the host CPU
73+
"contiguous_gradients": True, # Reduce gradient fragmentation.
74+
"overlap_comm": True, # Overlap reduce/backward operation of gradients for speed.
75+
"allgather_bucket_size": 2e8, # Number of elements to all gather at once.
76+
"reduce_bucket_size": 2e8, # Number of elements we reduce/allreduce at once.
77+
},
78+
}
79+
4980
train_dataloader, dev_dataloader = create_dataloaders(
5081
task,
5182
batch_size=batch_size,
@@ -69,16 +100,18 @@ def train(
69100

70101
trainer = pl.Trainer(
71102
logger=pl_loggers.TensorBoardLogger("lightning_logs", log_graph=False),
72-
callbacks=[EarlyStopping(monitor="val_loss", patience=10)],
103+
callbacks=[EarlyStopping(monitor="val_loss", patience=5)],
73104
checkpoint_callback=ModelCheckpoint(
74105
"./checkpoints/", monitor="val_loss"
75106
),
76107
min_epochs=2,
77108
deterministic=True,
78109
val_check_interval=0.2,
79-
gpus=[0] if torch.cuda.is_available() else None,
110+
gpus=list(range(torch.cuda.device_count()))
111+
if torch.cuda.is_available()
112+
else None,
80113
gradient_clip_val=1.0,
81-
plugins="deepspeed" if torch.cuda.is_available() else None,
114+
accelerator="ddp" if torch.cuda.is_available() else None,
82115
precision=16 if torch.cuda.is_available() else 32,
83116
accumulate_grad_batches=2 if rnn_type == "Transformer" else 1,
84117
)

0 commit comments

Comments
 (0)