Skip to content

TIO-IKIM/Transfer-learning-across-domain-boundaries

Repository files navigation

Transfer-Learning-across-domain-boundaries

Why is my medical AI looking at pictures of birds to learn about cancer?

Mostly because it works.

The aim of the attached paper (https://arxiv.org/abs/2306.17555) is to evaluate whether transfer learning suffers from crossing soft (e.g. CT scans to X-Rays) or hard (e.g. natural images to medical images) domain boundaries when going from pretraining to finetuning. A naive expectation would be that the more and the harder the boundaries you cross, the less you gain from having pretrained. The reality is quite nuanced, and this code is meant to help systematically charting where transfer learning remains useful, and whether it really makes sense to first train a neural network to tell apart birds and planes before you make it tell apart tissue and cancer, and, if so, under what conditions.

In theory, this package was used to pretrain ResNet-50s using SimCLR, and finetune either these ResNet-50s or their corresponding U-Nets. In practice, these constraints are entirely artificial, and any other network could be used instead. However, anything that is not a ResNet will require a little bit of fiddling because of the way the final layers in the networks are replaced. Additionally, the U-Net constructor is only equipped to handle ResNets and a couple other architectures at this time.

At this time, it is not precisely clear where the pretrained/finetuned models should be stored. This, and a few other things on the ToDo-list, will be taken care of soon!

How do I use this?

  1. Firstly, you clone this repository to your machine. Any location is fine.
  2. Secondly, you download any dataset you want to work with. These should normally live in the datasets folder. For the public datasets we used, I have appended instructions on where to download them and how to hammer them into the dataset structure that I ended up using. Since the splits I made were (stratified) random splits, my code won't much help you there, but you can correctly order the data yourself, wherever you get it.
  3. Runs are performed as follows: For pretraining, run python pkgs/SimCLR/main.py -c pkgs/SimCLR/config/your_pretraining_config_file.yaml. For finetuning, run python pkgs/FTE/finetune.py -c pkgs/FTE/config/your_finetuning_config_file.yaml. Config file options should be mostly self-explanatory, although I did append some comments to most stuff.

Note that while most of my config files have a w_m option specified, it is always null and no longer connects to anything under the hood.

The package also comes with a QueueManager.py file, which can be run and will execute any jobs in the Jobs.csv file. Jobs are .csv-file lines and consist of the command to run and "open" or "completed" as their status. Do not edit the Jobs.csv file while the QueueManager is running - this will not cause any issues, but it will delete your changes, because the QueueManager will write the current state of the jobs back into the file on exit. Also, be careful what you put into the Jobs.csv file - the QueueManager will run any valid bash command sequence. 4) The results for any runs (any relevant train, val and test metric, as well as network weights) are saved in logs_and_checkpoints. Both our pretrained and finetuned model weights can be downloaded from https://cloud.uk-essen.de/d/86d2be31e9b84d21bd5c/.

If you just want to use the pretrained models for something, do from utility.models import get_model and you're good to go.

If you want to simply reproduce the results from the paper, edit Jobs.csv and set every job you want to run from "completed" to "open". Then, start a screen or tmux session, run python QueueManager.py, detach from the screen or tmux session, and take a month or two of vacation while your GPUs crunch numbers. You might want to check in occasionally, though, in case it randomly hangs. Note that exact results will differ slightly from those reported because GPU calculations are not deterministic (for more on that, and why determinism is undesirable from a speed-centered POV, look up PyTorch Reproducibility). I don't know where the computer makes the non-deterministic decisions in practice, but I noticed that while runs were originally reproducible down to the last digit, they would later start deviating between runs already during Epoch 1, Step 1. I think this was issue was introduced with the PyTorch 2.0 update. It is apparently not unheard of, at any rate. However, the results between two runs are still so close that all paper results remain valid and reproducible. Note with regards to the reproducibility of the paper results that the finetuning task for the CT images from our hospital will not be available due to data privacy concerns.

If you want to create a different run on a pre-existing task, simply copy its .yaml config file and edit any content to your liking, aswell as editing the name. Be careful to also change the destination of the run's log files, as you will otherwise overwrite a previous run! The nomenclature is PT_{pretrained using what method}_FT_{finetuned on this task}_{additional details}.yaml. If you set the variable debug: True in the config file, finetuning will be performed using anomaly detection, 0 workers (meaning it is run in the main process), and occasional debug printouts (if you make any changes, it makes sense to put them behind a ìf args.debug is True: condition). Additionally, the dataset will raise any errors immediately instead of trying to ignore and replace missing or broken data. If you set the variable short_training: True in the config file, only one batch of training data and validation data each will be evaluated, so you can debug more quickly. If test_nan_inf: True is set in the config file, the dataset will filter out any tensors containing nan or inf, guaranteeing that such values are introduced by the model, if you still see them. Per default, this option is enabled, even if it slows down training a little bit. If your dataset is big enough, something is going to break, and its not always worth finding out what, if two out of twenty million images are lost. It is further recommended to set opmode: "disk" for debugging, as it eliminates the overhead of pushing data to the shared memory (which really isn't needed in a 0 workers environment anyway).

If you want to create your own tasks to finetune on, you can do so in pkgs/FTE/finetuning_tasks.py. There, every accepted task name is assigned a set of augmentations, a model (a default classification or segmentation model plus a loss function), its dataset (constructed from utility.utils.Custom_Dataset), an optimizer, and a scheduler. Any new task will have to be integrated into these functions. Just like for the other tasks, the dataset class will need a _get_target(self, idx: int, File: str = None) function assigned, which typically lives in the utility/utils.py file. For classification tasks, _get_target should spit out a LongTensor containing a single integer. The easiest way to go about creating your own tasks is to copy a similar task and just go from there. If something breaks down, there are always GitHub Issues, where I will happily answer your questions.

Can I have a different pretrained network that isn't a ResNet-50 or UNet based on ResNet-50?

Maybe. If the time and resources are available, I may be able to do this, but I will have to get a separate permission to create and publish something not strictly part of the paper. Open an issue, tell me what you'd like, and if it's doable and there's not too many requests, I'll see if I can get permission. No promises, though. The current plan is to pretrain at least a regular ViT using RadNet-1.28M and RadNet-12M.

Additional notes

This project was originally compiled on Ubuntu 20.04 LTS and with the requirements listed in the requirements.txt file. Any runs in the paper were performed on an NVIDIA DGX with 8 A100 (80GB VRAM) graphics cards. If the package does not perform tasks as you expect it to, check whether you have the environment, OS, and/or a reasonably powerful GPU setup, and if not, reduce the number of workers and/or batch size.

The project was compiled with a DataParallel instead of DistributedDataParallel PyTorch model, because only the one GPU machine (however, a very powerful one) was available for various reasons. However, moving to more nodes whould be straightforward.

In order to improve execution speed and reduce GPU VRAM use, the following steps have been taken, which you should be taking into account when modifying or extending the functionality of this package:

  1. Caching. The dataset gets fully cached using the live_cache option in the configs. This hands the datasets a pointer to a ProxyObject, behind which a multiprocessing Manager sits, who holds the data. Any datapoint that is loaded during the first epoch is also cached in the Manager. During later epochs, every datapoint is loaded from the Manager. Any deterministic transforms are cached as well. Internally, these are referred to as cpu_transforms. Non-deterministic transforms are internally referred to as gpu_transforms. You can disable this behaviour by specifying "pre_cache", which pre-caches the dataset using 0 workers, or "disk", which simply reads data from disk and performs no caching at all.
  2. Mixed precision. Mixed precision is used in PyTorch's automated fashion. This is easy to disable, but is not done from config, and there really is not reason why you would want this, I think.
  3. GPU transforms. Any data transformation in a dataset's gpu_transforms is applied to the batch by the GPU, if in the config device the flag tf_device is set to "gpu", which it typically is. During segmentation tasks, tf_device must be set to "cpu", performing even the non-deterministic transformations on the CPU. This is done because albumentations, which is our transformation package of choice for co-transforming image and targets, works on opencv-style arrays, which cannot be pushed to the GPU.
  4. Local loss calculation on the GPU. The models in this package contain their loss criterion as a module, and return both predictions and losses per mini-batch (meaning one value per individual GPU). This cuts down on VRAM use, because the master GPU (typically cuda:0) no longer needs to temporarily store the entire batch, and allows us to work with bigger batch sizes, if we so choose. The downside of this is that e.g. for SimCLR, our batch size is effectively smaller (because the GPU only sees the negative examples available locally, but not those from other GPUs. I found this to be no problem in practice, as any SimCLR run would easily converge on a minimum loss within the allotted 100 epochs, with or without local loss calculation).

Due to the complexity of the project and the one-size-fits-all approach where I tried to find reasonable parameters, which could lead to a successful fine-tuning a) on every task, b) within atleast vaguely competitive margins given the limitation of the architecture(s), c) while also staying blind to optimization for specific pretrainings (meaning: guesswork on runs that start from scratch), the stability of the training has suffered at least to some degree - occasionally, the optimizer will fling the model parameters out of its current local minimum for no apparent reason, and very very rarely the results vary strongly from the range they would normally converge to because one such fling happened right at the end of optimization. Since the solution to avoid this behavior is to pick the best performing model from across all validation steps of one run (or multiple separate runs, even), something which I state in the paper we explicitly wanted to avoid, or to find and eliminate every source of instability during training (good luck on that one), I have decided to accept this behavior and simply mention it here. You are now warned.

Where did you get the datasets?

This is where you can get started for each respective dataset:

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages