- 
                Notifications
    
You must be signed in to change notification settings  - Fork 547
 
Adding the Latent Shift attribution method #1024
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Open
      
      
            ieee8023
  wants to merge
  72
  commits into
  meta-pytorch:master
  
    
      
        
          
  
    
      Choose a base branch
      
     
    
      
        
      
      
        
          
          
        
        
          
            
              
              
              
  
           
        
        
          
            
              
              
           
        
       
     
  
        
          
            
          
            
          
        
       
    
      
from
ieee8023:master
  
      
      
   
  
    
  
  
  
 
  
      
    base: master
Could not load branches
            
              
  
    Branch not found: {{ refName }}
  
            
                
      Loading
              
            Could not load tags
            
            
              Nothing to show
            
              
  
            
                
      Loading
              
            Are you sure you want to change the base?
            Some commits from the old base branch may be removed from the timeline,
            and old review comments may become outdated.
          
          
  
     Open
                    Changes from all commits
      Commits
    
    
            Show all changes
          
          
            72 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      32b38e0
              
                Add Latent Shift
              
              
                ieee8023 0e57fe4
              
                video
              
              
                ieee8023 7c5025c
              
                align text
              
              
                ieee8023 3a89340
              
                cleanup
              
              
                ieee8023 d0c833a
              
                clean up docs
              
              
                ieee8023 4dc25bf
              
                add support for colab version
              
              
                ieee8023 12e78d3
              
                cleanup
              
              
                ieee8023 2cf44ba
              
                add more docs
              
              
                ieee8023 0a74565
              
                Merge branch 'master' into master
              
              
                ieee8023 0ceb34e
              
                cleanup format and add test
              
              
                ieee8023 9043790
              
                more cleanup
              
              
                ieee8023 2963ae5
              
                cleanup and add more docs
              
              
                ieee8023 d4320d2
              
                fix flake8 errors
              
              
                ieee8023 a039159
              
                fixing flake8 for real
              
              
                ieee8023 907d2d7
              
                fix format and add opion to limit printing
              
              
                ieee8023 01bb3b2
              
                fix type error
              
              
                ieee8023 222128e
              
                flake8
              
              
                ieee8023 4c588b9
              
                autopep8
              
              
                ieee8023 c1dd756
              
                make mypy happy
              
              
                ieee8023 77963d8
              
                ufmt format
              
              
                ieee8023 b0a08d7
              
                I really think flake8 will pass now
              
              
                ieee8023 8432ff3
              
                match reference to other references
              
              
                ieee8023 5597155
              
                small change to kick off tests again
              
              
                ieee8023 5407086
              
                Merge branch 'master' into master
              
              
                ieee8023 9b5272e
              
                Merge branch 'master' into master
              
              
                ieee8023 7a19759
              
                Merge branch 'master' into master
              
              
                ieee8023 7d64a75
              
                Merge branch 'master' into master
              
              
                ieee8023 e245cde
              
                Merge branch 'master' into master
              
              
                ieee8023 7245048
              
                Merge branch 'master' into master
              
              
                ieee8023 03fe557
              
                Merge branch 'master' into master
              
              
                ieee8023 298c9e8
              
                add options for extra loops and the cmap value
              
              
                ieee8023 04f16ca
              
                Merge branch 'master' into master
              
              
                ieee8023 3d6f842
              
                fix flake8
              
              
                ieee8023 17bc3af
              
                Add Latent Shift
              
              
                ieee8023 7615dcc
              
                video
              
              
                ieee8023 558b429
              
                align text
              
              
                ieee8023 430888e
              
                cleanup
              
              
                ieee8023 bf434a9
              
                clean up docs
              
              
                ieee8023 96e8b42
              
                add support for colab version
              
              
                ieee8023 34f48f6
              
                cleanup
              
              
                ieee8023 554db30
              
                add more docs
              
              
                ieee8023 cefc673
              
                cleanup format and add test
              
              
                ieee8023 42c2c36
              
                more cleanup
              
              
                ieee8023 77c574b
              
                cleanup and add more docs
              
              
                ieee8023 8aa3fec
              
                fix flake8 errors
              
              
                ieee8023 90ffd8e
              
                fixing flake8 for real
              
              
                ieee8023 67a576c
              
                fix format and add opion to limit printing
              
              
                ieee8023 2a9cab7
              
                fix type error
              
              
                ieee8023 a0f156a
              
                flake8
              
              
                ieee8023 435bee8
              
                autopep8
              
              
                ieee8023 2f618ff
              
                make mypy happy
              
              
                ieee8023 3f9bbdd
              
                ufmt format
              
              
                ieee8023 cec2237
              
                I really think flake8 will pass now
              
              
                ieee8023 b387097
              
                match reference to other references
              
              
                ieee8023 29951a0
              
                small change to kick off tests again
              
              
                ieee8023 653a67a
              
                add options for extra loops and the cmap value
              
              
                ieee8023 2292efa
              
                fix flake8
              
              
                ieee8023 390fee0
              
                Merge branch 'master' of github.com:ieee8023/captum
              
              
                ieee8023 74af5c8
              
                refactor image writing
              
              
                ieee8023 e9196ed
              
                refactor for batches and just returning heatmaps
              
              
                ieee8023 8a24e9b
              
                pep8
              
              
                ieee8023 f873da3
              
                ufmt
              
              
                ieee8023 92e93a7
              
                format errors
              
              
                ieee8023 8f9d8a2
              
                fix typing
              
              
                ieee8023 7779497
              
                reduce string length
              
              
                ieee8023 cd2c2f5
              
                remove usage of torchvision in tests
              
              
                ieee8023 6855d55
              
                Merge branch 'master' into master
              
              
                ieee8023 10a43c0
              
                add sigmoid param
              
              
                ieee8023 1d8e9dd
              
                Merge branch 'master' of github.com:ieee8023/captum
              
              
                ieee8023 c0dfdfb
              
                Merge branch 'master' into master
              
              
                ieee8023 4eb4c6d
              
                Merge branch 'master' into master
              
              
                ieee8023 d63a803
              
                Merge branch 'master' into master
              
              
                ieee8023 File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,267 @@ | ||
| #!/usr/bin/env python3 | ||
| 
     | 
||
| from typing import Any, Callable, Dict, List, Tuple, Union | ||
| 
     | 
||
| import numpy as np | ||
| import torch | ||
| from captum.attr._utils.attribution import GradientAttribution | ||
| from captum.log import log_usage | ||
| from torch import Tensor | ||
| 
     | 
||
| 
     | 
||
| class LatentShift(GradientAttribution): | ||
| r"""An implementation of the Latent Shift method to generate | ||
| counterfactual explanations. This method uses an autoencoder to restrict | ||
| the possible adversarial examples to remain in the data space by | ||
| adjusting the latent space of the autoencoder using dy/dz instead of | ||
| dy/dx in order to change the classifier's prediction. | ||
| 
     | 
||
| This class implements a search strategy to determine the lambda needed to | ||
| change the prediction of the classifier by a specific amount as well as | ||
| the code to generate a video and construct a heatmap representing the | ||
| image changes for viewing as an image. | ||
| 
     | 
||
| More details regarding the latent shift method can be found in the | ||
| original paper: | ||
| https://arxiv.org/abs/2102.09475 | ||
| And the original code repository: | ||
| https://github.com/mlmed/gifsplanation | ||
| """ | ||
| 
     | 
||
| def __init__(self, forward_func: Callable, autoencoder) -> None: | ||
| r""" | ||
| Args: | ||
| forward_func (callable): The forward function of the model or | ||
| any modification of it | ||
| autoencoder: An object with an encode and decode function which | ||
| maintains a gradient tape. | ||
| """ | ||
| GradientAttribution.__init__(self, forward_func) | ||
| self.ae = autoencoder | ||
| 
     | 
||
| # check if ae has encode and decode | ||
| assert hasattr(self.ae, "encode") | ||
| assert hasattr(self.ae, "decode") | ||
| 
     | 
||
| @log_usage() | ||
| def attribute( | ||
| self, | ||
| inputs: Tensor, | ||
| target: int, | ||
| fix_range: Union[Tuple, None] = None, | ||
| search_pred_diff: float = 0.8, | ||
| search_step_size: float = 10.0, | ||
| search_max_steps: int = 3000, | ||
| search_max_pixel_diff_pct: float = 0.05, | ||
| lambda_sweep_steps: int = 10, | ||
| heatmap_method: str = "int", | ||
| apply_sigmoid: bool = True, | ||
| verbose: bool = True, | ||
| return_dicts: bool = False, | ||
| ) -> Union[Tensor, List[Dict[str, Any]]]: | ||
| r""" | ||
| This method performs a search in order to determine the correct lambda | ||
| values to generate the shift. The search starts by stepping by | ||
| `search_step_size` in the negative direction while trying to determine | ||
| if the output of the classifier has changed by `search_pred_diff` or | ||
| when the change in the predict in stops going down. In order to avoid | ||
| artifacts if the shift is too large or in the wrong direction an extra | ||
| stop conditions is added `search_max_pixel_diff` if the change in the | ||
| image is too large. To avoid the search from taking too long a | ||
| `search_max_steps` will prevent the search from going on endlessly. | ||
| 
     | 
||
| 
     | 
||
| Args: | ||
| 
     | 
||
| inputs (tensor): Input for which the counterfactual is computed. | ||
| target (int): Output indices for which dydz is computed (for | ||
| classification cases, this is usually the target class). | ||
| fix_range (tuple): Overrides searching and directly specifies the | ||
| lambda range to use. e.g. [-100,0]. | ||
| search_pred_diff (float): The desired change in the classifiers | ||
| prediction. For example if the classifer predicts 0.9 | ||
| and pred_diff=0.8 the search will try to generate a | ||
| counterfactual where the prediction is 0.1. | ||
| search_step_size (float): When searching for the right lambda to use | ||
| this will be the initial step size. This is similar to | ||
| a learning rate. Smaller values avoid jumping over the | ||
| ideal lambda but the search may take a long time. | ||
| search_max_steps (int): The max steps to take when doing the search. | ||
| Sometimes steps make a tiny improvement and can go on | ||
| forever. This just bounds the time and gives up the | ||
| search. | ||
| search_max_pixel_diff_pct (float): When searching, stop if the pixel | ||
| difference is larger than this amount. This will | ||
| prevent large artifacts being introduced into the | ||
| image. |img0 - imgx| > |img0|*pct | ||
| lambda_sweep_steps (int): How many frames to generate for the video. | ||
| heatmap_method: Default: 'int'. Possible methods: 'int': Average | ||
| per frame differences. 'mean' : Average difference | ||
| between 0 and other lambda frames. 'mm': Difference | ||
| between first and last frames. 'max': Max difference | ||
| from lambda 0 frame | ||
| apply_sigmoid: Default: True. Apply a sigmoid to the output of the | ||
| model. Set to false to work with regression models or | ||
| if the model already applies a sigmoid. | ||
| verbose: True to print debug text | ||
| return_dicts (bool): Return a list of dicts containing information | ||
| from each image processed. Default False | ||
| 
     | 
||
| Returns: | ||
| attributions or (if return_dict=True) a list of dicts containing the | ||
| follow keys: | ||
| generated_images: A list of images generated at each step along | ||
| the dydz vector from the smallest lambda to the largest. By | ||
| default the smallest lambda represents the counterfactual | ||
| image and the largest lambda is 0 (representing no change). | ||
| lambdas: A list of the lambda values for each generated image. | ||
| preds: A list of the predictions of the model for each generated | ||
| image. | ||
| heatmap: A heatmap indicating the pixels which change in the | ||
| video sequence of images. | ||
| 
     | 
||
| 
     | 
||
| Example:: | ||
| 
     | 
||
| >>> # Load classifier and autoencoder | ||
| >>> model = classifiers.FaceAttribute() | ||
| >>> ae = autoencoders.VQGAN(weights="faceshq") | ||
| >>> | ||
| >>> # Load image | ||
| >>> x = torch.randn(1, 3, 1024, 1024) | ||
| >>> | ||
| >>> # Defining Latent Shift module | ||
| >>> attr = captum.attr.LatentShift(model, ae) | ||
| >>> | ||
| >>> # Computes counterfactual for class 3. | ||
| >>> output = attr.attribute(x, target=3) | ||
| 
     | 
||
| """ | ||
| 
     | 
||
| assert lambda_sweep_steps > 1, "lambda_sweep_steps must be at least 2" | ||
| 
     | 
||
| results = [] | ||
| # cheap batching | ||
| for idx in range(inputs.shape[0]): | ||
| inp = inputs[idx].unsqueeze(0) | ||
| z = self.ae.encode(inp).detach() | ||
| z.requires_grad = True | ||
| x_lambda0 = self.ae.decode(z) | ||
| pred = self.forward_func(x_lambda0)[:, target] | ||
| if apply_sigmoid: | ||
| pred = torch.sigmoid(pred) | ||
| dzdxp = torch.autograd.grad(pred, z)[0] | ||
| 
     | 
||
| # Cache so we can reuse at sweep stage | ||
| cache = {} | ||
| 
     | 
||
| def compute_shift(lambdax): | ||
| """Compute the shift for a specific lambda""" | ||
| if lambdax not in cache: | ||
| x_lambdax = self.ae.decode(z + dzdxp * lambdax).detach() | ||
| pred1 = self.forward_func(x_lambdax)[:, target] | ||
| if apply_sigmoid: | ||
| pred1 = torch.sigmoid(pred1) | ||
| pred1 = pred1.detach().cpu().numpy() | ||
| cache[lambdax] = x_lambdax, pred1 | ||
| return cache[lambdax] | ||
| 
     | 
||
| _, initial_pred = compute_shift(0) | ||
| 
     | 
||
| if fix_range: | ||
| lbound, rbound = fix_range | ||
| else: | ||
| # Left range | ||
| lbound = 0 | ||
| last_pred = initial_pred | ||
| pixel_sum = x_lambda0.abs().sum() # Used for pixel diff | ||
| while True: | ||
| x_lambdax, cur_pred = compute_shift(lbound) | ||
| pixel_diff = torch.abs(x_lambda0 - x_lambdax).sum().detach().cpu() | ||
| if verbose: | ||
| toprint = [ | ||
| f"Shift: {lbound}", | ||
| f"Pred: {float(cur_pred)}", | ||
| f"pixel_diff: {float(pixel_diff)}", | ||
| f"sum*diff_pct: {pixel_sum * search_max_pixel_diff_pct}", | ||
| ] | ||
| print(", ".join(toprint)) | ||
| 
     | 
||
| # If we stop decreasing the prediction | ||
| if last_pred < cur_pred: | ||
| break | ||
| # If the prediction becomes very low | ||
| if cur_pred < 0.05: | ||
| break | ||
| # If we have decreased the prediction by pred_diff | ||
| if initial_pred - search_pred_diff > cur_pred: | ||
| break | ||
| # If we are moving in the latent space too much | ||
| if lbound <= -search_max_steps: | ||
| break | ||
| # If we move too far we will distort the image | ||
| if pixel_diff > (pixel_sum * search_max_pixel_diff_pct): | ||
| break | ||
| 
     | 
||
| last_pred = cur_pred | ||
| lbound = lbound - search_step_size + lbound // 10 | ||
| 
     | 
||
| # Right range search not implemented | ||
| rbound = 0 | ||
| 
     | 
||
| if verbose: | ||
| print("Selected bounds: ", lbound, rbound) | ||
| 
     | 
||
| # Sweep over the range of lambda values to create a sequence | ||
| lambdas = np.linspace(lbound, rbound, lambda_sweep_steps) | ||
| assert lambda_sweep_steps == len( | ||
| lambdas | ||
| ), "Inconsistent number of lambda steps" | ||
| 
     | 
||
| if verbose: | ||
| print("Lambdas to compute: ", lambdas) | ||
| 
     | 
||
| preds = [] | ||
| generated_images = [] | ||
| 
     | 
||
| for lam in lambdas: | ||
| x_lambdax, pred = compute_shift(lam) | ||
| generated_images.append(x_lambdax.cpu().numpy()[0]) | ||
| preds.append(float(pred)) | ||
| 
     | 
||
| params = {} | ||
| params["generated_images"] = np.array(generated_images) | ||
| params["lambdas"] = lambdas | ||
| params["preds"] = preds | ||
| 
     | 
||
| x_lambda0 = x_lambda0.detach().cpu().numpy() | ||
| if heatmap_method == "max": | ||
| # Max difference from lambda 0 frame | ||
| heatmap = np.max(np.abs(x_lambda0 - generated_images), 0) | ||
| 
     | 
||
| elif heatmap_method == "mean": | ||
| # Average difference between 0 and other lambda frames | ||
| heatmap = np.mean(np.abs(x_lambda0 - generated_images), 0) | ||
| 
     | 
||
| elif heatmap_method == "mm": | ||
| # Difference between first and last frames | ||
| heatmap = np.abs(generated_images[0] - generated_images[-1]) | ||
| 
     | 
||
| elif heatmap_method == "int": | ||
| # Average per frame differences | ||
| image_changes = [] | ||
| for i in range(len(generated_images) - 1): | ||
| image_changes.append( | ||
| np.abs(generated_images[i] - generated_images[i + 1]) | ||
| ) | ||
| heatmap = np.mean(image_changes, 0) | ||
| else: | ||
| raise Exception("Unknown heatmap_method for 2d image") | ||
| 
     | 
||
| params["heatmap"] = heatmap | ||
| results.append(params) | ||
| 
     | 
||
| if return_dicts: | ||
| return results | ||
| else: | ||
| return torch.tensor([result["heatmap"] for result in results]) | ||
      
      Oops, something went wrong.
        
    
  
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
Uh oh!
There was an error while loading. Please reload this page.