-
Notifications
You must be signed in to change notification settings - Fork 47
/
search_server.py
258 lines (217 loc) · 9.03 KB
/
search_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
"""
A web search server for ParlAI, including Blenderbot2.
See README.md
"""
import html
import http.server
import json
import re
from typing import *
import urllib.parse
import bs4
import chardet
import fire
import html2text
import googlesearch
import parlai.agents.rag.retrieve_api
import rich
import rich.markup
import requests
print = rich.print
_DEFAULT_HOST = "0.0.0.0"
_DEFAULT_PORT = 8080
_DELAY_SEARCH = 1.0 # Making this too low will get you IP banned
_STYLE_GOOD = "[green]"
_STYLE_SKIP = ""
_CLOSE_STYLE_GOOD = "[/]" if _STYLE_GOOD else ""
_CLOSE_STYLE_SKIP = "[/]" if _STYLE_SKIP else ""
_REQUESTS_GET_TIMEOUT = 5
def _parse_host(host: str) -> Tuple[str, int]:
""" Parse the host string.
Should be in the format HOSTNAME:PORT.
Example: 0.0.0.0:8080
"""
splitted = host.split(":")
hostname = splitted[0]
port = splitted[1] if len(splitted) > 1 else _DEFAULT_PORT
return hostname, int(port)
def _get_and_parse(url: str) -> Dict[str, str]:
""" Download a webpage and parse it. """
try:
resp = requests.get(url, timeout=_REQUESTS_GET_TIMEOUT)
except requests.exceptions.RequestException as e:
print(f"[!] {e} for url {url}")
return None
else:
resp.encoding = resp.apparent_encoding
page = resp.text
###########################################################################
# Prepare the title
###########################################################################
output_dict = dict(title="", content="", url=url)
soup = bs4.BeautifulSoup(page, features="lxml")
pre_rendered = soup.find("title")
output_dict["title"] = (
html.unescape(pre_rendered.renderContents().decode()) if pre_rendered else ""
)
output_dict["title"] = (
output_dict["title"].replace("\n", "").replace("\r", "")
)
###########################################################################
# Prepare the content
###########################################################################
text_maker = html2text.HTML2Text()
text_maker.ignore_links = True
text_maker.ignore_tables = True
text_maker.ignore_images = True
text_maker.ignore_emphasis = True
text_maker.single_line = True
output_dict["content"] = html.unescape(text_maker.handle(page).strip())
return output_dict
class SearchABC(http.server.BaseHTTPRequestHandler):
def do_POST(self):
""" Handle POST requests from the client. (All requests are POST) """
#######################################################################
# Prepare and Parse
#######################################################################
content_length = int(self.headers["Content-Length"])
post_data = self.rfile.read(content_length)
# Figure out the encoding
if "charset=" in self.headers["Content-Type"]:
charset = re.match(r".*charset=([\w_\-]+)\b.*", self.headers["Content-Type"]).group(1)
else:
detector = chardet.UniversalDetector()
detector.feed(post_data)
detector.close()
charset = detector.result["encoding"]
post_data = post_data.decode(charset)
parsed = urllib.parse.parse_qs(post_data)
for v in parsed.values():
assert len(v) == 1, len(v)
parsed = {k: v[0] for k, v in parsed.items()}
#######################################################################
# Search, get the pages and parse the content of the pages
#######################################################################
print(f"\n[bold]Received query:[/] {parsed}")
n = int(parsed["n"])
q = parsed["q"]
# Over query a little bit in case we find useless URLs
content = []
dupe_detection_set = set()
# Search until we have n valid entries
for url in self.search(q=q, n=n):
if len(content) >= n:
break
# Get the content of the pages and parse it
maybe_content = _get_and_parse(url)
# Check that getting the content didn't fail
reason_empty_response = maybe_content is None
if not reason_empty_response:
reason_content_empty = (
maybe_content["content"] is None
or len(maybe_content["content"]) == 0
)
reason_already_seen_content = (
maybe_content["content"] in dupe_detection_set
)
else:
reason_content_empty = False
reason_already_seen_content = False
reasons = dict(
reason_empty_response=reason_empty_response,
reason_content_empty=reason_content_empty,
reason_already_seen_content=reason_already_seen_content,
)
if not any(reasons.values()):
###############################################################
# Log the entry
###############################################################
title_str = (
f"`{rich.markup.escape(maybe_content['title'])}`"
if maybe_content["title"]
else "<No Title>"
)
print(
f" {_STYLE_GOOD}>{_CLOSE_STYLE_GOOD} Result: Title: {title_str}\n"
f" {rich.markup.escape(maybe_content['url'])}"
# f"Content: {len(maybe_content['content'])}",
)
dupe_detection_set.add(maybe_content["content"])
content.append(maybe_content)
if len(content) >= n:
break
else:
###############################################################
# Log why it failed
###############################################################
reason_string = ", ".join(
{
reason_name
for reason_name, whether_failed in reasons.items()
if whether_failed
}
)
print(f" {_STYLE_SKIP}x{_CLOSE_STYLE_SKIP} Excluding an URL because `{_STYLE_SKIP}{reason_string}{_CLOSE_STYLE_SKIP}`:\n"
f" {url}")
###############################################################
# Prepare the answer and send it
###############################################################
content = content[:n]
output = json.dumps(dict(response=content)).encode("utf-8")
self.send_response(200)
self.send_header("Content-type", "text/html")
self.send_header("Content-Length", len(output))
self.end_headers()
self.wfile.write(output)
def search(self, q: str, n: int) -> Generator[str, None, None]:
return NotImplemented(
"Search is an abstract base class, not meant to be directly "
"instantiated. You should instantiate a derived class like "
"GoogleSearch."
)
class GoogleSearchServer(SearchABC):
def search(self, q: str, n: int) -> Generator[str, None, None]:
return googlesearch.search(q, num=n, stop=None, pause=_DELAY_SEARCH)
class Application:
def serve(
self, host: str = _DEFAULT_HOST) -> NoReturn:
""" Main entry point: Start the server.
Arguments:
host (str):
HOSTNAME:PORT of the server. HOSTNAME can be an IP.
Most of the time should be 0.0.0.0. Port 8080 doesn't work on colab.
Other ports also probably don't work on colab, test it out.
"""
hostname, port = _parse_host(host)
host = f"{hostname}:{port}"
with http.server.ThreadingHTTPServer(
(hostname, int(port)), GoogleSearchServer
) as server:
print("Serving forever.")
print(f"Host: {host}")
server.serve_forever()
def test_parser(self, url: str) -> None:
""" Test the webpage getter and parser.
Will try to download the page, then parse it, then will display the result.
"""
print(_get_and_parse(url))
def test_server(self, query: str, n: int, host : str = _DEFAULT_HOST) -> None:
""" Creates a thin fake client to test a server that is already up.
Expects a server to have already been started with `python search_server.py serve [options]`.
Creates a retriever client the same way ParlAi client does it for its chat bot, then
sends a query to the server.
"""
host, port = _parse_host(host)
print(f"Query: `{query}`")
print(f"n: {n}")
retriever = parlai.agents.rag.retrieve_api.SearchEngineRetriever(
dict(
search_server=f"{host}:{port}",
skip_retrieval_token=False,
)
)
print("Retrieving one.")
print(retriever.retrieve([query], n))
print("Done.")
if __name__ == "__main__":
fire.Fire(Application)