@@ -275,57 +275,6 @@ def attribute_future(self) -> None:
275275 """
276276 raise NotImplementedError ("attribute_future is not implemented for Occlusion" )
277277
278- def _construct_ablated_input (
279- self ,
280- expanded_input : Tensor ,
281- input_mask : Union [None , Tensor , Tuple [Tensor , ...]],
282- baseline : Union [None , float , Tensor ],
283- start_feature : int ,
284- end_feature : int ,
285- ** kwargs : Any ,
286- ) -> Tuple [Tensor , Tensor ]:
287- r"""
288- Ablates given expanded_input tensor with given feature mask, feature range,
289- and baselines, and any additional arguments.
290- expanded_input shape is (num_features, num_examples, ...)
291- with remaining dimensions corresponding to remaining original tensor
292- dimensions and num_features = end_feature - start_feature.
293-
294- input_mask is None for occlusion, and the mask is constructed
295- using sliding_window_tensors, strides, and shift counts, which are provided in
296- kwargs. baseline is expected to
297- be broadcastable to match expanded_input.
298-
299- This method returns the ablated input tensor, which has the same
300- dimensionality as expanded_input as well as the corresponding mask with
301- either the same dimensionality as expanded_input or second dimension
302- being 1. This mask contains 1s in locations which have been ablated (and
303- thus counted towards ablations for that feature) and 0s otherwise.
304- """
305- input_mask = torch .stack (
306- [
307- self ._occlusion_mask (
308- expanded_input ,
309- j ,
310- kwargs ["sliding_window_tensors" ],
311- kwargs ["strides" ],
312- kwargs ["shift_counts" ],
313- is_expanded_input = True ,
314- )
315- for j in range (start_feature , end_feature )
316- ],
317- dim = 0 ,
318- ).long ()
319- assert baseline is not None , "baseline should not be None"
320- ablated_tensor = (
321- expanded_input
322- * (
323- torch .ones (1 , dtype = torch .long , device = expanded_input .device )
324- - input_mask
325- ).to (expanded_input .dtype )
326- ) + (baseline * input_mask .to (expanded_input .dtype ))
327- return ablated_tensor , input_mask
328-
329278 def _occlusion_mask (
330279 self ,
331280 input : Tensor ,
@@ -380,21 +329,6 @@ def _occlusion_mask(
380329 )
381330 return padded_tensor .reshape ((1 ,) + tuple (padded_tensor .shape ))
382331
383- def _get_feature_range_and_mask (
384- self , input : Tensor , input_mask : Optional [Tensor ], ** kwargs : Any
385- ) -> Tuple [int , int , Union [None , Tensor , Tuple [Tensor , ...]]]:
386- feature_max = int (np .prod (kwargs ["shift_counts" ]))
387- return 0 , feature_max , None
388-
389- def _get_feature_counts (
390- self ,
391- inputs : TensorOrTupleOfTensorsGeneric ,
392- feature_mask : Tuple [Tensor , ...],
393- ** kwargs : Any ,
394- ) -> Tuple [int , ...]:
395- """return the numbers of possible input features"""
396- return tuple (np .prod (counts ).astype (int ) for counts in kwargs ["shift_counts" ])
397-
398332 def _get_feature_idx_to_tensor_idx (
399333 self , formatted_feature_mask : Tuple [Tensor , ...], ** kwargs : Any
400334 ) -> Dict [int , List [int ]]:
0 commit comments