Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DB-Connection/Multi-Query Feature #45

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 49 additions & 25 deletions Carrot-Assistant/omop/OMOP_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,27 @@ class OMOPMatcher:
"""
This class retrieves matches from an OMOP database and returns the best
"""
_instance = None

@classmethod
def get_instance(cls, logger: Optional[Logger] = None):
"""
This method returns the singleton instance of the OMOPMatcher class
and creates it if it does not exist.

Parameters
----------
logger: Logger
A logger for logging runs of the tool

Returns
-------
OMOPMatcher
The singleton instance of the OMOPMatcher class
"""
if cls._instance is None:
cls._instance = cls(logger)
return cls._instance

def __init__(self, logger: Optional[Logger] = None):
# Connect to database
Expand All @@ -27,34 +48,37 @@ def __init__(self, logger: Optional[Logger] = None):
self.logger = logger
load_dotenv()

try:
self.logger.info(
"Initialize the PostgreSQL connection based on the environment variables"
)
DB_HOST = environ["DB_HOST"]
DB_USER = environ["DB_USER"]
DB_PASSWORD = quote_plus(environ["DB_PASSWORD"])
DB_NAME = environ["DB_NAME"]
DB_PORT = environ["DB_PORT"]
DB_SCHEMA = environ["DB_SCHEMA"]

connection_string = (
f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
)
engine = create_engine(connection_string)
logger.info(f"Connected to PostgreSQL database {DB_NAME} on {DB_HOST}")
if not hasattr(self, 'engine'):

try:
self.logger.info(
"Initialize the PostgreSQL connection based on the environment variables"
)
DB_HOST = environ["DB_HOST"]
DB_USER = environ["DB_USER"]
DB_PASSWORD = quote_plus(environ["DB_PASSWORD"])
DB_NAME = environ["DB_NAME"]
DB_PORT = environ["DB_PORT"]
DB_SCHEMA = environ["DB_SCHEMA"]

connection_string = (
f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
)
engine = create_engine(connection_string)
logger.info(f"Connected to PostgreSQL database {DB_NAME} on {DB_HOST}")

except Exception as e:
logger.error(f"Failed to connect to PostgreSQL: {e}")
raise ValueError(f"Failed to connect to PostgreSQL: {e}")
except Exception as e:
logger.error(f"Failed to connect to PostgreSQL: {e}")
raise ValueError(f"Failed to connect to PostgreSQL: {e}")

self.engine = engine
self.schema = DB_SCHEMA
self.engine = engine
self.schema = DB_SCHEMA

def close(self):
"""Close the engine connection."""
self.engine.dispose()
self.logger.info("PostgreSQL connection closed.")
if hasattr(self, 'engine'):
self.engine.dispose()
self.logger.info("PostgreSQL connection closed.")

def calculate_best_matches(
self,
Expand Down Expand Up @@ -193,7 +217,7 @@ def fetch_OMOP_concepts(
session = Session()
results = session.execute(query).fetchall()
results = pd.DataFrame(results)
session.close()

if not results.empty:
# Define a function to calculate similarity score using the provided logic
def calculate_similarity(row):
Expand Down Expand Up @@ -481,5 +505,5 @@ def run(opt: argparse.Namespace, search_term:str, logger: Logger):
max_separation_descendant,
max_separation_ancestor,
)
omop_matcher.close()

return res
203 changes: 122 additions & 81 deletions Carrot-Assistant/routers/pipeline_routes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from fastapi import APIRouter
from fastapi import APIRouter, Request
import asyncio
from collections.abc import AsyncGenerator
import json
Expand All @@ -9,6 +9,7 @@

import assistant
from omop import OMOP_match
from omop.OMOP_match import OMOPMatcher
from options.base_options import BaseOptions
from components.embeddings import Embeddings
from options.pipeline_options import PipelineOptions, parse_pipeline_args
Expand All @@ -18,6 +19,7 @@

logger = Logger().make_logger()


class PipelineRequest(BaseModel):
"""
This class takes the format of a request to the API
Expand All @@ -34,111 +36,151 @@ class PipelineRequest(BaseModel):
pipeline_options: PipelineOptions = Field(default_factory=PipelineOptions)




async def generate_events(request: PipelineRequest) -> AsyncGenerator[str]:
async def generate_events(
request: PipelineRequest, use_llm: bool, end_session: bool
) -> AsyncGenerator[str]:
"""
Generate LLM output and OMOP results for a list of informal names
Generate LLM output and OMOP results for a list of informal names.

Parameters
parameters
----------
request: PipelineRequest
The request containing the list of informal names.

Workflow
--------
For each informal name:
The first event is to Query the OMOP database for a match
The second event is to fetches relevant concepts from the OMOP database
Finally,The function yields results as they become available,
allowing for real-time streaming.

Conditions
----------
If the OMOP database returns a match, the LLM is not queried

If the OMOP database does not return a match,
the LLM is used to find the formal name and the OMOP database is
queried for the LLM output.

Finally, the function yields the results for real-time streaming.
use_llm: bool
A flag to determine whether to use LLM to find the formal name.

end_session: bool
A flag to determine whether to end the session.

Yields
------
str
JSON encoded strings of the event results. Two types are yielded:
1. "llm_output": The result from the language model processing.
2. "omop_output": The result from the OMOP database matching.
JSON encoded strings of the event results.
"""

informal_names = request.names
opt = BaseOptions()
opt.initialize()
parse_pipeline_args(opt, request.pipeline_options)
opt = opt.parse()

print("Received informal names:", informal_names)
print(f"use_llm flag is set to: {use_llm}")
print(f"end_session flag is set to: {end_session}")

# Query OMOP for each informal name

for informal_name in informal_names:
print(f"Querying OMOP for informal name: {informal_name}")
omop_output = OMOP_match.run(opt=opt, search_term=informal_name, logger=logger)

if omop_output and any(concept["CONCEPT"] for concept in omop_output):
print(f"OMOP match found for {informal_name}: {omop_output}")
output = {"event": "omop_output", "data": omop_output}
yield json.dumps(output)
continue

else:
print("No satisfactory OMOP results found for {informal_name}, using LLM...")

# Use LLM to find the formal name and query OMOP for the LLM output

llm_outputs = assistant.run(opt=opt, informal_names=informal_names, logger=logger)
for llm_output in llm_outputs:


print("LLM output for", llm_output["informal_name"], ":", llm_output["reply"])

print("Querying OMOP for LLM output:", llm_output["reply"])

output = {"event": "llm_output", "data": llm_output}
yield json.dumps(output)

# Simulate some delay before sending the next part
await asyncio.sleep(2)

omop_output = OMOP_match.run(
opt=opt, search_term=llm_output["reply"], logger=logger
)

print("OMOP output for", llm_output["reply"], ":", omop_output)

output = {"event": "omop_output", "data": omop_output}

# If the user chooses to end the session, close the database connection
if end_session:
print("Final API call. Closing the database connection....")
output = {"event": "session_ended", "message": "Session has ended."}
yield json.dumps(output)
OMOPMatcher.get_instance().close()
return

no_match_names = []

try:
if informal_names:

# Query OMOP for the informal names
if not use_llm:
for informal_name in informal_names:
print(f"Querying OMOP for informal name: {informal_name}")
omop_output = OMOP_match.run(
opt=opt, search_term=informal_name, logger=logger
)

# If a match is found, yield the OMOP output
if omop_output and any(
concept["CONCEPT"] for concept in omop_output
):
print(f"OMOP match found for {informal_name}: {omop_output}")
output = {"event": "omop_output", "data": omop_output}
yield json.dumps(output)

# If no match is found, yield a message and add the name to the no_match_names list
else:
print(f"No satisfactory OMOP results found for {informal_name}")
output = {
"event": "omop_output",
"data": omop_output,
"message": f"No match found in OMOP database for {informal_name}.",
}
yield json.dumps(output)
no_match_names.append(informal_name)
print(f"\nno_match_names: {no_match_names}\n")
else:
no_match_names = informal_names

# Use LLM to find the formal name and query OMOP for the LLM output
if no_match_names and use_llm:
llm_outputs = assistant.run(
opt=opt, informal_names=no_match_names, logger=logger
)

for llm_output in llm_outputs:
print(
"LLM output for",
llm_output["informal_name"],
":",
llm_output["reply"],
)

output = {"event": "llm_output", "data": llm_output}
yield json.dumps(output)

finally:

# Ensure database connection is closed at the end of processing
if not no_match_names:
print(
"no matches found. Closing the database connection..."
)
OMOPMatcher.get_instance().close()

else:
print("\nDatabase connection remains open.")


@router.post("/")
async def run_pipeline(request: PipelineRequest) -> EventSourceResponse:
async def run_pipeline(request: Request) -> EventSourceResponse:
"""
Call generate_events to run the pipeline
This function runs the pipeline for a list of informal names.

Parameters
----------
request: PipelineRequest
The request containing a list of informal names
request: Request
The request containing the list of informal names.

Workflow
--------
The function generates events for each informal name in the list.

use_llm: bool
A flag to determine whether to use LLM to find the formal name.

Returns
-------
EventSourceResponse
The response containing the events
The response containing the results of the pipeline.
"""
return EventSourceResponse(generate_events(request))
body = await request.json()
pipeline_request = PipelineRequest(**body)

use_llm = body.get("use_llm", False)
end_session = body.get("end_session", False)

print(
f"Running pipeline with use_llm: {use_llm} and end_session: {end_session}"
)
return EventSourceResponse(
generate_events(pipeline_request, use_llm, end_session)
)


@router.post("/db")
async def run_db(request: PipelineRequest) -> List[Dict[str,Any]]:
async def run_db(request: PipelineRequest) -> List[Dict[str, Any]]:
"""
Fetch OMOP concepts for a name

Expand Down Expand Up @@ -166,7 +208,8 @@ async def run_db(request: PipelineRequest) -> List[Dict[str,Any]]:
omop_outputs.append({"event": "omop_output", "content": omop_output})

return omop_outputs



@router.post("/vector_search")
async def run_vector_search(request: PipelineRequest):
"""
Expand All @@ -187,12 +230,10 @@ async def run_vector_search(request: PipelineRequest):
"""
search_terms = request.names
embeddings = Embeddings(
embeddings_path=request.pipeline_options.embeddings_path,
force_rebuild=request.pipeline_options.force_rebuild,
embed_vocab=request.pipeline_options.embed_vocab,
model_name=request.pipeline_options.embedding_model,
search_kwargs=request.pipeline_options.embedding_search_kwargs,
)
return {'event': 'vector_search_output', 'content': embeddings.search(search_terms)}


embeddings_path=request.pipeline_options.embeddings_path,
force_rebuild=request.pipeline_options.force_rebuild,
embed_vocab=request.pipeline_options.embed_vocab,
model_name=request.pipeline_options.embedding_model,
search_kwargs=request.pipeline_options.embedding_search_kwargs,
)
return {"event": "vector_search_output", "content": embeddings.search(search_terms)}
Loading