Upload VolumeMaker.py
Browse files- VolumeMaker.py +591 -0
VolumeMaker.py
ADDED
|
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
#
|
| 4 |
+
# Copyright 2021 Gabriele Orlando
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
import torch,math
|
| 19 |
+
from pyuul.sources.globalVariables import *
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import random
|
| 23 |
+
|
| 24 |
+
def setup_seed(seed):
|
| 25 |
+
torch.manual_seed(seed)
|
| 26 |
+
torch.cuda.manual_seed_all(seed)
|
| 27 |
+
np.random.seed(seed)
|
| 28 |
+
random.seed(seed)
|
| 29 |
+
torch.backends.cudnn.deterministic = True
|
| 30 |
+
setup_seed(100)
|
| 31 |
+
|
| 32 |
+
class Voxels(torch.nn.Module):
|
| 33 |
+
|
| 34 |
+
def __init__(self, device=torch.device("cpu"),sparse=True):
|
| 35 |
+
"""
|
| 36 |
+
Constructor for the Voxels class, which builds the main PyUUL object.
|
| 37 |
+
|
| 38 |
+
Parameters
|
| 39 |
+
----------
|
| 40 |
+
|
| 41 |
+
device : torch.device
|
| 42 |
+
The device on which the model should run. E.g. torch.device("cuda") or torch.device("cpu:0")
|
| 43 |
+
sparse : bool
|
| 44 |
+
Use sparse tensors calculation when possible
|
| 45 |
+
|
| 46 |
+
Returns
|
| 47 |
+
-------
|
| 48 |
+
"""
|
| 49 |
+
super(Voxels, self).__init__()
|
| 50 |
+
|
| 51 |
+
self.sparse=sparse
|
| 52 |
+
self.boxsize = None
|
| 53 |
+
self.dev = device
|
| 54 |
+
|
| 55 |
+
def __transform_coordinates(self,coords,radius=None):
|
| 56 |
+
"""
|
| 57 |
+
Private function that transform the coordinates to fit them in the 3d box. It also takes care of the resolution.
|
| 58 |
+
|
| 59 |
+
Parameters
|
| 60 |
+
----------
|
| 61 |
+
coords : torch.Tensor
|
| 62 |
+
Coordinates of the atoms. Shape ( batch, numberOfAtoms, 3 )
|
| 63 |
+
radius : torch.Tensor or None
|
| 64 |
+
Radius of the atoms. Shape ( batch, numberOfAtoms )
|
| 65 |
+
|
| 66 |
+
Returns
|
| 67 |
+
-------
|
| 68 |
+
coords : torch.Tensor
|
| 69 |
+
transformed coordinates
|
| 70 |
+
|
| 71 |
+
"""
|
| 72 |
+
coords = (coords*self.dilatation)- self.translation
|
| 73 |
+
if not radius is None:
|
| 74 |
+
radius = radius*self.dilatation
|
| 75 |
+
return coords,radius
|
| 76 |
+
else:
|
| 77 |
+
return coords
|
| 78 |
+
'''
|
| 79 |
+
def get_coords_voxel(self, voxel_indices, resolution):
|
| 80 |
+
"""
|
| 81 |
+
returns the coordinates of the center of the voxel provided its indices.
|
| 82 |
+
|
| 83 |
+
Parameters
|
| 84 |
+
----------
|
| 85 |
+
voxel_indices : torch.Tensor
|
| 86 |
+
Coordinates of the atoms. Shape ( batch, numberOfAtoms, 3 )
|
| 87 |
+
resolution : torch.Tensor or None
|
| 88 |
+
Radius of the atoms. Shape ( batch, numberOfAtoms )
|
| 89 |
+
|
| 90 |
+
Returns
|
| 91 |
+
-------
|
| 92 |
+
"""
|
| 93 |
+
#voxel_indices is a n,3 long tensor
|
| 94 |
+
centersCoords = voxel_indices + 0.5*resolution
|
| 95 |
+
return (centersCoords + self.translation)/self.dilatation
|
| 96 |
+
'''
|
| 97 |
+
def __define_spatial_conformation(self,mincoords,cubes_around_atoms_dim,resolution):
|
| 98 |
+
"""
|
| 99 |
+
Private function that defines the space of the volume. Takes resolution and margins into consideration.
|
| 100 |
+
|
| 101 |
+
Parameters
|
| 102 |
+
----------
|
| 103 |
+
mincoords : torch.Tensor
|
| 104 |
+
minimum coordinates of each macromolecule of the batch. Shape ( batch, 3 )
|
| 105 |
+
cubes_around_atoms_dim : int
|
| 106 |
+
maximum distance in number of voxels to check for atom contribution to occupancy of a voxel
|
| 107 |
+
resolution : float
|
| 108 |
+
side in A of a voxel. The lower this value is the higher the resolution of the final representation will be
|
| 109 |
+
Returns
|
| 110 |
+
-------
|
| 111 |
+
"""
|
| 112 |
+
self.translation=(mincoords-(cubes_around_atoms_dim)).unsqueeze(1)
|
| 113 |
+
self.dilatation = 1.0/resolution
|
| 114 |
+
|
| 115 |
+
'''
|
| 116 |
+
def find_cubes_indices(self,coords):
|
| 117 |
+
coords_scaled = self.transform_coordinates(coords)
|
| 118 |
+
return torch.trunc(coords_scaled.data).long()
|
| 119 |
+
'''
|
| 120 |
+
|
| 121 |
+
def forward( self,coords, radius,channels,numberchannels=None,resolution=1, cubes_around_atoms_dim=5, steepness=10,function="sigmoid"):
|
| 122 |
+
"""
|
| 123 |
+
Voxels representation of the macromolecules
|
| 124 |
+
|
| 125 |
+
Parameters
|
| 126 |
+
----------
|
| 127 |
+
coords : torch.Tensor
|
| 128 |
+
Coordinates of the atoms. Shape ( batch, numberOfAtoms, 3 ). Can be calculated from a PDB file using utils.parsePDB
|
| 129 |
+
radius : torch.Tensor
|
| 130 |
+
Radius of the atoms. Shape ( batch, numberOfAtoms ). Can be calculated from a PDB file using utils.parsePDB and utils.atomlistToRadius
|
| 131 |
+
channels: torch.LongTensor
|
| 132 |
+
channels of the atoms. Atoms of the same type shold belong to the same channel. Shape ( batch, numberOfAtoms ). Can be calculated from a PDB file using utils.parsePDB and utils.atomlistToChannels
|
| 133 |
+
numberchannels : int or None
|
| 134 |
+
maximum number of channels. if None, max(atNameHashing) + 1 is used
|
| 135 |
+
|
| 136 |
+
cubes_around_atoms_dim : int
|
| 137 |
+
maximum distance in number of voxels for which the contribution to occupancy is taken into consideration. Every atom that is farer than cubes_around_atoms_dim voxels from the center of a voxel does no give any contribution to the relative voxel occupancy
|
| 138 |
+
resolution : float
|
| 139 |
+
side in A of a voxel. The lower this value is the higher the resolution of the final representation will be
|
| 140 |
+
|
| 141 |
+
steepness : float or int
|
| 142 |
+
steepness of the sigmoid occupancy function.
|
| 143 |
+
|
| 144 |
+
function : "sigmoid" or "gaussian"
|
| 145 |
+
occupancy function to use. Can be sigmoid (every atom has a sigmoid shaped occupancy function) or gaussian (based on Li et al. 2014)
|
| 146 |
+
Returns
|
| 147 |
+
-------
|
| 148 |
+
volume : torch.Tensor
|
| 149 |
+
voxel representation of the macromolecules in the batch. Shape ( batch, channels, x,y,z), where x,y,z are the size of the 3D volume in which the macromolecules have been represented
|
| 150 |
+
|
| 151 |
+
"""
|
| 152 |
+
padding_mask = ~channels.eq(PADDING_INDEX)
|
| 153 |
+
if numberchannels is None:
|
| 154 |
+
numberchannels = int(channels[padding_mask].max().cpu().data+1)
|
| 155 |
+
self.featureVectorSize = numberchannels
|
| 156 |
+
self.function = function
|
| 157 |
+
|
| 158 |
+
arange_type = torch.int16
|
| 159 |
+
|
| 160 |
+
gx = torch.arange(-cubes_around_atoms_dim, cubes_around_atoms_dim + 1, device=self.dev, dtype=arange_type)
|
| 161 |
+
gy = torch.arange(-cubes_around_atoms_dim, cubes_around_atoms_dim + 1, device=self.dev, dtype=arange_type)
|
| 162 |
+
gz = torch.arange(-cubes_around_atoms_dim, cubes_around_atoms_dim + 1, device=self.dev, dtype=arange_type)
|
| 163 |
+
self.lato = gx.shape[0]
|
| 164 |
+
|
| 165 |
+
x1 = gx.unsqueeze(1).expand(self.lato, self.lato).unsqueeze(-1)
|
| 166 |
+
x2 = gy.unsqueeze(0).expand(self.lato, self.lato).unsqueeze(-1)
|
| 167 |
+
|
| 168 |
+
xy = torch.cat([x1, x2], dim=-1).unsqueeze(2).expand(self.lato, self.lato, self.lato, 2)
|
| 169 |
+
x3 = gz.unsqueeze(0).unsqueeze(1).expand(self.lato, self.lato, self.lato).unsqueeze(-1)
|
| 170 |
+
|
| 171 |
+
del gx, gy, gz, x1, x2
|
| 172 |
+
|
| 173 |
+
self.standard_cube = torch.cat([xy, x3], dim=-1).unsqueeze(0).unsqueeze(0)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
### definition of the box ###
|
| 178 |
+
# you take the maximum and min coord on each dimension (every prot in the batch shares the same box. In the future we can pack, but I think this is not the bottleneck)
|
| 179 |
+
# I scale by resolution
|
| 180 |
+
# I add the cubes in which I define the gradient. One in the beginning and one at the end --> 2*
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
mincoords = torch.min(coords[:, :, :], dim=1)[0]
|
| 185 |
+
mincoords = torch.trunc(mincoords / resolution)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
box_size_x = (math.ceil(torch.max(coords[padding_mask][:,0])/resolution)-mincoords[:,0].min())+(2*cubes_around_atoms_dim+1)
|
| 189 |
+
box_size_y = (math.ceil(torch.max(coords[padding_mask][:,1])/resolution)-mincoords[:,1].min())+(2*cubes_around_atoms_dim+1)
|
| 190 |
+
box_size_z = (math.ceil(torch.max(coords[padding_mask][:,2])/resolution)-mincoords[:,2].min())+(2*cubes_around_atoms_dim+1)
|
| 191 |
+
#############################
|
| 192 |
+
|
| 193 |
+
self.__define_spatial_conformation(mincoords,cubes_around_atoms_dim,resolution) #define the spatial transforms to coordinates
|
| 194 |
+
coords,radius = self.__transform_coordinates(coords,radius)
|
| 195 |
+
|
| 196 |
+
boxsize = (int(box_size_x),int(box_size_y),int(box_size_z))
|
| 197 |
+
self.boxsize=boxsize
|
| 198 |
+
|
| 199 |
+
#selecting best types for indexing
|
| 200 |
+
if max(boxsize)<256: # i can use byte tensor
|
| 201 |
+
self.dtype_indices=torch.uint8
|
| 202 |
+
else:
|
| 203 |
+
self.dtype_indices = torch.int16
|
| 204 |
+
|
| 205 |
+
if self.function=="sigmoid":
|
| 206 |
+
volume = self.__forward_actual_calculation(coords, boxsize, radius, channels,padding_mask,steepness,resolution)
|
| 207 |
+
elif self.function=="gaussian":
|
| 208 |
+
volume = self.__forward_actual_calculationGaussian(coords, boxsize, radius, channels, padding_mask,resolution)
|
| 209 |
+
return volume
|
| 210 |
+
|
| 211 |
+
def __forward_actual_calculationGaussian(self, coords_scaled, boxsize, radius, atNameHashing, padding_mask,resolution):
|
| 212 |
+
"""
|
| 213 |
+
private function for the calculation of the gaussian voxel occupancy
|
| 214 |
+
|
| 215 |
+
Parameters
|
| 216 |
+
----------
|
| 217 |
+
coords_scaled : torch.LongTensor
|
| 218 |
+
Discrete Coordinates of the atoms. Shape ( batch, numberOfAtoms, 3 )
|
| 219 |
+
boxsize : torch.LongTensor
|
| 220 |
+
The size of the box in which the macromolecules are represented
|
| 221 |
+
radius : torch.Tensor
|
| 222 |
+
Radius of the atoms. Shape ( batch, numberOfAtoms ). Can be calculated from a PDB file using utils.parsePDB and utils.atomlistToRadius
|
| 223 |
+
atNameHashing: torch.LongTensor
|
| 224 |
+
channels of the atoms. Atoms of the same type shold belong to the same channel. Shape ( batch, numberOfAtoms ). Can be calculated from a PDB file using utils.parsePDB and utils.atomlistToChannels
|
| 225 |
+
resolution : float
|
| 226 |
+
side in A of a voxel. The lower this value is the higher the resolution of the final representation will be
|
| 227 |
+
padding_mask : torch.BoolTensor
|
| 228 |
+
tensor to mask the padding. Shape (batch, numberOfAtoms)
|
| 229 |
+
Returns
|
| 230 |
+
-------
|
| 231 |
+
volume : torch.Tensor
|
| 232 |
+
voxel representation of the macromolecules in the batch with Gaussian occupancy function. Shape ( batch, channels, x,y,z), where x,y,z are the size of the 3D volume in which the macromolecules have been represented
|
| 233 |
+
|
| 234 |
+
"""
|
| 235 |
+
batch = coords_scaled.shape[0]
|
| 236 |
+
dev = self.dev
|
| 237 |
+
L = coords_scaled.shape[1]
|
| 238 |
+
|
| 239 |
+
discrete_coordinates = torch.trunc(coords_scaled.data).to(self.dtype_indices)
|
| 240 |
+
|
| 241 |
+
#### making everything in the volume shape
|
| 242 |
+
|
| 243 |
+
# implicit_cube_formation
|
| 244 |
+
radius = radius.unsqueeze(2).unsqueeze(3).unsqueeze(4)
|
| 245 |
+
atNameHashing = atNameHashing.unsqueeze(2).unsqueeze(3).unsqueeze(4)
|
| 246 |
+
coords_scaled = coords_scaled.unsqueeze(2).unsqueeze(3).unsqueeze(4)
|
| 247 |
+
discrete_coordinates = discrete_coordinates.unsqueeze(2).unsqueeze(3).unsqueeze(4)
|
| 248 |
+
distmat_standard_cube = torch.norm(
|
| 249 |
+
coords_scaled - ((discrete_coordinates + self.standard_cube + 1) + 0.5 * resolution), dim=-1).to(
|
| 250 |
+
coords_scaled.dtype)
|
| 251 |
+
|
| 252 |
+
atNameHashing = atNameHashing.long()
|
| 253 |
+
#### old sigmoid stuff
|
| 254 |
+
'''
|
| 255 |
+
exponent = self.steepness*(distmat_standard_cube-radius)
|
| 256 |
+
|
| 257 |
+
exp_mask = exponent.ge(10)
|
| 258 |
+
exponent = torch.masked_fill(exponent,exp_mask, 10)
|
| 259 |
+
|
| 260 |
+
volume_cubes = 1.0/(1.0+torch.exp(exponent))
|
| 261 |
+
'''
|
| 262 |
+
### from doi: 10.1142/S0219633614400021 eq 1
|
| 263 |
+
sigma = 0.93
|
| 264 |
+
exponent = -distmat_standard_cube[padding_mask] ** 2 / (sigma ** 2 * radius[padding_mask] ** 2)
|
| 265 |
+
exp_mask = exponent.ge(10)
|
| 266 |
+
exponent = torch.masked_fill(exponent, exp_mask, 10)
|
| 267 |
+
volume_cubes = torch.exp(exponent)
|
| 268 |
+
|
| 269 |
+
#### index_put everything ###
|
| 270 |
+
batch_list = torch.arange(batch,device=dev).unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1).expand(batch,L,self.lato,self.lato,self.lato)
|
| 271 |
+
|
| 272 |
+
cubes_coords = (discrete_coordinates[padding_mask] + self.standard_cube.squeeze(0) + 1)[~exp_mask]
|
| 273 |
+
atNameHashing = atNameHashing[padding_mask].expand(-1,self.lato,self.lato,self.lato)
|
| 274 |
+
if self.sparse:
|
| 275 |
+
|
| 276 |
+
index_tens = torch.cat(
|
| 277 |
+
[batch_list[padding_mask][~exp_mask].view(-1).unsqueeze(0),
|
| 278 |
+
atNameHashing[~exp_mask].unsqueeze(0),
|
| 279 |
+
cubes_coords[:,0].unsqueeze(0),
|
| 280 |
+
cubes_coords[:,1].unsqueeze(0),
|
| 281 |
+
cubes_coords[:,2].unsqueeze(0),
|
| 282 |
+
])
|
| 283 |
+
#index_tens = torch.cat(index)
|
| 284 |
+
|
| 285 |
+
volume_cubes = volume_cubes[~exp_mask].view(-1)
|
| 286 |
+
volume_cubes = torch.log(1 - volume_cubes.contiguous())
|
| 287 |
+
#powOrExpIsNotImplementedInSparse
|
| 288 |
+
volume = torch.sparse_coo_tensor(indices=index_tens, values=volume_cubes.exp(), size=[batch, self.featureVectorSize, boxsize[0] , boxsize[1] , boxsize[2] ]).coalesce()
|
| 289 |
+
volume = torch.sparse_coo_tensor(volume.indices(),1 - volume.values(), volume.shape)
|
| 290 |
+
|
| 291 |
+
else:
|
| 292 |
+
volume = torch.zeros(batch,boxsize[0]+1,boxsize[1]+1,boxsize[2]+1,self.featureVectorSize,device=dev,dtype=torch.float)
|
| 293 |
+
#index = (batch_list[padding_mask].view(-1),cubes_coords[padding_mask][:,:,:,:,0].view(-1), cubes_coords[padding_mask][:,:,:,:,1].view(-1), cubes_coords[padding_mask][:,:,:,:,2].view(-1), atNameHashing[padding_mask].view(-1) )
|
| 294 |
+
index = (batch_list[padding_mask][~exp_mask].view(-1).long(),
|
| 295 |
+
cubes_coords[:,0].long(),
|
| 296 |
+
cubes_coords[:,1].long(),
|
| 297 |
+
cubes_coords[:,2].long(),
|
| 298 |
+
atNameHashing[~exp_mask])
|
| 299 |
+
volume_cubes=volume_cubes[~exp_mask].view(-1)
|
| 300 |
+
|
| 301 |
+
volume_cubes = torch.log(1 - volume_cubes.contiguous())
|
| 302 |
+
volume = 1- torch.exp(volume.index_put(index,volume_cubes,accumulate=True))
|
| 303 |
+
#volume = 1 - torch.exp(volume.index_put(index, torch.log(1 - volume_cubes.contiguous().view(-1)), accumulate=True))
|
| 304 |
+
volume=volume.permute(0,4,1,2,3)
|
| 305 |
+
#volume = -torch.nn.functional.threshold(-volume,-1,-1)
|
| 306 |
+
|
| 307 |
+
return volume
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
return volume
|
| 312 |
+
|
| 313 |
+
def __sparseClamp(self,volume, minv, maxv):
|
| 314 |
+
vals = volume.values()
|
| 315 |
+
ind = volume.indices()
|
| 316 |
+
|
| 317 |
+
vals = vals.clamp(minv, maxv)
|
| 318 |
+
volume = torch.sparse_coo_tensor(indices=ind, values=vals, size=volume.shape).coalesce()
|
| 319 |
+
return volume
|
| 320 |
+
|
| 321 |
+
def __forward_actual_calculation(self, coords_scaled, boxsize, radius,atNameHashing,padding_mask,steepness,resolution):
|
| 322 |
+
"""
|
| 323 |
+
private function for the calculation of the gaussian voxel occupancy
|
| 324 |
+
|
| 325 |
+
Parameters
|
| 326 |
+
----------
|
| 327 |
+
coords_scaled : torch.LongTensor
|
| 328 |
+
Discrete Coordinates of the atoms. Shape ( batch, numberOfAtoms, 3 )
|
| 329 |
+
boxsize : torch.LongTensor
|
| 330 |
+
The size of the box in which the macromolecules are represented
|
| 331 |
+
radius : torch.Tensor
|
| 332 |
+
Radius of the atoms. Shape ( batch, numberOfAtoms ). Can be calculated from a PDB file using utils.parsePDB and utils.atomlistToRadius
|
| 333 |
+
atNameHashing: torch.LongTensor
|
| 334 |
+
channels of the atoms. Atoms of the same type shold belong to the same channel. Shape ( batch, numberOfAtoms ). Can be calculated from a PDB file using utils.parsePDB and utils.atomlistToChannels
|
| 335 |
+
resolution : float
|
| 336 |
+
side in A of a voxel. The lower this value is the higher the resolution of the final representation will be
|
| 337 |
+
padding_mask : torch.BoolTensor
|
| 338 |
+
tensor to mask the padding. Shape (batch, numberOfAtoms)
|
| 339 |
+
steepness : float
|
| 340 |
+
steepness of the sigmoid function (coefficient of the exponent)
|
| 341 |
+
|
| 342 |
+
Returns
|
| 343 |
+
-------
|
| 344 |
+
volume : torch.Tensor
|
| 345 |
+
voxel representation of the macromolecules in the batch with Sigmoid occupancy function. Shape ( batch, channels, x,y,z), where x,y,z are the size of the 3D volume in which the macromolecules have been represented
|
| 346 |
+
|
| 347 |
+
"""
|
| 348 |
+
batch = coords_scaled.shape[0]
|
| 349 |
+
dev=self.dev
|
| 350 |
+
L = coords_scaled.shape[1]
|
| 351 |
+
|
| 352 |
+
discrete_coordinates = torch.trunc(coords_scaled.data).to(self.dtype_indices)
|
| 353 |
+
|
| 354 |
+
#### making everything in the volume shape
|
| 355 |
+
|
| 356 |
+
#implicit_cube_formation
|
| 357 |
+
radius = radius.unsqueeze(2).unsqueeze(3).unsqueeze(4)
|
| 358 |
+
atNameHashing = atNameHashing.unsqueeze(2).unsqueeze(3).unsqueeze(4)
|
| 359 |
+
coords_scaled = coords_scaled.unsqueeze(2).unsqueeze(3).unsqueeze(4)
|
| 360 |
+
discrete_coordinates = discrete_coordinates.unsqueeze(2).unsqueeze(3).unsqueeze(4)
|
| 361 |
+
distmat_standard_cube = torch.norm(coords_scaled-((discrete_coordinates + self.standard_cube + 1) + 0.5 * resolution), dim=-1).to(coords_scaled.dtype)
|
| 362 |
+
|
| 363 |
+
atNameHashing = atNameHashing.long()
|
| 364 |
+
|
| 365 |
+
exponent = steepness*(distmat_standard_cube[padding_mask]-radius[padding_mask])
|
| 366 |
+
del distmat_standard_cube
|
| 367 |
+
exp_mask = exponent.ge(10)
|
| 368 |
+
exponent = torch.masked_fill(exponent,exp_mask, 10)
|
| 369 |
+
|
| 370 |
+
volume_cubes = 1.0/(1.0+torch.exp(exponent))
|
| 371 |
+
|
| 372 |
+
#### index_put everything ###
|
| 373 |
+
batch_list = torch.arange(batch,device=dev).unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1).expand(batch,L,self.lato,self.lato,self.lato)
|
| 374 |
+
|
| 375 |
+
#cubes_coords = coords_scaled + self.standard_cube + 1
|
| 376 |
+
cubes_coords = (discrete_coordinates[padding_mask] + self.standard_cube.squeeze(0) + 1)[~exp_mask]
|
| 377 |
+
atNameHashing = atNameHashing[padding_mask].expand(-1,self.lato,self.lato,self.lato)
|
| 378 |
+
if self.sparse:
|
| 379 |
+
|
| 380 |
+
index_tens = torch.cat(
|
| 381 |
+
[batch_list[padding_mask][~exp_mask].view(-1).unsqueeze(0),
|
| 382 |
+
atNameHashing[~exp_mask].unsqueeze(0),
|
| 383 |
+
cubes_coords[:,0].unsqueeze(0),
|
| 384 |
+
cubes_coords[:,1].unsqueeze(0),
|
| 385 |
+
cubes_coords[:,2].unsqueeze(0),
|
| 386 |
+
])
|
| 387 |
+
#index_tens = torch.cat(index)
|
| 388 |
+
volume = torch.sparse_coo_tensor(indices=index_tens, values=volume_cubes[~exp_mask].view(-1), size=[batch, self.featureVectorSize, boxsize[0] , boxsize[1] , boxsize[2] ]).coalesce()
|
| 389 |
+
volume = self.__sparseClamp(volume,0,1)
|
| 390 |
+
|
| 391 |
+
else:
|
| 392 |
+
volume = torch.zeros(batch,boxsize[0]+1,boxsize[1]+1,boxsize[2]+1,self.featureVectorSize,device=dev,dtype=torch.float)
|
| 393 |
+
#index = (batch_list[padding_mask].view(-1),cubes_coords[padding_mask][:,:,:,:,0].view(-1), cubes_coords[padding_mask][:,:,:,:,1].view(-1), cubes_coords[padding_mask][:,:,:,:,2].view(-1), atNameHashing[padding_mask].view(-1) )
|
| 394 |
+
index = (batch_list[padding_mask][~exp_mask].view(-1).long(),
|
| 395 |
+
cubes_coords[:,0].long(),
|
| 396 |
+
cubes_coords[:,1].long(),
|
| 397 |
+
cubes_coords[:,2].long(),
|
| 398 |
+
atNameHashing[~exp_mask])
|
| 399 |
+
volume_cubes=volume_cubes[~exp_mask].view(-1)
|
| 400 |
+
|
| 401 |
+
volume = volume.index_put(index,volume_cubes.view(-1),accumulate=True)
|
| 402 |
+
|
| 403 |
+
volume = -torch.nn.functional.threshold(-volume,-1,-1)
|
| 404 |
+
volume = volume.permute(0,4,1,2,3)
|
| 405 |
+
|
| 406 |
+
return volume
|
| 407 |
+
'''
|
| 408 |
+
mesh will be added as soon as pytorch3d becomes a little more stable
|
| 409 |
+
def mesh(self,coords, radius,threshSurface = 0.01):
|
| 410 |
+
|
| 411 |
+
atNameHashing= torch.zeros(radius.shape).to(self.dev)
|
| 412 |
+
mask = radius.eq(PADDING_INDEX)
|
| 413 |
+
atNameHashing = atNameHashing.masked_fill_(mask,PADDING_INDEX)
|
| 414 |
+
vol = self(coords,radius,atNameHashing).to_dense()
|
| 415 |
+
mesh = cubifyNOALIGN(vol.sum(-1),thresh=threshSurface)# creates pytorch 3d mesh from cubes. It uses a MODIFIED version of pytorch3d with no align
|
| 416 |
+
return mesh
|
| 417 |
+
'''
|
| 418 |
+
|
| 419 |
+
class PointCloudSurface(torch.nn.Module):
|
| 420 |
+
def __init__(self,device="cpu"):
|
| 421 |
+
"""
|
| 422 |
+
Constructor for the CloudPointSurface class, which builds the main PyUUL object for cloud surface.
|
| 423 |
+
|
| 424 |
+
Parameters
|
| 425 |
+
----------
|
| 426 |
+
device : torch.device
|
| 427 |
+
The device on which the model should run. E.g. torch.device("cuda") or torch.device("cpu:0")
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
Returns
|
| 431 |
+
-------
|
| 432 |
+
"""
|
| 433 |
+
super(PointCloudSurface, self).__init__()
|
| 434 |
+
|
| 435 |
+
self.device=device
|
| 436 |
+
|
| 437 |
+
def __buildStandardSphere(self,npoints=50): # Fibonacci lattice
|
| 438 |
+
|
| 439 |
+
goldenRatio = (1 + 5 ** 0.5) / 2
|
| 440 |
+
i = torch.arange(0, npoints,device=self.device)
|
| 441 |
+
theta = 2 * math.pi * i / goldenRatio
|
| 442 |
+
phi = torch.acos(1 - 2 * (i + 0.5) / npoints)
|
| 443 |
+
|
| 444 |
+
x, y, z = torch.cos(theta) * torch.sin(phi), torch.sin(theta) * torch.sin(phi), torch.cos(phi)
|
| 445 |
+
|
| 446 |
+
coords=torch.cat([x.unsqueeze(-1),y.unsqueeze(-1),z.unsqueeze(-1)],dim=-1)
|
| 447 |
+
#plot_volume(False,20*coords.unsqueeze(0))
|
| 448 |
+
|
| 449 |
+
return coords
|
| 450 |
+
|
| 451 |
+
def forward(self, coords, radius, maxpoints=5000,external_radius_factor=1.4):
|
| 452 |
+
"""
|
| 453 |
+
Function to calculate the surface cloud point representation of macromolecules
|
| 454 |
+
|
| 455 |
+
Parameters
|
| 456 |
+
----------
|
| 457 |
+
coords : torch.Tensor
|
| 458 |
+
Coordinates of the atoms. Shape ( batch, numberOfAtoms, 3 ). Can be calculated from a PDB file using utils.parsePDB
|
| 459 |
+
radius : torch.Tensor
|
| 460 |
+
Radius of the atoms. Shape ( batch, numberOfAtoms ). Can be calculated from a PDB file using utils.parsePDB and utils.atomlistToRadius
|
| 461 |
+
maxpoints : int
|
| 462 |
+
number of points per macromolecule in the batch
|
| 463 |
+
external_radius_factor=1.4
|
| 464 |
+
multiplicative factor of the radius in order ot define the place to sample the points around each atom. The higher this value is, the smoother the surface will be
|
| 465 |
+
Returns
|
| 466 |
+
-------
|
| 467 |
+
surfacePointCloud : torch.Tensor
|
| 468 |
+
surface point cloud representation of the macromolecules in the batch. Shape ( batch, channels, numberOfAtoms, 3)
|
| 469 |
+
|
| 470 |
+
"""
|
| 471 |
+
padding_mask = ~radius.eq(PADDING_INDEX)
|
| 472 |
+
|
| 473 |
+
batch = coords.shape[0]
|
| 474 |
+
npoints = torch.div(maxpoints,(padding_mask.sum(-1).min() + 1), rounding_mode="floor") * 2 # we ensure that the smallest protein has at least maxpoints points
|
| 475 |
+
|
| 476 |
+
sphere = self.__buildStandardSphere(npoints)
|
| 477 |
+
finalPoints=[]
|
| 478 |
+
|
| 479 |
+
for b in range(batch):
|
| 480 |
+
|
| 481 |
+
distmat = torch.cdist(coords[b][padding_mask[b]].unsqueeze(0), coords[b][padding_mask[b]].unsqueeze(0))
|
| 482 |
+
L=distmat.shape[1]
|
| 483 |
+
AtomSelfContributionMask = torch.eye(L, dtype=torch.bool, device=self.device).unsqueeze(0)
|
| 484 |
+
triangular_mask = ~torch.tril(torch.ones((L, L), dtype=torch.bool, device=self.device), diagonal=-1).unsqueeze(0)
|
| 485 |
+
|
| 486 |
+
#todoMask = (distmat[b].le(5) & (~AtomSelfContributionMask) & triangular_mask).squeeze(0)
|
| 487 |
+
external_radius = radius * external_radius_factor
|
| 488 |
+
todoMask = (distmat[0].le(5) & (~AtomSelfContributionMask)).squeeze(0)
|
| 489 |
+
points = coords[b][padding_mask[b]].unsqueeze(0).unsqueeze(-2) - sphere.unsqueeze(0).unsqueeze(1) * external_radius[b][padding_mask[b]].unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
| 490 |
+
|
| 491 |
+
p = points.expand( L, L, npoints, 3)[todoMask]
|
| 492 |
+
c = coords[b][padding_mask[b]].unsqueeze(1).unsqueeze(-2).expand( L, L, points.shape[2], 3)[todoMask]
|
| 493 |
+
r = radius[b][padding_mask[b]].unsqueeze(1).unsqueeze(-2).expand( L, L, points.shape[2])[todoMask]
|
| 494 |
+
occupancy = self.__occupancy(p, c, r)
|
| 495 |
+
|
| 496 |
+
point_index = torch.arange(0,L*npoints,device=self.device).view(L,npoints).unsqueeze(0).expand(L,L,npoints)[todoMask]
|
| 497 |
+
point_occupancy =torch.zeros((L*npoints),dtype=torch.float,device=self.device)
|
| 498 |
+
point_occupancy = point_occupancy.index_put_([point_index.view(-1)], occupancy.view(-1), accumulate=True)
|
| 499 |
+
point_occupancy = (1- torch.exp(point_occupancy))
|
| 500 |
+
|
| 501 |
+
points_on_surfaceMask = point_occupancy.le(0.5)
|
| 502 |
+
|
| 503 |
+
points=points.permute(0,3,1,2).view(3,-1).transpose(0,1)[points_on_surfaceMask]
|
| 504 |
+
random_indices = torch.randint(0, points.shape[0], [maxpoints], device=self.device)
|
| 505 |
+
sampled_points = points[random_indices,:]
|
| 506 |
+
|
| 507 |
+
finalPoints +=[sampled_points]
|
| 508 |
+
|
| 509 |
+
return torch.cat(finalPoints,dim=0)
|
| 510 |
+
|
| 511 |
+
def __occupancy(self, points, coords, radius):
|
| 512 |
+
|
| 513 |
+
dist = torch.norm(points-coords,dim=-1)
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
sigma=0.93
|
| 517 |
+
exponent = -dist**2/(sigma**2 * radius**2)
|
| 518 |
+
exp_mask = exponent.ge(10)
|
| 519 |
+
exponent = torch.masked_fill(exponent, exp_mask, 10)
|
| 520 |
+
|
| 521 |
+
occupancy_on_points = torch.exp(exponent)
|
| 522 |
+
return torch.log(1-occupancy_on_points)
|
| 523 |
+
return occupancy_on_points
|
| 524 |
+
del exponent
|
| 525 |
+
|
| 526 |
+
AtomSelfContributionMask = torch.eye(L,dtype=torch.bool,device=self.device).unsqueeze(0).expand(batch,L,L)
|
| 527 |
+
occupancy_on_points[AtomSelfContributionMask]=0.0
|
| 528 |
+
|
| 529 |
+
occupancy = (1-torch.exp(torch.log(1-occupancy_on_points).sum(2)))#.sum(dim=-1)/npoints
|
| 530 |
+
#if log_correction:
|
| 531 |
+
# occupancy = -torch.log(occupancy + 1) # log scaling
|
| 532 |
+
return occupancy
|
| 533 |
+
|
| 534 |
+
class PointCloudVolume(torch.nn.Module):
|
| 535 |
+
def __init__(self, device="cpu"):
|
| 536 |
+
"""
|
| 537 |
+
Constructor for the CloudPointSurface class, which builds the main PyUUL object for volumetric point cloud.
|
| 538 |
+
|
| 539 |
+
Parameters
|
| 540 |
+
----------
|
| 541 |
+
device : torch.device
|
| 542 |
+
The device on which the model should run. E.g. torch.device("cuda") or torch.device("cpu:0")
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
Returns
|
| 546 |
+
-------
|
| 547 |
+
"""
|
| 548 |
+
super(PointCloudVolume, self).__init__()
|
| 549 |
+
|
| 550 |
+
self.device = device
|
| 551 |
+
|
| 552 |
+
def forward(self, coords, radius, maxpoints=500):
|
| 553 |
+
|
| 554 |
+
"""
|
| 555 |
+
Function to calculate the volumetric cloud point representation of macromolecules
|
| 556 |
+
|
| 557 |
+
Parameters
|
| 558 |
+
----------
|
| 559 |
+
coords : torch.Tensor
|
| 560 |
+
Coordinates of the atoms. Shape ( batch, numberOfAtoms, 3 ). Can be calculated from a PDB file using utils.parsePDB
|
| 561 |
+
radius : torch.Tensor
|
| 562 |
+
Radius of the atoms. Shape ( batch, numberOfAtoms ). Can be calculated from a PDB file using utils.parsePDB and utils.atomlistToRadius
|
| 563 |
+
maxpoints : int
|
| 564 |
+
number of points per macromolecule in the batch
|
| 565 |
+
|
| 566 |
+
Returns
|
| 567 |
+
-------
|
| 568 |
+
PointCloudVolume : torch.Tensor
|
| 569 |
+
volume point cloud representation of the macromolecules in the batch. Shape ( batch, channels, numberOfAtoms, 3)
|
| 570 |
+
|
| 571 |
+
"""
|
| 572 |
+
|
| 573 |
+
padding_mask = ~radius.eq(PADDING_INDEX)
|
| 574 |
+
|
| 575 |
+
#npoints = torch.div(maxpoints, padding_mask.sum(-1).min()) + 1 # we ensure that the smallest protein has at least 5000 points
|
| 576 |
+
|
| 577 |
+
batch = coords.shape[0]
|
| 578 |
+
L = coords.shape[1]
|
| 579 |
+
|
| 580 |
+
batched = []
|
| 581 |
+
for i in range(batch):
|
| 582 |
+
mean = coords[i][padding_mask[i]]
|
| 583 |
+
|
| 584 |
+
sampled = radius[i][padding_mask[i]].sqrt().unsqueeze(-1) * torch.randn((mean.size()), device=self.device) + mean
|
| 585 |
+
p = sampled.view(-1,3)
|
| 586 |
+
random_indices = torch.randint(0, p.shape[0], [maxpoints], device=self.device)
|
| 587 |
+
batched+=[p[random_indices].unsqueeze(0)]
|
| 588 |
+
|
| 589 |
+
batched = torch.cat(batched,dim=0)
|
| 590 |
+
return batched
|
| 591 |
+
|