Skip to content

Commit

Permalink
Merge pull request #98 from pyscal/allow_repeat_nodes
Browse files Browse the repository at this point in the history
Allow repeat nodes
  • Loading branch information
srmnitc authored Apr 29, 2024
2 parents f062811 + f8375e2 commit 0bc3f37
Show file tree
Hide file tree
Showing 10 changed files with 3,157 additions and 2,699 deletions.
12 changes: 3 additions & 9 deletions atomrdf/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,6 @@ def auto_query(
self,
source,
destination,
condition=None,
return_query=False,
enforce_types=None,
return_df=True,
Expand All @@ -968,8 +967,6 @@ def auto_query(
The source of the query.
destination : OntoTerm
The destination of the query.
condition :str, optional
The condition to be applied in the query. Defaults to None.
return_query : bool, optional
If True, returns the generated query instead of executing it. Defaults to False.
enforce_types : bool, optional
Expand All @@ -986,7 +983,7 @@ def auto_query(
if enforce_types is None:
for val in [True, False]:
query = self.ontology.create_query(
source, destination, condition=condition, enforce_types=val
source, destination, enforce_types=val
)
if return_query:
return query
Expand All @@ -995,7 +992,7 @@ def auto_query(
return res
else:
query = self.ontology.create_query(
source, destination, condition=condition, enforce_types=enforce_types
source, destination, enforce_types=enforce_types
)
if return_query:
return query
Expand All @@ -1007,7 +1004,7 @@ def auto_query(
# Methods to interact with sample
#################################
def query_sample(
self, destination, condition=None, return_query=False, enforce_types=None
self, destination, return_query=False, enforce_types=None
):
"""
Query the knowledge graph for atomic scale samples.
Expand All @@ -1016,8 +1013,6 @@ def query_sample(
----------
destination : OntoTerm
The destination of the query.
condition : str, optional
The condition to be applied in the query. Defaults to None.
return_query : bool, optional
If True, returns the generated query instead of executing it. Defaults to False.
enforce_types : bool, optional
Expand All @@ -1032,7 +1027,6 @@ def query_sample(
return self.auto_query(
self.ontology.terms.cmso.AtomicScaleSample,
destination,
condition=condition,
return_query=return_query,
enforce_types=enforce_types,
)
Expand Down
247 changes: 113 additions & 134 deletions atomrdf/network/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,35 +82,6 @@ def extra_namespaces(self):
def __radd__(self, ontonetwork):
return self.__add__(ontonetwork)

def get_shortest_path(self, source, target, triples=False):
"""
Compute the shortest path between two nodes in the graph.
Parameters:
-----------
source : node
The starting node for the path.
target : node
The target node for the path.
triples : bool, optional
If True, returns the path as a list of triples. Each triple consists of three consecutive nodes in the path.
If False, returns the path as a list of nodes.
Returns:
--------
path : list
The shortest path between the source and target nodes. If `triples` is True, the path is returned as a list of triples.
If `triples` is False, the path is returned as a list of nodes.
"""
path = nx.shortest_path(self.g, source=source, target=target)
if triples:
triple_list = []
for x in range(len(path) // 2):
triple_list.append(path[2 * x : 2 * x + 3])
return triple_list
return path

def _add_class_nodes(self):
for key, val in self.onto.attributes["class"].items():
self.g.add_node(val.name, node_type="class")
Expand Down Expand Up @@ -304,6 +275,58 @@ def draw(self,
dot.edge(_replace_name(edge[0]), _replace_name(edge[1]))
return dot

def _get_shortest_path(self, source, target):
#this function will be modified to take OntoTerms direcl as input; and use their names.
path = nx.shortest_path(self.g, source=source.query_name, target=target.query_name)
#replace the start and end with thier corresponding variable names
path[0] = source.variable_name
path[-1] = target.variable_name
return path

def get_shortest_path(self, source, target, triples=False):
"""
Compute the shortest path between two nodes in the graph.
Parameters:
-----------
source : node
The starting node for the path.
target : node
The target node for the path.
triples : bool, optional
If True, returns the path as a list of triples. Each triple consists of three consecutive nodes in the path.
If False, returns the path as a list of nodes.
Returns:
--------
path : list
The shortest path between the source and target nodes. If `triples` is True, the path is returned as a list of triples.
If `triples` is False, the path is returned as a list of nodes.
"""
#this function should also check for stepped queries
path = []
if len(target._parents) > 0:
#this needs a stepped query
complete_list = [source, *target._parents, target]
#get path for first two terms
path = self._get_shortest_path(complete_list[0], complete_list[1])
for x in range(2, len(complete_list)):
temp_source = complete_list[x-1]
temp_dest = complete_list[x]
temp_path = self._get_shortest_path(temp_source, temp_dest)
path.extend(temp_path[1:])
else:
path = self._get_shortest_path(source, target)

if triples:
triple_list = []
for x in range(len(path) // 2):
triple_list.append(path[2 * x : 2 * x + 3])
return triple_list

return path

def get_path_from_sample(self, target):
"""
Get the shortest path from the 'cmso:ComputationalSample' node to the target node.
Expand All @@ -318,46 +341,13 @@ def get_path_from_sample(self, target):
list
A list of triples representing the shortest path from 'cmso:ComputationalSample' to the target node.
"""
#get the path
path = self.get_shortest_path(
source="cmso:ComputationalSample", target=target, triples=True
source=self.terms.cmso.AtomicScaleSample, target=target, triples=True
)
return path

def create_stepped_query(self, source, destinations):
"""
Create a stepped query by creating triples in a stepped manner.
Parameters
----------
source : str
The source node for the query.
destinations : list
A list of destination nodes for the query.
Returns
-------
list
A list of triples representing the stepped query path.
Raises
------
ValueError
If there are less than 3 nodes in the `complete_list`.
"""
complete_list = [source, *destinations]
if len(complete_list) < 3:
raise ValueError("Need at least 3 nodes to create a stepped query")
triples = []
for x in range(1, len(complete_list)):
temp_source = complete_list[x-1]
temp_dest = complete_list[x]
path = self.get_shortest_path(temp_source, temp_dest, triples=True)
for p in path:
triples.append(p)
return triples

def create_query(self, source, destinations, condition=None, enforce_types=True):
def create_query(self, source, destinations, enforce_types=True):
"""
Create a SPARQL query string based on the given source, destinations, condition, and enforce_types.
Expand All @@ -367,8 +357,6 @@ def create_query(self, source, destinations, condition=None, enforce_types=True)
The source node from which the query starts.
destinations : list or Node
The destination node(s) to which the query should reach. If a single node is provided, it will be converted to a list.
condition : Condition, optional
The condition to be applied in the query. Defaults to None.
enforce_types : bool, optional
Whether to enforce the types of the source and destination nodes in the query. Defaults to True.
Expand All @@ -382,43 +370,19 @@ def create_query(self, source, destinations, condition=None, enforce_types=True)
if not isinstance(destinations, list):
destinations = [destinations]

#query name is how its called in SPARQL query
source_name = source.query_name

#same way we have to get destination names
#here a trick is applied: if it is a data property, we have to add "value" to the end, which is done in the query_name property
#now if it is an object property, the query has to end in the target class.
destination_names = []
# check if more than one of them have an associated condition -> if so throw error
no_of_conditions = 0
for destination in destinations:
if len(destination._parents) > 0:
#this is a list, we need a stepped query
destination_list = []
for parent in destination._parents:
destination_list.append(parent.query_name)
destination_list.append(destination.query_name)
destination_names.append(destination_list)
destination._parents = []
else:
destination_names.append([destination.query_name])

# if condition is specified, and is not there, add it
if condition is not None:
found = False
for destination in destination_names:
if condition.query_name in destination:
found = True
break
if not found:
destination_names.append([condition.query_name])

# add source if not available
found = False
for destination in destination_names:
if source_name in destination:
found = True
break
if not found:
destination_names.append([source_name])
if destination._condition is not None:
no_of_conditions += 1
if no_of_conditions > 1:
raise ValueError("Only one condition is allowed")

#iterate through the list, if they have condition parents, add them explicitely
for destination in destinations:
for parent in destination._condition_parents:
if parent.name not in [d.name for d in destinations]:
destinations.append(parent)

#all names are now collected, in a list of lists
# start prefix of query
Expand All @@ -428,62 +392,77 @@ def create_query(self, source, destinations, condition=None, enforce_types=True)
for key, val in self.extra_namespaces.items():
query.append(f"PREFIX {key}: <{val}>")

# now for each destination, start adding the paths in the query
all_triplets = {}
for count, destination in enumerate(destination_names):
if len(destination) == 1:
triplets = self.get_shortest_path(source_name, destination[0], triples=True)
else:
triplets = self.create_stepped_query(source_name, destination)
all_triplets[str(count)] = triplets

#construct the select distinct command:
#add source `variable_name`
#iterate over destinations, add their `variable_name`
select_destinations = [
f"?{self.strip_name(destination[-1])}" for destination in destination_names
"?"+destination.variable_name for destination in destinations
]
#note that the -1 index above picks the end product for stepped queries
select_destinations = ["?"+source.variable_name] + select_destinations
query.append(f'SELECT DISTINCT {" ".join(select_destinations)}')
query.append("WHERE {")

# now add corresponding triples
for count, destination in enumerate(destination_names):
for triple in all_triplets[str(count)]:

#constructing the spaql query path triples, by iterating over destinations
#for each destination:
# - check if it has parent by looking at `._parents`
# - if it has `_parents`, called step path method
# - else just get the path
# - replace the ends of the path with `variable_name`
# - if it deosnt exist in the collection of lines, add the lines
all_triplets = {}
for count, destination in enumerate(destinations):
#print(source, destination)
triplets = self.get_shortest_path(source, destination, triples=True)
for triple in triplets:
#print(triple)
line_text = " ?%s %s ?%s ."% ( self.strip_name(triple[0]),
line_text = " ?%s %s ?%s ."% ( triple[0].replace(":", "_"),
triple[1],
self.strip_name(triple[2]),
triple[2].replace(":", "_"),
)
if line_text not in query:
query.append(line_text)
query.append(line_text)


# we enforce types of the source and destination
if enforce_types:
if source.node_type == "class":
query.append(
" ?%s rdf:type %s ."
% (self.strip_name(source.query_name), source.query_name)
% (self.strip_name(source.variable_name), source.query_name)
)

for destination in destinations:
node_type = np.atleast_1d(destination)[-1].node_type
query_name = np.atleast_1d(destination)[-1].query_name

if node_type == "class":
if destination.node_type == "class":
query.append(
" ?%s rdf:type %s ."
% (
self.strip_name(query_name),
query_name,
destination.variable_name,
destination.query_name,
)
)
# now we have to add filters
# filters are only needed if it is a dataproperty
#- formulate the condition, given by the `FILTER` command:
# - extract the filter text from the term
# - loop over destinations:
# - call `replace(destination.query_name, destination.variable_name)`
filter_text = ""

# make filters; get all the unique filters from all the classes in destinations
if condition is not None:
if condition._condition is not None:
filter_text = condition._condition

for destination in destinations:
if destination._condition is not None:
filter_text = destination._condition
break

#replace the query_name with variable_name
if filter_text != "":
for destination in destinations:
filter_text = filter_text.replace(
destination.query_name, destination.variable_name
)
query.append(f"FILTER {filter_text}")
query.append("}")

#finished, clean up the terms;
for destination in destinations:
destination.refresh()

return "\n".join(query)
3 changes: 3 additions & 0 deletions atomrdf/network/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,7 @@ def patch_terms(iri, rn):
elif iri == "http://purls.helmholtz-metadaten.de/cmso/hasReference":
rn = ["str"]

elif iri == 'http://purls.helmholtz-metadaten.de/cmso/hasSpaceGroupSymbol':
rn = ["str"]

return rn
Loading

0 comments on commit 0bc3f37

Please sign in to comment.