Skip to content
This repository was archived by the owner on Jul 9, 2022. It is now read-only.

Commit 1825542

Browse files
committed
update readme
1 parent 697d672 commit 1825542

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

README.md

+17-6
Original file line numberDiff line numberDiff line change
@@ -33,20 +33,31 @@ model = PQRNN(
3333
nhead=2, # used when rnn_type == "Transformer"
3434
)
3535

36+
# Or load the model from your checkpoint
37+
# model = PQRNN.load_from_checkpoint(checkpoint_path="example.ckpt")
3638

37-
# Text data has to been pre-processed with DummyDataset
38-
dataset = {
39-
"train": train[["text", "label"]].to_dict("records"),
40-
"test": test[["text", "label"]].to_dict("records"),
41-
}
39+
# Text data has to be pre-processed with DummyDataset
4240
dataset = DummyDataset(
43-
dataset["train"],
41+
df[["text", "label"]].to_dict("records"),
42+
has_label=True,
4443
feature_size=128 * 2,
4544
add_eos_tag=True,
4645
add_bos_tag=True,
4746
max_seq_len=512,
4847
label2index={"pos": 1, "neg": 0},
4948
)
49+
50+
# Explicit train/val loop
51+
# Add model.eval() when necessary
52+
dataloader = create_dataloaders(dataset)
53+
for batch in dataloader:
54+
# labels could be an empty tensor if has_label is False when creating the dataset.
55+
# To change what are included in a batch, feel free to change the collate_fn function
56+
# in dataset.py
57+
projections, lengths, labels = batch
58+
logits = model.forward(projections)
59+
60+
# do your magic
5061
```
5162

5263
## CLI Usage

0 commit comments

Comments
 (0)