-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
138 lines (119 loc) · 4.53 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import json
import os
from pathlib import Path
import requests
from pymed import PubMed
# Create a PubMed object that GraphQL can use to query
# Note that the parameters are not required but kindly requested by PubMed Central
# https://www.ncbi.nlm.nih.gov/pmc/tools/developers/
pubmed = PubMed(tool="PolyMind")
# Read config
script_dir = Path(os.path.abspath(__file__)).parent
conf_path = script_dir / "config.json"
with open(conf_path, "r") as config_file:
config = json.load(config_file)
max_results = config.get("max_results", 5)
ctx_alloc = config.get("ctx_alloc", 0.3)
def main(params, memory, infer, ip, Shared_vars):
# Definitions for API-based tokenization
API_ENDPOINT_URL = Shared_vars.API_ENDPOINT_URI
if Shared_vars.TABBY:
API_ENDPOINT_URL += "v1/completions"
else:
API_ENDPOINT_URL += "completion"
def tokenize(input):
payload = {
"add_bos_token": "true",
"encode_special_tokens": "true",
"decode_special_tokens": "true",
"text": input,
"content": input,
}
request = requests.post(
API_ENDPOINT_URL.replace("completions", "token/encode")
if Shared_vars.TABBY
else API_ENDPOINT_URL.replace("completion", "tokenize"),
headers={
"Accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Bearer {Shared_vars.API_KEY}",
},
json=payload,
timeout=360,
)
return (
request.json()["length"]
if Shared_vars.TABBY
else len(request.json()["tokens"])
)
# Create a GraphQL query in plain text
category = params.get("category")
keywords = params.get("keywords")
kw_chunked = keywords.split(",")
query = ""
if category.lower() == "therapy":
query += "(Therapy/Broad[filter])"
elif category.lower() == "diagnosis":
query += "(Diagnosis/Broad[filter])"
elif category.lower() == "etiology":
query += "(Etiology/Broad[filter])"
elif category.lower() == "prognosis":
query += "(Prognosis/Broad[filter])"
elif category.lower() == "clinical prediction guides":
query += "(Clinical Prediction Guides/Broad[filter])"
for chunk in kw_chunked:
if len(query) > 0:
query = f'{query} AND "{chunk.strip()}"[tw]'
else:
query = f'"{chunk.strip()}"[tw]'
query = f'{query} AND medline[sb] AND "has abstract"[filter]'
# Execute the query against the API
results = pubmed.query(query, max_results=max_results)
# Create message containing RAG content
message = ""
test_message = ""
for article in results:
text = ""
r = article.toDict()
if r.get("title"):
text = text + r.get("title") + "\n"
if r.get("pubmed_id"):
text = (
text
+ "URL: https://pubmed.ncbi.nlm.nih.gov/"
+ r.get("pubmed_id").split("\n")[0]
+ "/\n"
)
if r.get("publication_date"):
text = text + "Publication Date: " + str(r.get("publication_date")) + "\n"
if r.get("journal"):
text = text + "Journal: " + r.get("journal") + "\n"
if r.get("doi"):
text = text + "DOI: " + r.get("doi").split("\n")[0] + "\n"
if r.get("abstract"):
text = text + r.get("abstract") + "\n"
# if r.get("methods"):
# text = text + "Methods: " + r.get("methods") + "\n"
# if r.get("results"):
# text = text + "Results: " + r.get("results") + "\n"
# if r.get("conclusions"):
# text = text + "Conclusions: " + r.get("conclusions") + "\n"
# Add separator if this is the first result
if len(message) > 0:
test_message = message + "***\n"
test_message += text
# Prevent RAG content from taking up too much of the context
if ctx_alloc == -1:
message = test_message
elif tokenize(test_message) < (Shared_vars.config.ctxlen * ctx_alloc):
message = test_message
else:
break
# Handle unsuccessful search
if len(message) == 0:
print("No results from PubMed")
return "No search results found on PubMed, notify the user of this and respond based on your knowledge"
print(message)
return "<search_results>:\n" + message + "</search_results>"
if __name__ == "__main__":
main(params, memory, infer, ip, Shared_vars)