Spaces:
Paused
Paused
| import os | |
| import torch | |
| from StructDiffusion.utils.rotation_continuity import compute_rotation_matrix_from_ortho6d | |
| def get_diffusion_variables_from_9D_actions(struct_xyztheta_inputs, obj_xyztheta_inputs): | |
| # important: we need to get the first two columns, not first two rows | |
| # array([[ 3, 4, 5], | |
| # [ 6, 7, 8], | |
| # [ 9, 10, 11]]) | |
| xyz_6d_idxs = [0, 1, 2, 3, 6, 9, 4, 7, 10] | |
| # print(batch_data["obj_xyztheta_inputs"].shape) | |
| # print(batch_data["struct_xyztheta_inputs"].shape) | |
| # only get the first and second columns of rotation | |
| obj_xyztheta_inputs = obj_xyztheta_inputs[:, :, xyz_6d_idxs] # B, N, 9 | |
| struct_xyztheta_inputs = struct_xyztheta_inputs[:, :, xyz_6d_idxs] # B, 1, 9 | |
| x = torch.cat([struct_xyztheta_inputs, obj_xyztheta_inputs], dim=1) # B, 1 + N, 9 | |
| # print(x.shape) | |
| return x | |
| def get_diffusion_variables_from_H(poses): | |
| """ | |
| [[0,1,2,3], | |
| [4,5,6,7], | |
| [8,9,10,11], | |
| [12,13,14,15] | |
| :param obj_xyztheta_inputs: B, N, 4, 4 | |
| :return: | |
| """ | |
| xyz_6d_idxs = [3, 7, 11, 0, 4, 8, 1, 5, 9] | |
| B, N, _, _ = poses.shape | |
| x = poses.reshape(B, N, 16)[:, :, xyz_6d_idxs] # B, N, 9 | |
| return x | |
| def get_struct_objs_poses(x): | |
| device = x.device | |
| # important: the noisy x can go out of bounds | |
| x = torch.clamp(x, min=-1, max=1) | |
| # x: B, 1 + N, 9 | |
| B = x.shape[0] | |
| N = x.shape[1] - 1 | |
| # compute_rotation_matrix_from_ortho6d takes in [B, 6], outputs [B, 3, 3] | |
| x_6d = x[:, :, 3:].reshape(-1, 6) | |
| x_rot = compute_rotation_matrix_from_ortho6d(x_6d).reshape(B, N+1, 3, 3) # B, 1 + N, 3, 3 | |
| x_trans = x[:, :, :3] # B, 1 + N, 3 | |
| x_full = torch.eye(4).repeat(B, 1 + N, 1, 1).to(device) | |
| x_full[:, :, :3, :3] = x_rot | |
| x_full[:, :, :3, 3] = x_trans | |
| struct_pose = x_full[:, 0].unsqueeze(1) # B, 1, 4, 4 | |
| pc_poses_in_struct = x_full[:, 1:] # B, N, 4, 4 | |
| return struct_pose, pc_poses_in_struct | |
| def compute_current_and_goal_pc_poses(obj_xyzs, struct_pose, pc_poses_in_struct): | |
| device = obj_xyzs.device | |
| # obj_xyzs: B, N, P, 3 | |
| # struct_pose: B, 1, 4, 4 | |
| # pc_poses_in_struct: B, N, 4, 4 | |
| B, N, _, _ = pc_poses_in_struct.shape | |
| _, _, P, _ = obj_xyzs.shape | |
| current_pc_poses = torch.eye(4).repeat(B, N, 1, 1).to(device) # B, N, 4, 4 | |
| # print(torch.mean(obj_xyzs, dim=2).shape) | |
| current_pc_poses[:, :, :3, 3] = torch.mean(obj_xyzs, dim=2) # B, N, 4, 4 | |
| struct_pose = struct_pose.repeat(1, N, 1, 1) # B, N, 4, 4 | |
| struct_pose = struct_pose.reshape(B * N, 4, 4) # B x 1, 4, 4 | |
| pc_poses_in_struct = pc_poses_in_struct.reshape(B * N, 4, 4) # B x N, 4, 4 | |
| goal_pc_poses = struct_pose @ pc_poses_in_struct # B x N, 4, 4 | |
| goal_pc_poses = goal_pc_poses.reshape(B, N, 4, 4) # B, N, 4, 4 | |
| return current_pc_poses, goal_pc_poses |