Self-Supervised Learning for Medical Imaging

Recent progress in large-scale pre-training has enabled great advancements in computer vision tasks such as image classification, segmentation and object detection. However, they suffer from sub-optimal out-of-distribution performance when evaluated on real data. This performance gap is further increased when evaluated on clinical datasets. The most common strategy to tackle this has been fine-tuning models on a small subset of high-quality data samples. This approach, however, doesn't scale in the long run. For example, if a pre-trained model is fine-tuned based on the images observed in a particular clinic, even a change in the imaging equipment can trigger the need to fine-tune another model.

Self-supervised learning (SSL) emerges as a natural solution for such issues. SSL methods use pretext task(s) to train models to generate high-quality representations without labelled data. They have been shown to perform better than traditional supervised learning methods while providing better robustness and out-of-distribution performance.

Why should we have a different approach for Medical Images?

While one can adopt preexisting SSL methods for medical images, this doesn't help realise the full potential of the algorithms since natural images differ immensely from medical images. For example, rotation prediction was used in some earlier methods as a pretext task (Gidaris et al., 2022). Still, it didn't offer much in certain medical domains, such as whole slide pathological images (WSI). Natural images often have a rotation variant. For example, in autonomous driving, the cars drive on top of the road, and the trees point upward, whereas in medical images, there’s no such sense of position or orientation. Various methods have been proposed to utilise the intrinsic characteristics of medical images while still maintaining transferability to different downstream tasks.

As Sowrirajan et al. (MoCo-CXR) point out in their paper, Chest X-rays differ fundamentally from natural images:

  1. Tasks such as classification depend on abnormalities in a few pixels, aka Regions of Interest (ROIs). Examples include small red dots in retinal fundus images or white patches in chest X-rays. These are often identified by variations in local textures from the whole image.
  2. The underlying data itself differs since X-rays are larger, grayscale and have similar spatial structures across samples (chest X-rays often come as anterior-posterior, posterior-anterior or lateral)

The standard set of augmentations used in SSL, such as random cropping or blurring, is therefore not ideal for chest X-rays, as they might remove the meaningful parts of the image. Moreover, random grayscale or colour jittering doesn't provide value since the images are already grayscale. Compared to WSIs, chest X-rays work fine with random rotations (about 10-15 degrees) and horizontal flipping.

Similarly, Whole Slide Images (WSIs) come with their own set of problems. These images are scanned copies of tissue slides obtained via a biopsy or surgery in a medical facility. These tissue slides are in centimetres and have micron-size pixels. Therefore, a whole slide image (WSI) is in the gigapixel range and divided into smaller patches for better analysis. Moreover, as tissue slides from a single patient can also differ (because of variations in staining), inter-WSI domain shift is common. These WSIs differ immensely when compared to natural images. Most natural images have the main object in the centre of the image, whereas, in a WSI, the abnormality might be in multiple smaller patches in any part of the image. This is also why Transfer Learning from natural to medical datasets doesn't perform well.

Clever tricks

Figure: Illustration of the proposed contrastive objective in MICLe. Source: Azizi et al. (2021)

Multi-Instance Contrastive Learning (MICLe) by Azizi et al. (2021) proposes using contrastive learning for multiple images of the same underlying pathology per patient case. Given multiple images from a particular patient, we generate positive pairs for contrastive learning by using two crops from different images, such as from different viewing angles. This enables learning representations that are robust to a change of viewpoint.

Figure: Overview of the proposed SSLP method. Source: Li et al. (MICCAI 2021)

Spatial Guided Self-Supervised Learning on Pathological Images (SSLP) by Li et al. (MICCAI 2021) proposes exploiting the semantic invariance in pathological images since patch-wise spatial proximity is a significant characteristic of WSIs. Apart from enforcing similarity between positive samples, SSLP aims to enforce spatial proximity based on the assumption that adjacent patches in an image are more likely to share the same semantic label.

Equation: Loss function used in SSLP, where S represents the spatial neighbourhood. Source: Li et al. (MICCAI 2021

In practice, we use heuristically chosen hard negative samples to deal with tumour heterogeneity.

Figure: Proposed MoCo-CXR training pipeline. Source: Sowrirajan et al. (MIDL 2021)

In their paper, Sowrirajan et al. (MIDL 2021) introduce MoCo-CXR, a variant of the well-known MoCo method aimed at producing models with better-quality representations and initialisations specially targeted for detecting various pathologies in chest X-rays. They choose MoCo over other contrastive methods, such as SimCLR, because it can perform better even at a smaller batch size. Moreover, after pre-training, they also fine-tune the models on varying fractions of labelled training data from standard medical datasets such as CheXpert and Shenzhen. This acts as a proxy for the real world, where large amounts of data remain unlabelled, and only a small portion of well-labelled data can be used toward supervised fine-tuning. They show that models pre-trained using their MoCo-CXR strategy have better representations and exhibit transferability across other Chest X-ray datasets.

Figure: Proposed Illustration of the Latent Augmentation methodology. Source:Yang et al. (ICLR 2022

Yang et al. (ICLR 2022), in their paper "Towards Better Understanding and Better Generalisation of Low-shot Classification in Histology Images with Contrastive Learning", introduce Latent Augmentation to better aid few-shot classification in WSIs. They use contrastive learning methods to learn a meaningful encoder in the pre-training stage and their Latent Augmentation strategy to "inherit knowledge" from the training dataset by "transferring semantic variants" in the latent space. After pre-training, they perform K-means on the generated representations to obtain clusters. They then create a dictionary that maps clusters to their mean representations. These are then used to augment the representations during training.

SSL excels in Zero-Shot Generalisation

Most image segmentation models suffer in terms of their generalisation capabilities, i.e. if a model has been trained to segment skin lesions, it fails to predict accurate segmentation maps for organs. Most prior approaches have aimed to tackle this problem with interactive segmentation, i.e. the user clicks on a particular object, and the model tries to produce segmentation maps without any prior fine-tuning. Moreover, most models are pre-trained on 2D images, whereas medical images mainly occur in higher dimensions such as MRI or ultrasound. However, foundation models have recently shown significant potential for robust zero-shot performance across image modalities.

Figure: Ability of MedSAM-2 to track objects based on a given prompt. Source:Zhu et al. 2024

SAM 2 introduced a new SOTA model capable of "real-time" object segmentation for images and video streams (for an overview of SAM, refer to our blog post). With the ability to deal with object motion and occlusion, SAM 2 is a transforming point for object segmentation. The authors of MedSAM-2 (Zhu et al. 2024) take this general visual model and adapt it to treat medical images as videos. This allows the user to only prompt the model (like in the SAM 2 demo). Then, the model can "track" the object in later images even though there might not be any temporal relationship between the images, significantly reducing a clinician's effort.

Figure: Overview of the MedSAM-2 framework. Source:Zhu et al. 2024

NOTE: It's not as simple as treating each slice as a separate 2D image for segmentation as that exposes the model to the numerous issues associated with 2D images

SSL Aids in Multimodal Scenarios

Figure: Various multimodal self-supervised learning pre-training strategies. Source:Huang et al. (2024)

Lately, multimodal learning has emerged as a viable solution to learn better-quality features and enable broader use of trained models. However, learning with more than one modality comes with its own problems, such as fusion and alignment of the various modalities. Researchers often use contrastive learning to add to the learning objective and further aid in learning high-quality representations. Another approach is self-prediction, where parts of the data (e.g., masked image regions) are predicted based on complementary modalities. This enables a better understanding of the interrelation between image and text. Generative models, such as those using generative adversarial networks or diffusers, learn to synthesise one modality from another, like generating medical reports from images. In some advanced models, vision-language models (VLMs), such as GPT-4 Vision, are used to generate or interpret image data through natural language prompts, offering intuitive interaction with medical imagery.

Figure: Illustration of the ContIG framework. Source:Aiham et al. (CVPR 2022)

Aiham et al. (CVPR 2022), in "ContIG: Self-supervised Multimodal Contrastive Learning for Medical Imaging with Genetics", propose a multimodal contrastive objective to tackle the non-trivial and challenging task of incorporating genetic data during pre-training. Their method aligns images and several genetic modalities in the feature space. They use three genetic modalities from the same individual for a given image. They then generate features from each modality using encoders (CNNs for images and MLPs for genomic data). A projection head is then used for each modality to produce equal-sized embeddings in the latent space and then trained with a contrastive objective. In particular, the embeddings from each individual are trained to be closer to the feature space and farther from those of other individuals.

Conclusion

Self-supervised learning (SSL) has emerged as a powerful paradigm for medical image analysis, offering solutions to unique challenges in this domain. Unlike natural images, medical data presents distinct characteristics—from localized abnormalities in X-rays to gigapixel-scale whole slide images—that require specialized approaches. Recent innovations like MICLe, SSLP, and MoCo-CXR demonstrate how domain-specific adaptations of SSL can improve performance on tasks like classification and segmentation. The field has also seen significant progress in zero-shot generalization, exemplified by models like MedSAM-2, which can track objects across different medical image modalities without additional training. Furthermore, multimodal learning approaches that combine imaging with other data types (like genetic information) are showing promise in creating more robust and comprehensive medical AI systems.

As the field evolves, Self-supervised learning methods will likely play an increasingly crucial role in developing more adaptable and generalisable medical imaging models, ultimately supporting better clinical decision-making.