Spaces:
Paused
Paused
| import torch | |
| import torch.nn as nn | |
| import math | |
| import torch.nn.functional as F | |
| from StructDiffusion.models.encoders import EncoderMLP, DropoutSampler, SinusoidalPositionEmbeddings | |
| from torch.nn import TransformerEncoder, TransformerEncoderLayer | |
| from StructDiffusion.models.point_transformer import PointTransformerEncoderSmall | |
| from StructDiffusion.models.point_transformer_large import PointTransformerCls | |
| class TransformerDiffusionModel(torch.nn.Module): | |
| def __init__(self, vocab_size, | |
| # transformer params | |
| encoder_input_dim=256, | |
| num_attention_heads=8, encoder_hidden_dim=16, encoder_dropout=0.1, encoder_activation="relu", | |
| encoder_num_layers=8, | |
| # output head params | |
| structure_dropout=0.0, object_dropout=0.0, | |
| # pc encoder params | |
| ignore_rgb=False, pc_emb_dim=256, posed_pc_emb_dim=80, | |
| pose_emb_dim=80, | |
| max_seq_size=7, max_token_type_size=4, | |
| seq_pos_emb_dim=8, seq_type_emb_dim=8, | |
| word_emb_dim=160, | |
| time_emb_dim=80, | |
| use_virtual_structure_frame=True, | |
| use_sentence_embedding=False, | |
| sentence_embedding_dim=None, | |
| ): | |
| super(TransformerDiffusionModel, self).__init__() | |
| assert posed_pc_emb_dim + pose_emb_dim == word_emb_dim | |
| assert encoder_input_dim == word_emb_dim + time_emb_dim + seq_pos_emb_dim + seq_type_emb_dim | |
| # 3D translation + 6D rotation | |
| action_dim = 3 + 6 | |
| # default: | |
| # 256 = 80 (point cloud) + 80 (position) + 80 (time) + 8 (position idx) + 8 (token idx) | |
| # 256 = 160 (word embedding) + 80 (time) + 8 (position idx) + 8 (token idx) | |
| # PC | |
| self.ignore_rgb = ignore_rgb | |
| if ignore_rgb: | |
| self.pc_encoder = PointTransformerEncoderSmall(output_dim=pc_emb_dim, input_dim=3, mean_center=True) | |
| else: | |
| self.pc_encoder = PointTransformerEncoderSmall(output_dim=pc_emb_dim, input_dim=6, mean_center=True) | |
| self.posed_pc_encoder = EncoderMLP(pc_emb_dim, posed_pc_emb_dim, uses_pt=True) | |
| # for virtual structure frame | |
| self.use_virtual_structure_frame = use_virtual_structure_frame | |
| if use_virtual_structure_frame: | |
| self.virtual_frame_embed = nn.Parameter(torch.randn(1, 1, posed_pc_emb_dim)) # B, 1, posed_pc_emb_dim | |
| # for language | |
| self.sentence_embedding_dim = sentence_embedding_dim | |
| self.use_sentence_embedding = use_sentence_embedding | |
| if use_sentence_embedding: | |
| self.sentence_embedding_down_sample = torch.nn.Linear(sentence_embedding_dim, word_emb_dim) | |
| else: | |
| self.word_embeddings = torch.nn.Embedding(vocab_size, word_emb_dim, padding_idx=0) | |
| # for diffusion | |
| self.pose_encoder = nn.Sequential(nn.Linear(action_dim, pose_emb_dim)) | |
| self.time_embeddings = nn.Sequential( | |
| SinusoidalPositionEmbeddings(time_emb_dim), | |
| nn.Linear(time_emb_dim, time_emb_dim), | |
| nn.GELU(), | |
| nn.Linear(time_emb_dim, time_emb_dim), | |
| ) | |
| # for transformer | |
| self.position_embeddings = torch.nn.Embedding(max_seq_size, seq_pos_emb_dim) | |
| self.type_embeddings = torch.nn.Embedding(max_token_type_size, seq_type_emb_dim) | |
| encoder_layers = TransformerEncoderLayer(encoder_input_dim, num_attention_heads, | |
| encoder_hidden_dim, encoder_dropout, encoder_activation) | |
| self.encoder = TransformerEncoder(encoder_layers, encoder_num_layers) | |
| self.struct_head = DropoutSampler(encoder_input_dim, action_dim, dropout_rate=structure_dropout) | |
| self.obj_head = DropoutSampler(encoder_input_dim, action_dim, dropout_rate=object_dropout) | |
| def encode_posed_pc(self, pcs, batch_size, num_objects): | |
| if self.ignore_rgb: | |
| center_xyz, x = self.pc_encoder(pcs[:, :, :3], None) | |
| else: | |
| center_xyz, x = self.pc_encoder(pcs[:, :, :3], pcs[:, :, 3:]) | |
| posed_pc_embed = self.posed_pc_encoder(x, center_xyz) | |
| posed_pc_embed = posed_pc_embed.reshape(batch_size, num_objects, -1) | |
| return posed_pc_embed | |
| def forward(self, t, pcs, sentence, poses, type_index, position_index, pad_mask): | |
| batch_size, num_objects, num_pts, _ = pcs.shape | |
| _, num_poses, _ = poses.shape | |
| if self.use_sentence_embedding: | |
| assert sentence.shape == (batch_size, self.sentence_embedding_dim), sentence.shape | |
| else: | |
| _, sentence_len = sentence.shape | |
| _, total_len = type_index.shape | |
| pcs = pcs.reshape(batch_size * num_objects, num_pts, -1) | |
| posed_pc_embed = self.encode_posed_pc(pcs, batch_size, num_objects) | |
| pose_embed = self.pose_encoder(poses) | |
| if self.use_virtual_structure_frame: | |
| virtual_frame_embed = self.virtual_frame_embed.repeat(batch_size, 1, 1) | |
| posed_pc_embed = torch.cat([virtual_frame_embed, posed_pc_embed], dim=1) | |
| tgt_obj_embed = torch.cat([pose_embed, posed_pc_embed], dim=-1) | |
| ######################### | |
| if self.use_sentence_embedding: | |
| # sentence: B, sentence_embedding_dim | |
| sentence_embed = self.sentence_embedding_down_sample(sentence).unsqueeze(1) # B, 1, word_emb_dim | |
| else: | |
| sentence_embed = self.word_embeddings(sentence) | |
| ######################### | |
| # transformer time dim: sentence, struct, obj | |
| # transformer feat dim: obj pc + pose / word, time, token type, position | |
| time_embed = self.time_embeddings(t) # B, dim | |
| time_embed = time_embed.unsqueeze(1).repeat(1, total_len, 1) # B, L, dim | |
| position_embed = self.position_embeddings(position_index) | |
| type_embed = self.type_embeddings(type_index) | |
| tgt_sequence_encode = torch.cat([sentence_embed, tgt_obj_embed], dim=1) | |
| tgt_sequence_encode = torch.cat([tgt_sequence_encode, time_embed, position_embed, type_embed], dim=-1) | |
| tgt_pad_mask = pad_mask | |
| ######################### | |
| # sequence_encode: [batch size, sequence_length, encoder input dimension] | |
| # input to transformer needs to have dimenion [sequence_length, batch size, encoder input dimension] | |
| tgt_sequence_encode = tgt_sequence_encode.transpose(1, 0) | |
| # convert to bool | |
| tgt_pad_mask = (tgt_pad_mask == 1) | |
| # encode: [sequence_length, batch_size, embedding_size] | |
| encode = self.encoder(tgt_sequence_encode, src_key_padding_mask=tgt_pad_mask) | |
| encode = encode.transpose(1, 0) | |
| ######################### | |
| target_encodes = encode[:, -num_poses:, :] | |
| if self.use_virtual_structure_frame: | |
| obj_encodes = target_encodes[:, 1:, :] | |
| pred_obj_poses = self.obj_head(obj_encodes) # B, N, 3 + 6 | |
| struct_encode = encode[:, 0, :].unsqueeze(1) | |
| # use a different sampler for struct prediction since it should have larger variance than object predictions | |
| pred_struct_pose = self.struct_head(struct_encode) # B, 1, 3 + 6 | |
| pred_poses = torch.cat([pred_struct_pose, pred_obj_poses], dim=1) | |
| else: | |
| pred_poses = self.obj_head(target_encodes) # B, N, 3 + 6 | |
| assert pred_poses.shape == poses.shape | |
| return pred_poses | |
| class FocalLoss(nn.Module): | |
| def __init__(self, gamma=2, alpha=.25): | |
| super(FocalLoss, self).__init__() | |
| # self.alpha = torch.tensor([alpha, 1-alpha]) | |
| self.gamma = gamma | |
| def forward(self, inputs, targets): | |
| BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') | |
| pt = torch.exp(-BCE_loss) | |
| # targets = targets.type(torch.long) | |
| # at = self.alpha.gather(0, targets.data.view(-1)) | |
| # F_loss = at*(1-pt)**self.gamma * BCE_loss | |
| F_loss = (1 - pt)**self.gamma * BCE_loss | |
| return F_loss.mean() | |
| class PCTDiscriminator(torch.nn.Module): | |
| def __init__(self, max_num_objects, include_env_pc=False, pct_random_sampling=False): | |
| super(PCTDiscriminator, self).__init__() | |
| # input_dim: xyz + one hot for each object | |
| if include_env_pc: | |
| self.classifier = PointTransformerCls(input_dim=max_num_objects + 1 + 3, output_dim=1, use_random_sampling=pct_random_sampling) | |
| else: | |
| self.classifier = PointTransformerCls(input_dim=max_num_objects + 3, output_dim=1, use_random_sampling=pct_random_sampling) | |
| def forward(self, scene_xyz): | |
| label = self.classifier(scene_xyz) | |
| return label | |
| def convert_logits(self, logits): | |
| return torch.sigmoid(logits) | |