A tested PyTorch framework for computational pathology research with working benchmarks on PatchCamelyon and CAMELYON16
View on GitHub matthewvaishnav/computational-pathology-research
Detailed technical documentation of the multimodal fusion architecture.
┌─────────────────────────────────────────────────────────────────┐
│ Input Modalities │
├─────────────────────────────────────────────────────────────────┤
│ WSI Features │ Genomic Data │ Clinical Text │
│ [B, N, 1024] │ [B, 2000] │ [B, L] │
└────────┬─────────┴────────┬──────────┴────────┬─────────────────┘
│ │ │
▼ ▼ ▼
┌─────────────────────────────────────────────────────────────────┐
│ Modality-Specific Encoders │
├─────────────────────────────────────────────────────────────────┤
│ WSI Encoder │ Genomic Encoder │ Clinical Encoder │
│ (Attention) │ (MLP) │ (Transformer) │
│ 8.5M params │ 2.1M params │ 12.3M params │
└────────┬─────────┴────────┬──────────┴────────┬─────────────────┘
│ │ │
└──────────────────┼───────────────────┘
▼
┌─────────────────────────────────────────────────────────────────┐
│ Cross-Modal Attention Fusion │
├─────────────────────────────────────────────────────────────────┤
│ • Pairwise attention between all modalities │
│ • Learns cross-modal relationships │
│ • Handles missing modalities │
│ • 3.2M parameters │
└────────────────────────────┬────────────────────────────────────┘
│
▼
[B, embed_dim]
│
┌───────────────────┴───────────────────┐
│ │
▼ ▼
┌─────────────────────┐ ┌─────────────────────┐
│ Task-Specific │ │ Temporal Reasoning │
│ Heads │ │ (Optional) │
├─────────────────────┤ ├─────────────────────┤
│ • Classification │ │ • Temporal Attn │
│ • Survival │ │ • Progression │
│ • Segmentation │ │ • 467K params │
│ • 1.5M params │ └─────────────────────┘
└─────────────────────┘
│
▼
Predictions
Format: [batch_size, num_patches, 1024]
Whole-Slide Image
│
▼
┌─────────────────┐
│ Patch Extraction│ (e.g., 256x256 patches)
└────────┬────────┘
│
▼
┌─────────────────┐
│ Feature Extract │ (e.g., ResNet-50 pretrained)
└────────┬────────┘
│
▼
[N, 1024] features
Properties:
Format: [batch_size, 2000]
Raw Genomic Data
│
▼
┌─────────────────┐
│ Gene Selection │ (e.g., top 2000 genes)
└────────┬────────┘
│
▼
┌─────────────────┐
│ Normalization │ (e.g., log-transform, z-score)
└────────┬────────┘
│
▼
[2000] features
Properties:
Format: [batch_size, seq_len]
Clinical Notes
│
▼
┌─────────────────┐
│ Tokenization │ (e.g., WordPiece, BPE)
└────────┬────────┘
│
▼
┌─────────────────┐
│ Truncation/Pad │ (max length)
└────────┬────────┘
│
▼
[L] token IDs
Properties:
class WSIEncoder(nn.Module):
"""
Attention-based patch aggregation.
Input: [B, N, 1024]
Output: [B, embed_dim]
"""
Architecture:
Input: [B, N, 1024]
│
▼
┌─────────────────┐
│ Linear Proj │ 1024 → embed_dim
└────────┬────────┘
│
▼
┌─────────────────┐
│ Positional Enc │ Learnable
└────────┬────────┘
│
▼
┌─────────────────┐
│ Transformer │ L layers, H heads
│ Encoder │ Self-attention over patches
└────────┬────────┘
│
▼
┌─────────────────┐
│ Attention Pool │ Weighted average
└────────┬────────┘
│
▼
Output: [B, embed_dim]
Key Features:
class GenomicEncoder(nn.Module):
"""
MLP with batch normalization.
Input: [B, 2000]
Output: [B, embed_dim]
"""
Architecture:
Input: [B, 2000]
│
▼
┌─────────────────┐
│ Linear │ 2000 → 1024
│ BatchNorm │
│ ReLU │
│ Dropout(0.1) │
└────────┬────────┘
│
▼
┌─────────────────┐
│ Linear │ 1024 → 512
│ BatchNorm │
│ ReLU │
│ Dropout(0.1) │
└────────┬────────┘
│
▼
┌─────────────────┐
│ Linear │ 512 → embed_dim
│ LayerNorm │
└────────┬────────┘
│
▼
Output: [B, embed_dim]
Key Features:
class ClinicalTextEncoder(nn.Module):
"""
Transformer-based text encoder.
Input: [B, L]
Output: [B, embed_dim]
"""
Architecture:
Input: [B, L] token IDs
│
▼
┌─────────────────┐
│ Embedding │ vocab_size → embed_dim
└────────┬────────┘
│
▼
┌─────────────────┐
│ Positional Enc │ Sinusoidal
└────────┬────────┘
│
▼
┌─────────────────┐
│ Transformer │ L layers, H heads
│ Encoder │ Self-attention over tokens
└────────┬────────┘
│
▼
┌─────────────────┐
│ [CLS] Token │ First token embedding
└────────┬────────┘
│
▼
Output: [B, embed_dim]
Key Features:
class CrossModalAttention(nn.Module):
"""
Pairwise attention between modalities.
Input: {modality: [B, embed_dim]}
Output: [B, embed_dim * num_modalities]
"""
Architecture:
Modality Embeddings
WSI Genomic Clinical
│ │ │
└────────┼────────┘
│
▼
┌─────────────────────────────┐
│ Pairwise Attention │
│ │
│ WSI → Genomic │
│ WSI → Clinical │
│ Genomic → WSI │
│ Genomic → Clinical │
│ Clinical → WSI │
│ Clinical → Genomic │
└────────────┬────────────────┘
│
▼
┌─────────────────────────────┐
│ Concatenate │
│ [WSI', Genomic', Clinical']│
└────────────┬────────────────┘
│
▼
┌─────────────────────────────┐
│ Linear Projection │
│ 3*embed_dim → embed_dim │
└────────────┬────────────────┘
│
▼
Fused Embedding
[B, embed_dim]
Attention Mechanism:
# For each pair (modality_i, modality_j):
Q = W_q @ modality_i # Query
K = W_k @ modality_j # Key
V = W_v @ modality_j # Value
attention = softmax(Q @ K.T / sqrt(d_k))
output = attention @ V
Key Features:
Available Modalities: {WSI, Genomic}
Missing: Clinical
┌─────────────────────────────┐
│ Compute Available Pairs │
│ • WSI → Genomic │
│ • Genomic → WSI │
│ (Skip Clinical pairs) │
└────────────┬────────────────┘
│
▼
┌─────────────────────────────┐
│ Concatenate Available │
│ [WSI', Genomic', 0] │
│ (Zero for missing) │
└────────────┬────────────────┘
│
▼
Fused Embedding
Properties:
class TemporalAttention(nn.Module):
"""
Attention over slide sequence.
Input: [B, T, embed_dim]
Output: [B, T, embed_dim]
"""
Architecture:
Slide Sequence
[slide_0, slide_1, ..., slide_T]
│
▼
┌─────────────────────────────┐
│ Temporal Positional Enc │
│ Based on timestamps │
└────────────┬────────────────┘
│
▼
┌─────────────────────────────┐
│ Transformer Encoder │
│ Self-attention over time │
└────────────┬────────────────┘
│
▼
Attended Slides
[B, T, embed_dim]
Temporal Encoding:
# Convert timestamps to positional encoding
temporal_bins = timestamps / max_temporal_distance
temporal_enc = sinusoidal_encoding(temporal_bins)
Attended Slides
[s_0, s_1, s_2, s_3]
│
▼
┌─────────────────────────────┐
│ Compute Differences │
│ Δ_1 = s_1 - s_0 │
│ Δ_2 = s_2 - s_1 │
│ Δ_3 = s_3 - s_2 │
└────────────┬────────────────┘
│
▼
┌─────────────────────────────┐
│ Concatenate │
│ [s_i, Δ_i] for each pair │
└────────────┬────────────────┘
│
▼
┌─────────────────────────────┐
│ MLP Projection │
│ 2*embed_dim → embed_dim/2 │
└────────────┬────────────────┘
│
▼
Progression Features
[B, T-1, embed_dim/2]
Attended Slides + Progression
│
▼
┌─────────────────────────────┐
│ Pooling Strategy │
│ • Attention-weighted │
│ • Mean pooling │
│ • Max pooling │
│ • Last slide │
└────────────┬────────────────┘
│
▼
Sequence Representation
[B, embed_dim]
class ClassificationHead(nn.Module):
"""
Multi-class classification.
Input: [B, embed_dim]
Output: [B, num_classes]
"""
Architecture:
Input: [B, embed_dim]
│
▼
┌─────────────────┐
│ LayerNorm │
└────────┬────────┘
│
▼
┌─────────────────┐
│ Linear │ embed_dim → embed_dim
│ ReLU │
│ Dropout(0.3) │
└────────┬────────┘
│
▼
┌─────────────────┐
│ Linear │ embed_dim → num_classes
└────────┬────────┘
│
▼
Output: [B, num_classes] logits
class SurvivalPredictionHead(nn.Module):
"""
Cox proportional hazards.
Input: [B, embed_dim]
Output: [B, 1] hazard
"""
Architecture:
Input: [B, embed_dim]
│
▼
┌─────────────────┐
│ LayerNorm │
└────────┬────────┘
│
▼
┌─────────────────┐
│ Linear │ embed_dim → embed_dim//2
│ ReLU │
│ Dropout(0.3) │
└────────┬────────┘
│
▼
┌─────────────────┐
│ Linear │ embed_dim//2 → 1
└────────┬────────┘
│
▼
Output: [B, 1] log hazard
Batch
│
├─> WSI Encoder ──────┐
│ │
├─> Genomic Encoder ──┼─> Cross-Modal ──> Fused
│ │ Attention Embedding
└─> Clinical Encoder ─┘ │
│
▼
Task Head
│
▼
Predictions
# Classification
loss = CrossEntropyLoss(predictions, labels)
# Survival
loss = CoxLoss(hazards, survival_times, events)
# Multi-task
loss = α * classification_loss + β * survival_loss
optimizer = AdamW(
model.parameters(),
lr=1e-4,
weight_decay=0.01
)
scheduler = CosineAnnealingLR(
optimizer,
T_max=num_epochs
)
# Training step
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
Alternatives Considered:
Benefits:
Alternatives Considered:
Benefits:
Alternatives Considered:
Benefits:
| Parameter | Value | Notes |
|---|---|---|
| embed_dim | 256 | Embedding dimension |
| num_heads | 8 | Attention heads |
| num_layers | 2-4 | Transformer layers |
| dropout | 0.1-0.3 | Regularization |
| learning_rate | 1e-4 | AdamW |
| weight_decay | 0.01 | L2 regularization |
| batch_size | 8-32 | Depends on memory |
| max_grad_norm | 1.0 | Gradient clipping |
# Xavier/Glorot for linear layers
nn.init.xavier_uniform_(layer.weight)
nn.init.zeros_(layer.bias)
# Positional encodings
pos_enc = sinusoidal_encoding(positions)
# Embeddings
nn.init.normal_(embedding.weight, mean=0, std=0.02)
| Component | Complexity | Notes |
|---|---|---|
| WSI Encoder | O(N² · d) | N patches, d embed_dim |
| Genomic Encoder | O(d²) | MLP layers |
| Clinical Encoder | O(L² · d) | L tokens |
| Cross-Modal Fusion | O(M² · d) | M modalities |
| Temporal Attention | O(T² · d) | T time points |
Total: O(N² · d + L² · d + T² · d)
| Component | Memory | Notes |
|---|---|---|
| Model Parameters | 110MB | FP32 weights |
| Activations | ~2GB | Batch=16 |
| Gradients | 110MB | Same as params |
| Optimizer State | 220MB | AdamW (2x params) |
Total: ~2.5GB for training (batch=16)
# Add new modality
class RadiomicsEncoder(nn.Module):
def __init__(self, input_dim, embed_dim):
self.encoder = nn.Sequential(
nn.Linear(input_dim, embed_dim),
nn.LayerNorm(embed_dim)
)
# Integrate into fusion
modalities['radiomics'] = radiomics_encoder(batch['radiomics'])
# Multiple heads
classification_head = ClassificationHead(embed_dim, num_classes)
survival_head = SurvivalPredictionHead(embed_dim)
segmentation_head = SegmentationHead(embed_dim)
# Combined loss
loss = (α * classification_loss +
β * survival_loss +
γ * segmentation_loss)
# Contrastive learning
contrastive_loss = SimCLR(embeddings, augmented_embeddings)
# Masked reconstruction
masked_loss = MSE(reconstructed, original)
# Combined
pretrain_loss = contrastive_loss + masked_loss
Architecture Highlights:
Key Innovations:
Production Ready:
Last Updated: 2026-04-05
Version: 1.0.0
Status: Production-ready for research ✅