Skip to content

Commit

Permalink
Refactor of old version of ES classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
grzegorzZ1 committed Nov 9, 2023
1 parent 74fc744 commit 38a7d6b
Showing 1 changed file with 79 additions and 43 deletions.
122 changes: 79 additions & 43 deletions pipeline.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import requests
import time
import click
import os
import json
from rdflib import Graph, URIRef, BNode, Namespace

from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk


def extract_title_from_graph(graph: Graph):
Expand All @@ -19,7 +18,7 @@ def extract_abstract_from_graph(graph: Graph):
description_uri = URIRef("http://purl.org/spar/datacite/hasDescription")
abstract_node = list(graph.triples((None, description_uri, None)))[0][2]
abstract = list(graph.triples((BNode(abstract_node), None, None)))[1][2]
return abstract
return str(abstract)


def extract_embedding_from_graph(graph: Graph):
Expand All @@ -29,6 +28,37 @@ def extract_embedding_from_graph(graph: Graph):
return embedding


def store_records_bulk(es_object, index, data):
requests = []
for i, row in enumerate(data):
request = row
request["_op_type"] = "index"
request["_index"] = index
requests.append(request)
bulk(es_object, requests)


def find_n_best(result, n, label_colname):
results = []
for i in range(n):
results.append(
{
label_colname: result["hits"]["hits"][i]["_source"][label_colname],
"score": result["hits"]["hits"][i]["_score"],
}
)
return results


def connect_elasticsearch(es_config):
_es = None
_es = Elasticsearch([es_config])
while not _es.ping():
print("Could not connect to Elastic Search, retrying in 3s...")
time.sleep(3)
return _es


def get_query(title, abstract, embedding):
query = {
"query": {
Expand Down Expand Up @@ -66,51 +96,57 @@ def main():
mappings = {
"mappings": {
"properties": {
"prefLabel": {"type": "keyword"},
"broader": {"type": "keyword"},
"related": {"type": "keyword"},
"prefLabel": {"type": "text"},
"broader": {"type": "text"},
"related": {"type": "text"},
"embedding": {"type": "float"},
}
}
}
with Elasticsearch(
[{"host": "localhost", "port": 9200, "scheme": "http"}]
with connect_elasticsearch(
{"host": "localhost", "port": 9200, "scheme": "http"}
) as es:
es.indices.create(index=index_name, body=mappings)

for file in os.listdir("/home/concepts.json"):
print(f"Indexing {file} file...")
concepts_batch = json.load(file)
for key, value in concepts_batch.items():
concept = {
"prefLabel": value["prefLabel"],
"related": value["related"],
"broader": value["broader"],
"embedding": value["embedding"],
}
es.index(index=index_name, id=key, body=concept)

archives = ["csis", "scpe"]
input_path = f"{os.getcwd()}/output"
for archive in archives:
root_dir = os.path.join(input_path, archive)
for dir in os.listdir(root_dir):
dir_path = os.path.join(root_dir, dir)
for ttl_file in os.listdir(dir_path):
print(f"Finding best matching concepts for file {ttl_file}...")
graph = Graph()
graph.parse(ttl_file)
file_title = extract_title_from_graph(graph)
file_abstract = extract_abstract_from_graph(graph)
file_abstract_embedding = extract_abstract_from_graph(graph)
query = get_query(
file_title, file_abstract, file_abstract_embedding
)
results = es.search(index=index_name, body=query)
for hit in results["hits"]["hits"]:
print(hit["_score"])
print(hit["_source"])
break
print(f"Creating index {index_name}...")
try:
es.indices.create(index=index_name, body=mappings)
print("Index created")
concept_json_dir = "/home/concepts_json"
concepts = []
for file_path in os.listdir(concept_json_dir):
print(f"Indexing {file_path} file...")
with open(os.path.join(concept_json_dir, file_path), "r") as file:
concepts_batch = json.load(file)
for key, value in concepts_batch.items():
prefLabel = value.get("prefLabel", None)
broader = value.get("broader", None)
related = value.get("related", None)
concept = {
"prefLabel": prefLabel if prefLabel else [],
"broader": broader if broader else [],
"related": related if related else [],
}
concepts.append(concept)
store_records_bulk(es, index_name, concepts)
except:
print(f"Index {index_name} already exists!")
archives = ["csis", "scpe"]
input_path = f"{os.getcwd()}/output"
for archive in archives:
root_dir = os.path.join(input_path, archive)
for dir in os.listdir(root_dir):
dir_path = os.path.join(root_dir, dir)
for ttl_file in os.listdir(dir_path):
print(f"Finding best matching concepts for file {ttl_file}...")
graph = Graph()
graph.parse(os.path.join(dir_path, ttl_file))
file_title = extract_title_from_graph(graph)
file_abstract = extract_abstract_from_graph(graph)
file_abstract_embedding = extract_abstract_from_graph(graph)
query = get_query(
file_title, file_abstract, file_abstract_embedding
)
results = es.search(index=index_name, body=query)
print(find_n_best(results, 3, "prefLabel"))

except:
raise
Expand Down

0 comments on commit 38a7d6b

Please sign in to comment.