serdaryildiz's picture
Update Model/dino/dino.py
f81a237 verified
import numpy
import torch
from PIL import Image
from torch import nn
from torchvision import transforms
preprocess = transforms.Compose([transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
class DinoV2(nn.Module):
def __init__(self, model_name):
super().__init__()
self.vision_encoder = torch.hub.load('facebookresearch/dinov2', model_name)
self.vision_encoder = self.vision_encoder.eval()
return
def forward(self, x):
return self.vision_encoder.forward_features(x)['x_norm_patchtokens']
def get_output_dim(self):
with torch.no_grad():
dummpy_input_image = preprocess(Image.fromarray(numpy.zeros((512, 512, 3), dtype=numpy.uint8))).to(
next(self.parameters()).device)
encoder_output_size = self.vision_encoder(dummpy_input_image.unsqueeze(0)).shape[-1]
return encoder_output_size