@@ -220,6 +220,42 @@ def test_multi_input_ablation_with_mask(self) -> None:
220220 perturbations_per_eval = (1 , 2 , 3 ),
221221 )
222222
223+ def test_multi_input_ablation_with_int_input_tensor_and_float_baseline (
224+ self ,
225+ ) -> None :
226+ def sum_forward (* inps : torch .Tensor ) -> torch .Tensor :
227+ flattened = [torch .flatten (inp , start_dim = 1 ) for inp in inps ]
228+ return torch .cat (flattened , dim = 1 ).sum (1 )
229+
230+ ablation_algo = FeatureAblation (sum_forward )
231+ inp1 = torch .tensor ([[0 , 1 ], [3 , 4 ]])
232+ inp2 = torch .tensor (
233+ [
234+ [[0.1 , 0.2 ], [0.3 , 0.2 ]],
235+ [[0.4 , 0.5 ], [0.3 , 0.2 ]],
236+ ]
237+ )
238+ inp3 = torch .tensor ([[0 ], [1 ]])
239+
240+ expected = (
241+ torch .tensor ([[- 0.2 , 0.8 ], [2.8 , 3.8 ]]),
242+ torch .tensor (
243+ [
244+ [[- 3.0 , - 2.9 ], [- 2.8 , - 2.9 ]],
245+ [[- 2.7 , - 2.6 ], [- 2.8 , - 2.9 ]],
246+ ]
247+ ),
248+ torch .tensor ([[- 0.4 ], [0.6 ]]),
249+ )
250+ self ._ablation_test_assert (
251+ ablation_algo ,
252+ (inp1 , inp2 , inp3 ),
253+ expected ,
254+ target = None ,
255+ baselines = (0.2 , 3.1 , 0.4 ),
256+ test_enable_cross_tensor_attribution = [False , True ],
257+ )
258+
223259 def test_multi_input_ablation_with_mask_weighted (self ) -> None :
224260 ablation_algo = FeatureAblation (BasicModel_MultiLayer_MultiInput ())
225261 ablation_algo .use_weights = True
0 commit comments