@@ -118,18 +118,45 @@ def __init__(
118118 self .output_transform = output_transform
119119
120120 def forward (self , x : Tensor , ** kwargs : Any ) -> Tensor :
121- blocks = []
122- shape = self .raw_output_shape
123- ndim = len (shape )
121+ # Collect the (possibly broadcasted) feature tensors in `blocks`.
122+ # ``self.raw_output_shape`` encodes the *final* shape of the concatenated
123+ # feature map. For each individual ``feature_map`` we therefore need to
124+ # (1) obtain its dense representation, (2) broadcast it so that its
125+ # trailing dimensions match ``self.raw_output_shape`` and (3) rescale it
126+ # if we replicate the same tensor multiple times (to keep ‖ϕ‖ roughly
127+ # invariant).
128+
129+ blocks : list [Tensor ] = []
130+
131+ shape = self .raw_output_shape # target output shape
132+ ndim = len (shape ) # #feature dimensions incl. batch
133+
124134 for feature_map in self :
135+ # 1. Evaluate (dense) features for the current sub-map.
125136 block = feature_map (x , ** kwargs ).to_dense ()
137+
126138 block_ndim = len (feature_map .output_shape )
139+
140+ # 2. If this map has fewer *feature* dimensions than the direct sum
141+ # (e.g. vector-valued sub-map in a matrix-valued direct-sum) we have
142+ # to *tile* it along the missing leading feature dimensions so that
143+ # shapes line up for concatenation.
127144 if block_ndim < ndim :
145+ # ``tile_shape`` tells us how many copies we need along every
146+ # missing feature dimension (could be >1 for e.g. Kronecker sums).
128147 tile_shape = shape [- ndim :- block_ndim ]
148+
149+ # Rescale by 1/√k when we replicate the same block *k* times to
150+ # avoid artificially inflating its norm (motivated by the fact
151+ # that direct sums of orthogonal features preserve inner-products
152+ # only up to such a scaling).
129153 num_copies = prod (tile_shape )
130154 if num_copies > 1 :
131- block = block * (num_copies ** - 0.5 )
155+ block = block * (num_copies ** - 0.5 )
132156
157+ # ``multi_index`` inserts ``None`` (i.e. `None` in slice syntax)
158+ # so that broadcasting expands the tensor along the new axes
159+ # without additional memory allocations.
133160 multi_index = (
134161 ...,
135162 * repeat (None , ndim - block_ndim ),
@@ -138,12 +165,17 @@ def forward(self, x: Tensor, **kwargs: Any) -> Tensor:
138165 block = block [multi_index ].expand (
139166 * block .shape [:- block_ndim ], * tile_shape , * block .shape [- block_ndim :]
140167 )
168+
169+ # 3. Append the (now correctly shaped) block to be concatenated later.
141170 blocks .append (block )
142171
172+ # Concatenate along the *last* axis (feature dimension).
143173 return torch .concat (blocks , dim = - 1 )
144174
145175 @property
146176 def raw_output_shape (self ) -> Size :
177+ # If the container is empty (e.g. DirectSumFeatureMap([])), treat the
178+ # output as 0-D until feature maps are added.
147179 if not self :
148180 return Size ([])
149181
@@ -204,13 +236,23 @@ def forward(self, x: Tensor, **kwargs: Any) -> Tensor:
204236 block = feature_map (x , ** kwargs )
205237 block_ndim = len (feature_map .output_shape )
206238 if block_ndim == ndim :
239+ # Case 1: this sub-map already has the *max* feature-rank we’re
240+ # going to emit. We simply make sure it is stored sparsely:
241+ # – Convert `LinearOperator` → dense so that `.to_sparse()`
242+ # is available.
243+ # – If it is still dense, call `.to_sparse()`; otherwise keep
244+ # the sparse representation it already has.
207245 block = block .to_dense () if isinstance (block , LinearOperator ) else block
208246 block = block if block .is_sparse else block .to_sparse ()
209247 else :
248+ # Case 2: lower-rank feature-map. Bring it up to `ndim` by
249+ # slicing with `None` (adds singleton axes) so broadcasting will
250+ # later expand it. We stay dense here because we’ll stuff the
251+ # result into a block-diag sparse matrix at the very end.
210252 multi_index = (
211253 ...,
212- * repeat (None , ndim - block_ndim ),
213- * repeat (slice (None ), block_ndim ),
254+ * repeat (None , ndim - block_ndim ), # adds missing dims
255+ * repeat (slice (None ), block_ndim ), # keep existing dims
214256 )
215257 block = block .to_dense ()[multi_index ]
216258 blocks .append (block )
0 commit comments