From e7da7d31c711df13f3edead143559912819ea79c Mon Sep 17 00:00:00 2001 From: Kevin Vizhalil Date: Tue, 15 Aug 2023 16:39:45 -0400 Subject: [PATCH 1/2] Merging Inferred and lookup results #2022 --- code/ARAX/ARAXQuery/ARAX_expander.py | 26 ++++++++++++++++++++------ code/ARAX/test/test_ARAX_expand.py | 18 ++++++++++++------ 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/code/ARAX/ARAXQuery/ARAX_expander.py b/code/ARAX/ARAXQuery/ARAX_expander.py index 2e1671063..86ee56b23 100644 --- a/code/ARAX/ARAXQuery/ARAX_expander.py +++ b/code/ARAX/ARAXQuery/ARAX_expander.py @@ -388,7 +388,10 @@ def apply(self, response, input_parameters, mode: str = "ARAX"): infer_input_parameters = {"action": "drug_treatment_graph_expansion",'node_curie': object_curie, 'qedge_id': inferred_qedge_key} inferer = ARAXInfer() infer_response = inferer.apply(response, infer_input_parameters) - return infer_response + # return infer_response + response = infer_response + overarching_kg = eu.convert_standard_kg_to_qg_organized_kg(message.knowledge_graph) + elif set(['biolink:regulates']).intersection(set(qedge.predicates)): # Figure out if this is a "regulates" query, then use call XCRG models # Call XCRG models and simply return whatever it returns # Get the subject and object of this edge @@ -425,7 +428,8 @@ def apply(self, response, input_parameters, mode: str = "ARAX"): infer_input_parameters = {"action": "chemical_gene_regulation_graph_expansion", 'object_qnode_id' : qedge.object, 'object_curie': object_curie, 'qedge_id': inferred_qedge_key, 'regulation_type': regulation_type} inferer = ARAXInfer() infer_response = inferer.apply(response, infer_input_parameters) - return infer_response + response = infer_response + overarching_kg = eu.convert_standard_kg_to_qg_organized_kg(message.knowledge_graph) else: log.info(f"Qedge {inferred_qedge_key} has knowledge_type == inferred, but the query is not " f"DTD-related (e.g., 'biolink:ameliorates', 'biolink:treats') or CRG-related ('biolink:regulates') according to the specified predicate. Will answer using the normal 'fill' strategy (not creative mode).") @@ -434,10 +438,12 @@ def apply(self, response, input_parameters, mode: str = "ARAX"): f"the qedges has knowledge_type == inferred. Will answer using the normal 'fill' strategy " f"(not creative mode).") - # Expand any specified edges if qedge_keys_to_expand: query_sub_graph = self._extract_query_subgraph(qedge_keys_to_expand, query_graph, log) + if inferred_qedge_keys and len(query_graph.edges) == 1: + for edge in query_sub_graph.edges.keys(): + query_sub_graph.edges[edge].knowledge_type = 'lookup' if log.status != 'OK': return response log.debug(f"Query graph for this Expand() call is: {query_sub_graph.to_dict()}") @@ -473,7 +479,9 @@ def apply(self, response, input_parameters, mode: str = "ARAX"): # Create a query graph for this edge (that uses curies found in prior steps) one_hop_qg = self._get_query_graph_for_edge(qedge_key, query_graph, overarching_kg, log) - + if inferred_qedge_keys and len(query_graph.edges) == 1: + for edge in one_hop_qg.edges.keys(): + one_hop_qg.edges[edge].knowledge_type = 'lookup' # Figure out the prune threshold (use what user provided or otherwise do something intelligent) if parameters.get("prune_threshold"): pre_prune_threshold = parameters["prune_threshold"] @@ -486,7 +494,10 @@ def apply(self, response, input_parameters, mode: str = "ARAX"): for qnode_key in fulfilled_qnode_keys: num_kg_nodes = len(overarching_kg.nodes_by_qg_id[qnode_key]) if num_kg_nodes > pre_prune_threshold: - overarching_kg = self._prune_kg(qnode_key, pre_prune_threshold, overarching_kg, query_graph, log) + if inferred_qedge_keys and len(inferred_qedge_keys) == 1: + overarching_kg = self._prune_kg(qnode_key, pre_prune_threshold, overarching_kg, message.query_graph, log) + else: + overarching_kg = self._prune_kg(qnode_key, pre_prune_threshold, overarching_kg, query_graph, log) # Re-formulate the QG for this edge now that the KG has been slimmed down one_hop_qg = self._get_query_graph_for_edge(qedge_key, query_graph, overarching_kg, log) if log.status != 'OK': @@ -650,7 +661,10 @@ def apply(self, response, input_parameters, mode: str = "ARAX"): self._apply_any_kryptonite_edges(overarching_kg, message.query_graph, message.encountered_kryptonite_edges_info, response) # Remove any paths that are now dead-ends - overarching_kg = self._remove_dead_end_paths(query_graph, overarching_kg, response) + if inferred_qedge_keys and len(inferred_qedge_keys) == 1: + overarching_kg = self._remove_dead_end_paths(message.query_graph, overarching_kg, response) + else: + overarching_kg = self._remove_dead_end_paths(query_graph, overarching_kg, response) if response.status != 'OK': return response diff --git a/code/ARAX/test/test_ARAX_expand.py b/code/ARAX/test/test_ARAX_expand.py index d7864bbe3..a3cbc835c 100644 --- a/code/ARAX/test/test_ARAX_expand.py +++ b/code/ARAX/test/test_ARAX_expand.py @@ -1114,12 +1114,18 @@ def test_xdtd_expand(): nodes_by_qg_id, edges_by_qg_id, message = _run_query_and_do_standard_testing(json_query=query, return_message=True) assert message.auxiliary_graphs for edge in edges_by_qg_id["t_edge"].values(): - assert edge.attributes - support_graph_attributes = [attribute for attribute in edge.attributes if attribute.attribute_type_id == "biolink:support_graphs"] - assert support_graph_attributes - assert len(support_graph_attributes) == 1 - support_graph_attribute = support_graph_attributes[0] - assert support_graph_attribute.value[0] in message.auxiliary_graphs + inferred_edge = False + for source in edge.sources: + if source.resource_role == "primary_knowledge_source" and source.resource_id == "infores:arax": + inferred_edge = True + # Perform Tests only for inferred edges + if inferred_edge: + assert edge.attributes + support_graph_attributes = [attribute for attribute in edge.attributes if attribute.attribute_type_id == "biolink:support_graphs"] + assert support_graph_attributes + assert len(support_graph_attributes) == 1 + support_graph_attribute = support_graph_attributes[0] + assert support_graph_attribute.value[0] in message.auxiliary_graphs @pytest.mark.slow From 164bfedad006faa237cfa9dec970db351894b1a9 Mon Sep 17 00:00:00 2001 From: Kevin Vizhalil Date: Tue, 15 Aug 2023 18:38:12 -0400 Subject: [PATCH 2/2] Inferred Results scores are the probabilities computed by xDTD model --- code/ARAX/ARAXQuery/ARAX_ranker.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/code/ARAX/ARAXQuery/ARAX_ranker.py b/code/ARAX/ARAXQuery/ARAX_ranker.py index 324e4d0a8..8db91d0a4 100644 --- a/code/ARAX/ARAXQuery/ARAX_ranker.py +++ b/code/ARAX/ARAXQuery/ARAX_ranker.py @@ -625,8 +625,29 @@ def aggregate_scores_dmk(self, response): #print(float(len(ranks_list))) result_scores = sum(ranks_list)/float(len(ranks_list)) #print(result_scores) + + # Replace Inferred Results Score with Probability score calculated by xDTD model + inferred_qedge_keys = [qedge_key for qedge_key, qedge in message.query_graph.edges.items() + if qedge.knowledge_type == "inferred"] for result, score in zip(results, result_scores): result.analyses[0].score = score # For now we only ever have one Analysis per Result + if inferred_qedge_keys: + inferred_qedge_key = inferred_qedge_keys[0] + edge_bindings = result.analyses[0].edge_bindings + inferred_edge_bindings = [] + if edge_bindings: + inferred_edge_bindings = edge_bindings.get(inferred_qedge_key,[]) + for edge_name in inferred_edge_bindings: + edge_id = edge_name.id + edge_attributes = message.knowledge_graph.edges[edge_id].attributes + if edge_attributes is not None: + for edge_attribute in edge_attributes: + if edge_attribute.original_attribute_name == 'probability_treats' and edge_attribute.value is not None: + result.analyses[0].score = float(edge_attribute.value) + + + + # for result in message.results: # self.result_confidence_maker(result)