-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_markuplm.py
35 lines (26 loc) · 945 Bytes
/
test_markuplm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from markuplm_qa import MarkupLM, QADataset
import pandas as pd
import getopt
import sys
def main():
#Command to run the code
"""
python3 test_markuplm.py --device="cuda" --test_ds="dataset.json" --output_csv="test_set_results.csv"
"""
args = sys.argv[1:]
oplist, args = getopt.getopt(args, "", ["test_ds=","device=","output_csv="])
for op in oplist:
if op[0] == "--test_ds":
test_ds_path = op[1]
elif op[0] == "--device":
device = op[1]
elif op[0] == "--output_csv":
output_csv_path = op[1]
markup_lm = MarkupLM("microsoft/markuplm-base-finetuned-websrc",device)
dataset_dict = markup_lm.load_dataset(test_ds_path)
#Code for evaluating the model on dataset
answers_list = markup_lm.test(dataset_dict)
df = pd.DataFrame.from_records(answers_list)
df.to_csv(output_csv_path)
if __name__ == "__main__":
main()