File size: 1,093 Bytes
2b8e195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f81a237
2b8e195
 
 
 
 
 
 
 
f81a237
2b8e195
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import numpy
import torch
from PIL import Image
from torch import nn
from torchvision import transforms

preprocess = transforms.Compose([transforms.Resize((224, 224)),
                                 transforms.ToTensor(),
                                 transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                      std=[0.229, 0.224, 0.225])])


class DinoV2(nn.Module):

    def __init__(self, model_name):
        super().__init__()
        self.vision_encoder = torch.hub.load('facebookresearch/dinov2', model_name)
        self.vision_encoder = self.vision_encoder.eval()
        return

    def forward(self, x):
        return self.vision_encoder.forward_features(x)['x_norm_patchtokens']

    def get_output_dim(self):
        with torch.no_grad():
            dummpy_input_image = preprocess(Image.fromarray(numpy.zeros((512, 512, 3), dtype=numpy.uint8))).to(
                next(self.parameters()).device)
            encoder_output_size = self.vision_encoder(dummpy_input_image.unsqueeze(0)).shape[-1]
        return encoder_output_size