Skip to content

Commit

Permalink
Merge pull request #2099 from RTXteam/issue2022
Browse files Browse the repository at this point in the history
Returning Lookup results also when request with knowledge_type=Inferred is called
  • Loading branch information
kvnthomas98 authored Aug 15, 2023
2 parents af5482d + 164bfed commit 88e0261
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 12 deletions.
26 changes: 20 additions & 6 deletions code/ARAX/ARAXQuery/ARAX_expander.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,10 @@ def apply(self, response, input_parameters, mode: str = "ARAX"):
infer_input_parameters = {"action": "drug_treatment_graph_expansion",'node_curie': object_curie, 'qedge_id': inferred_qedge_key}
inferer = ARAXInfer()
infer_response = inferer.apply(response, infer_input_parameters)
return infer_response
# return infer_response
response = infer_response
overarching_kg = eu.convert_standard_kg_to_qg_organized_kg(message.knowledge_graph)

elif set(['biolink:regulates']).intersection(set(qedge.predicates)): # Figure out if this is a "regulates" query, then use call XCRG models
# Call XCRG models and simply return whatever it returns
# Get the subject and object of this edge
Expand Down Expand Up @@ -425,7 +428,8 @@ def apply(self, response, input_parameters, mode: str = "ARAX"):
infer_input_parameters = {"action": "chemical_gene_regulation_graph_expansion", 'object_qnode_id' : qedge.object, 'object_curie': object_curie, 'qedge_id': inferred_qedge_key, 'regulation_type': regulation_type}
inferer = ARAXInfer()
infer_response = inferer.apply(response, infer_input_parameters)
return infer_response
response = infer_response
overarching_kg = eu.convert_standard_kg_to_qg_organized_kg(message.knowledge_graph)
else:
log.info(f"Qedge {inferred_qedge_key} has knowledge_type == inferred, but the query is not "
f"DTD-related (e.g., 'biolink:ameliorates', 'biolink:treats') or CRG-related ('biolink:regulates') according to the specified predicate. Will answer using the normal 'fill' strategy (not creative mode).")
Expand All @@ -434,10 +438,12 @@ def apply(self, response, input_parameters, mode: str = "ARAX"):
f"the qedges has knowledge_type == inferred. Will answer using the normal 'fill' strategy "
f"(not creative mode).")


# Expand any specified edges
if qedge_keys_to_expand:
query_sub_graph = self._extract_query_subgraph(qedge_keys_to_expand, query_graph, log)
if inferred_qedge_keys and len(query_graph.edges) == 1:
for edge in query_sub_graph.edges.keys():
query_sub_graph.edges[edge].knowledge_type = 'lookup'
if log.status != 'OK':
return response
log.debug(f"Query graph for this Expand() call is: {query_sub_graph.to_dict()}")
Expand Down Expand Up @@ -473,7 +479,9 @@ def apply(self, response, input_parameters, mode: str = "ARAX"):

# Create a query graph for this edge (that uses curies found in prior steps)
one_hop_qg = self._get_query_graph_for_edge(qedge_key, query_graph, overarching_kg, log)

if inferred_qedge_keys and len(query_graph.edges) == 1:
for edge in one_hop_qg.edges.keys():
one_hop_qg.edges[edge].knowledge_type = 'lookup'
# Figure out the prune threshold (use what user provided or otherwise do something intelligent)
if parameters.get("prune_threshold"):
pre_prune_threshold = parameters["prune_threshold"]
Expand All @@ -486,7 +494,10 @@ def apply(self, response, input_parameters, mode: str = "ARAX"):
for qnode_key in fulfilled_qnode_keys:
num_kg_nodes = len(overarching_kg.nodes_by_qg_id[qnode_key])
if num_kg_nodes > pre_prune_threshold:
overarching_kg = self._prune_kg(qnode_key, pre_prune_threshold, overarching_kg, query_graph, log)
if inferred_qedge_keys and len(inferred_qedge_keys) == 1:
overarching_kg = self._prune_kg(qnode_key, pre_prune_threshold, overarching_kg, message.query_graph, log)
else:
overarching_kg = self._prune_kg(qnode_key, pre_prune_threshold, overarching_kg, query_graph, log)
# Re-formulate the QG for this edge now that the KG has been slimmed down
one_hop_qg = self._get_query_graph_for_edge(qedge_key, query_graph, overarching_kg, log)
if log.status != 'OK':
Expand Down Expand Up @@ -650,7 +661,10 @@ def apply(self, response, input_parameters, mode: str = "ARAX"):
self._apply_any_kryptonite_edges(overarching_kg, message.query_graph,
message.encountered_kryptonite_edges_info, response)
# Remove any paths that are now dead-ends
overarching_kg = self._remove_dead_end_paths(query_graph, overarching_kg, response)
if inferred_qedge_keys and len(inferred_qedge_keys) == 1:
overarching_kg = self._remove_dead_end_paths(message.query_graph, overarching_kg, response)
else:
overarching_kg = self._remove_dead_end_paths(query_graph, overarching_kg, response)
if response.status != 'OK':
return response

Expand Down
21 changes: 21 additions & 0 deletions code/ARAX/ARAXQuery/ARAX_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,8 +625,29 @@ def aggregate_scores_dmk(self, response):
#print(float(len(ranks_list)))
result_scores = sum(ranks_list)/float(len(ranks_list))
#print(result_scores)

# Replace Inferred Results Score with Probability score calculated by xDTD model
inferred_qedge_keys = [qedge_key for qedge_key, qedge in message.query_graph.edges.items()
if qedge.knowledge_type == "inferred"]
for result, score in zip(results, result_scores):
result.analyses[0].score = score # For now we only ever have one Analysis per Result
if inferred_qedge_keys:
inferred_qedge_key = inferred_qedge_keys[0]
edge_bindings = result.analyses[0].edge_bindings
inferred_edge_bindings = []
if edge_bindings:
inferred_edge_bindings = edge_bindings.get(inferred_qedge_key,[])
for edge_name in inferred_edge_bindings:
edge_id = edge_name.id
edge_attributes = message.knowledge_graph.edges[edge_id].attributes
if edge_attributes is not None:
for edge_attribute in edge_attributes:
if edge_attribute.original_attribute_name == 'probability_treats' and edge_attribute.value is not None:
result.analyses[0].score = float(edge_attribute.value)





# for result in message.results:
# self.result_confidence_maker(result)
Expand Down
18 changes: 12 additions & 6 deletions code/ARAX/test/test_ARAX_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,12 +1114,18 @@ def test_xdtd_expand():
nodes_by_qg_id, edges_by_qg_id, message = _run_query_and_do_standard_testing(json_query=query, return_message=True)
assert message.auxiliary_graphs
for edge in edges_by_qg_id["t_edge"].values():
assert edge.attributes
support_graph_attributes = [attribute for attribute in edge.attributes if attribute.attribute_type_id == "biolink:support_graphs"]
assert support_graph_attributes
assert len(support_graph_attributes) == 1
support_graph_attribute = support_graph_attributes[0]
assert support_graph_attribute.value[0] in message.auxiliary_graphs
inferred_edge = False
for source in edge.sources:
if source.resource_role == "primary_knowledge_source" and source.resource_id == "infores:arax":
inferred_edge = True
# Perform Tests only for inferred edges
if inferred_edge:
assert edge.attributes
support_graph_attributes = [attribute for attribute in edge.attributes if attribute.attribute_type_id == "biolink:support_graphs"]
assert support_graph_attributes
assert len(support_graph_attributes) == 1
support_graph_attribute = support_graph_attributes[0]
assert support_graph_attribute.value[0] in message.auxiliary_graphs


@pytest.mark.slow
Expand Down

0 comments on commit 88e0261

Please sign in to comment.