CXR-Findings-AI / utils /complete_model.py
manu02's picture
Update utils/complete_model.py
97bc6d7 verified
import os
import torch
import torch.nn as nn
from transformers import AutoModel, GPT2Tokenizer
from utils.modifiedGPT2 import create_decoder
from utils.layer_mask import gaussian_layer_stack_pipeline
class DINOEncoder(nn.Module):
def __init__(self, model_id="facebook/dinov3-vits16-pretrain-lvd1689m", freeze=True):
super().__init__()
self.model = AutoModel.from_pretrained(model_id)
if freeze:
for p in self.model.parameters():
p.requires_grad = False
@torch.no_grad()
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""
pixel_values: [B, C, H, W]
returns patches: [B, Np, Cenc]
"""
out = self.model(pixel_values=pixel_values)
tokens = out.last_hidden_state # [B, 1+Np, Cenc] (CLS + patches) for ViT-like
# Skip a few special tokens if your backbone adds them; adjust as needed.
patches = tokens[:, 5:, :] # [B, Np, Cenc]
return patches
class DinoUNet(nn.Module):
def __init__(self, model_name="facebook/dinov3-convnext-small-pretrain-lvd1689m", freeze=True):
super().__init__()
self.encoder = AutoModel.from_pretrained(model_name)
# NOTE: confirm channels of the chosen hidden state; 768 is common for small convnext/dinov3
self.channel_adapter = nn.Conv2d(768, 512, kernel_size=1)
self.decoder = nn.Sequential(
nn.Conv2d(512, 256, 3, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(256, 128, 2, stride=2), nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, 64, 2, stride=2), nn.ReLU(inplace=True),
nn.Conv2d(64, 1, 1)
)
if freeze:
for m in (self.encoder, self.channel_adapter, self.decoder):
for p in m.parameters():
p.requires_grad = False
@torch.no_grad()
def forward(self, x: torch.Tensor, num_layers: int) -> torch.Tensor:
"""
x: [B, C, H, W]; returns mask: [B, 1, H', W'] (your upsampling stack defines H',W')
"""
enc_feats = self.encoder(x, output_hidden_states=True, return_dict=True)
# take the last 4D feature map from hidden_states
feats = next(h for h in reversed(enc_feats.hidden_states) if isinstance(h, torch.Tensor) and h.ndim == 4)
feats = self.channel_adapter(feats)
pred = self.decoder(feats) # (B,1,h,w)
_, _, segmentation_mask = gaussian_layer_stack_pipeline(pred, n_layers = num_layers)
return segmentation_mask # [B, num_layers, h, w]
class LinearProjection(nn.Module):
def __init__(self, input_dim=384, output_dim=768, freeze=False):
super().__init__()
self.proj = nn.Linear(input_dim, output_dim)
if freeze:
for p in self.proj.parameters():
p.requires_grad = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, Np, input_dim] -> [B, Np, output_dim]
return self.proj(x)
class CustomModel(nn.Module):
def __init__(
self,
device: str = "cuda",
ENCODER_MODEL_PATH: str | None = "dino_encoder.pth",
SEGMENTER_MODEL_PATH: str | None = "dino_segmenter.pth",
DECODER_MODEL_PATH: str | None = "dino_decoder.pth",
LINEAR_PROJECTION_PATH: str | None = "linear_projection.pth",
freeze_encoder: bool = True,
freeze_segmenter: bool = True,
freeze_linear_projection: bool = False,
freeze_decoder: bool = False,
attention_implementation: str = "sdpa",
):
super().__init__()
self.device = torch.device(device)
# Encoder
self.encoder = DINOEncoder()
if ENCODER_MODEL_PATH and os.path.exists(ENCODER_MODEL_PATH):
self.encoder.load_state_dict(torch.load(ENCODER_MODEL_PATH, map_location="cpu"), strict=False)
print("Loaded encoder weights from", ENCODER_MODEL_PATH)
if freeze_encoder:
self.encoder.eval()
# Segmenter
self.segmenter = DinoUNet()
if SEGMENTER_MODEL_PATH and os.path.exists(SEGMENTER_MODEL_PATH):
self.segmenter.load_state_dict(torch.load(SEGMENTER_MODEL_PATH, map_location="cpu"), strict=False)
print("Loaded segmenter weights from", SEGMENTER_MODEL_PATH)
if freeze_segmenter:
self.segmenter.eval()
# Decoder (modified GPT-2)
self.decoder = create_decoder(attention=attention_implementation) # must expose .config.hidden_size & .config.num_hidden_layers
if DECODER_MODEL_PATH and os.path.exists(DECODER_MODEL_PATH):
self.decoder.load_state_dict(torch.load(DECODER_MODEL_PATH, map_location="cpu"), strict=False)
print("Loaded decoder weights from", DECODER_MODEL_PATH)
if freeze_decoder:
self.decoder.eval()
# Linear projection: DINO hidden -> GPT2 hidden
enc_h = self.encoder.model.config.hidden_size
dec_h = self.decoder.config.hidden_size
self.linear_projection = LinearProjection(input_dim=enc_h, output_dim=dec_h)
if LINEAR_PROJECTION_PATH and os.path.exists(LINEAR_PROJECTION_PATH):
self.linear_projection.load_state_dict(torch.load(LINEAR_PROJECTION_PATH, map_location="cpu"), strict=False)
print("Loaded linear projection weights from", LINEAR_PROJECTION_PATH)
if freeze_linear_projection:
self.linear_projection.eval()
# Tokenizer (pad token for GPT-2)
self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.pad_token_id = self.tokenizer.pad_token_id # ✅ use ID, not string
self.num_layers = self.decoder.config.num_hidden_layers
# move everything once
self.to(self.device)
def forward(self, pixel_values: torch.Tensor, tgt_ids: torch.Tensor | None = None, **kwargs) -> dict:
"""
pixel_values: [B,C,H,W], float
tgt_ids: [B,T], long (token IDs), padded with pad_token_id if any padding is present
"""
pixel_values = pixel_values.to(self.device, non_blocking=True)
# Visual path
patches = self.encoder(pixel_values) # [B,Np,Cenc]
projected_patches = self.linear_projection(patches) # [B,Np,n_embd]
# Segmentation path per layer
segmented_layers = self.segmenter(pixel_values, self.num_layers) # [B,n_layers,H,W] (per current decoder)
# Text path (optional teacher-forced training)
labels = None
if tgt_ids is not None:
if tgt_ids.dtype != torch.long:
tgt_ids = tgt_ids.long()
tgt_ids = tgt_ids.to(self.device, non_blocking=True) # [B,T]
text_embeds = self.decoder.transformer.wte(tgt_ids) # [B,T,n_embd]
inputs_embeds = torch.cat([projected_patches, text_embeds], dim=1) # [B,Np+T,n_embd]
# Labels: ignore prefix tokens (vision) and PADs in text
B, Np, _ = projected_patches.shape
labels_prefix = torch.full((B, Np), -100, device=self.device, dtype=torch.long)
text_labels = tgt_ids.clone()
text_labels[text_labels == self.pad_token_id] = -100 # ✅ compare to ID
labels = torch.cat([labels_prefix, text_labels], dim=1) # [B,Np+T]
else:
inputs_embeds = projected_patches
# Decoder forward
out = self.decoder(inputs_embeds=inputs_embeds, segmentation_mask=segmented_layers, labels=labels, **kwargs)
return out
@torch.inference_mode()
def generate(
self,
pixel_values: torch.Tensor,
max_new_tokens: int = 100,
output_attentions: bool = False,
) -> torch.Tensor:
"""
pixel_values: [B,C,H,W], float
returns generated_ids: [B, T]
"""
pixel_values = pixel_values.to(self.device, non_blocking=True)
# Visual path
patches = self.encoder(pixel_values) # [B,Np,Cenc]
projected_patches = self.linear_projection(patches) # [B,Np,n_embd]
# Segmentation path per layer
segmented_layers = self.segmenter(pixel_values, self.num_layers) # [B,n_layers,H,W] (per current decoder)
# Generate
output = self.decoder.generate(
inputs_embeds=projected_patches,
max_new_tokens=max_new_tokens,
do_sample=False,
repetition_penalty=1.2,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.pad_token_id,
use_cache=True,
segmentation_mask=segmented_layers,
prefix_allowed_length=0,
plot_attention_mask=False,
plot_attention_mask_layer=[],
plot_attention_map=False,
plot_attention_map_layer=[],
plot_attention_map_generation=0,
output_attentions=output_attentions,
return_dict_in_generate=True,
)
# Remove prefix tokens (vision)
generated_ids = output.sequences#[:, projected_patches.shape[1]:] # [B,T]
generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
return generated_ids, generated_text, output.attentions if output_attentions else None
def create_complete_model(device: str = "cuda", **kwargs) -> CustomModel:
model = CustomModel(device=device, **kwargs)
return model
def save_complete_model(model: CustomModel, save_path: str, device: str = "cuda") -> None:
# Ensure folder exists
os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
# Save on CPU to keep checkpoint portable
orig_device = next(model.parameters()).device
model.to("cpu")
torch.save(model.state_dict(), save_path)
print(f"Saved complete model weights to {save_path}")
# Restore model device
model.to(device if isinstance(device, str) else orig_device)
def save_checkpoint(model: CustomModel, optimizer: torch.optim.Optimizer, save_path: str) -> None:
# Ensure folder exists
os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
checkpoint = {
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
}
torch.save(checkpoint, save_path)
print(f"Saved checkpoint to {save_path}")
def load_complete_model(model: CustomModel, load_path: str, device: str = "cpu", strict: bool = True) -> CustomModel:
if not os.path.exists(load_path):
print(f"No weights found at {load_path}")
model.to(device)
return model
# Load to CPU first, then move to target device
state = torch.load(load_path, map_location="cpu")
missing, unexpected = model.load_state_dict(state, strict=strict)
if not strict:
if missing:
print(f"[load warning] Missing keys: {missing}")
if unexpected:
print(f"[load warning] Unexpected keys: {unexpected}")
model.to(device)
print(f"Loaded complete model weights from {load_path}")
return model
def load_checkpoint(model: CustomModel, optimizer: torch.optim.Optimizer, load_path: str, device: str = "cpu") -> tuple[CustomModel, torch.optim.Optimizer]:
if not os.path.exists(load_path):
print(f"No checkpoint found at {load_path}")
model.to(device)
return model, optimizer
# Load to CPU first, then move to target device
checkpoint = torch.load(load_path, map_location="cpu")
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
model.to(device)
print(f"Loaded checkpoint from {load_path}")
return model, optimizer
from transformers import AutoImageProcessor
from PIL import Image
import logging
import re
# Configure basic logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# ==============================================================================
# 1. Architecture Definition (MLP)
# ==============================================================================
class EmbeddingClassifier(nn.Module):
"""
Flexible MLP Classifier: Input Embeddings -> Hidden Layers -> Logits.
"""
def __init__(self, embedding_dim, num_classes, custom_dims=(512, 256, 256),
activation="gelu", dropout=0.05, bn=False, use_layernorm=True):
super().__init__()
layers = []
# First layer: Embeddings -> First hidden dimension
layers.append(nn.Linear(embedding_dim, custom_dims[0]))
if use_layernorm: layers.append(nn.LayerNorm(custom_dims[0]))
elif bn: layers.append(nn.BatchNorm1d(custom_dims[0]))
layers.append(nn.GELU() if activation.lower() == "gelu" else nn.ReLU())
if dropout > 0: layers.append(nn.Dropout(dropout))
# Intermediate layers
for i in range(len(custom_dims) - 1):
layers.append(nn.Linear(custom_dims[i], custom_dims[i + 1]))
if use_layernorm: layers.append(nn.LayerNorm(custom_dims[i + 1]))
elif bn: layers.append(nn.BatchNorm1d(custom_dims[i + 1]))
layers.append(nn.GELU() if activation.lower() == "gelu" else nn.ReLU())
if dropout > 0: layers.append(nn.Dropout(dropout))
# Final layer: Last hidden dim -> Num classes (Logits)
layers.append(nn.Linear(custom_dims[-1], num_classes))
self.classifier = nn.Sequential(*layers)
def forward(self, embeddings):
return self.classifier(embeddings)
# ==============================================================================
# 2. Prediction Wrapper Class
# ==============================================================================
class ChestXrayPredictor:
"""
Wrapper class responsible for receiving an image, processing it,
and returning class probabilities.
"""
def __init__(self, base_model, classifier, processor, label_cols, device):
self.base_model = base_model
self.classifier = classifier
self.processor = processor
self.label_cols = label_cols
self.device = device
# Ensure models are in eval mode
self.base_model.eval()
self.classifier.eval()
def predict(self, image_source):
"""
Runs inference on a single image.
Args:
image_source: File path (str) or PIL.Image object.
Returns:
dict: { "Class_Name": probability (0.0 - 1.0) }
"""
try:
# 1. Flexible Input Handling (Path or Object)
if isinstance(image_source, str):
image = Image.open(image_source).convert('RGB')
else:
image = image_source.convert('RGB')
# 2. Preprocessing
inputs = self.processor(images=image, return_tensors="pt")
pixel_values = inputs['pixel_values'].to(self.device)
# 3. Inference
with torch.no_grad():
# A. Get Embeddings from DINO
outputs = self.base_model(pixel_values=pixel_values)
# Handle different transformer output formats
if hasattr(outputs, 'last_hidden_state'):
embeddings = outputs.last_hidden_state.mean(dim=1)
else:
embeddings = outputs[0].mean(dim=1)
# B. Classify Embeddings
logits = self.classifier(embeddings)
# Convert to standard Python float list for JSON serialization
probs = torch.sigmoid(logits).cpu().numpy()[0].tolist()
# 4. Format Output
return {
label: round(prob, 4)
for label, prob in zip(self.label_cols, probs)
}
except Exception as e:
logger.error(f"Error predicting image: {e}")
return {"error": str(e)}
# ==============================================================================
# 3. Factory Function (The "Builder")
# ==============================================================================
def create_classifier(checkpoint_path, model_id="facebook/dinov3-vits16-pretrain-lvd1689m", device=None):
"""
Loads the checkpoint, reconstructs the specific architecture,
and returns a ready-to-use ChestXrayPredictor instance.
Args:
checkpoint_path (str): Path to the .pth file.
model_id (str): HuggingFace model ID for DINO.
device (str, optional): 'cuda' or 'cpu'. Auto-detects if None.
Returns:
ChestXrayPredictor: Initialized object ready for prediction.
"""
device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"🔄 Starting model initialization on: {device}")
try:
# A. Load Checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
label_cols = checkpoint.get('label_cols', [
"Cardiomegaly", "Consolidation", "Edema",
"Atelectasis", "Pleural Effusion", "No Findings"
])
# B. Load Base Model (DINO)
logger.info("🤖 Loading DINO backbone...")
base_model = AutoModel.from_pretrained(model_id).to(device)
# Load fine-tuned DINO weights if they exist in checkpoint
if 'base_model_state_dict' in checkpoint:
base_model.load_state_dict(checkpoint['base_model_state_dict'])
logger.info(" - Fine-tuned DINO weights loaded from checkpoint.")
else:
logger.info(" - Using default pre-trained DINO weights.")
processor = AutoImageProcessor.from_pretrained(model_id)
# C. Detect Embedding Dimension
if hasattr(base_model.config, 'hidden_size'):
embedding_dim = base_model.config.hidden_size
else:
# Dummy inference to detect output size
with torch.no_grad():
dummy = torch.randn(1, 3, 224, 224).to(device)
out = base_model(pixel_values=dummy)
embedding_dim = out.last_hidden_state.shape[-1]
# D. Reconstruct Classifier Architecture
logger.info("🏗️ Reconstructing classifier architecture...")
model_state = checkpoint['model_state_dict']
classifier = _build_mlp_from_state(model_state, embedding_dim)
# Load classifier weights
classifier.load_state_dict(model_state)
classifier.to(device)
logger.info("✅ Model created successfully.")
# E. Return the Wrapper Instance
return ChestXrayPredictor(base_model, classifier, processor, label_cols, device)
except Exception as e:
logger.error(f"❌ Fatal error creating the classifier: {e}")
raise e
def _build_mlp_from_state(model_state, embedding_dim):
"""
Private helper function to inspect state_dict and rebuild the MLP architecture.
"""
linear_layers = []
for key, val in model_state.items():
# Look for 2D weights (Linear layers) inside the classifier
if 'classifier' in key and key.endswith('.weight') and len(val.shape) == 2:
match = re.search(r'classifier\.(\d+)\.weight', key)
if match:
layer_idx = int(match.group(1))
linear_layers.append((layer_idx, val.shape[1], val.shape[0])) # idx, in_features, out_features
if not linear_layers:
raise ValueError("No linear layers found in checkpoint. Check architecture.")
# Sort by layer index to ensure correct order
linear_layers.sort(key=lambda x: x[0])
num_classes = linear_layers[-1][2]
hidden_dims = tuple([x[2] for x in linear_layers[:-1]])
# Detect Normalization types
uses_bn = any('running_mean' in k for k in model_state.keys())
has_norm = any(k.endswith('.weight') and len(model_state[k].shape) == 1 for k in model_state.keys() if 'classifier' in k)
uses_layernorm = has_norm and not uses_bn
return EmbeddingClassifier(
embedding_dim=embedding_dim,
num_classes=num_classes,
custom_dims=hidden_dims,
bn=uses_bn,
use_layernorm=uses_layernorm
)