Skip to content

Commit

Permalink
Merge pull request #34 from joamatab/add_docstrings
Browse files Browse the repository at this point in the history
add docstrings
  • Loading branch information
flaport authored Jun 5, 2024
2 parents 0821c2e + 6c6df6b commit 6eaa2b6
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 35 deletions.
44 changes: 26 additions & 18 deletions sax/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,16 @@ def circuit(
return_type: str = "sdict",
ignore_missing_ports: bool = False,
) -> Tuple[Model, CircuitInfo]:
"""create a circuit function for a given netlist"""
"""Create a circuit function for a given netlist.
Args:
netlist: The netlist to create a circuit for.
models: A dictionary of models to use in the circuit.
backend: The backend to use for the circuit.
return_type: The type of the circuit function to return.
ignore_missing_ports: Ignore missing ports in the netlist.
"""
netlist = _ensure_recursive_netlist_dict(netlist)

# TODO: do the following two steps *after* recursive netlist parsing.
Expand All @@ -71,7 +80,7 @@ def circuit(
continue

flatnet = recnet.__root__[model_name]
current_models.update(new_models)
current_models |= new_models
new_models = {}

current_models[model_name] = circuit = _flat_circuit(
Expand Down Expand Up @@ -141,7 +150,7 @@ def _patch_path():
os_paths = {p: None for p in os.environ.get("PATH", "").split(os.pathsep)}
sys_paths = {p: None for p in sys.path}
other_paths = {os.path.dirname(sys.executable): None}
os.environ["PATH"] = os.pathsep.join({**os_paths, **sys_paths, **other_paths})
os.environ["PATH"] = os.pathsep.join(os_paths | sys_paths | other_paths)


def _my_dag_pos(dag):
Expand All @@ -167,13 +176,11 @@ def _my_dag_pos(dag):


def _find_root(g):
nodes = [n for n, d in g.in_degree() if d == 0]
return nodes
return [n for n, d in g.in_degree() if d == 0]


def _find_leaves(g):
nodes = [n for n, d in g.out_degree() if d == 0]
return nodes
return [n for n, d in g.out_degree() if d == 0]


def _validate_models(models, dag):
Expand Down Expand Up @@ -242,11 +249,9 @@ def _circuit(**settings: Settings) -> SType:


def _forward_global_settings(instances, settings):
global_settings = {}
for k in list(settings.keys()):
if k in instances:
continue
global_settings[k] = settings.pop(k)
global_settings = {
k: settings.pop(k) for k in list(settings.keys()) if k not in instances
}
if global_settings:
settings = update_settings(settings, **global_settings)
return settings
Expand All @@ -255,10 +260,7 @@ def _forward_global_settings(instances, settings):
def _port_modes_dict(port_modes):
result = {}
for port_mode in port_modes:
if "@" in port_mode:
port, mode = port_mode.split("@")
else:
port, mode = port_mode, None
port, mode = port_mode.split("@") if "@" in port_mode else (port_mode, None)
if port not in result:
result[port] = set()
if mode is not None:
Expand Down Expand Up @@ -407,7 +409,13 @@ def get_required_circuit_models(
netlist: Union[Netlist, NetlistDict, RecursiveNetlist, RecursiveNetlistDict],
models: Optional[Dict[str, Model]] = None,
) -> List:
"""Figure out which models are needed for a given netlist"""
"""Figure out which models are needed for a given netlist.
Args:
netlist: The netlist to create a circuit for.
models: A dictionary of models to use in the circuit.
"""
if models is None:
models = {}
assert isinstance(models, dict)
Expand All @@ -434,7 +442,7 @@ def get_required_circuit_models(
else:
component = instance["component"]
if (component not in missing_models) and (component not in models):
missing_models[component] = models.get(component, None)
missing_models[component] = models.get(component)
missing_model_names.append(component)
g.add_node(component)
g.add_edge(model_name, component)
Expand Down
65 changes: 48 additions & 17 deletions sax/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,39 +21,49 @@ def straight(
length: float = 10.0,
loss: float = 0.0,
) -> SDict:
"""a simple straight waveguide model"""
"""A simple straight waveguide model.
Args:
wl: wavelength in microns.
wl0: reference wavelength in microns.
neff: effective index.
ng: group index.
length: length of the waveguide in microns.
loss: loss in dB/cm.
"""
dwl = wl - wl0
dneff_dwl = (ng - neff) / wl0
_neff = neff - dwl * dneff_dwl
phase = 2 * jnp.pi * _neff * length / wl
amplitude = jnp.asarray(10 ** (-loss * length / 20), dtype=complex)
transmission = amplitude * jnp.exp(1j * phase)
sdict = reciprocal(
return reciprocal(
{
("in0", "out0"): transmission,
}
)
return sdict


def coupler(*, coupling: float = 0.5) -> SDict:
"""a simple coupler model"""
kappa = coupling**0.5
tau = (1 - coupling) ** 0.5
sdict = reciprocal(
return reciprocal(
{
("in0", "out0"): tau,
("in0", "out1"): 1j * kappa,
("in1", "out0"): 1j * kappa,
("in1", "out1"): tau,
}
)
return sdict


def _validate_ports(
ports, num_inputs, num_outputs, diagonal
) -> Tuple[Tuple[str, ...], Tuple[str, ...], int, int]:
"""Validate the ports and return the input and output ports."""

if ports is None:
if num_inputs is None or num_outputs is None:
raise ValueError(
Expand Down Expand Up @@ -103,6 +113,17 @@ def unitary(
reciprocal=True,
diagonal=False,
) -> Model:
"""A unitary model.
Args:
num_inputs: number of input ports.
num_outputs: number of output ports.
ports: tuple of input and output ports.
jit: whether to jit the model.
reciprocal: whether the model is reciprocal.
diagonal: whether the model is diagonal.
"""
input_ports, output_ports, num_inputs, num_outputs = _validate_ports(
ports, num_inputs, num_outputs, diagonal
)
Expand Down Expand Up @@ -156,9 +177,7 @@ def func(wl: float = 1.5) -> SCoo:

func.__name__ = f"unitary_{num_inputs}_{num_outputs}"
func.__qualname__ = f"unitary_{num_inputs}_{num_outputs}"
if jit:
return jax.jit(func)
return func
return jax.jit(func) if jit else func


@cache
Expand All @@ -171,6 +190,16 @@ def copier(
reciprocal=True,
diagonal=False,
) -> Model:
"""A copier model.
Args:
num_inputs: number of input ports.
num_outputs: number of output ports.
ports: tuple of input and output ports.
jit: whether to jit the model.
reciprocal: whether the model is reciprocal.
diagonal: whether the model is diagonal.
"""
input_ports, output_ports, num_inputs, num_outputs = _validate_ports(
ports, num_inputs, num_outputs, diagonal
)
Expand Down Expand Up @@ -212,9 +241,7 @@ def func(wl: float = 1.5) -> SCoo:

func.__name__ = f"unitary_{num_inputs}_{num_outputs}"
func.__qualname__ = f"unitary_{num_inputs}_{num_outputs}"
if jit:
return jax.jit(func)
return func
return jax.jit(func) if jit else func


@cache
Expand All @@ -225,14 +252,20 @@ def passthru(
jit=True,
reciprocal=True,
) -> Model:
"""A passthru model.
Args:
num_links: number of links.
ports: tuple of input and output ports.
jit: whether to jit the model.
reciprocal: whether the model is reciprocal.
"""
passthru = unitary(
num_links, num_links, ports, jit=jit, reciprocal=reciprocal, diagonal=True
)
passthru.__name__ = f"passthru_{num_links}_{num_links}"
passthru.__qualname__ = f"passthru_{num_links}_{num_links}"
if jit:
return jax.jit(passthru)
return passthru
return jax.jit(passthru) if jit else passthru


models = {
Expand All @@ -245,6 +278,4 @@ def passthru(


def get_models(copy: bool = True):
if copy:
return {**models}
return models
return {**models} if copy else models

0 comments on commit 6eaa2b6

Please sign in to comment.