This repository was archived by the owner on Jul 9, 2022. It is now read-only.
Commit 1825542 1 parent 697d672 commit 1825542 Copy full SHA for 1825542
File tree 1 file changed +17
-6
lines changed
1 file changed +17
-6
lines changed Original file line number Diff line number Diff line change @@ -33,20 +33,31 @@ model = PQRNN(
33
33
nhead = 2 , # used when rnn_type == "Transformer"
34
34
)
35
35
36
+ # Or load the model from your checkpoint
37
+ # model = PQRNN.load_from_checkpoint(checkpoint_path="example.ckpt")
36
38
37
- # Text data has to been pre-processed with DummyDataset
38
- dataset = {
39
- " train" : train[[" text" , " label" ]].to_dict(" records" ),
40
- " test" : test[[" text" , " label" ]].to_dict(" records" ),
41
- }
39
+ # Text data has to be pre-processed with DummyDataset
42
40
dataset = DummyDataset(
43
- dataset[" train" ],
41
+ df[[" text" , " label" ]].to_dict(" records" ),
42
+ has_label = True ,
44
43
feature_size = 128 * 2 ,
45
44
add_eos_tag = True ,
46
45
add_bos_tag = True ,
47
46
max_seq_len = 512 ,
48
47
label2index = {" pos" : 1 , " neg" : 0 },
49
48
)
49
+
50
+ # Explicit train/val loop
51
+ # Add model.eval() when necessary
52
+ dataloader = create_dataloaders(dataset)
53
+ for batch in dataloader:
54
+ # labels could be an empty tensor if has_label is False when creating the dataset.
55
+ # To change what are included in a batch, feel free to change the collate_fn function
56
+ # in dataset.py
57
+ projections, lengths, labels = batch
58
+ logits = model.forward(projections)
59
+
60
+ # do your magic
50
61
```
51
62
52
63
## CLI Usage
You can’t perform that action at this time.
0 commit comments