Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoModel, GPT2Tokenizer | |
| from utils.modifiedGPT2 import create_decoder | |
| from utils.layer_mask import gaussian_layer_stack_pipeline | |
| class DINOEncoder(nn.Module): | |
| def __init__(self, model_id="facebook/dinov3-vits16-pretrain-lvd1689m", freeze=True): | |
| super().__init__() | |
| self.model = AutoModel.from_pretrained(model_id) | |
| if freeze: | |
| for p in self.model.parameters(): | |
| p.requires_grad = False | |
| def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: | |
| """ | |
| pixel_values: [B, C, H, W] | |
| returns patches: [B, Np, Cenc] | |
| """ | |
| out = self.model(pixel_values=pixel_values) | |
| tokens = out.last_hidden_state # [B, 1+Np, Cenc] (CLS + patches) for ViT-like | |
| # Skip a few special tokens if your backbone adds them; adjust as needed. | |
| patches = tokens[:, 5:, :] # [B, Np, Cenc] | |
| return patches | |
| class DinoUNet(nn.Module): | |
| def __init__(self, model_name="facebook/dinov3-convnext-small-pretrain-lvd1689m", freeze=True): | |
| super().__init__() | |
| self.encoder = AutoModel.from_pretrained(model_name) | |
| # NOTE: confirm channels of the chosen hidden state; 768 is common for small convnext/dinov3 | |
| self.channel_adapter = nn.Conv2d(768, 512, kernel_size=1) | |
| self.decoder = nn.Sequential( | |
| nn.Conv2d(512, 256, 3, padding=1), nn.ReLU(inplace=True), | |
| nn.ConvTranspose2d(256, 128, 2, stride=2), nn.ReLU(inplace=True), | |
| nn.ConvTranspose2d(128, 64, 2, stride=2), nn.ReLU(inplace=True), | |
| nn.Conv2d(64, 1, 1) | |
| ) | |
| if freeze: | |
| for m in (self.encoder, self.channel_adapter, self.decoder): | |
| for p in m.parameters(): | |
| p.requires_grad = False | |
| def forward(self, x: torch.Tensor, num_layers: int) -> torch.Tensor: | |
| """ | |
| x: [B, C, H, W]; returns mask: [B, 1, H', W'] (your upsampling stack defines H',W') | |
| """ | |
| enc_feats = self.encoder(x, output_hidden_states=True, return_dict=True) | |
| # take the last 4D feature map from hidden_states | |
| feats = next(h for h in reversed(enc_feats.hidden_states) if isinstance(h, torch.Tensor) and h.ndim == 4) | |
| feats = self.channel_adapter(feats) | |
| pred = self.decoder(feats) # (B,1,h,w) | |
| _, _, segmentation_mask = gaussian_layer_stack_pipeline(pred, n_layers = num_layers) | |
| return segmentation_mask # [B, num_layers, h, w] | |
| class LinearProjection(nn.Module): | |
| def __init__(self, input_dim=384, output_dim=768, freeze=False): | |
| super().__init__() | |
| self.proj = nn.Linear(input_dim, output_dim) | |
| if freeze: | |
| for p in self.proj.parameters(): | |
| p.requires_grad = False | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # x: [B, Np, input_dim] -> [B, Np, output_dim] | |
| return self.proj(x) | |
| class CustomModel(nn.Module): | |
| def __init__( | |
| self, | |
| device: str = "cuda", | |
| ENCODER_MODEL_PATH: str | None = "dino_encoder.pth", | |
| SEGMENTER_MODEL_PATH: str | None = "dino_segmenter.pth", | |
| DECODER_MODEL_PATH: str | None = "dino_decoder.pth", | |
| LINEAR_PROJECTION_PATH: str | None = "linear_projection.pth", | |
| freeze_encoder: bool = True, | |
| freeze_segmenter: bool = True, | |
| freeze_linear_projection: bool = False, | |
| freeze_decoder: bool = False, | |
| attention_implementation: str = "sdpa", | |
| ): | |
| super().__init__() | |
| self.device = torch.device(device) | |
| # Encoder | |
| self.encoder = DINOEncoder() | |
| if ENCODER_MODEL_PATH and os.path.exists(ENCODER_MODEL_PATH): | |
| self.encoder.load_state_dict(torch.load(ENCODER_MODEL_PATH, map_location="cpu"), strict=False) | |
| print("Loaded encoder weights from", ENCODER_MODEL_PATH) | |
| if freeze_encoder: | |
| self.encoder.eval() | |
| # Segmenter | |
| self.segmenter = DinoUNet() | |
| if SEGMENTER_MODEL_PATH and os.path.exists(SEGMENTER_MODEL_PATH): | |
| self.segmenter.load_state_dict(torch.load(SEGMENTER_MODEL_PATH, map_location="cpu"), strict=False) | |
| print("Loaded segmenter weights from", SEGMENTER_MODEL_PATH) | |
| if freeze_segmenter: | |
| self.segmenter.eval() | |
| # Decoder (modified GPT-2) | |
| self.decoder = create_decoder(attention=attention_implementation) # must expose .config.hidden_size & .config.num_hidden_layers | |
| if DECODER_MODEL_PATH and os.path.exists(DECODER_MODEL_PATH): | |
| self.decoder.load_state_dict(torch.load(DECODER_MODEL_PATH, map_location="cpu"), strict=False) | |
| print("Loaded decoder weights from", DECODER_MODEL_PATH) | |
| if freeze_decoder: | |
| self.decoder.eval() | |
| # Linear projection: DINO hidden -> GPT2 hidden | |
| enc_h = self.encoder.model.config.hidden_size | |
| dec_h = self.decoder.config.hidden_size | |
| self.linear_projection = LinearProjection(input_dim=enc_h, output_dim=dec_h) | |
| if LINEAR_PROJECTION_PATH and os.path.exists(LINEAR_PROJECTION_PATH): | |
| self.linear_projection.load_state_dict(torch.load(LINEAR_PROJECTION_PATH, map_location="cpu"), strict=False) | |
| print("Loaded linear projection weights from", LINEAR_PROJECTION_PATH) | |
| if freeze_linear_projection: | |
| self.linear_projection.eval() | |
| # Tokenizer (pad token for GPT-2) | |
| self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
| if self.tokenizer.pad_token_id is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| self.pad_token_id = self.tokenizer.pad_token_id # ✅ use ID, not string | |
| self.num_layers = self.decoder.config.num_hidden_layers | |
| # move everything once | |
| self.to(self.device) | |
| def forward(self, pixel_values: torch.Tensor, tgt_ids: torch.Tensor | None = None, **kwargs) -> dict: | |
| """ | |
| pixel_values: [B,C,H,W], float | |
| tgt_ids: [B,T], long (token IDs), padded with pad_token_id if any padding is present | |
| """ | |
| pixel_values = pixel_values.to(self.device, non_blocking=True) | |
| # Visual path | |
| patches = self.encoder(pixel_values) # [B,Np,Cenc] | |
| projected_patches = self.linear_projection(patches) # [B,Np,n_embd] | |
| # Segmentation path per layer | |
| segmented_layers = self.segmenter(pixel_values, self.num_layers) # [B,n_layers,H,W] (per current decoder) | |
| # Text path (optional teacher-forced training) | |
| labels = None | |
| if tgt_ids is not None: | |
| if tgt_ids.dtype != torch.long: | |
| tgt_ids = tgt_ids.long() | |
| tgt_ids = tgt_ids.to(self.device, non_blocking=True) # [B,T] | |
| text_embeds = self.decoder.transformer.wte(tgt_ids) # [B,T,n_embd] | |
| inputs_embeds = torch.cat([projected_patches, text_embeds], dim=1) # [B,Np+T,n_embd] | |
| # Labels: ignore prefix tokens (vision) and PADs in text | |
| B, Np, _ = projected_patches.shape | |
| labels_prefix = torch.full((B, Np), -100, device=self.device, dtype=torch.long) | |
| text_labels = tgt_ids.clone() | |
| text_labels[text_labels == self.pad_token_id] = -100 # ✅ compare to ID | |
| labels = torch.cat([labels_prefix, text_labels], dim=1) # [B,Np+T] | |
| else: | |
| inputs_embeds = projected_patches | |
| # Decoder forward | |
| out = self.decoder(inputs_embeds=inputs_embeds, segmentation_mask=segmented_layers, labels=labels, **kwargs) | |
| return out | |
| def generate( | |
| self, | |
| pixel_values: torch.Tensor, | |
| max_new_tokens: int = 100, | |
| output_attentions: bool = False, | |
| ) -> torch.Tensor: | |
| """ | |
| pixel_values: [B,C,H,W], float | |
| returns generated_ids: [B, T] | |
| """ | |
| pixel_values = pixel_values.to(self.device, non_blocking=True) | |
| # Visual path | |
| patches = self.encoder(pixel_values) # [B,Np,Cenc] | |
| projected_patches = self.linear_projection(patches) # [B,Np,n_embd] | |
| # Segmentation path per layer | |
| segmented_layers = self.segmenter(pixel_values, self.num_layers) # [B,n_layers,H,W] (per current decoder) | |
| # Generate | |
| output = self.decoder.generate( | |
| inputs_embeds=projected_patches, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, | |
| repetition_penalty=1.2, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| pad_token_id=self.pad_token_id, | |
| use_cache=True, | |
| segmentation_mask=segmented_layers, | |
| prefix_allowed_length=0, | |
| plot_attention_mask=False, | |
| plot_attention_mask_layer=[], | |
| plot_attention_map=False, | |
| plot_attention_map_layer=[], | |
| plot_attention_map_generation=0, | |
| output_attentions=output_attentions, | |
| return_dict_in_generate=True, | |
| ) | |
| # Remove prefix tokens (vision) | |
| generated_ids = output.sequences#[:, projected_patches.shape[1]:] # [B,T] | |
| generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | |
| return generated_ids, generated_text, output.attentions if output_attentions else None | |
| def create_complete_model(device: str = "cuda", **kwargs) -> CustomModel: | |
| model = CustomModel(device=device, **kwargs) | |
| return model | |
| def save_complete_model(model: CustomModel, save_path: str, device: str = "cuda") -> None: | |
| # Ensure folder exists | |
| os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) | |
| # Save on CPU to keep checkpoint portable | |
| orig_device = next(model.parameters()).device | |
| model.to("cpu") | |
| torch.save(model.state_dict(), save_path) | |
| print(f"Saved complete model weights to {save_path}") | |
| # Restore model device | |
| model.to(device if isinstance(device, str) else orig_device) | |
| def save_checkpoint(model: CustomModel, optimizer: torch.optim.Optimizer, save_path: str) -> None: | |
| # Ensure folder exists | |
| os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) | |
| checkpoint = { | |
| "model_state_dict": model.state_dict(), | |
| "optimizer_state_dict": optimizer.state_dict(), | |
| } | |
| torch.save(checkpoint, save_path) | |
| print(f"Saved checkpoint to {save_path}") | |
| def load_complete_model(model: CustomModel, load_path: str, device: str = "cpu", strict: bool = True) -> CustomModel: | |
| if not os.path.exists(load_path): | |
| print(f"No weights found at {load_path}") | |
| model.to(device) | |
| return model | |
| # Load to CPU first, then move to target device | |
| state = torch.load(load_path, map_location="cpu") | |
| missing, unexpected = model.load_state_dict(state, strict=strict) | |
| if not strict: | |
| if missing: | |
| print(f"[load warning] Missing keys: {missing}") | |
| if unexpected: | |
| print(f"[load warning] Unexpected keys: {unexpected}") | |
| model.to(device) | |
| print(f"Loaded complete model weights from {load_path}") | |
| return model | |
| def load_checkpoint(model: CustomModel, optimizer: torch.optim.Optimizer, load_path: str, device: str = "cpu") -> tuple[CustomModel, torch.optim.Optimizer]: | |
| if not os.path.exists(load_path): | |
| print(f"No checkpoint found at {load_path}") | |
| model.to(device) | |
| return model, optimizer | |
| # Load to CPU first, then move to target device | |
| checkpoint = torch.load(load_path, map_location="cpu") | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) | |
| model.to(device) | |
| print(f"Loaded checkpoint from {load_path}") | |
| return model, optimizer | |
| from transformers import AutoImageProcessor | |
| from PIL import Image | |
| import logging | |
| import re | |
| # Configure basic logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # ============================================================================== | |
| # 1. Architecture Definition (MLP) | |
| # ============================================================================== | |
| class EmbeddingClassifier(nn.Module): | |
| """ | |
| Flexible MLP Classifier: Input Embeddings -> Hidden Layers -> Logits. | |
| """ | |
| def __init__(self, embedding_dim, num_classes, custom_dims=(512, 256, 256), | |
| activation="gelu", dropout=0.05, bn=False, use_layernorm=True): | |
| super().__init__() | |
| layers = [] | |
| # First layer: Embeddings -> First hidden dimension | |
| layers.append(nn.Linear(embedding_dim, custom_dims[0])) | |
| if use_layernorm: layers.append(nn.LayerNorm(custom_dims[0])) | |
| elif bn: layers.append(nn.BatchNorm1d(custom_dims[0])) | |
| layers.append(nn.GELU() if activation.lower() == "gelu" else nn.ReLU()) | |
| if dropout > 0: layers.append(nn.Dropout(dropout)) | |
| # Intermediate layers | |
| for i in range(len(custom_dims) - 1): | |
| layers.append(nn.Linear(custom_dims[i], custom_dims[i + 1])) | |
| if use_layernorm: layers.append(nn.LayerNorm(custom_dims[i + 1])) | |
| elif bn: layers.append(nn.BatchNorm1d(custom_dims[i + 1])) | |
| layers.append(nn.GELU() if activation.lower() == "gelu" else nn.ReLU()) | |
| if dropout > 0: layers.append(nn.Dropout(dropout)) | |
| # Final layer: Last hidden dim -> Num classes (Logits) | |
| layers.append(nn.Linear(custom_dims[-1], num_classes)) | |
| self.classifier = nn.Sequential(*layers) | |
| def forward(self, embeddings): | |
| return self.classifier(embeddings) | |
| # ============================================================================== | |
| # 2. Prediction Wrapper Class | |
| # ============================================================================== | |
| class ChestXrayPredictor: | |
| """ | |
| Wrapper class responsible for receiving an image, processing it, | |
| and returning class probabilities. | |
| """ | |
| def __init__(self, base_model, classifier, processor, label_cols, device): | |
| self.base_model = base_model | |
| self.classifier = classifier | |
| self.processor = processor | |
| self.label_cols = label_cols | |
| self.device = device | |
| # Ensure models are in eval mode | |
| self.base_model.eval() | |
| self.classifier.eval() | |
| def predict(self, image_source): | |
| """ | |
| Runs inference on a single image. | |
| Args: | |
| image_source: File path (str) or PIL.Image object. | |
| Returns: | |
| dict: { "Class_Name": probability (0.0 - 1.0) } | |
| """ | |
| try: | |
| # 1. Flexible Input Handling (Path or Object) | |
| if isinstance(image_source, str): | |
| image = Image.open(image_source).convert('RGB') | |
| else: | |
| image = image_source.convert('RGB') | |
| # 2. Preprocessing | |
| inputs = self.processor(images=image, return_tensors="pt") | |
| pixel_values = inputs['pixel_values'].to(self.device) | |
| # 3. Inference | |
| with torch.no_grad(): | |
| # A. Get Embeddings from DINO | |
| outputs = self.base_model(pixel_values=pixel_values) | |
| # Handle different transformer output formats | |
| if hasattr(outputs, 'last_hidden_state'): | |
| embeddings = outputs.last_hidden_state.mean(dim=1) | |
| else: | |
| embeddings = outputs[0].mean(dim=1) | |
| # B. Classify Embeddings | |
| logits = self.classifier(embeddings) | |
| # Convert to standard Python float list for JSON serialization | |
| probs = torch.sigmoid(logits).cpu().numpy()[0].tolist() | |
| # 4. Format Output | |
| return { | |
| label: round(prob, 4) | |
| for label, prob in zip(self.label_cols, probs) | |
| } | |
| except Exception as e: | |
| logger.error(f"Error predicting image: {e}") | |
| return {"error": str(e)} | |
| # ============================================================================== | |
| # 3. Factory Function (The "Builder") | |
| # ============================================================================== | |
| def create_classifier(checkpoint_path, model_id="facebook/dinov3-vits16-pretrain-lvd1689m", device=None): | |
| """ | |
| Loads the checkpoint, reconstructs the specific architecture, | |
| and returns a ready-to-use ChestXrayPredictor instance. | |
| Args: | |
| checkpoint_path (str): Path to the .pth file. | |
| model_id (str): HuggingFace model ID for DINO. | |
| device (str, optional): 'cuda' or 'cpu'. Auto-detects if None. | |
| Returns: | |
| ChestXrayPredictor: Initialized object ready for prediction. | |
| """ | |
| device = device or ('cuda' if torch.cuda.is_available() else 'cpu') | |
| logger.info(f"🔄 Starting model initialization on: {device}") | |
| try: | |
| # A. Load Checkpoint | |
| checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) | |
| label_cols = checkpoint.get('label_cols', [ | |
| "Cardiomegaly", "Consolidation", "Edema", | |
| "Atelectasis", "Pleural Effusion", "No Findings" | |
| ]) | |
| # B. Load Base Model (DINO) | |
| logger.info("🤖 Loading DINO backbone...") | |
| base_model = AutoModel.from_pretrained(model_id).to(device) | |
| # Load fine-tuned DINO weights if they exist in checkpoint | |
| if 'base_model_state_dict' in checkpoint: | |
| base_model.load_state_dict(checkpoint['base_model_state_dict']) | |
| logger.info(" - Fine-tuned DINO weights loaded from checkpoint.") | |
| else: | |
| logger.info(" - Using default pre-trained DINO weights.") | |
| processor = AutoImageProcessor.from_pretrained(model_id) | |
| # C. Detect Embedding Dimension | |
| if hasattr(base_model.config, 'hidden_size'): | |
| embedding_dim = base_model.config.hidden_size | |
| else: | |
| # Dummy inference to detect output size | |
| with torch.no_grad(): | |
| dummy = torch.randn(1, 3, 224, 224).to(device) | |
| out = base_model(pixel_values=dummy) | |
| embedding_dim = out.last_hidden_state.shape[-1] | |
| # D. Reconstruct Classifier Architecture | |
| logger.info("🏗️ Reconstructing classifier architecture...") | |
| model_state = checkpoint['model_state_dict'] | |
| classifier = _build_mlp_from_state(model_state, embedding_dim) | |
| # Load classifier weights | |
| classifier.load_state_dict(model_state) | |
| classifier.to(device) | |
| logger.info("✅ Model created successfully.") | |
| # E. Return the Wrapper Instance | |
| return ChestXrayPredictor(base_model, classifier, processor, label_cols, device) | |
| except Exception as e: | |
| logger.error(f"❌ Fatal error creating the classifier: {e}") | |
| raise e | |
| def _build_mlp_from_state(model_state, embedding_dim): | |
| """ | |
| Private helper function to inspect state_dict and rebuild the MLP architecture. | |
| """ | |
| linear_layers = [] | |
| for key, val in model_state.items(): | |
| # Look for 2D weights (Linear layers) inside the classifier | |
| if 'classifier' in key and key.endswith('.weight') and len(val.shape) == 2: | |
| match = re.search(r'classifier\.(\d+)\.weight', key) | |
| if match: | |
| layer_idx = int(match.group(1)) | |
| linear_layers.append((layer_idx, val.shape[1], val.shape[0])) # idx, in_features, out_features | |
| if not linear_layers: | |
| raise ValueError("No linear layers found in checkpoint. Check architecture.") | |
| # Sort by layer index to ensure correct order | |
| linear_layers.sort(key=lambda x: x[0]) | |
| num_classes = linear_layers[-1][2] | |
| hidden_dims = tuple([x[2] for x in linear_layers[:-1]]) | |
| # Detect Normalization types | |
| uses_bn = any('running_mean' in k for k in model_state.keys()) | |
| has_norm = any(k.endswith('.weight') and len(model_state[k].shape) == 1 for k in model_state.keys() if 'classifier' in k) | |
| uses_layernorm = has_norm and not uses_bn | |
| return EmbeddingClassifier( | |
| embedding_dim=embedding_dim, | |
| num_classes=num_classes, | |
| custom_dims=hidden_dims, | |
| bn=uses_bn, | |
| use_layernorm=uses_layernorm | |
| ) | |