Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Returning Lookup results also when Inferred is called #2099

Merged
merged 2 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions code/ARAX/ARAXQuery/ARAX_expander.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,10 @@ def apply(self, response, input_parameters, mode: str = "ARAX"):
infer_input_parameters = {"action": "drug_treatment_graph_expansion",'node_curie': object_curie, 'qedge_id': inferred_qedge_key}
inferer = ARAXInfer()
infer_response = inferer.apply(response, infer_input_parameters)
return infer_response
# return infer_response
response = infer_response
overarching_kg = eu.convert_standard_kg_to_qg_organized_kg(message.knowledge_graph)

elif set(['biolink:regulates']).intersection(set(qedge.predicates)): # Figure out if this is a "regulates" query, then use call XCRG models
# Call XCRG models and simply return whatever it returns
# Get the subject and object of this edge
Expand Down Expand Up @@ -425,7 +428,8 @@ def apply(self, response, input_parameters, mode: str = "ARAX"):
infer_input_parameters = {"action": "chemical_gene_regulation_graph_expansion", 'object_qnode_id' : qedge.object, 'object_curie': object_curie, 'qedge_id': inferred_qedge_key, 'regulation_type': regulation_type}
inferer = ARAXInfer()
infer_response = inferer.apply(response, infer_input_parameters)
return infer_response
response = infer_response
overarching_kg = eu.convert_standard_kg_to_qg_organized_kg(message.knowledge_graph)
else:
log.info(f"Qedge {inferred_qedge_key} has knowledge_type == inferred, but the query is not "
f"DTD-related (e.g., 'biolink:ameliorates', 'biolink:treats') or CRG-related ('biolink:regulates') according to the specified predicate. Will answer using the normal 'fill' strategy (not creative mode).")
Expand All @@ -434,10 +438,12 @@ def apply(self, response, input_parameters, mode: str = "ARAX"):
f"the qedges has knowledge_type == inferred. Will answer using the normal 'fill' strategy "
f"(not creative mode).")


# Expand any specified edges
if qedge_keys_to_expand:
query_sub_graph = self._extract_query_subgraph(qedge_keys_to_expand, query_graph, log)
if inferred_qedge_keys and len(query_graph.edges) == 1:
for edge in query_sub_graph.edges.keys():
query_sub_graph.edges[edge].knowledge_type = 'lookup'
if log.status != 'OK':
return response
log.debug(f"Query graph for this Expand() call is: {query_sub_graph.to_dict()}")
Expand Down Expand Up @@ -473,7 +479,9 @@ def apply(self, response, input_parameters, mode: str = "ARAX"):

# Create a query graph for this edge (that uses curies found in prior steps)
one_hop_qg = self._get_query_graph_for_edge(qedge_key, query_graph, overarching_kg, log)

if inferred_qedge_keys and len(query_graph.edges) == 1:
for edge in one_hop_qg.edges.keys():
one_hop_qg.edges[edge].knowledge_type = 'lookup'
# Figure out the prune threshold (use what user provided or otherwise do something intelligent)
if parameters.get("prune_threshold"):
pre_prune_threshold = parameters["prune_threshold"]
Expand All @@ -486,7 +494,10 @@ def apply(self, response, input_parameters, mode: str = "ARAX"):
for qnode_key in fulfilled_qnode_keys:
num_kg_nodes = len(overarching_kg.nodes_by_qg_id[qnode_key])
if num_kg_nodes > pre_prune_threshold:
overarching_kg = self._prune_kg(qnode_key, pre_prune_threshold, overarching_kg, query_graph, log)
if inferred_qedge_keys and len(inferred_qedge_keys) == 1:
overarching_kg = self._prune_kg(qnode_key, pre_prune_threshold, overarching_kg, message.query_graph, log)
else:
overarching_kg = self._prune_kg(qnode_key, pre_prune_threshold, overarching_kg, query_graph, log)
# Re-formulate the QG for this edge now that the KG has been slimmed down
one_hop_qg = self._get_query_graph_for_edge(qedge_key, query_graph, overarching_kg, log)
if log.status != 'OK':
Expand Down Expand Up @@ -650,7 +661,10 @@ def apply(self, response, input_parameters, mode: str = "ARAX"):
self._apply_any_kryptonite_edges(overarching_kg, message.query_graph,
message.encountered_kryptonite_edges_info, response)
# Remove any paths that are now dead-ends
overarching_kg = self._remove_dead_end_paths(query_graph, overarching_kg, response)
if inferred_qedge_keys and len(inferred_qedge_keys) == 1:
overarching_kg = self._remove_dead_end_paths(message.query_graph, overarching_kg, response)
else:
overarching_kg = self._remove_dead_end_paths(query_graph, overarching_kg, response)
if response.status != 'OK':
return response

Expand Down
21 changes: 21 additions & 0 deletions code/ARAX/ARAXQuery/ARAX_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,8 +625,29 @@ def aggregate_scores_dmk(self, response):
#print(float(len(ranks_list)))
result_scores = sum(ranks_list)/float(len(ranks_list))
#print(result_scores)

# Replace Inferred Results Score with Probability score calculated by xDTD model
inferred_qedge_keys = [qedge_key for qedge_key, qedge in message.query_graph.edges.items()
if qedge.knowledge_type == "inferred"]
for result, score in zip(results, result_scores):
result.analyses[0].score = score # For now we only ever have one Analysis per Result
if inferred_qedge_keys:
inferred_qedge_key = inferred_qedge_keys[0]
edge_bindings = result.analyses[0].edge_bindings
inferred_edge_bindings = []
if edge_bindings:
inferred_edge_bindings = edge_bindings.get(inferred_qedge_key,[])
for edge_name in inferred_edge_bindings:
edge_id = edge_name.id
edge_attributes = message.knowledge_graph.edges[edge_id].attributes
if edge_attributes is not None:
for edge_attribute in edge_attributes:
if edge_attribute.original_attribute_name == 'probability_treats' and edge_attribute.value is not None:
result.analyses[0].score = float(edge_attribute.value)





# for result in message.results:
# self.result_confidence_maker(result)
Expand Down
18 changes: 12 additions & 6 deletions code/ARAX/test/test_ARAX_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,12 +1114,18 @@ def test_xdtd_expand():
nodes_by_qg_id, edges_by_qg_id, message = _run_query_and_do_standard_testing(json_query=query, return_message=True)
assert message.auxiliary_graphs
for edge in edges_by_qg_id["t_edge"].values():
assert edge.attributes
support_graph_attributes = [attribute for attribute in edge.attributes if attribute.attribute_type_id == "biolink:support_graphs"]
assert support_graph_attributes
assert len(support_graph_attributes) == 1
support_graph_attribute = support_graph_attributes[0]
assert support_graph_attribute.value[0] in message.auxiliary_graphs
inferred_edge = False
for source in edge.sources:
if source.resource_role == "primary_knowledge_source" and source.resource_id == "infores:arax":
inferred_edge = True
# Perform Tests only for inferred edges
if inferred_edge:
assert edge.attributes
support_graph_attributes = [attribute for attribute in edge.attributes if attribute.attribute_type_id == "biolink:support_graphs"]
assert support_graph_attributes
assert len(support_graph_attributes) == 1
support_graph_attribute = support_graph_attributes[0]
assert support_graph_attribute.value[0] in message.auxiliary_graphs


@pytest.mark.slow
Expand Down
Loading