Skip to content

Commit

Permalink
coerce nets into connections
Browse files Browse the repository at this point in the history
  • Loading branch information
flaport committed Jun 17, 2024
1 parent 4c086a6 commit 429e337
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 14 deletions.
10 changes: 7 additions & 3 deletions sax/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,12 @@ def circuit(
ignore_missing_ports: Ignore missing ports in the netlist.
"""
backend = _validate_circuit_backend(backend)

instance_models = _extract_instance_models(netlist)
recnet: RecursiveNetlist = parse_netlist(netlist, remove_unused_instances=True)
dependency_dag: nx.DiGraph[str] = _validate_dag(_create_dag(recnet, models))
dependency_dag: nx.DiGraph[str] = _create_dag(recnet, models, validate=True)
models = _validate_models(models, dependency_dag, extra_models=instance_models)
backend = _validate_circuit_backend(backend)

circuit = None
new_models = {}
Expand Down Expand Up @@ -85,6 +86,7 @@ def circuit(
def _create_dag(
netlist: RecursiveNetlist,
models: dict[str, Model] | None = None,
validate: bool = False,
) -> nx.DiGraph[str]:
if models is None:
models = {}
Expand All @@ -111,6 +113,8 @@ def _create_dag(
nodes = [parent_node, *nx.descendants(g, parent_node)]
g = nx.induced_subgraph(g, nodes)
assert isinstance(g, nx.DiGraph)
if validate:
g = _validate_dag(g)
return g


Expand Down Expand Up @@ -140,7 +144,7 @@ def get_required_circuit_models(
"""
instance_models = _extract_instance_models(netlist)
recnet: RecursiveNetlist = parse_netlist(netlist, remove_unused_instances=True)
dependency_dag: nx.DiGraph[str] = _validate_dag(_create_dag(recnet, models))
dependency_dag: nx.DiGraph[str] = _create_dag(recnet, models, validate=True)
_, required, _ = _find_missing_models(
models, dependency_dag, extra_models=instance_models
)
Expand Down
74 changes: 63 additions & 11 deletions sax/netlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from natsort import natsorted
from pydantic import AfterValidator
from pydantic import BaseModel as _BaseModel
from pydantic import BeforeValidator, ConfigDict, Field, RootModel
from pydantic import BeforeValidator, ConfigDict, Field, RootModel, model_validator
from typing_extensions import Annotated

from .utils import clean_string, hash_dict
Expand Down Expand Up @@ -83,16 +83,6 @@ class Placement(BaseModel):
port: str | PortPlacement | None = None


def _coerce_component(obj: Any) -> Component:
if isinstance(obj, str):
return Component(component=obj)
elif isinstance(obj, partial):
return _component_from_partial(obj)
elif callable(obj):
return _coerce_component(obj.__name__)
return Component.model_validate(obj)


def _component_from_partial(p: partial):
settings = {}
f: Any = p
Expand All @@ -109,6 +99,20 @@ def _component_from_partial(p: partial):
return Component(component=f.__name__, settings=settings)


def _coerce_component(obj: Any) -> Component:
if isinstance(obj, str):
return Component(component=obj)
elif isinstance(obj, partial):
return _component_from_partial(obj)
elif callable(obj):
return _coerce_component(obj.__name__)
elif isinstance(obj, dict) and "info" in obj:
info = obj.pop("info", {})
settings = obj.pop("settings", {})
obj["settings"] = {**settings, **info}
return Component.model_validate(obj)


CoercingComponent = Annotated[Component, BeforeValidator(_coerce_component)]


Expand All @@ -133,12 +137,60 @@ def _validate_instance_port_str(s: str):
InstancePortStr = Annotated[str, AfterValidator(_validate_instance_port_str)]


def _nets_to_connections(nets: list[dict], connections: dict):
connections = {k: v for k, v in connections.items()}
inverse_connections = {v: k for k, v in connections.items()}

def _is_connected(p):
return (p in connections) or (p in inverse_connections)

def _add_connection(p, q):
connections[p] = q
inverse_connections[q] = p

def _get_connected_port(p):
if p in connections:
return connections[p]
else:
return inverse_connections[p]

for net in nets:
p = net["p1"]
q = net["p2"]
if _is_connected(p):
_q = _get_connected_port(p)
raise ValueError(
"SAX currently does not support multiply connected ports. "
f"Got {p}<->{q} and {p}<->{_q}"
)
if _is_connected(q):
_p = _get_connected_port(q)
raise ValueError(
"SAX currently does not support multiply connected ports. "
f"Got {p}<->{q} and {_p}<->{q}"
)
_add_connection(p, q)
return connections


class Netlist(BaseModel):
instances: dict[InstanceStr, CoercingComponent] = Field(default_factory=dict)
connections: dict[InstancePortStr, InstancePortStr] = Field(default_factory=dict)
ports: dict[PortStr, InstancePortStr] = Field(default_factory=dict)
placements: dict[InstanceStr, Placement] = Field(default_factory=dict)

@model_validator(mode="before")
@classmethod
def coerce_nets_into_connections(cls, netlist: dict):
if not isinstance(netlist, dict):
return netlist
if "nets" in netlist:
nets = netlist.pop("nets", [])
connections = netlist.pop("connections", {})
connections = _nets_to_connections(nets, connections)
netlist["connections"] = connections
return netlist


class RecursiveNetlist(RootModel):
root: dict[str, Netlist]
Expand Down

0 comments on commit 429e337

Please sign in to comment.