Skip to content

Commit

Permalink
Merge pull request #141 from bgyori/unicode-fix
Browse files Browse the repository at this point in the history
Unicode fix
  • Loading branch information
bgyori authored Sep 5, 2023
2 parents c669d2f + 215f15f commit 84e6a41
Show file tree
Hide file tree
Showing 7 changed files with 291 additions and 152 deletions.
34 changes: 1 addition & 33 deletions src/indra_cogex/apps/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import codecs
import json
import numpy
import logging
Expand All @@ -21,6 +20,7 @@
from indra.assemblers.html.assembler import _format_evidence_text, _format_stmt_text
from indra.statements import Statement
from indra.util.statement_presentation import _get_available_ev_source_counts
from indra_cogex.util import unicode_escape, UnicodeEscapeError
from indra_cogex.apps.constants import VUE_SRC_JS, VUE_SRC_CSS, sources_dict
from indra_cogex.apps.curation_cache.curation_cache import Curations
from indra_cogex.apps.proxies import curation_cache
Expand Down Expand Up @@ -127,38 +127,6 @@ def render_statements(
)


class UnicodeEscapeError(Exception):
pass


def unicode_escape(s: str, attempt: int = 1, max_attempts: int = 5) -> str:
"""Remove extra escapes from unicode characters in a string
Parameters
----------
s :
A string to remove extra escapes in unicode characters from
attempt :
The current attempt number.
max_attempts :
The maximum number of attempts to remove extra escapes.
Returns
-------
:
The string with extra escapes removed.
"""
escaped = codecs.escape_decode(s)[0].decode()
# No more escaping needed
if escaped.count('\\\\u') == 0:
return bytes(escaped, "utf-8").decode("unicode_escape")
# Too many attempts, return the input
if attempt >= max_attempts:
raise UnicodeEscapeError(f"Could not remove extra escapes from {s}")
# Try again
return unicode_escape(escaped, attempt + 1, max_attempts)


def format_stmts(
stmts: Iterable[Statement],
evidence_counts: Optional[Mapping[int, int]] = None,
Expand Down
31 changes: 4 additions & 27 deletions src/indra_cogex/sources/indra_db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

"""Processor for the INDRA database."""

import codecs
import csv
import gzip
import json
Expand All @@ -15,7 +14,6 @@
from pathlib import Path
from typing import Iterable, Optional, Tuple, Union

from indra.databases.identifiers import ensure_prefix_if_needed
from indra.statements import (
Agent,
default_ns_order,
Expand All @@ -37,6 +35,7 @@
processed_stmts_fname,
stmts_from_json,
)
from indra_cogex.util import load_stmt_json_str

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -86,7 +85,7 @@ def get_nodes(self): # noqa:D102
batch_iter(reader, batch_size=batch_size, return_func=list),
desc="Getting BioEntity nodes",
):
sj_list = [load_statement_json(sjs) for _, sjs in batch]
sj_list = [load_stmt_json_str(sjs) for _, sjs in batch]
stmts = stmts_from_json(sj_list)
for stmt in stmts:
for agent in stmt.real_agent_list():
Expand Down Expand Up @@ -125,7 +124,7 @@ def get_relations(self, max_complex_members: int = 3): # noqa:D102
f"statement hash {stmt_hash}. Are the source files updated?"
)
continue
stmt_json = load_statement_json(stmt_json_str)
stmt_json = load_stmt_json_str(stmt_json_str)
if stmt_json["evidence"][0]["source_api"] == "medscan":
stmt_json["evidence"] = []
data = {
Expand Down Expand Up @@ -237,11 +236,7 @@ def get_nodes(self, num_rows: Optional[int] = None) -> Iterable[Node]:
stmt_hash = int(stmt_hash_str)
if stmt_hash not in included_hashes:
continue
try:
stmt_json = load_statement_json(stmt_json_str)
except StatementJSONDecodeError as e:
logger.warning(e)
continue
stmt_json = load_stmt_json_str(stmt_json_str)

# Loop all evidences
# NOTE: there should be a single evidence for each
Expand Down Expand Up @@ -367,10 +362,6 @@ def _get_node_paths(cls, node_type: str) -> Path:
)


class StatementJSONDecodeError(Exception):
pass


def get_ag_ns_id(ag: Agent) -> Tuple[str, str]:
"""Return a namespace, identifier tuple for a given agent.
Expand All @@ -390,20 +381,6 @@ def get_ag_ns_id(ag: Agent) -> Tuple[str, str]:
return None, None


def load_statement_json(json_str: str, attempt: int = 1, max_attempts: int = 5) -> json:
try:
return json.loads(json_str)
except json.JSONDecodeError:
if attempt < max_attempts:
json_str = codecs.escape_decode(json_str)[0].decode()
return load_statement_json(
json_str, attempt=attempt + 1, max_attempts=max_attempts
)
raise StatementJSONDecodeError(
f"Could not decode statement JSON after " f"{attempt} attempts: {json_str}"
)


def load_text_refs_for_reading_dict(fname: str):
text_refs = {}
for line in tqdm(
Expand Down
76 changes: 6 additions & 70 deletions src/indra_cogex/sources/indra_db/assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import gzip
import logging
import math
import json
import pickle
import itertools
from pathlib import Path
Expand All @@ -11,9 +10,7 @@
import networkx as nx
import numpy as np
import tqdm
import codecs
import pystow
import sqlite3
from collections import defaultdict, Counter

from indra.belief import BeliefEngine
Expand All @@ -27,6 +24,7 @@
unique_stmts_fname,
source_counts_fname,
)
from indra_cogex.util import load_stmt_json_str

StmtList = List[Statement]

Expand All @@ -36,10 +34,6 @@
refinement_cycles_fname = base_folder.join(name="refinement_cycles.pkl")


class StatementJSONDecodeError(Exception):
pass


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -86,7 +80,7 @@ def get_refinement_graph() -> nx.DiGraph:
try:
_, sjs = next(reader1)
stmt = stmt_from_json(
load_statement_json(sjs, remove_evidence=True)
load_stmt_json_str(sjs, remove_evidence=True)
)
stmts1.append(stmt)
except StopIteration:
Expand Down Expand Up @@ -118,7 +112,8 @@ def get_refinement_graph() -> nx.DiGraph:
for _, sjs in batch:
try:
stmt = stmt_from_json(
load_statement_json(sjs, remove_evidence=True)
load_stmt_json_str(sjs,
remove_evidence=True)
)
stmts2.append(stmt)
except StopIteration:
Expand Down Expand Up @@ -173,37 +168,6 @@ def get_refinement_graph() -> nx.DiGraph:
return ref_graph


def load_statement_json(
json_str: str,
attempt: int = 1,
max_attempts: int = 5,
remove_evidence: bool = False,
):
try:
return json.loads(json_str)
except json.JSONDecodeError:
if attempt < max_attempts:
json_str = codecs.escape_decode(json_str)[0].decode()
sj = load_statement_json(
json_str, attempt=attempt + 1, max_attempts=max_attempts
)
if remove_evidence:
sj["evidence"] = []
return sj
raise StatementJSONDecodeError(
f"Could not decode statement JSON after " f"{attempt} attempts: {json_str}"
)


def get_stmts(db, limit, offset):
cur = db.execute("select * from processed limit %s offset %s" % (limit, offset))
stmts = [
stmt_from_json(load_statement_json(sjs, remove_evidence=True))
for _, sjs in tqdm.tqdm(cur.fetchall(), total=limit, desc="Loading statements")
]
return stmts


def get_related(stmts: StmtList) -> Set[Tuple[int, int]]:
stmts_by_type = defaultdict(list)
for stmt in stmts:
Expand Down Expand Up @@ -232,34 +196,6 @@ def get_related_split(stmts1: StmtList, stmts2: StmtList) -> Set[Tuple[int, int]
return refinements


def sqlite_approach():
"""
Assembly notes:
Step 1: Create a SQLITE DB
sqlite3 -batch statements.db "create table processed (hash integer, stmt text);"
zcat < unique_statements.tsv.gz | sqlite3 -cmd ".mode tabs" -batch statements.db ".import '|cat -' processed"
sqlite3 -batch statements.db "create index processed_idx on processed (hash);"
"""
db = sqlite3.connect(base_folder.join(name="statements.db"))

cur = db.execute("select count(1) from processed")
num_rows = cur.fetchone()[0]

offset0 = 0
num_batches = math.ceil(num_rows / batch_size)
refinements = set()
for i in tqdm.tqdm(range(num_batches)):
offset1 = i * batch_size
stmts1 = get_stmts(db, batch_size, offset1)
refinements |= get_related(stmts1)
for j in tqdm.tqdm(range(i + 1, num_batches)):
offset2 = j * batch_size
stmts2 = get_stmts(db, batch_size, offset2)
refinements |= get_related_split(stmts1, stmts2)


def sample_unique_stmts(
num: int = 100000, n_rows: Optional[int] = None
) -> List[Tuple[int, Statement]]:
Expand Down Expand Up @@ -293,7 +229,7 @@ def sample_unique_stmts(
reader = csv.reader(f, delimiter="\t")
for index, (sh, sjs) in enumerate(reader):
if index in indices:
stmts.append((int(sh), stmt_from_json(load_statement_json(sjs))))
stmts.append((int(sh), stmt_from_json(load_stmt_json_str(sjs))))
t.update()
if len(stmts) == num:
break
Expand Down Expand Up @@ -390,7 +326,7 @@ def _add_belief_scores_for_batch(batch: List[Tuple[int, Statement]]):
try:
stmt_hash_string, statement_json_string = next(reader)
statement = stmt_from_json(
load_statement_json(
load_stmt_json_str(
statement_json_string, remove_evidence=True
)
)
Expand Down
25 changes: 4 additions & 21 deletions src/indra_cogex/sources/indra_db/raw_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
import pystow
from adeft.download import get_available_models
from indra.util import batch_iter
from indra.statements import stmts_from_json
from indra.statements import stmts_from_json, stmt_from_json
from indra.tools import assemble_corpus as ac
from indra_cogex.util import load_stmt_json_str

base_folder = pystow.module("indra", "db")
reading_text_content_fname = base_folder.join(name="reading_text_content_meta.tsv.gz")
Expand All @@ -30,24 +31,6 @@
logger = logging.getLogger(__name__)


class StatementJSONDecodeError(Exception):
pass


def load_statement_json(json_str: str, attempt: int = 1, max_attempts: int = 5):
try:
return json.loads(json_str)
except json.JSONDecodeError:
if attempt < max_attempts:
json_str = codecs.escape_decode(json_str)[0].decode()
return load_statement_json(
json_str, attempt=attempt + 1, max_attempts=max_attempts
)
raise StatementJSONDecodeError(
f"Could not decode statement JSON after " f"{attempt} attempts: {json_str}"
)


def reader_prioritize(reader_contents):
drop = set()
# We first organize the contents by source/text type
Expand Down Expand Up @@ -322,7 +305,7 @@ def get_update(start_date):
text_ref_id = reading_id_to_text_ref_id.get(int(reading_id))
if text_ref_id:
refs = text_refs.get(text_ref_id)
stmt_json = load_statement_json(stmt_json_raw)
stmt_json = load_stmt_json_str(stmt_json_raw)
if refs:
stmt_json["evidence"][0]["text_refs"] = refs
if refs.get("PMID"):
Expand Down Expand Up @@ -366,7 +349,7 @@ def get_update(start_date):
for sh, stmt_json_str in tqdm.tqdm(
reader, total=60405451, desc="Gathering grounded and unique statements"
):
stmt = stmts_from_json([load_statement_json(stmt_json_str)])[0]
stmt = stmt_from_json(load_stmt_json_str(stmt_json_str))
if len(stmt.real_agent_list()) < 2:
continue
if all(
Expand Down
Loading

0 comments on commit 84e6a41

Please sign in to comment.