Multimodal Fusion Architecture
This document describes the multimodal fusion architecture for computational pathology, including cross-modal attention, temporal reasoning, and task-specific prediction heads.
Architecture Overview
The complete architecture consists of four main components:
- Modality-Specific Encoders (
src/models/encoders.py) - Cross-Modal Attention Fusion (
src/models/fusion.py) - Temporal Reasoning (
src/models/temporal.py) - Task-Specific Prediction Heads (
src/models/heads.py)
Components
1. Multimodal Fusion Model
The MultimodalFusionModel integrates all modality encoders and fusion mechanisms:
python
from src.models import MultimodalFusionModel
# Initialize model
model = MultimodalFusionModel(embed_dim=256)
# Prepare batch
batch = {
'wsi_features': torch.randn(batch_size, num_patches, 1024),
'genomic': torch.randn(batch_size, num_genes),
'clinical_text': torch.randint(0, vocab_size, (batch_size, seq_len))
}
# Get fused embedding
fused_embedding = model(batch) # [batch_size, 256]Handling Missing Modalities:
The model gracefully handles missing modalities by setting them to None:
python
batch = {
'wsi_features': torch.randn(batch_size, num_patches, 1024),
'genomic': None, # Missing genomic data
'clinical_text': torch.randint(0, vocab_size, (batch_size, seq_len))
}
fused_embedding = model(batch) # Still works!2. Cross-Modal Attention Fusion
The fusion layer implements pairwise cross-modal attention:
python
from src.models import MultiModalFusionLayer
fusion = MultiModalFusionLayer(embed_dim=256, num_heads=8)
embeddings = {
'wsi': wsi_embeddings, # [batch_size, 256]
'genomic': genomic_embeddings, # [batch_size, 256]
'clinical': clinical_embeddings # [batch_size, 256]
}
fused = fusion(embeddings) # [batch_size, 256]3. Temporal Reasoning
For analyzing disease progression across multiple slides:
python
from src.models import CrossSlideTemporalReasoner
temporal_reasoner = CrossSlideTemporalReasoner(embed_dim=256)
# Slide sequence from same patient
slide_embeddings = torch.randn(batch_size, num_slides, 256)
timestamps = torch.tensor([[0, 30, 90, 180, 365]]) # Days
# Get sequence-level embedding and progression features
sequence_emb, progression_features = temporal_reasoner(
slide_embeddings,
timestamps
)Features:
- Temporal attention with positional encoding
- Progression feature extraction (differences between consecutive slides)
- Multiple pooling strategies: attention, mean, max, last
4. Task-Specific Prediction Heads
Classification
python
from src.models import ClassificationHead
classifier = ClassificationHead(
input_dim=256,
num_classes=5,
dropout=0.3
)
logits = classifier(embeddings) # [batch_size, 5]Survival Prediction
python
from src.models import SurvivalPredictionHead
# Risk score prediction
risk_head = SurvivalPredictionHead(input_dim=256)
risk_scores = risk_head(embeddings) # [batch_size, 1]
# Discrete time hazard prediction
hazard_head = SurvivalPredictionHead(
input_dim=256,
num_time_bins=12
)
hazards = hazard_head(embeddings, return_hazards=True) # [batch_size, 12]
survival_curve = hazard_head.compute_survival_curve(embeddings) # [batch_size, 12]Multi-Task Learning
python
from src.models import MultiTaskHead
multi_head = MultiTaskHead(
input_dim=256,
classification_config={'num_classes': 5},
survival_config={'num_time_bins': 12}
)
class_logits, survival_output = multi_head(embeddings)Complete End-to-End Example
python
import torch
from src.models import (
MultimodalFusionModel,
CrossSlideTemporalReasoner,
ClassificationHead,
SurvivalPredictionHead
)
# 1. Multimodal fusion
fusion_model = MultimodalFusionModel(embed_dim=256)
batch = {
'wsi_features': torch.randn(8, 100, 1024),
'genomic': torch.randn(8, 2000),
'clinical_text': torch.randint(0, 30000, (8, 128))
}
fused_embeddings = fusion_model(batch)
# 2. Temporal reasoning (optional, for longitudinal data)
temporal_reasoner = CrossSlideTemporalReasoner(embed_dim=256)
slide_sequence = torch.randn(8, 5, 256)
timestamps = torch.tensor([[0, 30, 90, 180, 365]] * 8).float()
sequence_emb, progression_features = temporal_reasoner(
slide_sequence,
timestamps
)
# 3. Task-specific predictions
classifier = ClassificationHead(input_dim=256, num_classes=4)
class_logits = classifier(fused_embeddings)
survival_head = SurvivalPredictionHead(input_dim=256, num_time_bins=12)
survival_hazards = survival_head(fused_embeddings, return_hazards=True)Configuration
All components support flexible configuration:
python
# Custom encoder configurations
wsi_config = {
'input_dim': 2048,
'hidden_dim': 1024,
'output_dim': 512,
'num_heads': 16,
'num_layers': 4,
'dropout': 0.2,
'pooling': 'attention'
}
genomic_config = {
'input_dim': 5000,
'hidden_dims': [2048, 1024, 512],
'output_dim': 512,
'dropout': 0.4,
'use_batch_norm': True
}
model = MultimodalFusionModel(
wsi_config=wsi_config,
genomic_config=genomic_config,
embed_dim=512
)Key Features
- Modular Design: Each component can be used independently
- Missing Modality Handling: Gracefully handles missing data
- Flexible Architecture: Configurable dimensions, layers, and pooling strategies
- Temporal Reasoning: Built-in support for longitudinal analysis
- Multi-Task Learning: Joint training on multiple objectives
- Research-Ready: Designed for experimentation and ablation studies
References
- See
src/models/encoders.pyfor modality-specific encoder implementations - See
src/models/fusion.pyfor cross-modal attention mechanisms - See
src/models/multimodal.pyfor the complete fusion model - See
src/models/temporal.pyfor temporal reasoning components - See
src/models/heads.pyfor task-specific prediction heads