A tested PyTorch framework for computational pathology research with working benchmarks on PatchCamelyon and CAMELYON16
View on GitHub matthewvaishnav/computational-pathology-research
Complete API documentation for the Computational Pathology Research Framework.
Loads 96x96 pixel patches from the PatchCamelyon dataset.
from src.data import PatchCamelyonDataset
dataset = PatchCamelyonDataset(
root_dir="data/pcam",
split="train", # "train", "val", or "test"
transform=None
)
Parameters:
root_dir (str): Path to PCam data directorysplit (str): Dataset split to loadtransform (callable, optional): Transform to apply to imagesReturns:
image (torch.Tensor): Image tensor of shape (3, 96, 96)label (int): Binary label (0 or 1)Loads pre-extracted features for slide-level classification.
from src.data import CAMELYONSlideDataset
dataset = CAMELYONSlideDataset(
root_dir="data/camelyon/features",
split="train",
max_patches=None
)
Parameters:
root_dir (str): Path to HDF5 feature directorysplit (str): Dataset split to loadmax_patches (int, optional): Maximum patches per slideReturns:
features (torch.Tensor): Feature tensor of shape (num_patches, feature_dim)label (int): Binary label (0 or 1)slide_id (str): Slide identifierCollates variable-length slides into batched tensors with masking.
from src.data import collate_slide_bags
batch = collate_slide_bags(samples)
Parameters:
samples (list): List of (features, label, slide_id) tuplesReturns:
features (torch.Tensor): Padded features of shape (batch, max_patches, feature_dim)labels (torch.Tensor): Labels of shape (batch,)num_patches (torch.Tensor): Number of patches per slide of shape (batch,)slide_ids (list): List of slide identifiersBasic CNN classifier for patch-level classification.
from src.models import SimpleClassifier
model = SimpleClassifier(
num_classes=2,
dropout=0.5
)
Parameters:
num_classes (int): Number of output classesdropout (float): Dropout probabilityForward:
output = model(images) # images: (batch, 3, 96, 96)
# output: (batch, num_classes)
Slide-level classifier with attention-based aggregation.
from src.models import SimpleSlideClassifier
model = SimpleSlideClassifier(
feature_dim=2048,
hidden_dim=256,
num_classes=2,
pooling="attention", # "mean", "max", or "attention"
dropout=0.5
)
Parameters:
feature_dim (int): Input feature dimensionhidden_dim (int): Hidden layer dimensionnum_classes (int): Number of output classespooling (str): Aggregation methoddropout (float): Dropout probabilityForward:
output = model(features, num_patches)
# features: (batch, max_patches, feature_dim)
# num_patches: (batch,)
# output: (batch, num_classes)
Loads pretrained encoders from torchvision or timm.
from src.models.pretrained import load_pretrained_encoder
encoder = load_pretrained_encoder(
model_name="resnet50",
source="torchvision", # "torchvision" or "timm"
pretrained=True,
num_classes=2
)
# Access feature dimension
feature_dim = encoder.feature_dim
Parameters:
model_name (str): Model architecture namesource (str): Model source (“torchvision” or “timm”)pretrained (bool): Load pretrained weightsnum_classes (int): Number of output classesSupported Models:
torchvision:
timm:
Trains model for one epoch.
from src.training import train_epoch
metrics = train_epoch(
model=model,
train_loader=train_loader,
criterion=criterion,
optimizer=optimizer,
device=device,
epoch=epoch
)
Parameters:
model (nn.Module): Model to traintrain_loader (DataLoader): Training data loadercriterion (nn.Module): Loss functionoptimizer (Optimizer): Optimizerdevice (torch.device): Device to train onepoch (int): Current epoch numberReturns:
metrics (dict): Training metrics
loss (float): Average training lossaccuracy (float): Training accuracytime (float): Epoch duration in secondsEvaluates model on validation/test set.
from src.training import evaluate
metrics = evaluate(
model=model,
val_loader=val_loader,
criterion=criterion,
device=device
)
Parameters:
model (nn.Module): Model to evaluateval_loader (DataLoader): Validation data loadercriterion (nn.Module): Loss functiondevice (torch.device): Device to evaluate onReturns:
metrics (dict): Evaluation metrics
loss (float): Average validation lossaccuracy (float): Validation accuracyauc (float): Area under ROC curvepredictions (np.ndarray): Model predictionslabels (np.ndarray): Ground truth labelsSets random seed for reproducibility.
from src.utils import set_seed
set_seed(42)
Parameters:
seed (int): Random seed valueSaves model checkpoint.
from src.utils import save_checkpoint
save_checkpoint(
model=model,
optimizer=optimizer,
epoch=epoch,
metrics=metrics,
path="checkpoints/model.pth"
)
Parameters:
model (nn.Module): Model to saveoptimizer (Optimizer): Optimizer stateepoch (int): Current epochmetrics (dict): Training metricspath (str): Save pathLoads model checkpoint.
from src.utils import load_checkpoint
checkpoint = load_checkpoint(
path="checkpoints/model.pth",
model=model,
optimizer=optimizer
)
Parameters:
path (str): Checkpoint pathmodel (nn.Module, optional): Model to load weights intooptimizer (Optimizer, optional): Optimizer to load state intoReturns:
checkpoint (dict): Checkpoint dictionary
epoch (int): Saved epochmetrics (dict): Saved metricsmodel_state_dict (dict): Model weightsoptimizer_state_dict (dict): Optimizer stateExample YAML configuration for training:
# experiments/configs/pcam.yaml
data:
root_dir: "data/pcam"
batch_size: 32
num_workers: 4
model:
architecture: "resnet18"
num_classes: 2
dropout: 0.5
training:
epochs: 10
learning_rate: 0.001
weight_decay: 0.0001
optimizer: "adam"
scheduler: "step"
step_size: 5
gamma: 0.1
logging:
log_interval: 10
checkpoint_dir: "checkpoints/pcam"
save_best: true
import torch
from torch.utils.data import DataLoader
from src.data import PatchCamelyonDataset
from src.models import SimpleClassifier
from src.training import train_epoch, evaluate
from src.utils import set_seed
# Set seed for reproducibility
set_seed(42)
# Create datasets
train_dataset = PatchCamelyonDataset("data/pcam", "train")
val_dataset = PatchCamelyonDataset("data/pcam", "val")
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
# Create model
model = SimpleClassifier(num_classes=2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Training setup
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Training loop
for epoch in range(10):
train_metrics = train_epoch(model, train_loader, criterion, optimizer, device, epoch)
val_metrics = evaluate(model, val_loader, criterion, device)
print(f"Epoch {epoch+1}/10")
print(f"Train Loss: {train_metrics['loss']:.4f}, Acc: {train_metrics['accuracy']:.4f}")
print(f"Val Loss: {val_metrics['loss']:.4f}, Acc: {val_metrics['accuracy']:.4f}")
from src.models.pretrained import load_pretrained_encoder
import torch.nn as nn
# Load pretrained ResNet50
encoder = load_pretrained_encoder(
model_name="resnet50",
source="torchvision",
pretrained=True,
num_classes=2
)
# Get feature dimension
print(f"Feature dimension: {encoder.feature_dim}")
# Use in training
model = encoder.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)