@@ -70,63 +70,87 @@ class SamplerQNN(NeuralNetwork):
7070 from the :class:`~qiskit_machine_learning.circuit.library.QNNCircuit`.
7171
7272 The output can be set up in different formats, and an optional post-processing step
73- can be used to interpret the sampler's output in a particular context (e.g. mapping the
74- resulting bitstring to match the number of classes).
73+ can be used to interpret or map the sampler's raw output in a particular context (e.g. mapping
74+ the resulting bitstring to match the number of classes) via an ``interpret`` function .
7575
76- In this example the network maps the output of the quantum circuit to two classes via a custom
77- ` interpret` function:
76+ The ``output_shape`` parameter defines the shape of the output array after applying the
77+ interpret function, and can be set following the guidelines below.
7878
79- .. code-block::
79+ * **Default behavior:** if no interpret function is provided, the default output_shape is
80+ ``2**num_qubits``, which corresponds to the number of possible bit-strings for the given
81+ number of qubits.
82+ * **Custom interpret function:** when using a custom interpret function, you must specify
83+ ``output_shape`` to match the expected output of the interpret function. For instance, if
84+ your interpret function maps bit-strings to two classes, you should set ``output_shape=2``.
85+ * **Number of classical registers:** if you want to reshape the output by the number of
86+ classical registers, set ``output_shape=2**circuit.num_clbits``. This is useful when
87+ the number of classical registers differs from the number of qubits.
88+ * **Tuple shape:** if the interpret function returns a tuple, ``output_shape`` should be a
89+ ``tuple`` that matches the dimensions of the interpreted output.
90+
91+ In this example, the network maps the output of the quantum circuit to two classes via a custom
92+ ``interpret`` function:
93+
94+
95+ .. code-block:: python
8096
8197 from qiskit import QuantumCircuit
8298 from qiskit.circuit.library import ZZFeatureMap, RealAmplitudes
8399 from qiskit_machine_learning.circuit.library import QNNCircuit
84-
85100 from qiskit_machine_learning.neural_networks import SamplerQNN
86101
87102 num_qubits = 2
88103
104+ # Define a custom interpret function that calculates the parity of the bitstring
89105 def parity(x):
90106 return f"{bin(x)}".count("1") % 2
91107
92- # Using the QNNCircuit:
93- # Create a parameterized 2 qubit circuit composed of the default ZZFeatureMap feature map
94- # and RealAmplitudes ansatz.
108+ # Example 1: Using the QNNCircuit class
109+ # QNNCircuit automatically combines a feature map and an ansatz into a single circuit
95110 qnn_qc = QNNCircuit(num_qubits)
96111
97112 qnn = SamplerQNN(
98- circuit=qnn_qc,
113+ circuit=qnn_qc, # Note that this is a QNNCircuit instance
99114 interpret=parity,
100- output_shape=2
115+ output_shape=2 # Reshape by the number of classical registers
101116 )
102117
118+ # Do a forward pass with input data and custom weights
103119 qnn.forward(input_data=[1, 2], weights=[1, 2, 3, 4, 5, 6, 7, 8])
104120
105- # Explicitly specifying the ansatz and feature map:
121+ # Example 2: Explicitly specifying the feature map and ansatz
122+ # Create a feature map and an ansatz separately
106123 feature_map = ZZFeatureMap(feature_dimension=num_qubits)
107124 ansatz = RealAmplitudes(num_qubits=num_qubits)
108125
126+ # Compose the feature map and ansatz manually (otherwise done within QNNCircuit)
109127 qc = QuantumCircuit(num_qubits)
110128 qc.compose(feature_map, inplace=True)
111129 qc.compose(ansatz, inplace=True)
112130
113131 qnn = SamplerQNN(
114- circuit=qc,
132+ circuit=qc, # Note that this is a QuantumCircuit instance
115133 input_params=feature_map.parameters,
116134 weight_params=ansatz.parameters,
117135 interpret=parity,
118- output_shape=2
136+ output_shape=2 # Reshape by the number of classical registers
119137 )
120138
139+ # Perform a forward pass with input data and weights
121140 qnn.forward(input_data=[1, 2], weights=[1, 2, 3, 4, 5, 6, 7, 8])
122141
142+
123143 The following attributes can be set via the constructor but can also be read and
124144 updated once the SamplerQNN object has been constructed.
125145
126146 Attributes:
127147
128- sampler (BaseSampler): The sampler primitive used to compute the neural network's results.
129- gradient (BaseSamplerGradient): A sampler gradient to be used for the backward pass.
148+ sampler (BaseSampler): The sampler primitive used to compute the neural network's
149+ results. If not provided, a default instance of the reference sampler defined by
150+ :class:`~qiskit.primitives.Sampler` will be used.
151+ gradient (BaseSamplerGradient): An optional sampler gradient used for the backward
152+ pass. If not provided, a default instance of
153+ :class:`~qiskit_machine_learning.gradients.ParamShiftSamplerGradient` will be used.
130154 """
131155
132156 def __init__ (
@@ -173,8 +197,8 @@ def __init__(
173197 sparse: Returns whether the output is sparse or not.
174198 interpret: A callable that maps the measured integer to another unsigned integer or tuple
175199 of unsigned integers. These are used as new indices for the (potentially sparse)
176- output array. If no interpret function is passed , then an identity function will be
177- used by this neural network.
200+ output array. If the interpret function is ``None`` , then an identity function will be
201+ used by this neural network: ``lambda x: x`` (default) .
178202 output_shape: The output shape of the custom interpretation. For SamplerV1, it is ignored
179203 if no custom interpret method is provided where the shape is taken to be
180204 ``2^circuit.num_qubits``.
@@ -190,7 +214,7 @@ def __init__(
190214 Raises:
191215 QiskitMachineLearningError: Invalid parameter values.
192216 """
193- # set primitive, provide default
217+ # Set primitive, provide default
194218 if sampler is None :
195219 sampler = Sampler ()
196220
@@ -226,8 +250,10 @@ def __init__(
226250 if sparse :
227251 _optionals .HAS_SPARSE .require_now ("DOK" )
228252
253+ self ._interpret = interpret
229254 self .set_interpret (interpret , output_shape )
230- # set gradient
255+
256+ # Set gradient
231257 if gradient is None :
232258 if isinstance (sampler , BaseSamplerV1 ):
233259 gradient = ParamShiftSamplerGradient (sampler = self .sampler )
@@ -283,7 +309,7 @@ def set_interpret(
283309 interpret : Callable [[int ], int | tuple [int , ...]] | None = None ,
284310 output_shape : int | tuple [int , ...] | None = None ,
285311 ) -> None :
286- """Change ' interpret' and corresponding ' output_shape' .
312+ """Change `` interpret`` and corresponding `` output_shape`` .
287313
288314 Args:
289315 interpret: A callable that maps the measured integer to another unsigned integer or
@@ -308,13 +334,13 @@ def _compute_output_shape(
308334 QiskitMachineLearningError: If an invalid ``sampler``provided.
309335 """
310336
311- # this definition is required by mypy
337+ # This definition is required by mypy
312338 output_shape_ : tuple [int , ...] = (- 1 ,)
313339
314340 if interpret is not None :
315341 if output_shape is None :
316342 raise QiskitMachineLearningError (
317- "No output shape given; it's required when using custom interpret! "
343+ "No output shape given, but it's required when using custom interpret function. "
318344 )
319345 if isinstance (output_shape , Integral ):
320346 output_shape = int (output_shape )
@@ -354,6 +380,7 @@ def _postprocess(self, num_samples: int, result: SamplerResult) -> np.ndarray |
354380 else :
355381 # Fallback to 'c' if 'meas' is not available.
356382 bitstring_counts = result [i ].data .c .get_counts ()
383+
357384 # Normalize the counts to probabilities
358385 total_shots = sum (bitstring_counts .values ())
359386 probabilities = {k : v / total_shots for k , v in bitstring_counts .items ()}
0 commit comments