diff --git a/inseq/attr/step_functions.py b/inseq/attr/step_functions.py index e296f15d..e1c5ab81 100644 --- a/inseq/attr/step_functions.py +++ b/inseq/attr/step_functions.py @@ -252,6 +252,7 @@ def pcxmi_fn( attribution_model=attribution_model, contrast_sources=contrast_sources, contrast_target_prefixes=contrast_target_prefixes, + target_ids=target_ids, **kwargs, ) return -torch.log2(torch.div(original_probs, contrast_probs)) @@ -260,6 +261,7 @@ def pcxmi_fn( def kl_divergence_fn( attribution_model: "AttributionModel", forward_output: ModelOutput, + target_ids: TargetIdsTensor, contrast_sources: Optional[FeatureAttributionInput] = None, contrast_target_prefixes: Optional[FeatureAttributionInput] = None, top_k: int = 0,