Skip to content

Commit

Permalink
Merge pull request #397 from roedoejet/dev.ej/network-lite
Browse files Browse the repository at this point in the history
Refactor: replace networkx by a very lightweight custom DiGraph class
  • Loading branch information
joanise authored Sep 12, 2024
2 parents 966a057 + d1b3437 commit 54d4e18
Show file tree
Hide file tree
Showing 13 changed files with 390 additions and 48 deletions.
15 changes: 5 additions & 10 deletions g2p/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@ def make_g2p( # noqa: C901
NoPath: if there is path between in_lang and out_lang
"""
# Defer expensive imports
from networkx import shortest_path # type: ignore
from networkx.exception import NetworkXNoPath # type: ignore

from g2p.log import LOGGER
from g2p.mappings import Mapping
from g2p.mappings.langs import LANGS_NETWORK
Expand Down Expand Up @@ -100,13 +97,13 @@ def make_g2p( # noqa: C901

# Try to find the shortest path between the nodes
try:
path = shortest_path(LANGS_NETWORK, in_lang, out_lang)
except NetworkXNoPath as e:
path = LANGS_NETWORK.shortest_path(in_lang, out_lang)
except ValueError:
LOGGER.error(
f"Sorry, we couldn't find a way to convert {in_lang} to {out_lang}. "
"Please update your langs by running `g2p update` and try again."
)
raise NoPath(in_lang, out_lang) from e
raise NoPath(in_lang, out_lang)

# Find all mappings needed
mappings_needed = []
Expand Down Expand Up @@ -162,8 +159,6 @@ def get_arpabet_langs():
LANG_NAMES maps each code to its full language name and is ordered by codes
"""
# Defer expensive imports
from networkx import has_path

from g2p.mappings import LANGS
from g2p.mappings.langs import LANGS_NETWORK

Expand Down Expand Up @@ -203,8 +198,8 @@ def get_arpabet_langs():
and not x.endswith("-equiv")
and not x.endswith("-no-symbols")
and x not in ["und-ascii", "moh-festival"]
and LANGS_NETWORK.has_node(x)
and has_path(LANGS_NETWORK, x, "eng-arpabet")
and x in LANGS_NETWORK
and LANGS_NETWORK.has_path(x, "eng-arpabet")
]

# Sort LANGS so the -h messages list them alphabetically
Expand Down
5 changes: 2 additions & 3 deletions g2p/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from fastapi import FastAPI, HTTPException, Path, Query
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, PlainTextResponse
from networkx.algorithms.dag import ancestors, descendants # type: ignore

from g2p import make_g2p
from g2p.exceptions import NoPath
Expand Down Expand Up @@ -71,7 +70,7 @@ def get_all_ancestors_of_node(
"""Get the valid ancestors in the network's path to a given node. These
are all the mappings that you can convert from in order to get the
given node."""
return sorted(ancestors(LANGS_NETWORK, node.name))
return sorted(LANGS_NETWORK.ancestors(node.name))


@api.get(
Expand All @@ -84,7 +83,7 @@ def get_all_ancestors_of_node(
def get_all_descendants_of_node(
node: Lang = Path(description="language node name"),
) -> List[str]:
return sorted(descendants(LANGS_NETWORK, node.name))
return sorted(LANGS_NETWORK.descendants(node.name))


@api.get(
Expand Down
11 changes: 4 additions & 7 deletions g2p/api_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@

from fastapi import Body, FastAPI, HTTPException, Path
from fastapi.middleware.cors import CORSMiddleware
from networkx import shortest_path # type: ignore
from networkx.algorithms.dag import ancestors, descendants # type: ignore
from networkx.exception import NetworkXNoPath # type: ignore
from pydantic import BaseModel, Field

import g2p
Expand Down Expand Up @@ -385,7 +382,7 @@ def get_possible_output_conversions_for_a_writing_system(
are all the phonetic or orthographic systems into which you can convert
this input.
"""
return sorted(descendants(g2p_langs.LANGS_NETWORK, lang.name))
return sorted(g2p_langs.LANGS_NETWORK.descendants(lang.name))


@api.get(
Expand All @@ -399,7 +396,7 @@ def get_writing_systems_that_can_be_converted_to_an_output(
are all the phonetic or orthographic systems that you can convert
into this output.
"""
return sorted(ancestors(g2p_langs.LANGS_NETWORK, lang.name))
return sorted(g2p_langs.LANGS_NETWORK.ancestors(lang.name))


@api.get(
Expand All @@ -412,8 +409,8 @@ def get_path_from_one_language_to_another(
) -> List[str]:
"""Get the sequence of intermediate forms used to convert from {in_lang} to {out_lang}."""
try:
return shortest_path(g2p_langs.LANGS_NETWORK, in_lang.name, out_lang.name)
except NetworkXNoPath:
return g2p_langs.LANGS_NETWORK.shortest_path(in_lang.name, out_lang.name)
except ValueError:
raise HTTPException(
status_code=400, detail=f"No path from {in_lang} to {out_lang}"
)
3 changes: 1 addition & 2 deletions g2p/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from typing import Dict, List, Union

import socketio # type: ignore
from networkx import shortest_path # type: ignore
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import HTMLResponse, RedirectResponse
Expand Down Expand Up @@ -282,7 +281,7 @@ async def change_table(sid, message):
namespace="/table",
)
else:
path = shortest_path(LANGS_NETWORK, message["in_lang"], message["out_lang"])
path = LANGS_NETWORK.shortest_path(message["in_lang"], message["out_lang"])
mappings: List[Mapping] = []
for lang1, lang2 in zip(path[:-1], path[1:]):
transducer = make_g2p(lang1, lang2, tokenize=False)
Expand Down
8 changes: 3 additions & 5 deletions g2p/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,8 +501,6 @@ def convert( # noqa: C901
fra->fra-ipa, fra-ipa->eng-ipa and eng-ipa->eng-arpabet.
"""
# Defer expensive imports
from networkx import has_path # type: ignore

from g2p.log import LOGGER
from g2p.mappings import MAPPINGS_AVAILABLE, Mapping, MappingConfig
from g2p.mappings.langs import LANGS_NETWORK
Expand Down Expand Up @@ -535,7 +533,7 @@ def convert( # noqa: C901
if out_lang not in LANGS_NETWORK.nodes:
raise click.UsageError(f"'{out_lang}' is not a valid value for 'OUT_LANG'")
# Check if path exists
if not has_path(LANGS_NETWORK, in_lang, out_lang):
if not LANGS_NETWORK.has_path(in_lang, out_lang):
raise click.UsageError(
f"Path between '{in_lang}' and '{out_lang}' does not exist"
)
Expand Down Expand Up @@ -689,7 +687,7 @@ def update(in_dir, out_dir):
reload_db()
network_to_echart(
outfile=os.path.join(os.path.dirname(static_file), "languages-network.json")
) # updates g2p/status/languages-network.json
) # updates g2p/static/languages-network.json


@click.option(
Expand Down Expand Up @@ -851,4 +849,4 @@ def show_mappings(lang1, lang2, verbose, csv):
print()
else:
for i, m in enumerate(mappings):
print(f"{i+1}: {m.in_lang}{m.out_lang}")
print(f"{i+1}: {m.in_lang}{m.out_lang} ({m.display_name})")
10 changes: 5 additions & 5 deletions g2p/mappings/langs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import json
import os

import networkx # type: ignore

from g2p.constants import LANGS_DIR, LANGS_FILE_NAME, NETWORK_FILE_NAME
from g2p.log import LOGGER

from .network_lite import DiGraph, node_link_graph

assert LANGS_DIR == os.path.dirname(__file__)
LANGS_PKL = os.path.join(LANGS_DIR, LANGS_FILE_NAME)
LANGS_NWORK_PATH = os.path.join(LANGS_DIR, NETWORK_FILE_NAME)
Expand All @@ -25,14 +25,14 @@ def load_langs(path: str = LANGS_PKL):
return {}


def load_network(path: str = LANGS_NWORK_PATH):
def load_network(path: str = LANGS_NWORK_PATH) -> DiGraph[str]:
try:
with gzip.open(path, "rt", encoding="utf8") as f:
data = json.load(f)
return networkx.node_link_graph(data)
return node_link_graph(data)
except Exception as e:
LOGGER.warning(f"Failed to read language network from {path}: {e}")
return networkx.DiGraph()
return DiGraph()


def get_available_languages(langs: dict) -> list:
Expand Down
198 changes: 198 additions & 0 deletions g2p/mappings/langs/network_lite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
from collections import deque
from typing import (
Any,
Dict,
Generic,
Hashable,
Iterable,
Iterator,
List,
Set,
Tuple,
TypeVar,
Union,
)

from typing_extensions import TypedDict

T = TypeVar("T", bound=Hashable)


class DiGraph(Generic[T]):
"""A simple directed graph class
Most functions raise KeyError if called with a node u or v not in the graph.
"""

def __init__(self) -> None:
"""Contructor, empty if no data, else load from data"""
self._edges: Dict[T, List[T]] = {}

def clear(self):
"""Clear the graph"""
self._edges.clear()

def update(self, edges: Iterable[Tuple[T, T]], nodes: Iterable[T]):
"""Update the graph with new edges and nodes"""
for node in nodes:
self.add_node(node)
for u, v in edges:
self.add_edge(u, v)

def add_node(self, u: T):
"""Add a node to the graph"""
if u not in self._edges:
self._edges[u] = []

def add_edge(self, u: T, v: T):
"""Add a directed edge from u to v"""
self.add_node(u)
self.add_node(v)
if v not in self._edges[u]:
self._edges[u].append(v)

def add_edges_from(self, edges: Iterable[Tuple[T, T]]):
"""Add edges from a list of tuples"""
for u, v in edges:
self.add_edge(u, v)

@property # read-only
def nodes(self):
"""Return the nodes"""
return self._edges.keys()

@property # read-only
def edges(self) -> Iterator[Tuple[T, T]]:
"""Iterate over all edges"""
for u, neighbours in self._edges.items():
for v in neighbours:
yield u, v

def __contains__(self, u: T) -> bool:
"""Check if a node is in the graph"""
return u in self._edges

def has_path(self, u: T, v: T) -> bool:
"""Check if there is a path from u to v"""
if v not in self._edges:
raise KeyError(f"Node {v} not in graph")
visited: Set[T] = set()
return self._has_path(u, v, visited)

def _has_path(self, u: T, v: T, visited: Set[T]) -> bool:
"""Helper function for has_path"""
visited.add(u)
if u == v:
return True
for neighbour in self._edges[u]:
if neighbour not in visited:
if self._has_path(neighbour, v, visited):
return True
return False

def successors(self, u: T) -> Iterator[T]:
"""Return the successors of u"""
return iter(self._edges[u])

def descendants(self, u: T) -> Set[T]:
"""Return the descendants of u"""
visited: Set[T] = set()
self._descendants(u, visited)
visited.remove(u)
return visited

def _descendants(self, u: T, visited: Set[T]):
"""Helper function for descendants"""
visited.add(u)
for neighbour in self._edges[u]:
if neighbour not in visited:
self._descendants(neighbour, visited)

def ancestors(self, u: T) -> Set[T]:
"""Return the ancestors of u"""
reversed_graph: DiGraph[T] = DiGraph()
reversed_graph.add_edges_from((v, u) for u, v in self.edges)
for node in self.nodes:
reversed_graph.add_node(node)
return reversed_graph.descendants(u)

def shortest_path(self, u: T, v: T) -> List[T]:
"""Return the shortest path from u to v
Algorithm: Dijsktra's algorithm for unweighted graphs, which is just BFS
Returns:
list: the shortest path from u to v
Raises:
KeyError: if u or v is not in the graph
ValueError: if there is no path from u to v
"""

if v not in self._edges:
raise KeyError(f"Node {v} not in graph")
visited: Dict[T, Union[T, None]] = {
u: None
} # dict of {node: predecessor on shortest path from u}
queue: deque[T] = deque([u])
while queue:
u = queue.popleft()
if u == v:
rev_path: List[T] = []
nextu: Union[T, None] = u
while nextu is not None:
rev_path.append(nextu)
nextu = visited[nextu]
return list(reversed(rev_path))
for neighbour in self._edges[u]:
if neighbour not in visited:
visited[neighbour] = u
queue.append(neighbour)
raise ValueError(f"No path from {u} to {v}")


NodeDict = TypedDict("NodeDict", {"id": Any})


class NodeLinkDict(TypedDict, Generic[T]):
source: T
target: T


class NodeLinkDataDict(TypedDict, Generic[T]):
directed: bool
graph: Dict
links: List[NodeLinkDict[T]]
multigraph: bool
nodes: List[NodeDict]


def node_link_graph(data: NodeLinkDataDict[T]) -> DiGraph[T]:
"""Replacement for networkx.node_link_graph"""
if not data.get("directed", False):
raise ValueError("Graph must be directed")
if data.get("multigraph", True):
raise ValueError("Graph must not be a multigraph")
if not isinstance(data.get("nodes", None), list):
raise ValueError('data["nodes"] must be a list')
if not isinstance(data.get("links", None), list):
raise ValueError('data["links"] must be a list')

graph: DiGraph[T] = DiGraph()
for node in data["nodes"]:
graph.add_node(node["id"])
for edge in data["links"]:
graph.add_edge(edge["source"], edge["target"])
return graph


def node_link_data(graph: DiGraph[T]) -> NodeLinkDataDict[T]:
"""Replacement for networkx.node_link_data"""
nodes: List[NodeDict] = [{"id": node} for node in graph.nodes]
links: List[NodeLinkDict[T]] = [{"source": u, "target": v} for u, v in graph.edges]
return {
"directed": True,
"graph": {},
"links": links,
"multigraph": False,
"nodes": nodes,
}
Loading

0 comments on commit 54d4e18

Please sign in to comment.