UNI-based ABMIL models for metastasis detection
These are weakly-supervised, attention-based multiple instance learning models for binary metastasis detection (normal versus metastasis). The models were trained on the CAMELYON16 dataset using UNI embeddings.
Data
- Training set consisted of 243 whole slide images (WSIs).
- 143 negative
- 100 positive
- 52 macrometastases
- 48 micrometastases
- Validation set consisted of 27 WSIs.
- 16 negative
- 11 positive
- 6 macrometastases
- 5 micrometastases
- Test set consisted of 129 WSIs.
- 80 negative
- 49 positive
- 22 macrometastases
- 27 micrometastases
Evaluation
Below are the classification results on the test set.
Seed | Sensitivity | Specificity | BA | Precision | F1 |
---|---|---|---|---|---|
0 | 0.959 | 1.000 | 0.980 | 1.000 | 0.979 |
1 | 0.959 | 0.988 | 0.973 | 0.979 | 0.969 |
2 | 1.000 | 1.000 | 1.000 | 1.000 | 1.000 |
3 | 0.980 | 0.950 | 0.965 | 0.923 | 0.950 |
4 | 0.980 | 1.000 | 0.990 | 1.000 | 0.990 |
How to reuse the model
The model expects 128 x 128 micrometer patches, embedded with the UNI model.
import torch
from abmil import AttentionMILModel
model = AttentionMILModel(in_features=1024, L=512, D=384, num_classes=2, gated_attention=True)
model.eval()
state_dict = torch.load("seed2/model_best.pt", map_location="cpu", weights_only=True)
model.load_state_dict(state_dict)
# Load a bag of features
bag = torch.ones(1000, 1024)
with torch.inference_mode():
logits, attention = model(bag)
How to train the model
Download the UNI embeddings for CAMELYON16 from https://huggingface.co/datasets/kaczmarj/camelyon16-uni and then, run the commands below.
# Seed 0
python train_classification.py --model-name AttentionMILModel --features-dir path/to/features/ --output-dir outputs/abmil-uni-128um_seed0 --csv data.csv --label-col binary_label_int --num-classes 2 --embedding-size 1024 --split-json splits.json --fold 0 --num-epochs 20 --seed 0 -L 512 -D 384 --lr 1e-4
# Seed 1
python train_classification.py --model-name AttentionMILModel --features-dir path/to/features/ --output-dir outputs/abmil-uni-128um_seed1 --csv data.csv --label-col binary_label_int --num-classes 2 --embedding-size 1024 --split-json splits.json --fold 0 --num-epochs 20 --seed 1 -L 512 -D 384 --lr 1e-4
# Seed 2
python train_classification.py --model-name AttentionMILModel --features-dir path/to/features/ --output-dir outputs/abmil-uni-128um_seed2 --csv data.csv --label-col binary_label_int --num-classes 2 --embedding-size 1024 --split-json splits.json --fold 0 --num-epochs 20 --seed 2 -L 512 -D 384 --lr 1e-4
# Seed 3
python train_classification.py --model-name AttentionMILModel --features-dir path/to/features/ --output-dir outputs/abmil-uni-128um_seed3 --csv data.csv --label-col binary_label_int --num-classes 2 --embedding-size 1024 --split-json splits.json --fold 0 --num-epochs 20 --seed 3 -L 512 -D 384 --lr 1e-4
# Seed 4
python train_classification.py --model-name AttentionMILModel --features-dir path/to/features/ --output-dir outputs/abmil-uni-128um_seed4 --csv data.csv --label-col binary_label_int --num-classes 2 --embedding-size 1024 --split-json splits.json --fold 0 --num-epochs 20 --seed 4 -L 512 -D 384 --lr 1e-4