Skip to content

zeeshannisar/CX_GAN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

99 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Counterfactual Explanation and Instance-Generation using Cycle-Consistent Generative Adversarial Networks

This Repository contains the code of our paper titled Counterfactual Explanation and Instance-Generation using Cycle-Consistent Generative Adversarial Networks available at https://arxiv.org/abs/2301.08939. In this paper, we present two separate methods to address the counterfactual explanations (CX). A counterfactual explanation (CX) explicates a casual reasoning process of the form: “if X had not happened, then Y would not have happened”. However, existing CX approaches [CAM, Grad-CAM] are deficient at supplementing counterfactual explanations with plausible counterfactual instances (CIs). To address the issue, we develop a novel CX/CI generation method in which we view CI generation as unpaired imageto-image translation and CX as image-to-image conversion mapping. The method is built on generative adversarial networks (GANs) with a cyclically-consistent loss function. Initially, we develop a Cascaded Model to learn CX and CI generation individually. Then, we develop an Integrated End-to-End Model for joint learning of both CX and CI. We evaluate our proposed method on three different datasets: Synthetic, Tuberculosis and BraTS. All experiments confirm the efficacy of the proposed method in generating accurate CX and CI.

Proposed Models and Results

Cascaded Model:

In our Cascaded Model we aim to acheive two cascaded objectives:

  • Learning to generate CIs
  • Learning to produce a CX w.r.t. the generated CI.

We view CI generation as unpaired image-to-image translation and CX as image-to-image conversion mapping. We represent the input domain as X, consisting of N images and the counterfactual domain as Y comprised of M images. For CI generation, we aim to learn a mapping function such that the distribution of generated images G(X) closely matches with input images X, and becomes indistinguishable from the distribution of images in Y. To impose this constraint, we pose CI generation as an unpaired image-to-image translation problem and adopt CycleGAN:Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks to learn the model. The trained model is then fed with input image xi in order to generate CI as yi. As a result, we obtain input-counterfactual image pairs (xi ; yi) for subsequent CX. Following Visual Feature Attribution using Wasserstein GANs, we define CX as a map M(x) that, when added into input image xi produce counterfactual image yi via:

yi = xi + M(xi)

Cascaded Model Architecture:

Cascaded Model Implementation and Results:

Synthetic Data

The entire experiments and evaluations for Synthetic data had carried out for a synthetically generated dataset. The script to generate synthetic data can be found at Synthetic Data Generate Script originally inspired from Visual Feature Attribution using Wasserstein GANs. We have used CycleGAN:Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks as our Base-Network/Step-1 to generate a normal distribution against an anomalous distribution. Step-1 is used for generating pairs i.e., for each anomalous image we get a normal image. The results of the Step-1 seems like below.

Finally we have introduced an approach as Step-2 to detect the infectious region using the concept of what mask M(x) added to x changes it to y as stated in Visual Feature Attribution using Wasserstein GANs. In addition to masks we have also generate heatmaps for a better visualization.

BRATS Data

The entire experiments and evaluations for BRATS data had carried out on BRATS 2017 dataset. We have used CycleGAN:Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks as our Base-Network/Step-1 to generate a normal distribution against an anomalous distribution. Step-1 is used for generating pairs i.e., for each anomalous image we get a normal image. The results of the Step-1 seems like below.

Finally we have introduced an approach as Step-2 to detect the infectious region using the concept of what mask M(x) added to x changes it to y as stated in Visual Feature Attribution using Wasserstein GANs. In addition to masks we have also generate heatmaps for a better visualization.

Integrated End-to-End Model:

A disadvantage of the cascaded model is that separate networks are trained for CX and CI generation, and the performance of CX network relies on efficacy of the CI generation network. This section presents a method for joint learning of both CX and CI through an integrated model. We built on CycleGAN:Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks by empowering it to describe transformations while generating pairs across domains. In contrast to the standard CycleGAN, where G and F directly map samples from one domain to another, we want G or F to learn changes that can be made in images of one domain in order to produce images of another domain.

Integrated Model Architecture:

Integrated Model Implementation and Results:

This code can be directly comiled on COLAB. You just need to set dataset and modelhistory path.

Training History for Integrated CX-GAN on BRATS 2017 dataset

Training History for Integrated CX-GAN on synthetic Dataset

Paper Results:

We perform experiments on a synthetic dataset and two publically available medical imaging datasets including BraTS and tuberculosis datasets (i.e. Shenzhen, Montgomery County and Korean Institute of Tuberculosis). We evaluate our proposed method against comparable visual explanation methods including CAM, Grad-CAM and VA-GAN, where CAM and Grad-CAM use classification networks, while VA-GAN and the proposed CXGAN employ image translation networks.

Synthetic Data:

In addition to real medical imaging datasets, we evaluate both the proposed and related methods on a synthetically generated dataset of 10000 128x128 images classified into two classes. One half of the dataset represents a healthy control group (label 0) and another half represents a patient group (label 1). The dataset is generated by close adherence to the data generation process set out in VA-GAN.

Tuberculosis chest X-ray Data:

This dataset contains de-identified Chest X-Rays (CXRs) from three different public resources: (1) the National Institute of Health (NIH) Tuberculosis Chest X-ray database, (2) the Belarus Tuberculosis database, and (3) Korean Institute of Tuberculosis (KIT) under Korean National Tuberculosis Association, South Korea. The NIH is further categorized into two separate datasets: (a) Montgomery County (MC) and (b) Shenzhen. The Montgomery and Shenzhen dataset contains 138 and 662 patients respectively, with and without TB. The MC Dataset consists of 138 CXRs including 80 normal (i.e., without TB) and 58 anomalous (i.e., with TB) CXRs. The Shenzhen dataset comprises of 662 CXRs where 326 are normal, and 336 are anomalous CXRs. The Belarus dataset has a total of 304 CXRs of patients with anomalous CXRs. The KIT dataset contains 10, 848 DICOM images with 7,020 normal and 3,828 anomalous CXRs. Following the experimental setup of [4], the experimental evaluation is performed on Shenzhen and MC Dataset by acquiring pixel-level labels from the authors of [4] to evaluate the performance of our proposed approach.

The input data is preprocessed with the following steps: (1) border from the edges of each CXR is cropped to exempt noisy ratio, (2) from 4K×4K pixels, each CXR is resized to 527×527 pixels and cropped 15 pixels away randomly to retain lesions shape in abnormal regions. Any additional augmentations (except for horizontal mirroring and flipping) allowable for lesion deformation is not adopted. In the final step, each data sample is normalized with z-score normalization. We split the overall dataset to 80:20 for training and validation/test set.

BraTS Data:

The dataset contains brain MRIs classified into normal and tumorous classes. We preprocess the data to filter-out MRI slices that contain the full brain. The dataset contains 3174 images where 2711 are tumorous and 463 non-tumorous. We split each set into 80-20 train/test sets, resulting in 2538 training images and 636 testing images. The filtered slices are resized to 256 * 256 and the data normalized to the 0-to-1 range. We further increase the data size by performing run-time augmentation on training sets through random jittering and mirroring. For augmenting, the images are scaled to 286 * 286 and then randomly cropped to 256 * 256.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published