"""utils module""" import numpy as np from scipy.special import softmax from mindspore.ops import operations as P import mindspore.numpy as mnp import mindspore.nn as nn from mindspore.common.tensor import Tensor from commons import residue_constants import commons.r3 as r3 QUAT_TO_ROT = np.zeros((4, 4, 3, 3), dtype=np.float32) QUAT_TO_ROT[0, 0] = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] # rr QUAT_TO_ROT[1, 1] = [[1, 0, 0], [0, -1, 0], [0, 0, -1]] # ii QUAT_TO_ROT[2, 2] = [[-1, 0, 0], [0, 1, 0], [0, 0, -1]] # jj QUAT_TO_ROT[3, 3] = [[-1, 0, 0], [0, -1, 0], [0, 0, 1]] # kk QUAT_TO_ROT[1, 2] = [[0, 2, 0], [2, 0, 0], [0, 0, 0]] # ij QUAT_TO_ROT[1, 3] = [[0, 0, 2], [0, 0, 0], [2, 0, 0]] # ik QUAT_TO_ROT[2, 3] = [[0, 0, 0], [0, 0, 2], [0, 2, 0]] # jk QUAT_TO_ROT[0, 1] = [[0, 0, 0], [0, 0, -2], [0, 2, 0]] # ir QUAT_TO_ROT[0, 2] = [[0, 0, 2], [0, 0, 0], [-2, 0, 0]] # jr QUAT_TO_ROT[0, 3] = [[0, -2, 0], [2, 0, 0], [0, 0, 0]] # kr QUAT_TO_ROT = Tensor(QUAT_TO_ROT) def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): """Create pseudo beta features.""" is_gly = mnp.equal(aatype, residue_constants.restype_order['G']) ca_idx = residue_constants.atom_order['CA'] cb_idx = residue_constants.atom_order['CB'] pseudo_beta = mnp.where( mnp.tile(is_gly[..., None].astype("int32"), [1,] * len(is_gly.shape) + [3,]).astype("bool"), all_atom_positions[..., ca_idx, :], all_atom_positions[..., cb_idx, :]) if all_atom_masks is not None: pseudo_beta_mask = mnp.where(is_gly, all_atom_masks[..., ca_idx], all_atom_masks[..., cb_idx]) pseudo_beta_mask = pseudo_beta_mask.astype(mnp.float32) return pseudo_beta, pseudo_beta_mask return pseudo_beta def dgram_from_positions(positions, num_bins, min_bin, max_bin): """Compute distogram from amino acid positions. Arguments: positions: [N_res, 3] Position coordinates. num_bins: The number of bins in the distogram. min_bin: The left edge of the first bin. max_bin: The left edge of the final bin. The final bin catches everything larger than `max_bin`. Returns: Distogram with the specified number of bins. """ def squared_difference(x, y): return mnp.square(x - y) lower_breaks = mnp.linspace(min_bin, max_bin, num_bins) lower_breaks = mnp.square(lower_breaks) upper_breaks = mnp.concatenate([lower_breaks[1:], mnp.array([1e8], dtype=mnp.float32)], axis=-1) dist2 = mnp.sum(squared_difference(mnp.expand_dims(positions, axis=-2), mnp.expand_dims(positions, axis=-3)), axis=-1, keepdims=True) dgram = ((dist2 > lower_breaks).astype(mnp.float32) * (dist2 < upper_breaks).astype(mnp.float32)) return dgram def _multiply(a, b): return mnp.stack([mnp.concatenate([(a[0][0] * b[0][0] + a[0][1] * b[1][0] + a[0][2] * b[2][0])[None, ...], (a[0][0] * b[0][1] + a[0][1] * b[1][1] + a[0][2] * b[2][1])[None, ...], (a[0][0] * b[0][2] + a[0][1] * b[1][2] + a[0][2] * b[2][2])[None, ...]], axis=0), mnp.concatenate([(a[1][0] * b[0][0] + a[1][1] * b[1][0] + a[1][2] * b[2][0])[None, ...], (a[1][0] * b[0][1] + a[1][1] * b[1][1] + a[1][2] * b[2][1])[None, ...], (a[1][0] * b[0][2] + a[1][1] * b[1][2] + a[1][2] * b[2][2])[None, ...]], axis=0), mnp.concatenate([(a[2][0] * b[0][0] + a[2][1] * b[1][0] + a[2][2] * b[2][0])[None, ...], (a[2][0] * b[0][1] + a[2][1] * b[1][1] + a[2][2] * b[2][1])[None, ...], (a[2][0] * b[0][2] + a[2][1] * b[1][2] + a[2][2] * b[2][2])[None, ...]], axis=0)]) def apply_rot_to_vec(rot, vec, unstack=False): """Multiply rotation matrix by a vector.""" if unstack: x, y, z = vec[:, 0], vec[:, 1], vec[:, 2] else: x, y, z = vec return [rot[0][0] * x + rot[0][1] * y + rot[0][2] * z, rot[1][0] * x + rot[1][1] * y + rot[1][2] * z, rot[2][0] * x + rot[2][1] * y + rot[2][2] * z] def make_canonical_transform(n_xyz, ca_xyz, c_xyz): """Returns translation and rotation matrices to canonicalize residue atoms. Note that this method does not take care of symmetries. If you provide the atom positions in the non-standard way, the N atom will end up not at [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You need to take care of such cases in your code. Args: n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates. ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates. c_xyz: An array of shape [batch, 3] of carbon xyz coordinates. Returns: A tuple (translation, rotation) where: translation is an array of shape [batch, 3] defining the translation. rotation is an array of shape [batch, 3, 3] defining the rotation. After applying the translation and rotation to all atoms in a residue: * All atoms will be shifted so that CA is at the origin, * All atoms will be rotated so that C is at the x-axis, * All atoms will be shifted so that N is in the xy plane. """ # Place CA at the origin. translation = -ca_xyz n_xyz = n_xyz + translation c_xyz = c_xyz + translation # Place C on the x-axis. c_x, c_y, c_z = c_xyz[:, 0], c_xyz[:, 1], c_xyz[:, 2] # Rotate by angle c1 in the x-y plane (around the z-axis). sin_c1 = -c_y / mnp.sqrt(1e-20 + c_x ** 2 + c_y ** 2) cos_c1 = c_x / mnp.sqrt(1e-20 + c_x ** 2 + c_y ** 2) zeros = mnp.zeros_like(sin_c1).astype("float32") ones = mnp.ones_like(sin_c1).astype("float32") # # pylint: disable=bad-whitespace c1_rot_matrix = mnp.stack([mnp.concatenate((cos_c1[None, ...], (-sin_c1)[None, ...], zeros[None, ...]), axis=0), mnp.concatenate((sin_c1[None, ...], cos_c1[None, ...], zeros[None, ...]), axis=0), mnp.concatenate((zeros[None, ...], zeros[None, ...], ones[None, ...]), axis=0)]) # # Rotate by angle c2 in the x-z plane (around the y-axis). sin_c2 = c_z / mnp.sqrt(1e-20 + c_x ** 2 + c_y ** 2 + c_z ** 2) cos_c2 = mnp.sqrt(c_x ** 2 + c_y ** 2) / mnp.sqrt(1e-20 + c_x ** 2 + c_y ** 2 + c_z ** 2) c2_rot_matrix = mnp.stack([mnp.concatenate((cos_c2[None, ...], zeros[None, ...], sin_c2[None, ...]), axis=0), mnp.concatenate((zeros[None, ...], ones[None, ...], zeros[None, ...]), axis=0), mnp.concatenate(((-sin_c2)[None, ...], zeros[None, ...], cos_c2[None, ...]), axis=0)]) c_rot_matrix = _multiply(c2_rot_matrix, c1_rot_matrix) n_xyz = mnp.stack(apply_rot_to_vec(c_rot_matrix, n_xyz, unstack=True)).T # Place N in the x-y plane. _, n_y, n_z = n_xyz[:, 0], n_xyz[:, 1], n_xyz[:, 2] # Rotate by angle alpha in the y-z plane (around the x-axis). sin_n = -n_z / mnp.sqrt(1e-20 + n_y ** 2 + n_z ** 2) cos_n = n_y / mnp.sqrt(1e-20 + n_y ** 2 + n_z ** 2) n_rot_matrix = mnp.stack([mnp.concatenate([ones[None, ...], zeros[None, ...], zeros[None, ...]], axis=0), mnp.concatenate([zeros[None, ...], cos_n[None, ...], (-sin_n)[None, ...]], axis=0), mnp.concatenate([zeros[None, ...], sin_n[None, ...], cos_n[None, ...]], axis=0)]) return translation, mnp.transpose(_multiply(n_rot_matrix, c_rot_matrix), [2, 0, 1]) def make_transform_from_reference(n_xyz, ca_xyz, c_xyz): """Returns rotation and translation matrices to convert from reference. Note that this method does not take care of symmetries. If you provide the atom positions in the non-standard way, the N atom will end up not at [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You need to take care of such cases in your code. Args: n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates. ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates. c_xyz: An array of shape [batch, 3] of carbon xyz coordinates. Returns: A tuple (rotation, translation) where: rotation is an array of shape [batch, 3, 3] defining the rotation. translation is an array of shape [batch, 3] defining the translation. After applying the translation and rotation to the reference backbone, the coordinates will approximately equal to the input coordinates. The order of translation and rotation differs from make_canonical_transform because the rotation from this function should be applied before the translation, unlike make_canonical_transform. """ translation, rotation = make_canonical_transform(n_xyz, ca_xyz, c_xyz) return mnp.transpose(rotation, (0, 2, 1)), -translation def rot_to_quat(rot, unstack_inputs=False): """Convert rotation matrix to quaternion. Note that this function calls self_adjoint_eig which is extremely expensive on the GPU. If at all possible, this function should run on the CPU. Args: rot: rotation matrix (see below for format). unstack_inputs: If true, rotation matrix should be shape (..., 3, 3) otherwise the rotation matrix should be a list of lists of tensors. Returns: Quaternion as (..., 4) tensor. """ if unstack_inputs: rot = mnp.transpose(rot, [2, 1, 0]) xx, xy, xz = rot[0][0], rot[0][1], rot[0][2] yx, yy, yz = rot[1][0], rot[1][1], rot[1][2] zx, zy, zz = rot[2][0], rot[2][1], rot[2][2] k = mnp.stack((mnp.stack((xx + yy + zz, zy - yz, xz - zx, yx - xy), axis=-1), mnp.stack((zy - yz, xx - yy - zz, xy + yx, xz + zx), axis=-1), mnp.stack((xz - zx, xy + yx, yy - xx - zz, yz + zy), axis=-1), mnp.stack((yx - xy, xz + zx, yz + zy, zz - xx - yy), axis=-1)), axis=-2) k = (1. / 3.) * k k = k[:, :, 0] return k def quat_to_rot(normalized_quat): """Convert a normalized quaternion to a rotation matrix.""" rot_tensor = mnp.sum(mnp.reshape(QUAT_TO_ROT, (4, 4, 9)) * normalized_quat[..., :, None, None] * normalized_quat[..., None, :, None], axis=(-3, -2)) rot = mnp.moveaxis(rot_tensor, -1, 0) # Unstack. return [[rot[0], rot[1], rot[2]], [rot[3], rot[4], rot[5]], [rot[6], rot[7], rot[8]]] def quat_affine(quaternion, translation, rotation=None, normalize=True, unstack_inputs=False): """create quat affine representations""" if unstack_inputs and rotation is not None: rotation = mnp.transpose(rotation, [2, 1, 0]) translation = mnp.moveaxis(translation, -1, 0) # Unstack. if normalize and quaternion is not None: quaternion = quaternion / mnp.norm(quaternion, axis=-1, keepdims=True) if rotation is None: rotation = quat_to_rot(quaternion) return quaternion, rotation, translation def apply_inverse_rot_to_vec(rot, vec): """Multiply the inverse of a rotation matrix by a vector.""" # Inverse rotation is just transpose return mnp.concatenate(((rot[0][0] * vec[0] + rot[1][0] * vec[1] + rot[2][0] * vec[2])[None, ...], (rot[0][1] * vec[0] + rot[1][1] * vec[1] + rot[2][1] * vec[2])[None, ...], (rot[0][2] * vec[0] + rot[1][2] * vec[1] + rot[2][2] * vec[2])[None, ...]), axis=0) def invert_point(transformed_point, rotation, translation, extra_dims=0): """Apply inverse of transformation to a point. Args: transformed_point: List of 3 tensors to apply affine extra_dims: Number of dimensions at the end of the transformed_point shape that are not present in the rotation and translation. The most common use is rotation N points at once with extra_dims=1 for use in a network. Returns: Transformed point after applying affine. """ for _ in range(extra_dims): rotation = mnp.expand_dims(rotation, axis=-1) translation = mnp.expand_dims(translation, axis=-1) rot_point = transformed_point - translation return apply_inverse_rot_to_vec(rotation, rot_point) def _invert_point(transformed_point, rotation, translation): """Apply inverse of transformation to a point. Args: transformed_point: List of 3 tensors to apply affine extra_dims: Number of dimensions at the end of the transformed_point shape that are not present in the rotation and translation. The most common use is rotation N points at once with extra_dims=1 for use in a network. Returns: Transformed point after applying affine. """ r00 = mnp.expand_dims(rotation[0][0], axis=-1) r01 = mnp.expand_dims(rotation[0][1], axis=-1) r02 = mnp.expand_dims(rotation[0][2], axis=-1) r10 = mnp.expand_dims(rotation[1][0], axis=-1) r11 = mnp.expand_dims(rotation[1][1], axis=-1) r12 = mnp.expand_dims(rotation[1][2], axis=-1) r20 = mnp.expand_dims(rotation[2][0], axis=-1) r21 = mnp.expand_dims(rotation[2][1], axis=-1) r22 = mnp.expand_dims(rotation[2][2], axis=-1) t0 = mnp.expand_dims(translation[0], axis=-1) t1 = mnp.expand_dims(translation[1], axis=-1) t2 = mnp.expand_dims(translation[2], axis=-1) rot_point = [transformed_point[0] - t0, transformed_point[1] - t1, transformed_point[2] - t2] result = [r00 * rot_point[0] + r10 * rot_point[1] + r20 * rot_point[2], r01 * rot_point[0] + r11 * rot_point[1] + r21 * rot_point[2], r02 * rot_point[0] + r12 * rot_point[1] + r22 * rot_point[2]] return result def mask_mean(mask, value, axis=None, drop_mask_channel=False, eps=1e-10): """Masked mean.""" if drop_mask_channel: mask = mask[..., 0] mask_shape = mask.shape value_shape = value.shape broadcast_factor = 1. value_size = value_shape[axis] mask_size = mask_shape[axis] if mask_size == 1: broadcast_factor *= value_size return mnp.sum(mask * value, axis=axis) / (mnp.sum(mask, axis=axis) * broadcast_factor + eps) def atom37_to_torsion_angles( aatype, # (B, N) all_atom_pos, # (B, N, 37, 3) all_atom_mask, # (B, N, 37) chi_atom_indices, chi_angles_mask, mirror_psi_mask, chi_pi_periodic, indices0, indices1 ): """Computes the 7 torsion angles (in sin, cos encoding) for each residue. The 7 torsion angles are in the order '[pre_omega, phi, psi, chi_1, chi_2, chi_3, chi_4]', here pre_omega denotes the omega torsion angle between the given amino acid and the previous amino acid. Args: aatype: Amino acid type, given as array with integers. all_atom_pos: atom37 representation of all atom coordinates. all_atom_mask: atom37 representation of mask on all atom coordinates. placeholder_for_undefined: flag denoting whether to set masked torsion angles to zero. Returns: Dict containing: * 'torsion_angles_sin_cos': Array with shape (B, N, 7, 2) where the final 2 dimensions denote sin and cos respectively * 'alt_torsion_angles_sin_cos': same as 'torsion_angles_sin_cos', but with the angle shifted by pi for all chi angles affected by the naming ambiguities. * 'torsion_angles_mask': Mask for which chi angles are present. """ # Map aatype > 20 to 'Unknown' (20). aatype = mnp.minimum(aatype, 20) # Compute the backbone angles. num_batch, num_res = aatype.shape pad = mnp.zeros([num_batch, 1, 37, 3], mnp.float32) prev_all_atom_pos = mnp.concatenate([pad, all_atom_pos[:, :-1, :, :]], axis=1) pad = mnp.zeros([num_batch, 1, 37], mnp.float32) prev_all_atom_mask = mnp.concatenate([pad, all_atom_mask[:, :-1, :]], axis=1) # For each torsion angle collect the 4 atom positions that define this angle. # shape (B, N, atoms=4, xyz=3) pre_omega_atom_pos = mnp.concatenate([prev_all_atom_pos[:, :, 1:3, :], all_atom_pos[:, :, 0:2, :]], axis=-2) phi_atom_pos = mnp.concatenate([prev_all_atom_pos[:, :, 2:3, :], all_atom_pos[:, :, 0:3, :]], axis=-2) psi_atom_pos = mnp.concatenate([all_atom_pos[:, :, 0:3, :], all_atom_pos[:, :, 4:5, :]], axis=-2) # # Collect the masks from these atoms. # # Shape [batch, num_res] # ERROR NO PROD pre_omega_mask = (P.ReduceProd()(prev_all_atom_mask[:, :, 1:3], -1) # prev CA, C * P.ReduceProd()(all_atom_mask[:, :, 0:2], -1)) # this N, CA phi_mask = (prev_all_atom_mask[:, :, 2] # prev C * P.ReduceProd()(all_atom_mask[:, :, 0:3], -1)) # this N, CA, C psi_mask = (P.ReduceProd()(all_atom_mask[:, :, 0:3], -1) * # this N, CA, C all_atom_mask[:, :, 4]) # this O # Collect the atoms for the chi-angles. # Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4]. # Select atoms to compute chis. Shape: [batch, num_res, chis=4, atoms=4]. atom_indices = mnp.take(chi_atom_indices, aatype, axis=0) # # Gather atom positions Batch Gather. Shape: [batch, num_res, chis=4, atoms=4, xyz=3]. # 4 seq_length 4 4 batch, sequence length, chis, atoms seq_length = all_atom_pos.shape[1] atom_indices = atom_indices.reshape((4, seq_length, 4, 4, 1)).astype("int64") new_indices = P.Concat(4)((indices0, indices1, atom_indices)) # 4, seq_length, 4, 4, 3 chis_atom_pos = P.GatherNd()(all_atom_pos, new_indices) chis_mask = mnp.take(chi_angles_mask, aatype, axis=0) chi_angle_atoms_mask = P.GatherNd()(all_atom_mask, new_indices) # chis_atom_pos = P.GatherBatch(axis=0, batch=2)(all_atom_pos, atom_indices) # chis_mask = mnp.take(chi_angles_mask, aatype, axis=0) # chi_angle_atoms_mask = P.GatherBatch(axis=0, batch=2)(all_atom_mask, atom_indices) # Check if all 4 chi angle atoms were set. Shape: [batch, num_res, chis=4]. chi_angle_atoms_mask = P.ReduceProd()(chi_angle_atoms_mask, -1) chis_mask = chis_mask * (chi_angle_atoms_mask).astype(mnp.float32) # Stack all torsion angle atom positions. # Shape (B, N, torsions=7, atoms=4, xyz=3)ls torsions_atom_pos = mnp.concatenate([pre_omega_atom_pos[:, :, None, :, :], phi_atom_pos[:, :, None, :, :], psi_atom_pos[:, :, None, :, :], chis_atom_pos], axis=2) # Stack up masks for all torsion angles. # shape (B, N, torsions=7) torsion_angles_mask = mnp.concatenate([pre_omega_mask[:, :, None], phi_mask[:, :, None], psi_mask[:, :, None], chis_mask], axis=2) torsion_frames_rots, torsion_frames_trans = r3.rigids_from_3_points( torsions_atom_pos[:, :, :, 1, :], torsions_atom_pos[:, :, :, 2, :], torsions_atom_pos[:, :, :, 0, :]) inv_torsion_rots, inv_torsion_trans = r3.invert_rigids(torsion_frames_rots, torsion_frames_trans) forth_atom_rel_pos = r3.rigids_mul_vecs(inv_torsion_rots, inv_torsion_trans, torsions_atom_pos[:, :, :, 3, :]) # Compute the position of the forth atom in this frame (y and z coordinate torsion_angles_sin_cos = mnp.stack([forth_atom_rel_pos[..., 2], forth_atom_rel_pos[..., 1]], axis=-1) torsion_angles_sin_cos /= mnp.sqrt(mnp.sum(mnp.square(torsion_angles_sin_cos), axis=-1, keepdims=True) + 1e-8) # Mirror psi, because we computed it from the Oxygen-atom. torsion_angles_sin_cos *= mirror_psi_mask chi_is_ambiguous = mnp.take(chi_pi_periodic, aatype, axis=0) mirror_torsion_angles = mnp.concatenate([mnp.ones([num_batch, num_res, 3]), 1.0 - 2.0 * chi_is_ambiguous], axis=-1) alt_torsion_angles_sin_cos = (torsion_angles_sin_cos * mirror_torsion_angles[:, :, :, None]) return torsion_angles_sin_cos, alt_torsion_angles_sin_cos, torsion_angles_mask def get_chi_atom_indices(): """Returns atom indices needed to compute chi angles for all residue types. Returns: A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are in the order specified in residue_constants.restypes + unknown residue type at the end. For chi angles which are not defined on the residue, the positions indices are by default set to 0. """ chi_atom_indices = [] for residue_name in residue_constants.restypes: residue_name = residue_constants.restype_1to3[residue_name] residue_chi_angles = residue_constants.chi_angles_atoms[residue_name] atom_indices = [] for chi_angle in residue_chi_angles: atom_indices.append([residue_constants.atom_order[atom] for atom in chi_angle]) for _ in range(4 - len(atom_indices)): atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA. chi_atom_indices.append(atom_indices) chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue. return np.asarray(chi_atom_indices) def to_tensor(quaternion, translation): return mnp.concatenate([quaternion, translation], axis=-1) def from_tensor(tensor, normalize=False): quaternion, tx, ty, tz = mnp.split(tensor, [4, 5, 6], axis=-1) return quat_affine(quaternion, mnp.stack([tx[..., 0], ty[..., 0], tz[..., 0]], axis=-1), normalize=normalize) # return quat_affine(quaternion, [tx[..., 0], ty[..., 0], tz[..., 0]], normalize=normalize) def generate_new_affine(sequence_mask): num_residues, _ = sequence_mask.shape quaternion = mnp.tile(mnp.reshape(mnp.asarray([1., 0., 0., 0.]), [1, 4]), [num_residues, 1]) translation = mnp.zeros([num_residues, 3]) return quat_affine(quaternion, translation, unstack_inputs=True) def pre_compose(quaternion, rotation, translation, update): """Return a new QuatAffine which applies the transformation update first. Args: update: Length-6 vector. 3-vector of x, y, and z such that the quaternion update is (1, x, y, z) and zero for the 3-vector is the identity quaternion. 3-vector for translation concatenated. Returns: New QuatAffine object. """ vector_quaternion_update, x, y, z = mnp.split(update, [3, 4, 5], axis=-1) trans_update = [mnp.squeeze(x, axis=-1), mnp.squeeze(y, axis=-1), mnp.squeeze(z, axis=-1)] new_quaternion = (quaternion + quat_multiply_by_vec(quaternion, vector_quaternion_update)) trans_update = apply_rot_to_vec(rotation, trans_update) new_translation = [translation[0] + trans_update[0], translation[1] + trans_update[1], translation[2] + trans_update[2]] return quat_affine(new_quaternion, mnp.stack(new_translation, axis=-1)) def scale_translation(quaternion, translation, rotation, position_scale): """Return a new quat affine with a different scale for translation.""" return quat_affine(quaternion, mnp.stack([translation[0] * position_scale, translation[1] * position_scale, translation[2] * position_scale], axis=-1), rotation=rotation, normalize=False) def rigids_from_tensor4x4(m): """Construct Rigids object from an 4x4 array. Here the 4x4 is representing the transformation in homogeneous coordinates. Args: m: Array representing transformations in homogeneous coordinates. Returns: Rigids object corresponding to transformations m """ return m[..., 0, 0], m[..., 0, 1], m[..., 0, 2], m[..., 1, 0], m[..., 1, 1], m[..., 1, 2], m[..., 2, 0], \ m[..., 2, 1], m[..., 2, 2], m[..., 0, 3], m[..., 1, 3], m[..., 2, 3] def apply_to_point(rotation, translation, point): """apply to point func""" r00 = mnp.expand_dims(rotation[0][0], axis=-1) r01 = mnp.expand_dims(rotation[0][1], axis=-1) r02 = mnp.expand_dims(rotation[0][2], axis=-1) r10 = mnp.expand_dims(rotation[1][0], axis=-1) r11 = mnp.expand_dims(rotation[1][1], axis=-1) r12 = mnp.expand_dims(rotation[1][2], axis=-1) r20 = mnp.expand_dims(rotation[2][0], axis=-1) r21 = mnp.expand_dims(rotation[2][1], axis=-1) r22 = mnp.expand_dims(rotation[2][2], axis=-1) t0 = mnp.expand_dims(translation[0], axis=-1) t1 = mnp.expand_dims(translation[1], axis=-1) t2 = mnp.expand_dims(translation[2], axis=-1) p0 = point[0] p1 = point[1] p2 = point[2] rot_point = [r00 * p0 + r01 * p1 + r02 * p2, r10 * p0 + r11 * p1 + r12 * p2, r20 * p0 + r21 * p1 + r22 * p2] result = [rot_point[0] + t0, rot_point[1] + t1, rot_point[2] + t2] return result def frames_and_literature_positions_to_atom14_pos(aatype, all_frames_to_global, restype_atom14_to_rigid_group, restype_atom14_rigid_group_positions, restype_atom14_mask): # (N, 14) """Put atom literature positions (atom14 encoding) in each rigid group. Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" line 11 Args: aatype: aatype for each residue. all_frames_to_global: All per residue coordinate frames. Returns: Positions of all atom coordinates in global frame. """ # Pick the appropriate transform for every atom. residx_to_group_idx = P.Gather()(restype_atom14_to_rigid_group, aatype, 0) group_mask = nn.OneHot(depth=8, axis=-1)(residx_to_group_idx) # # r3.Rigids with shape (N, 14) map_atoms_to_global = map_atoms_to_global_func(all_frames_to_global, group_mask) # Gather the literature atom positions for each residue. # r3.Vecs with shape (N, 14) lit_positions = vecs_from_tensor(P.Gather()(restype_atom14_rigid_group_positions, aatype, 0)) # Transform each atom from its local frame to the global frame. # r3.Vecs with shape (N, 14) pred_positions = rigids_mul_vecs(map_atoms_to_global, lit_positions) # Mask out non-existing atoms. mask = P.Gather()(restype_atom14_mask, aatype, 0) pred_positions = pred_map_mul(pred_positions, mask) return pred_positions def pred_map_mul(pred_positions, mask): return [pred_positions[0] * mask, pred_positions[1] * mask, pred_positions[2] * mask] def rots_mul_vecs(m, v): """Apply rotations 'm' to vectors 'v'.""" return [m[0] * v[0] + m[1] * v[1] + m[2] * v[2], m[3] * v[0] + m[4] * v[1] + m[5] * v[2], m[6] * v[0] + m[7] * v[1] + m[8] * v[2]] def rigids_mul_vecs(r, v): """Apply rigid transforms 'r' to points 'v'.""" rots = rots_mul_vecs(r, v) vecs_add_r = [rots[0] + r[9], rots[1] + r[10], rots[2] + r[11]] return vecs_add_r def vecs_from_tensor(x): # shape (...) """Converts from tensor of shape (3,) to Vecs.""" # num_components = x.shape[-1] # assert num_components == 3 return x[..., 0], x[..., 1], x[..., 2] def get_exp_atom_pos(atom_pos): return [mnp.expand_dims(atom_pos[0], axis=0), mnp.expand_dims(atom_pos[1], axis=0), mnp.expand_dims(atom_pos[2], axis=0) ] def to_tensor_new(quaternion, translation): tr_new = [mnp.expand_dims(translation[0], axis=-1), mnp.expand_dims(translation[1], axis=-1), mnp.expand_dims(translation[2], axis=-1)] return mnp.concatenate([quaternion, tr_new[0], tr_new[1], tr_new[2]], axis=-1) def quat_multiply_by_vec(quat, vec): """Multiply a quaternion by a pure-vector quaternion.""" return mnp.sum(residue_constants.QUAT_MULTIPLY_BY_VEC * quat[..., :, None, None] * vec[..., None, :, None], axis=(-3, -2)) def rigids_mul_rots(xx, xy, xz, yx, yy, yz, zx, zy, zz, ones, zeros, cos_angles, sin_angles): """Compose rigid transformations 'r' with rotations 'm'.""" c00 = xx * ones + xy * zeros + xz * zeros c01 = yx * ones + yy * zeros + yz * zeros c02 = zx * ones + zy * zeros + zz * zeros c10 = xx * zeros + xy * cos_angles + xz * sin_angles c11 = yx * zeros + yy * cos_angles + yz * sin_angles c12 = zx * zeros + zy * cos_angles + zz * sin_angles c20 = xx * zeros + xy * (-sin_angles) + xz * cos_angles c21 = yx * zeros + yy * (-sin_angles) + yz * cos_angles c22 = zx * zeros + zy * (-sin_angles) + zz * cos_angles return c00, c10, c20, c01, c11, c21, c02, c12, c22 def rigids_mul_rigids(a, b): """Group composition of Rigids 'a' and 'b'.""" c00 = a[0] * b[0] + a[1] * b[3] + a[2] * b[6] c01 = a[3] * b[0] + a[4] * b[3] + a[5] * b[6] c02 = a[6] * b[0] + a[7] * b[3] + a[8] * b[6] c10 = a[0] * b[1] + a[1] * b[4] + a[2] * b[7] c11 = a[3] * b[1] + a[4] * b[4] + a[5] * b[7] c12 = a[6] * b[1] + a[7] * b[4] + a[8] * b[7] c20 = a[0] * b[2] + a[1] * b[5] + a[2] * b[8] c21 = a[3] * b[2] + a[4] * b[5] + a[5] * b[8] c22 = a[6] * b[2] + a[7] * b[5] + a[8] * b[8] tr0 = a[0] * b[9] + a[1] * b[10] + a[2] * b[11] tr1 = a[3] * b[9] + a[4] * b[10] + a[5] * b[11] tr2 = a[6] * b[9] + a[7] * b[10] + a[8] * b[11] new_tr0 = a[9] + tr0 new_tr1 = a[10] + tr1 new_tr2 = a[11] + tr2 return [c00, c10, c20, c01, c11, c21, c02, c12, c22, new_tr0, new_tr1, new_tr2] def rigits_concate_all(xall, x5, x6, x7): return [mnp.concatenate([xall[0][:, 0:5], x5[0][:, None], x6[0][:, None], x7[0][:, None]], axis=-1), mnp.concatenate([xall[1][:, 0:5], x5[1][:, None], x6[1][:, None], x7[1][:, None]], axis=-1), mnp.concatenate([xall[2][:, 0:5], x5[2][:, None], x6[2][:, None], x7[2][:, None]], axis=-1), mnp.concatenate([xall[3][:, 0:5], x5[3][:, None], x6[3][:, None], x7[3][:, None]], axis=-1), mnp.concatenate([xall[4][:, 0:5], x5[4][:, None], x6[4][:, None], x7[4][:, None]], axis=-1), mnp.concatenate([xall[5][:, 0:5], x5[5][:, None], x6[5][:, None], x7[5][:, None]], axis=-1), mnp.concatenate([xall[6][:, 0:5], x5[6][:, None], x6[6][:, None], x7[6][:, None]], axis=-1), mnp.concatenate([xall[7][:, 0:5], x5[7][:, None], x6[7][:, None], x7[7][:, None]], axis=-1), mnp.concatenate([xall[8][:, 0:5], x5[8][:, None], x6[8][:, None], x7[8][:, None]], axis=-1), mnp.concatenate([xall[9][:, 0:5], x5[9][:, None], x6[9][:, None], x7[9][:, None]], axis=-1), mnp.concatenate([xall[10][:, 0:5], x5[10][:, None], x6[10][:, None], x7[10][:, None]], axis=-1), mnp.concatenate([xall[11][:, 0:5], x5[11][:, None], x6[11][:, None], x7[11][:, None]], axis=-1) ] def reshape_back(backb): return [backb[0][:, None], backb[1][:, None], backb[2][:, None], backb[3][:, None], backb[4][:, None], backb[5][:, None], backb[6][:, None], backb[7][:, None], backb[8][:, None], backb[9][:, None], backb[10][:, None], backb[11][:, None] ] def l2_normalize(x, axis=-1): return x / mnp.sqrt(mnp.sum(x ** 2, axis=axis, keepdims=True)) def torsion_angles_to_frames(aatype, backb_to_global, torsion_angles_sin_cos, restype_rigid_group_default_frame): """Compute rigid group frames from torsion angles.""" # Gather the default frames for all rigid groups. m = P.Gather()(restype_rigid_group_default_frame, aatype, 0) xx1, xy1, xz1, yx1, yy1, yz1, zx1, zy1, zz1, x1, y1, z1 = rigids_from_tensor4x4(m) # Create the rotation matrices according to the given angles (each frame is # defined such that its rotation is around the x-axis). sin_angles = torsion_angles_sin_cos[..., 0] cos_angles = torsion_angles_sin_cos[..., 1] # insert zero rotation for backbone group. num_residues, = aatype.shape sin_angles = mnp.concatenate([mnp.zeros([num_residues, 1]), sin_angles], axis=-1) cos_angles = mnp.concatenate([mnp.ones([num_residues, 1]), cos_angles], axis=-1) zeros = mnp.zeros_like(sin_angles) ones = mnp.ones_like(sin_angles) # Apply rotations to the frames. xx2, xy2, xz2, yx2, yy2, yz2, zx2, zy2, zz2 = rigids_mul_rots(xx1, xy1, xz1, yx1, yy1, yz1, zx1, zy1, zz1, ones, zeros, cos_angles, sin_angles) all_frames = [xx2, xy2, xz2, yx2, yy2, yz2, zx2, zy2, zz2, x1, y1, z1] # chi2, chi3, and chi4 frames do not transform to the backbone frame but to # the previous frame. So chain them up accordingly. chi2_frame_to_frame = [xx2[:, 5], xy2[:, 5], xz2[:, 5], yx2[:, 5], yy2[:, 5], yz2[:, 5], zx2[:, 5], zy2[:, 5], zz2[:, 5], x1[:, 5], y1[:, 5], z1[:, 5]] chi3_frame_to_frame = [xx2[:, 6], xy2[:, 6], xz2[:, 6], yx2[:, 6], yy2[:, 6], yz2[:, 6], zx2[:, 6], zy2[:, 6], zz2[:, 6], x1[:, 6], y1[:, 6], z1[:, 6]] chi4_frame_to_frame = [xx2[:, 7], xy2[:, 7], xz2[:, 7], yx2[:, 7], yy2[:, 7], yz2[:, 7], zx2[:, 7], zy2[:, 7], zz2[:, 7], x1[:, 7], y1[:, 7], z1[:, 7]] # chi1_frame_to_backb = [xx2[:, 4], xy2[:, 4], xz2[:, 4], yx2[:, 4], yy2[:, 4], yz2[:, 4], zx2[:, 4], zy2[:, 4], zz2[:, 4], x1[:, 4], y1[:, 4], z1[:, 4]] chi2_frame_to_backb = rigids_mul_rigids(chi1_frame_to_backb, chi2_frame_to_frame) chi3_frame_to_backb = rigids_mul_rigids(chi2_frame_to_backb, chi3_frame_to_frame) chi4_frame_to_backb = rigids_mul_rigids(chi3_frame_to_backb, chi4_frame_to_frame) # Recombine them to a r3.Rigids with shape (N, 8). all_frames_to_backb = rigits_concate_all(all_frames, chi2_frame_to_backb, chi3_frame_to_backb, chi4_frame_to_backb) backb_to_global_new = reshape_back(backb_to_global) # Create the global frames. # shape (N, 8) all_frames_to_global = rigids_mul_rigids(backb_to_global_new, all_frames_to_backb) # all_frames_to_global = rigids_mul_rigids(all_frames_to_backb, backb_to_global) return all_frames_to_global def map_atoms_to_global_func(all_frames, group_mask): return [mnp.sum(all_frames[0][:, None, :] * group_mask, axis=-1), mnp.sum(all_frames[1][:, None, :] * group_mask, axis=-1), mnp.sum(all_frames[2][:, None, :] * group_mask, axis=-1), mnp.sum(all_frames[3][:, None, :] * group_mask, axis=-1), mnp.sum(all_frames[4][:, None, :] * group_mask, axis=-1), mnp.sum(all_frames[5][:, None, :] * group_mask, axis=-1), mnp.sum(all_frames[6][:, None, :] * group_mask, axis=-1), mnp.sum(all_frames[7][:, None, :] * group_mask, axis=-1), mnp.sum(all_frames[8][:, None, :] * group_mask, axis=-1), mnp.sum(all_frames[9][:, None, :] * group_mask, axis=-1), mnp.sum(all_frames[10][:, None, :] * group_mask, axis=-1), mnp.sum(all_frames[11][:, None, :] * group_mask, axis=-1) ] def get_exp_frames(frames): return [mnp.expand_dims(frames[0], axis=0), mnp.expand_dims(frames[1], axis=0), mnp.expand_dims(frames[2], axis=0), mnp.expand_dims(frames[3], axis=0), mnp.expand_dims(frames[4], axis=0), mnp.expand_dims(frames[5], axis=0), mnp.expand_dims(frames[6], axis=0), mnp.expand_dims(frames[7], axis=0), mnp.expand_dims(frames[8], axis=0), mnp.expand_dims(frames[9], axis=0), mnp.expand_dims(frames[10], axis=0), mnp.expand_dims(frames[11], axis=0) ] def vecs_to_tensor(v): """Converts 'v' to tensor with shape 3, inverse of 'vecs_from_tensor'.""" return mnp.stack([v[0], v[1], v[2]], axis=-1) def atom14_to_atom37(atom14_data, residx_atom37_to_atom14, atom37_atom_exists, indices0): """Convert atom14 to atom37 representation.""" seq_length = atom14_data.shape[0] residx_atom37_to_atom14 = residx_atom37_to_atom14.reshape((seq_length, 37, 1)) new_indices = P.Concat(2)((indices0, residx_atom37_to_atom14)) atom37_data = P.GatherNd()(atom14_data, new_indices) # atom37_data = P.GatherBatch()(atom14_data, residx_atom37_to_atom14) if len(atom14_data.shape) == 2: atom37_data *= atom37_atom_exists elif len(atom14_data.shape) == 3: atom37_data *= atom37_atom_exists[:, :, None].astype(atom37_data.dtype) return atom37_data def batch_apply_rot_to_vec(rot, vec, unstack=False): """Multiply rotation matrix by a vector.""" if unstack: x, y, z = vec[:, :, 0], vec[:, :, 1], vec[:, :, 2] else: x, y, z = vec return [(rot[:, 0, 0, :] * x + rot[:, 0, 1, :] * y + rot[:, 0, 2, :] * z)[:, None, :], (rot[:, 1, 0, :] * x + rot[:, 1, 1, :] * y + rot[:, 1, 2, :] * z)[:, None, :], (rot[:, 2, 0, :] * x + rot[:, 2, 1, :] * y + rot[:, 2, 2, :] * z)[:, None, :]] def _batch_multiply(a, b): """ batch multiply operation""" x1 = mnp.concatenate( [(a[:, 0, 0, :] * b[:, 0, 0, :] + a[:, 0, 1, :] * b[:, 1, 0, :] + a[:, 0, 2, :] * b[:, 2, 0, :])[:, None, :], (a[:, 0, 0, :] * b[:, 0, 1, :] + a[:, 0, 1, :] * b[:, 1, 1, :] + a[:, 0, 2, :] * b[:, 2, 1, :])[:, None, :], (a[:, 0, 0, :] * b[:, 0, 2, :] + a[:, 0, 1, :] * b[:, 1, 2, :] + a[:, 0, 2, :] * b[:, 2, 2, :])[:, None, :]], axis=1)[:, None, :, :] x2 = mnp.concatenate( [(a[:, 1, 0, :] * b[:, 0, 0, :] + a[:, 1, 1, :] * b[:, 1, 0, :] + a[:, 1, 2, :] * b[:, 2, 0, :])[:, None, :], (a[:, 1, 0, :] * b[:, 0, 1, :] + a[:, 1, 1, :] * b[:, 1, 1, :] + a[:, 1, 2, :] * b[:, 2, 1, :])[:, None, :], (a[:, 1, 0, :] * b[:, 0, 2, :] + a[:, 1, 1, :] * b[:, 1, 2, :] + a[:, 1, 2, :] * b[:, 2, 2, :])[:, None, :]], axis=1)[:, None, :, :] x3 = mnp.concatenate( [(a[:, 2, 0, :] * b[:, 0, 0, :] + a[:, 2, 1, :] * b[:, 1, 0, :] + a[:, 2, 2, :] * b[:, 2, 0, :])[:, None, :], (a[:, 2, 0, :] * b[:, 0, 1, :] + a[:, 2, 1, :] * b[:, 1, 1, :] + a[:, 2, 2, :] * b[:, 2, 1, :])[:, None, :], (a[:, 2, 0, :] * b[:, 0, 2, :] + a[:, 2, 1, :] * b[:, 1, 2, :] + a[:, 2, 2, :] * b[:, 2, 2, :])[:, None, :]], axis=1)[:, None, :, :] return mnp.concatenate([x1, x2, x3], axis=1) def batch_make_canonical_transform(n_xyz, ca_xyz, c_xyz): """Returns translation and rotation matrices to canonicalize residue atoms. Note that this method does not take care of symmetries. If you provide the atom positions in the non-standard way, the N atom will end up not at [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You need to take care of such cases in your code. Args: n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates. ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates. c_xyz: An array of shape [batch, 3] of carbon xyz coordinates. Returns: A tuple (translation, rotation) where: translation is an array of shape [batch, 3] defining the translation. rotation is an array of shape [batch, 3, 3] defining the rotation. After applying the translation and rotation to all atoms in a residue: * All atoms will be shifted so that CA is at the origin, * All atoms will be rotated so that C is at the x-axis, * All atoms will be shifted so that N is in the xy plane. """ # Place CA at the origin. translation = -ca_xyz n_xyz = n_xyz + translation c_xyz = c_xyz + translation # Place C on the x-axis. c_x, c_y, c_z = c_xyz[:, :, 0], c_xyz[:, :, 1], c_xyz[:, :, 2] # Rotate by angle c1 in the x-y plane (around the z-axis). sin_c1 = -c_y / mnp.sqrt(1e-20 + c_x ** 2 + c_y ** 2) cos_c1 = c_x / mnp.sqrt(1e-20 + c_x ** 2 + c_y ** 2) zeros = mnp.zeros_like(sin_c1).astype("float32") ones = mnp.ones_like(sin_c1).astype("float32") # # pylint: disable=bad-whitespace c1_rot_matrix = mnp.concatenate( [mnp.concatenate((cos_c1[:, None, ...], (-sin_c1)[:, None, ...], zeros[:, None, ...]), axis=1)[:, None, :, :], mnp.concatenate((sin_c1[:, None, ...], cos_c1[:, None, ...], zeros[:, None, ...]), axis=1)[:, None, :, :], mnp.concatenate((zeros[:, None, ...], zeros[:, None, ...], ones[:, None, ...]), axis=1)[:, None, :, :]], axis=1) # # Rotate by angle c2 in the x-z plane (around the y-axis). sin_c2 = c_z / mnp.sqrt(1e-20 + c_x ** 2 + c_y ** 2 + c_z ** 2) cos_c2 = mnp.sqrt(c_x ** 2 + c_y ** 2) / mnp.sqrt(1e-20 + c_x ** 2 + c_y ** 2 + c_z ** 2) c2_rot_matrix = mnp.concatenate( [mnp.concatenate((cos_c2[:, None, ...], zeros[:, None, ...], sin_c2[:, None, ...]), axis=1)[:, None, :, :], mnp.concatenate((zeros[:, None, ...], ones[:, None, ...], zeros[:, None, ...]), axis=1)[:, None, :, :], mnp.concatenate(((-sin_c2)[:, None, ...], zeros[:, None, ...], cos_c2[:, None, ...]), axis=1)[:, None, :, :]], axis=1) c_rot_matrix = _batch_multiply(c2_rot_matrix, c1_rot_matrix) n_xyz = mnp.transpose(mnp.concatenate(batch_apply_rot_to_vec(c_rot_matrix, n_xyz, unstack=True), axis=1), (0, 2, 1)) # # Place N in the x-y plane. _, n_y, n_z = n_xyz[:, :, 0], n_xyz[:, :, 1], n_xyz[:, :, 2] # # Rotate by angle alpha in the y-z plane (around the x-axis). sin_n = -n_z / mnp.sqrt(1e-20 + n_y ** 2 + n_z ** 2) cos_n = n_y / mnp.sqrt(1e-20 + n_y ** 2 + n_z ** 2) n_rot_matrix = mnp.concatenate( [mnp.concatenate([ones[:, None, ...], zeros[:, None, ...], zeros[:, None, ...]], axis=1)[:, None, :, :], mnp.concatenate([zeros[:, None, ...], cos_n[:, None, ...], (-sin_n)[:, None, ...]], axis=1)[:, None, :, :], mnp.concatenate([zeros[:, None, ...], sin_n[:, None, ...], cos_n[:, None, ...]], axis=1)[:, None, :, :]], axis=1) return translation, mnp.transpose(_batch_multiply(n_rot_matrix, c_rot_matrix), [0, 3, 1, 2]) def batch_make_transform_from_reference(n_xyz, ca_xyz, c_xyz): """Returns rotation and translation matrices to convert from reference. Note that this method does not take care of symmetries. If you provide the atom positions in the non-standard way, the N atom will end up not at [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You need to take care of such cases in your code. Args: n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates. ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates. c_xyz: An array of shape [batch, 3] of carbon xyz coordinates. Returns: A tuple (rotation, translation) where: rotation is an array of shape [batch, 3, 3] defining the rotation. translation is an array of shape [batch, 3] defining the translation. After applying the translation and rotation to the reference backbone, the coordinates will approximately equal to the input coordinates. The order of translation and rotation differs from make_canonical_transform because the rotation from this function should be applied before the translation, unlike make_canonical_transform. """ translation, rotation = batch_make_canonical_transform(n_xyz, ca_xyz, c_xyz) return mnp.transpose(rotation, (0, 1, 3, 2)), -translation def batch_rot_to_quat(rot, unstack_inputs=False): """Convert rotation matrix to quaternion. Note that this function calls self_adjoint_eig which is extremely expensive on the GPU. If at all possible, this function should run on the CPU. Args: rot: rotation matrix (see below for format). unstack_inputs: If true, rotation matrix should be shape (..., 3, 3) otherwise the rotation matrix should be a list of lists of tensors. Returns: Quaternion as (..., 4) tensor. """ if unstack_inputs: rot = mnp.transpose(rot, [0, 3, 2, 1]) xx, xy, xz = rot[:, 0, 0, :], rot[:, 0, 1, :], rot[:, 0, 2, :] yx, yy, yz = rot[:, 1, 0, :], rot[:, 1, 1, :], rot[:, 1, 2, :] zx, zy, zz = rot[:, 2, 0, :], rot[:, 2, 1, :], rot[:, 2, 2, :] k = mnp.stack((mnp.stack((xx + yy + zz, zy - yz, xz - zx, yx - xy), axis=-1), mnp.stack((zy - yz, xx - yy - zz, xy + yx, xz + zx), axis=-1), mnp.stack((xz - zx, xy + yx, yy - xx - zz, yz + zy), axis=-1), mnp.stack((yx - xy, xz + zx, yz + zy, zz - xx - yy), axis=-1)), axis=-2) k = (1. / 3.) * k k = k[:, :, :, 0] return k def batch_quat_affine(quaternion, translation, rotation=None, normalize=True, unstack_inputs=False): if unstack_inputs: if rotation is not None: rotation = mnp.transpose(rotation, [0, 3, 2, 1]) translation = mnp.moveaxis(translation, -1, 1) # Unstack. if normalize and quaternion is not None: quaternion = quaternion / mnp.norm(quaternion, axis=-1, keepdims=True) return quaternion, rotation, translation def batch_apply_inverse_rot_to_vec(rot, vec): """Multiply the inverse of a rotation matrix by a vector.""" # Inverse rotation is just transpose return mnp.concatenate( ((rot[:, 0, 0, :] * vec[:, 0] + rot[:, 1, 0, :] * vec[:, 1] + rot[:, 2, 0, :] * vec[:, 2])[:, None, ...], (rot[:, 0, 1, :] * vec[:, 0] + rot[:, 1, 1, :] * vec[:, 1] + rot[:, 2, 1, :] * vec[:, 2])[:, None, ...], (rot[:, 0, 2, :] * vec[:, 0] + rot[:, 1, 2, :] * vec[:, 1] + rot[:, 2, 2, :] * vec[:, 2])[:, None, ...]), axis=1) def batch_invert_point(transformed_point, rotation, translation, extra_dims=0): """Apply inverse of transformation to a point. Args: transformed_point: List of 3 tensors to apply affine extra_dims: Number of dimensions at the end of the transformed_point shape that are not present in the rotation and translation. The most common use is rotation N points at once with extra_dims=1 for use in a network. Returns: Transformed point after applying affine. """ for _ in range(extra_dims): rotation = mnp.expand_dims(rotation, axis=-1) translation = mnp.expand_dims(translation, axis=-1) rot_point = transformed_point - translation return batch_apply_inverse_rot_to_vec(rotation, rot_point) def compute_confidence(predicted_lddt_logits): """compute confidence""" num_bins = predicted_lddt_logits.shape[-1] bin_width = 1 / num_bins start_n = bin_width / 2 plddt = compute_plddt(predicted_lddt_logits, start_n, bin_width) confidence = np.mean(plddt) return confidence def compute_plddt(logits, start_n, bin_width): """Computes per-residue pLDDT from logits. Args: logits: [num_res, num_bins] output from the PredictedLDDTHead. Returns: plddt: [num_res] per-residue pLDDT. """ bin_centers = np.arange(start=start_n, stop=1.0, step=bin_width) probs = softmax(logits, axis=-1) predicted_lddt_ca = np.sum(probs * bin_centers[None, :], axis=-1) return predicted_lddt_ca * 100