Skip to content

Commit

Permalink
Cleanup and refactoring. Tests added.
Browse files Browse the repository at this point in the history
  • Loading branch information
flaport committed Jun 17, 2024
1 parent 5204870 commit 88248e1
Show file tree
Hide file tree
Showing 8 changed files with 248 additions and 256 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
## [0.13.0](https://github.com/flaport/sax/compare/0.12.2...0.13.0)

- Deprecate `sax.nn`.
- Remove support for pydantic v1.
- Deprecate support for pydantic v1.
- Add support for gdsfactory 8 netlists (i.e. `nets` in stead of `connections`)

## [0.10.2](https://github.com/flaport/sax/compare/0.10.1...0.10.2)
Expand Down
10 changes: 3 additions & 7 deletions sax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,8 @@

from scipy.constants import c as c

try:
from flax.core.frozen_dict import FrozenDict as FrozenDict
except ImportError:
FrozenDict = dict

from . import backends as backends
from . import models as models
from . import patched as patched
from . import saxtypes as saxtypes
from . import utils as utils
from .circuit import circuit as circuit
Expand All @@ -30,7 +24,9 @@
from .models import passthru as passthru
from .multimode import multimode as multimode
from .multimode import singlemode as singlemode
from .netlist import RecursiveNetlist, flatten_netlist
from .netlist import Netlist as Netlist
from .netlist import RecursiveNetlist as RecursiveNetlist
from .netlist import flatten_netlist as flatten_netlist
from .netlist import get_component_instances as get_component_instances
from .netlist import get_netlist_instances_by_prefix as get_netlist_instances_by_prefix
from .netlist import load_netlist as load_netlist
Expand Down
243 changes: 103 additions & 140 deletions sax/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,15 @@
import shutil
import sys
from functools import partial
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
from typing import Callable, NamedTuple

import black
import networkx as nx
import numpy as np
from pydantic import ValidationError

from .backends import circuit_backends
from .netlist import (
Netlist,
NetlistDict,
RecursiveNetlist,
RecursiveNetlistDict,
remove_unused_instances,
)
from .netlist import AnyNetlist, NetlistDict, RecursiveNetlist
from .netlist import netlist as parse_netlist
from .saxtypes import Model, Settings, SType, scoo, sdense, sdict
from .utils import (
_replace_kwargs,
Expand All @@ -35,16 +29,16 @@ class CircuitInfo(NamedTuple):
"""Information about the circuit function you created."""

dag: nx.DiGraph[str]
models: Dict[str, Model]
models: dict[str, Model]


def circuit(
netlist: Union[Netlist, NetlistDict, RecursiveNetlist, RecursiveNetlistDict],
models: Optional[Dict[str, Model]] = None,
netlist: AnyNetlist,
models: dict[str, Model] | None = None,
backend: str = "default",
return_type: str = "sdict",
ignore_missing_ports: bool = False,
) -> Tuple[Model, CircuitInfo]:
) -> tuple[Model, CircuitInfo]:
"""Create a circuit function for a given netlist.
Args:
Expand All @@ -55,15 +49,10 @@ def circuit(
ignore_missing_ports: Ignore missing ports in the netlist.
"""
netlist = _ensure_recursive_netlist_dict(netlist)

# TODO: do the following two steps *after* recursive netlist parsing.
netlist = remove_unused_instances(netlist)
netlist, instance_models = _extract_instance_models(netlist)
recnet: RecursiveNetlist = _validate_net(netlist)
recnet: RecursiveNetlist = parse_netlist(netlist, remove_unused_instances=True)
recnet, instance_models = _extract_instance_models(recnet)
dependency_dag: nx.DiGraph[str] = _validate_dag(_create_dag(recnet, models))

models = _validate_models({**(models or {}), **instance_models}, dependency_dag)
models = _validate_models(models, dependency_dag, extra_models=instance_models)
backend = _validate_circuit_backend(backend)

circuit = None
Expand Down Expand Up @@ -95,7 +84,7 @@ def circuit(

def _create_dag(
netlist: RecursiveNetlist,
models: Optional[Dict[str, Any]] = None,
models: dict[str, Model] | None = None,
) -> nx.DiGraph[str]:
if models is None:
models = {}
Expand All @@ -121,8 +110,8 @@ def _create_dag(
parent_node = next(iter(netlist.root.keys()))
nodes = [parent_node, *nx.descendants(g, parent_node)]
g = nx.induced_subgraph(g, nodes)

return g # type: ignore
assert isinstance(g, nx.DiGraph)
return g


def draw_dag(dag, with_labels=True, **kwargs):
Expand All @@ -138,6 +127,72 @@ def draw_dag(dag, with_labels=True, **kwargs):
return nx.draw(dag, _my_dag_pos(dag), with_labels=with_labels, **kwargs)


def get_required_circuit_models(
netlist: AnyNetlist,
models: dict[str, Model] | None = None,
) -> list[str]:
"""Figure out which models are needed for a given netlist.
Args:
netlist: The netlist to create a circuit for.
models: A dictionary of models to use in the circuit.
"""
recnet: RecursiveNetlist = parse_netlist(netlist, remove_unused_instances=True)
recnet, instance_models = _extract_instance_models(recnet)
dependency_dag: nx.DiGraph[str] = _validate_dag(_create_dag(recnet, models))
_, required, _ = _find_missing_models(
models, dependency_dag, extra_models=instance_models
)
return required


def _flat_circuit(
instances, connections, ports, models, backend, ignore_missing_ports=False
):
analyze_insts_fn, analyze_fn, evaluate_fn = circuit_backends[backend]
dummy_instances = analyze_insts_fn(instances, models)
inst_port_mode = {
k: _port_modes_dict(get_ports(s)) for k, s in dummy_instances.items()
}
connections = _get_multimode_connections(
connections, inst_port_mode, ignore_missing_ports=ignore_missing_ports
)
ports = _get_multimode_ports(
ports, inst_port_mode, ignore_missing_ports=ignore_missing_ports
)

inst2model = {}
for k, inst in instances.items():
inst2model[k] = models[inst.component]

model_settings = {name: get_settings(model) for name, model in inst2model.items()}
netlist_settings = {
name: {
k: v for k, v in (inst.settings or {}).items() if k in model_settings[name]
}
for name, inst in instances.items()
}
default_settings = merge_dicts(model_settings, netlist_settings)
analyzed = analyze_fn(dummy_instances, connections, ports)

def _circuit(**settings: Settings) -> SType:
full_settings = merge_dicts(default_settings, settings)
full_settings = _forward_global_settings(inst2model, full_settings)
full_settings = merge_dicts(full_settings, settings)

instances: dict[str, SType] = {}
for inst_name, model in inst2model.items():
instances[inst_name] = model(**full_settings.get(inst_name, {}))

S = evaluate_fn(analyzed, instances)
return S

_replace_kwargs(_circuit, **default_settings)

return _circuit


def _patch_path():
os_paths = {p: None for p in os.environ.get("PATH", "").split(os.pathsep)}
sys_paths = {p: None for p in sys.path}
Expand Down Expand Up @@ -175,9 +230,25 @@ def _find_leaves(g):
return [n for n, d in g.out_degree() if d == 0]


def _validate_models(models, dag):
def _find_missing_models(
models: dict | None, dag: nx.DiGraph, extra_models: dict | None = None
) -> tuple[dict[str, Callable], list[str], list[str]]:
if extra_models is None:
extra_models = {}
if models is None:
models = {}
models = {**models, **extra_models}
required_models = _find_leaves(dag)
missing_models = [m for m in required_models if m not in models]
return models, required_models, missing_models


def _validate_models(
models: dict | None, dag: nx.DiGraph, extra_models: dict | None = None
) -> dict[str, Model]:
models, required_models, missing_models = _find_missing_models(
models, dag, extra_models
)
if missing_models:
model_diff = {
"Missing Models": missing_models,
Expand All @@ -188,53 +259,7 @@ def _validate_models(models, dag):
"Missing models. The following models are still missing to build "
f"the circuit:\n{black.format_str(repr(model_diff), mode=black.Mode())}"
)
return {**models} # shallow copy


def _flat_circuit(
instances, connections, ports, models, backend, ignore_missing_ports=False
):
analyze_insts_fn, analyze_fn, evaluate_fn = circuit_backends[backend]
dummy_instances = analyze_insts_fn(instances, models)
inst_port_mode = {
k: _port_modes_dict(get_ports(s)) for k, s in dummy_instances.items()
}
connections = _get_multimode_connections(
connections, inst_port_mode, ignore_missing_ports=ignore_missing_ports
)
ports = _get_multimode_ports(
ports, inst_port_mode, ignore_missing_ports=ignore_missing_ports
)

inst2model = {}
for k, inst in instances.items():
inst2model[k] = models[inst.component]

model_settings = {name: get_settings(model) for name, model in inst2model.items()}
netlist_settings = {
name: {
k: v for k, v in (inst.settings or {}).items() if k in model_settings[name]
}
for name, inst in instances.items()
}
default_settings = merge_dicts(model_settings, netlist_settings)
analyzed = analyze_fn(dummy_instances, connections, ports)

def _circuit(**settings: Settings) -> SType:
full_settings = merge_dicts(default_settings, settings)
full_settings = _forward_global_settings(inst2model, full_settings)
full_settings = merge_dicts(full_settings, settings)

instances: Dict[str, SType] = {}
for inst_name, model in inst2model.items():
instances[inst_name] = model(**full_settings.get(inst_name, {}))

S = evaluate_fn(analyzed, instances)
return S

_replace_kwargs(_circuit, **default_settings)

return _circuit
return models


def _forward_global_settings(instances, settings):
Expand Down Expand Up @@ -324,21 +349,11 @@ def _enforce_return_type(model, return_type):
return stype_func(model)


def _ensure_recursive_netlist_dict(netlist: Any) -> RecursiveNetlistDict:
if not isinstance(netlist, dict):
netlist = netlist.model_dump()
if "instances" in netlist:
netlist = {"top_level": netlist}
netlist = {**netlist}
for k, v in netlist.items():
netlist[k] = {**v}
return netlist


def _extract_instance_models(netlist):
def _extract_instance_models(netlist: RecursiveNetlist):
models = {}
for netname, net in netlist.items():
net = {**net}
recnet_dict = netlist.model_dump()
for netname, net in recnet_dict.items():
net: NetlistDict = {**net}
net["instances"] = {**net["instances"]}
for name, inst in net["instances"].items():
if callable(inst):
Expand All @@ -355,7 +370,8 @@ def _extract_instance_models(netlist):
"component": inst.__name__,
"settings": settings,
}
netlist[netname] = net
recnet_dict[netname] = net
netlist = RecursiveNetlist.model_validate(recnet_dict)
return netlist, models


Expand All @@ -370,19 +386,6 @@ def _validate_circuit_backend(backend):
return backend


def _validate_net(
netlist: Union[Netlist, RecursiveNetlist, NetlistDict, RecursiveNetlistDict]
) -> RecursiveNetlist:
if isinstance(netlist, dict):
try:
netlist = Netlist.model_validate(netlist)
except ValidationError:
netlist = RecursiveNetlist.model_validate(netlist)
if isinstance(netlist, Netlist):
netlist = RecursiveNetlist(root={"top_level": netlist})
return netlist


def _validate_dag(dag):
nodes = _find_root(dag)
if len(nodes) > 1:
Expand All @@ -392,43 +395,3 @@ def _validate_dag(dag):
if not dag.is_directed():
raise ValueError("Netlist dependency cycles detected!")
return dag


def get_required_circuit_models(
netlist: Union[Netlist, NetlistDict, RecursiveNetlist, RecursiveNetlistDict],
models: Optional[Dict[str, Model]] = None,
) -> List:
"""Figure out which models are needed for a given netlist.
Args:
netlist: The netlist to create a circuit for.
models: A dictionary of models to use in the circuit.
"""
if models is None:
models = {}
assert isinstance(models, dict)
netlist = _ensure_recursive_netlist_dict(netlist)
# TODO: do the following two steps *after* recursive netlist parsing.
netlist = remove_unused_instances(netlist)
netlist, _ = _extract_instance_models(netlist)
recnet: RecursiveNetlist = _validate_net(netlist)

missing_models = {}
missing_model_names = []
g = nx.DiGraph()

for model_name, subnetlist in recnet.model_dump().items():
if model_name not in missing_models:
missing_models[model_name] = models.get(model_name, subnetlist)
g.add_node(model_name)
if model_name in models:
continue
for instance in subnetlist["instances"].values():
component = instance["component"]
if (component not in missing_models) and (component not in models):
missing_models[component] = models.get(component)
missing_model_names.append(component)
g.add_node(component)
g.add_edge(model_name, component)
return missing_model_names
Loading

0 comments on commit 88248e1

Please sign in to comment.