diff --git a/AdjointPlugin11CircuitMZI.ipynb b/AdjointPlugin11CircuitMZI.ipynb index a0c9cb67..69b4e490 100644 --- a/AdjointPlugin11CircuitMZI.ipynb +++ b/AdjointPlugin11CircuitMZI.ipynb @@ -43,6 +43,7 @@ "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", + "import functools\n", "\n", "import jax\n", "import jax.numpy as jnp\n", @@ -922,14 +923,18 @@ } ], "source": [ - "import functools\n", + "def component1x2(params=params0, beta=1.0):\n", + " return component(params=params, beta=beta, shape=(1,2))\n", + " \n", + "def component2x2(params=params0, beta=1.0):\n", + " return component(params=params, beta=beta, shape=(2,2))\n", "\n", "circuit_fn, _ = sax.circuit(\n", " netlist={\n", " \"instances\": {\n", - " \"splitter\": functools.partial(component, shape=(1,2)),\n", + " \"splitter\": component1x2,\n", " \"phase_shifter\": phase_shifter,\n", - " \"combiner\": functools.partial(component, shape=(2,2)),\n", + " \"combiner\": component2x2,\n", " },\n", " \"connections\": {\n", " \"splitter,out0\": \"phase_shifter,in\",\n", @@ -943,6 +948,7 @@ " },\n", " }\n", ")\n", + "\n", "circuit_fn" ] },