forked from flairNLP/fundus
-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_parser_test_files.py
139 lines (114 loc) · 5.28 KB
/
generate_parser_test_files.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import subprocess
from argparse import ArgumentParser, Namespace
from logging import WARN
from typing import List, Optional
from tqdm import tqdm
from fundus import Crawler, PublisherCollection
from fundus.logging import basic_logger
from fundus.publishers.base_objects import PublisherEnum
from fundus.scraping.article import Article
from fundus.scraping.filter import RequiresAll
from fundus.scraping.html import WebSource
from fundus.scraping.scraper import BaseScraper, WebScraper
from tests.test_parser import attributes_required_to_cover
from tests.utility import HTMLTestFile, get_test_case_json, load_html_test_file_mapping
def get_test_article(enum: PublisherEnum, url: Optional[str] = None) -> Optional[Article]:
if url is not None:
source = WebSource([url], publisher=enum.publisher_name)
scraper = BaseScraper(source, parser_mapping={enum.publisher_name: enum.parser})
return next(scraper.scrape(error_handling="suppress", extraction_filter=RequiresAll()), None)
crawler = Crawler(enum)
return next(crawler.crawl(max_articles=1, error_handling="suppress", only_complete=True), None)
def parse_arguments() -> Namespace:
parser = ArgumentParser(
prog="generate_parser_test_files",
description=(
"script to generate/update/overwrite test cases for parser unit tests. "
"by default this will only generate files which do not exist yet. "
"every changed/added file will automatically be added to git."
),
)
parser.add_argument(
"-a",
"--attributes",
nargs="+",
default=[],
help=(
"the attributes which should be used to create test cases. "
f"default: {', '.join(attributes_required_to_cover)}"
),
)
parser.add_argument("-p", dest="publishers", metavar="P", nargs="+", help="only consider given publishers")
parser.add_argument(
"-u",
"--urls",
metavar="U",
nargs="+",
help="use given URL instead of searching for an article. if set the urls will be mapped to the order of -p",
)
group = parser.add_mutually_exclusive_group()
group.add_argument(
"-o",
"--overwrite",
action="store_true",
help="overwrite existing html and json files for the latest parser version",
)
group.add_argument(
"-oj",
"--overwrite_json",
action="store_true",
help="parse from existing html and overwrite existing json content",
)
arguments = parser.parse_args()
if arguments.urls is not None:
if arguments.publishers is None:
parser.error("-u requires -p. you can only specify URLs when also specifying publishers.")
if len(arguments.urls) != len(arguments.publishers):
parser.error("-u and -p do not have the same argument length")
return arguments
def main() -> None:
arguments = parse_arguments()
# sort args.attributes for consistency
arguments.attributes = list(sorted(arguments.attributes)) or attributes_required_to_cover
basic_logger.setLevel(WARN)
publishers: List[PublisherEnum] = (
list(PublisherCollection)
if arguments.publishers is None
else [PublisherCollection[pub] for pub in arguments.publishers]
)
urls = arguments.urls if arguments.urls is not None else [None] * len(publishers)
with tqdm(total=len(publishers)) as bar:
for url, publisher in zip(urls, publishers):
bar.set_description(desc=publisher.name, refresh=True)
# load json
test_data_file = get_test_case_json(publisher)
test_data = content if (content := test_data_file.load()) and not arguments.overwrite_json else {}
# load html
html_mapping = load_html_test_file_mapping(publisher) if not arguments.overwrite else {}
if arguments.overwrite or not html_mapping.get(publisher.parser.latest_version):
if not (article := get_test_article(publisher, url)):
basic_logger.error(f"Couldn't get article for {publisher.name}. Skipping")
continue
html = HTMLTestFile(
url=article.html.responded_url,
content=article.html.content,
crawl_date=article.html.crawl_date,
publisher=publisher,
)
html.write()
subprocess.call(["git", "add", html.path], stdout=subprocess.PIPE)
html_mapping[publisher.parser.latest_version] = html
test_data[publisher.parser.latest_version.__name__] = {}
for html in html_mapping.values():
versioned_parser = html.publisher.parser(html.crawl_date)
extraction = versioned_parser.parse(html.content)
new = {attr: value for attr, value in extraction.items() if attr in arguments.attributes}
if not (entry := test_data.get(type(versioned_parser).__name__)):
test_data[type(versioned_parser).__name__] = new
else:
entry.update(new)
test_data_file.write(test_data)
bar.update()
subprocess.call(["git", "add", test_data_file.path], stdout=subprocess.PIPE)
if __name__ == "__main__":
main()