You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 47 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038
  1. """utils module"""
  2. import numpy as np
  3. from scipy.special import softmax
  4. from mindspore.ops import operations as P
  5. import mindspore.numpy as mnp
  6. import mindspore.nn as nn
  7. from mindspore.common.tensor import Tensor
  8. from commons import residue_constants
  9. import commons.r3 as r3
  10. QUAT_TO_ROT = np.zeros((4, 4, 3, 3), dtype=np.float32)
  11. QUAT_TO_ROT[0, 0] = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] # rr
  12. QUAT_TO_ROT[1, 1] = [[1, 0, 0], [0, -1, 0], [0, 0, -1]] # ii
  13. QUAT_TO_ROT[2, 2] = [[-1, 0, 0], [0, 1, 0], [0, 0, -1]] # jj
  14. QUAT_TO_ROT[3, 3] = [[-1, 0, 0], [0, -1, 0], [0, 0, 1]] # kk
  15. QUAT_TO_ROT[1, 2] = [[0, 2, 0], [2, 0, 0], [0, 0, 0]] # ij
  16. QUAT_TO_ROT[1, 3] = [[0, 0, 2], [0, 0, 0], [2, 0, 0]] # ik
  17. QUAT_TO_ROT[2, 3] = [[0, 0, 0], [0, 0, 2], [0, 2, 0]] # jk
  18. QUAT_TO_ROT[0, 1] = [[0, 0, 0], [0, 0, -2], [0, 2, 0]] # ir
  19. QUAT_TO_ROT[0, 2] = [[0, 0, 2], [0, 0, 0], [-2, 0, 0]] # jr
  20. QUAT_TO_ROT[0, 3] = [[0, -2, 0], [2, 0, 0], [0, 0, 0]] # kr
  21. QUAT_TO_ROT = Tensor(QUAT_TO_ROT)
  22. def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
  23. """Create pseudo beta features."""
  24. is_gly = mnp.equal(aatype, residue_constants.restype_order['G'])
  25. ca_idx = residue_constants.atom_order['CA']
  26. cb_idx = residue_constants.atom_order['CB']
  27. pseudo_beta = mnp.where(
  28. mnp.tile(is_gly[..., None].astype("int32"), [1,] * len(is_gly.shape) + [3,]).astype("bool"),
  29. all_atom_positions[..., ca_idx, :],
  30. all_atom_positions[..., cb_idx, :])
  31. if all_atom_masks is not None:
  32. pseudo_beta_mask = mnp.where(is_gly, all_atom_masks[..., ca_idx], all_atom_masks[..., cb_idx])
  33. pseudo_beta_mask = pseudo_beta_mask.astype(mnp.float32)
  34. return pseudo_beta, pseudo_beta_mask
  35. return pseudo_beta
  36. def dgram_from_positions(positions, num_bins, min_bin, max_bin):
  37. """Compute distogram from amino acid positions.
  38. Arguments:
  39. positions: [N_res, 3] Position coordinates.
  40. num_bins: The number of bins in the distogram.
  41. min_bin: The left edge of the first bin.
  42. max_bin: The left edge of the final bin. The final bin catches
  43. everything larger than `max_bin`.
  44. Returns:
  45. Distogram with the specified number of bins.
  46. """
  47. def squared_difference(x, y):
  48. return mnp.square(x - y)
  49. lower_breaks = mnp.linspace(min_bin, max_bin, num_bins)
  50. lower_breaks = mnp.square(lower_breaks)
  51. upper_breaks = mnp.concatenate([lower_breaks[1:], mnp.array([1e8], dtype=mnp.float32)], axis=-1)
  52. dist2 = mnp.sum(squared_difference(mnp.expand_dims(positions, axis=-2),
  53. mnp.expand_dims(positions, axis=-3)), axis=-1, keepdims=True)
  54. dgram = ((dist2 > lower_breaks).astype(mnp.float32) * (dist2 < upper_breaks).astype(mnp.float32))
  55. return dgram
  56. def _multiply(a, b):
  57. return mnp.stack([mnp.concatenate([(a[0][0] * b[0][0] + a[0][1] * b[1][0] + a[0][2] * b[2][0])[None, ...],
  58. (a[0][0] * b[0][1] + a[0][1] * b[1][1] + a[0][2] * b[2][1])[None, ...],
  59. (a[0][0] * b[0][2] + a[0][1] * b[1][2] + a[0][2] * b[2][2])[None, ...]], axis=0),
  60. mnp.concatenate([(a[1][0] * b[0][0] + a[1][1] * b[1][0] + a[1][2] * b[2][0])[None, ...],
  61. (a[1][0] * b[0][1] + a[1][1] * b[1][1] + a[1][2] * b[2][1])[None, ...],
  62. (a[1][0] * b[0][2] + a[1][1] * b[1][2] + a[1][2] * b[2][2])[None, ...]], axis=0),
  63. mnp.concatenate([(a[2][0] * b[0][0] + a[2][1] * b[1][0] + a[2][2] * b[2][0])[None, ...],
  64. (a[2][0] * b[0][1] + a[2][1] * b[1][1] + a[2][2] * b[2][1])[None, ...],
  65. (a[2][0] * b[0][2] + a[2][1] * b[1][2] + a[2][2] * b[2][2])[None, ...]],
  66. axis=0)])
  67. def apply_rot_to_vec(rot, vec, unstack=False):
  68. """Multiply rotation matrix by a vector."""
  69. if unstack:
  70. x, y, z = vec[:, 0], vec[:, 1], vec[:, 2]
  71. else:
  72. x, y, z = vec
  73. return [rot[0][0] * x + rot[0][1] * y + rot[0][2] * z,
  74. rot[1][0] * x + rot[1][1] * y + rot[1][2] * z,
  75. rot[2][0] * x + rot[2][1] * y + rot[2][2] * z]
  76. def make_canonical_transform(n_xyz, ca_xyz, c_xyz):
  77. """Returns translation and rotation matrices to canonicalize residue atoms.
  78. Note that this method does not take care of symmetries. If you provide the
  79. atom positions in the non-standard way, the N atom will end up not at
  80. [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You
  81. need to take care of such cases in your code.
  82. Args:
  83. n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates.
  84. ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates.
  85. c_xyz: An array of shape [batch, 3] of carbon xyz coordinates.
  86. Returns:
  87. A tuple (translation, rotation) where:
  88. translation is an array of shape [batch, 3] defining the translation.
  89. rotation is an array of shape [batch, 3, 3] defining the rotation.
  90. After applying the translation and rotation to all atoms in a residue:
  91. * All atoms will be shifted so that CA is at the origin,
  92. * All atoms will be rotated so that C is at the x-axis,
  93. * All atoms will be shifted so that N is in the xy plane.
  94. """
  95. # Place CA at the origin.
  96. translation = -ca_xyz
  97. n_xyz = n_xyz + translation
  98. c_xyz = c_xyz + translation
  99. # Place C on the x-axis.
  100. c_x, c_y, c_z = c_xyz[:, 0], c_xyz[:, 1], c_xyz[:, 2]
  101. # Rotate by angle c1 in the x-y plane (around the z-axis).
  102. sin_c1 = -c_y / mnp.sqrt(1e-20 + c_x ** 2 + c_y ** 2)
  103. cos_c1 = c_x / mnp.sqrt(1e-20 + c_x ** 2 + c_y ** 2)
  104. zeros = mnp.zeros_like(sin_c1).astype("float32")
  105. ones = mnp.ones_like(sin_c1).astype("float32")
  106. # # pylint: disable=bad-whitespace
  107. c1_rot_matrix = mnp.stack([mnp.concatenate((cos_c1[None, ...], (-sin_c1)[None, ...], zeros[None, ...]), axis=0),
  108. mnp.concatenate((sin_c1[None, ...], cos_c1[None, ...], zeros[None, ...]), axis=0),
  109. mnp.concatenate((zeros[None, ...], zeros[None, ...], ones[None, ...]), axis=0)])
  110. # # Rotate by angle c2 in the x-z plane (around the y-axis).
  111. sin_c2 = c_z / mnp.sqrt(1e-20 + c_x ** 2 + c_y ** 2 + c_z ** 2)
  112. cos_c2 = mnp.sqrt(c_x ** 2 + c_y ** 2) / mnp.sqrt(1e-20 + c_x ** 2 + c_y ** 2 + c_z ** 2)
  113. c2_rot_matrix = mnp.stack([mnp.concatenate((cos_c2[None, ...], zeros[None, ...], sin_c2[None, ...]), axis=0),
  114. mnp.concatenate((zeros[None, ...], ones[None, ...], zeros[None, ...]), axis=0),
  115. mnp.concatenate(((-sin_c2)[None, ...], zeros[None, ...], cos_c2[None, ...]), axis=0)])
  116. c_rot_matrix = _multiply(c2_rot_matrix, c1_rot_matrix)
  117. n_xyz = mnp.stack(apply_rot_to_vec(c_rot_matrix, n_xyz, unstack=True)).T
  118. # Place N in the x-y plane.
  119. _, n_y, n_z = n_xyz[:, 0], n_xyz[:, 1], n_xyz[:, 2]
  120. # Rotate by angle alpha in the y-z plane (around the x-axis).
  121. sin_n = -n_z / mnp.sqrt(1e-20 + n_y ** 2 + n_z ** 2)
  122. cos_n = n_y / mnp.sqrt(1e-20 + n_y ** 2 + n_z ** 2)
  123. n_rot_matrix = mnp.stack([mnp.concatenate([ones[None, ...], zeros[None, ...], zeros[None, ...]], axis=0),
  124. mnp.concatenate([zeros[None, ...], cos_n[None, ...], (-sin_n)[None, ...]], axis=0),
  125. mnp.concatenate([zeros[None, ...], sin_n[None, ...], cos_n[None, ...]], axis=0)])
  126. return translation, mnp.transpose(_multiply(n_rot_matrix, c_rot_matrix), [2, 0, 1])
  127. def make_transform_from_reference(n_xyz, ca_xyz, c_xyz):
  128. """Returns rotation and translation matrices to convert from reference.
  129. Note that this method does not take care of symmetries. If you provide the
  130. atom positions in the non-standard way, the N atom will end up not at
  131. [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You
  132. need to take care of such cases in your code.
  133. Args:
  134. n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates.
  135. ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates.
  136. c_xyz: An array of shape [batch, 3] of carbon xyz coordinates.
  137. Returns:
  138. A tuple (rotation, translation) where:
  139. rotation is an array of shape [batch, 3, 3] defining the rotation.
  140. translation is an array of shape [batch, 3] defining the translation.
  141. After applying the translation and rotation to the reference backbone,
  142. the coordinates will approximately equal to the input coordinates.
  143. The order of translation and rotation differs from make_canonical_transform
  144. because the rotation from this function should be applied before the
  145. translation, unlike make_canonical_transform.
  146. """
  147. translation, rotation = make_canonical_transform(n_xyz, ca_xyz, c_xyz)
  148. return mnp.transpose(rotation, (0, 2, 1)), -translation
  149. def rot_to_quat(rot, unstack_inputs=False):
  150. """Convert rotation matrix to quaternion.
  151. Note that this function calls self_adjoint_eig which is extremely expensive on
  152. the GPU. If at all possible, this function should run on the CPU.
  153. Args:
  154. rot: rotation matrix (see below for format).
  155. unstack_inputs: If true, rotation matrix should be shape (..., 3, 3)
  156. otherwise the rotation matrix should be a list of lists of tensors.
  157. Returns:
  158. Quaternion as (..., 4) tensor.
  159. """
  160. if unstack_inputs:
  161. rot = mnp.transpose(rot, [2, 1, 0])
  162. xx, xy, xz = rot[0][0], rot[0][1], rot[0][2]
  163. yx, yy, yz = rot[1][0], rot[1][1], rot[1][2]
  164. zx, zy, zz = rot[2][0], rot[2][1], rot[2][2]
  165. k = mnp.stack((mnp.stack((xx + yy + zz, zy - yz, xz - zx, yx - xy), axis=-1),
  166. mnp.stack((zy - yz, xx - yy - zz, xy + yx, xz + zx), axis=-1),
  167. mnp.stack((xz - zx, xy + yx, yy - xx - zz, yz + zy), axis=-1),
  168. mnp.stack((yx - xy, xz + zx, yz + zy, zz - xx - yy), axis=-1)), axis=-2)
  169. k = (1. / 3.) * k
  170. k = k[:, :, 0]
  171. return k
  172. def quat_to_rot(normalized_quat):
  173. """Convert a normalized quaternion to a rotation matrix."""
  174. rot_tensor = mnp.sum(mnp.reshape(QUAT_TO_ROT, (4, 4, 9)) * normalized_quat[..., :, None, None] *
  175. normalized_quat[..., None, :, None], axis=(-3, -2))
  176. rot = mnp.moveaxis(rot_tensor, -1, 0) # Unstack.
  177. return [[rot[0], rot[1], rot[2]],
  178. [rot[3], rot[4], rot[5]],
  179. [rot[6], rot[7], rot[8]]]
  180. def quat_affine(quaternion, translation, rotation=None, normalize=True, unstack_inputs=False):
  181. """create quat affine representations"""
  182. if unstack_inputs and rotation is not None:
  183. rotation = mnp.transpose(rotation, [2, 1, 0])
  184. translation = mnp.moveaxis(translation, -1, 0) # Unstack.
  185. if normalize and quaternion is not None:
  186. quaternion = quaternion / mnp.norm(quaternion, axis=-1, keepdims=True)
  187. if rotation is None:
  188. rotation = quat_to_rot(quaternion)
  189. return quaternion, rotation, translation
  190. def apply_inverse_rot_to_vec(rot, vec):
  191. """Multiply the inverse of a rotation matrix by a vector."""
  192. # Inverse rotation is just transpose
  193. return mnp.concatenate(((rot[0][0] * vec[0] + rot[1][0] * vec[1] + rot[2][0] * vec[2])[None, ...],
  194. (rot[0][1] * vec[0] + rot[1][1] * vec[1] + rot[2][1] * vec[2])[None, ...],
  195. (rot[0][2] * vec[0] + rot[1][2] * vec[1] + rot[2][2] * vec[2])[None, ...]), axis=0)
  196. def invert_point(transformed_point, rotation, translation, extra_dims=0):
  197. """Apply inverse of transformation to a point.
  198. Args:
  199. transformed_point: List of 3 tensors to apply affine
  200. extra_dims: Number of dimensions at the end of the transformed_point
  201. shape that are not present in the rotation and translation. The most
  202. common use is rotation N points at once with extra_dims=1 for use in a
  203. network.
  204. Returns:
  205. Transformed point after applying affine.
  206. """
  207. for _ in range(extra_dims):
  208. rotation = mnp.expand_dims(rotation, axis=-1)
  209. translation = mnp.expand_dims(translation, axis=-1)
  210. rot_point = transformed_point - translation
  211. return apply_inverse_rot_to_vec(rotation, rot_point)
  212. def _invert_point(transformed_point, rotation, translation):
  213. """Apply inverse of transformation to a point.
  214. Args:
  215. transformed_point: List of 3 tensors to apply affine
  216. extra_dims: Number of dimensions at the end of the transformed_point
  217. shape that are not present in the rotation and translation. The most
  218. common use is rotation N points at once with extra_dims=1 for use in a
  219. network.
  220. Returns:
  221. Transformed point after applying affine.
  222. """
  223. r00 = mnp.expand_dims(rotation[0][0], axis=-1)
  224. r01 = mnp.expand_dims(rotation[0][1], axis=-1)
  225. r02 = mnp.expand_dims(rotation[0][2], axis=-1)
  226. r10 = mnp.expand_dims(rotation[1][0], axis=-1)
  227. r11 = mnp.expand_dims(rotation[1][1], axis=-1)
  228. r12 = mnp.expand_dims(rotation[1][2], axis=-1)
  229. r20 = mnp.expand_dims(rotation[2][0], axis=-1)
  230. r21 = mnp.expand_dims(rotation[2][1], axis=-1)
  231. r22 = mnp.expand_dims(rotation[2][2], axis=-1)
  232. t0 = mnp.expand_dims(translation[0], axis=-1)
  233. t1 = mnp.expand_dims(translation[1], axis=-1)
  234. t2 = mnp.expand_dims(translation[2], axis=-1)
  235. rot_point = [transformed_point[0] - t0, transformed_point[1] - t1, transformed_point[2] - t2]
  236. result = [r00 * rot_point[0] + r10 * rot_point[1] + r20 * rot_point[2],
  237. r01 * rot_point[0] + r11 * rot_point[1] + r21 * rot_point[2],
  238. r02 * rot_point[0] + r12 * rot_point[1] + r22 * rot_point[2]]
  239. return result
  240. def mask_mean(mask, value, axis=None, drop_mask_channel=False, eps=1e-10):
  241. """Masked mean."""
  242. if drop_mask_channel:
  243. mask = mask[..., 0]
  244. mask_shape = mask.shape
  245. value_shape = value.shape
  246. broadcast_factor = 1.
  247. value_size = value_shape[axis]
  248. mask_size = mask_shape[axis]
  249. if mask_size == 1:
  250. broadcast_factor *= value_size
  251. return mnp.sum(mask * value, axis=axis) / (mnp.sum(mask, axis=axis) * broadcast_factor + eps)
  252. def atom37_to_torsion_angles(
  253. aatype, # (B, N)
  254. all_atom_pos, # (B, N, 37, 3)
  255. all_atom_mask, # (B, N, 37)
  256. chi_atom_indices,
  257. chi_angles_mask,
  258. mirror_psi_mask,
  259. chi_pi_periodic,
  260. indices0,
  261. indices1
  262. ):
  263. """Computes the 7 torsion angles (in sin, cos encoding) for each residue.
  264. The 7 torsion angles are in the order
  265. '[pre_omega, phi, psi, chi_1, chi_2, chi_3, chi_4]',
  266. here pre_omega denotes the omega torsion angle between the given amino acid
  267. and the previous amino acid.
  268. Args:
  269. aatype: Amino acid type, given as array with integers.
  270. all_atom_pos: atom37 representation of all atom coordinates.
  271. all_atom_mask: atom37 representation of mask on all atom coordinates.
  272. placeholder_for_undefined: flag denoting whether to set masked torsion
  273. angles to zero.
  274. Returns:
  275. Dict containing:
  276. * 'torsion_angles_sin_cos': Array with shape (B, N, 7, 2) where the final
  277. 2 dimensions denote sin and cos respectively
  278. * 'alt_torsion_angles_sin_cos': same as 'torsion_angles_sin_cos', but
  279. with the angle shifted by pi for all chi angles affected by the naming
  280. ambiguities.
  281. * 'torsion_angles_mask': Mask for which chi angles are present.
  282. """
  283. # Map aatype > 20 to 'Unknown' (20).
  284. aatype = mnp.minimum(aatype, 20)
  285. # Compute the backbone angles.
  286. num_batch, num_res = aatype.shape
  287. pad = mnp.zeros([num_batch, 1, 37, 3], mnp.float32)
  288. prev_all_atom_pos = mnp.concatenate([pad, all_atom_pos[:, :-1, :, :]], axis=1)
  289. pad = mnp.zeros([num_batch, 1, 37], mnp.float32)
  290. prev_all_atom_mask = mnp.concatenate([pad, all_atom_mask[:, :-1, :]], axis=1)
  291. # For each torsion angle collect the 4 atom positions that define this angle.
  292. # shape (B, N, atoms=4, xyz=3)
  293. pre_omega_atom_pos = mnp.concatenate([prev_all_atom_pos[:, :, 1:3, :], all_atom_pos[:, :, 0:2, :]], axis=-2)
  294. phi_atom_pos = mnp.concatenate([prev_all_atom_pos[:, :, 2:3, :], all_atom_pos[:, :, 0:3, :]], axis=-2)
  295. psi_atom_pos = mnp.concatenate([all_atom_pos[:, :, 0:3, :], all_atom_pos[:, :, 4:5, :]], axis=-2)
  296. # # Collect the masks from these atoms.
  297. # # Shape [batch, num_res]
  298. # ERROR NO PROD
  299. pre_omega_mask = (P.ReduceProd()(prev_all_atom_mask[:, :, 1:3], -1) # prev CA, C
  300. * P.ReduceProd()(all_atom_mask[:, :, 0:2], -1)) # this N, CA
  301. phi_mask = (prev_all_atom_mask[:, :, 2] # prev C
  302. * P.ReduceProd()(all_atom_mask[:, :, 0:3], -1)) # this N, CA, C
  303. psi_mask = (P.ReduceProd()(all_atom_mask[:, :, 0:3], -1) * # this N, CA, C
  304. all_atom_mask[:, :, 4]) # this O
  305. # Collect the atoms for the chi-angles.
  306. # Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4].
  307. # Select atoms to compute chis. Shape: [batch, num_res, chis=4, atoms=4].
  308. atom_indices = mnp.take(chi_atom_indices, aatype, axis=0)
  309. # # Gather atom positions Batch Gather. Shape: [batch, num_res, chis=4, atoms=4, xyz=3].
  310. # 4 seq_length 4 4 batch, sequence length, chis, atoms
  311. seq_length = all_atom_pos.shape[1]
  312. atom_indices = atom_indices.reshape((4, seq_length, 4, 4, 1)).astype("int64")
  313. new_indices = P.Concat(4)((indices0, indices1, atom_indices)) # 4, seq_length, 4, 4, 3
  314. chis_atom_pos = P.GatherNd()(all_atom_pos, new_indices)
  315. chis_mask = mnp.take(chi_angles_mask, aatype, axis=0)
  316. chi_angle_atoms_mask = P.GatherNd()(all_atom_mask, new_indices)
  317. # chis_atom_pos = P.GatherBatch(axis=0, batch=2)(all_atom_pos, atom_indices)
  318. # chis_mask = mnp.take(chi_angles_mask, aatype, axis=0)
  319. # chi_angle_atoms_mask = P.GatherBatch(axis=0, batch=2)(all_atom_mask, atom_indices)
  320. # Check if all 4 chi angle atoms were set. Shape: [batch, num_res, chis=4].
  321. chi_angle_atoms_mask = P.ReduceProd()(chi_angle_atoms_mask, -1)
  322. chis_mask = chis_mask * (chi_angle_atoms_mask).astype(mnp.float32)
  323. # Stack all torsion angle atom positions.
  324. # Shape (B, N, torsions=7, atoms=4, xyz=3)ls
  325. torsions_atom_pos = mnp.concatenate([pre_omega_atom_pos[:, :, None, :, :],
  326. phi_atom_pos[:, :, None, :, :],
  327. psi_atom_pos[:, :, None, :, :],
  328. chis_atom_pos], axis=2)
  329. # Stack up masks for all torsion angles.
  330. # shape (B, N, torsions=7)
  331. torsion_angles_mask = mnp.concatenate([pre_omega_mask[:, :, None],
  332. phi_mask[:, :, None],
  333. psi_mask[:, :, None],
  334. chis_mask], axis=2)
  335. torsion_frames_rots, torsion_frames_trans = r3.rigids_from_3_points(
  336. torsions_atom_pos[:, :, :, 1, :],
  337. torsions_atom_pos[:, :, :, 2, :],
  338. torsions_atom_pos[:, :, :, 0, :])
  339. inv_torsion_rots, inv_torsion_trans = r3.invert_rigids(torsion_frames_rots, torsion_frames_trans)
  340. forth_atom_rel_pos = r3.rigids_mul_vecs(inv_torsion_rots, inv_torsion_trans, torsions_atom_pos[:, :, :, 3, :])
  341. # Compute the position of the forth atom in this frame (y and z coordinate
  342. torsion_angles_sin_cos = mnp.stack([forth_atom_rel_pos[..., 2], forth_atom_rel_pos[..., 1]], axis=-1)
  343. torsion_angles_sin_cos /= mnp.sqrt(mnp.sum(mnp.square(torsion_angles_sin_cos), axis=-1, keepdims=True) + 1e-8)
  344. # Mirror psi, because we computed it from the Oxygen-atom.
  345. torsion_angles_sin_cos *= mirror_psi_mask
  346. chi_is_ambiguous = mnp.take(chi_pi_periodic, aatype, axis=0)
  347. mirror_torsion_angles = mnp.concatenate([mnp.ones([num_batch, num_res, 3]), 1.0 - 2.0 * chi_is_ambiguous], axis=-1)
  348. alt_torsion_angles_sin_cos = (torsion_angles_sin_cos * mirror_torsion_angles[:, :, :, None])
  349. return torsion_angles_sin_cos, alt_torsion_angles_sin_cos, torsion_angles_mask
  350. def get_chi_atom_indices():
  351. """Returns atom indices needed to compute chi angles for all residue types.
  352. Returns:
  353. A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
  354. in the order specified in residue_constants.restypes + unknown residue type
  355. at the end. For chi angles which are not defined on the residue, the
  356. positions indices are by default set to 0.
  357. """
  358. chi_atom_indices = []
  359. for residue_name in residue_constants.restypes:
  360. residue_name = residue_constants.restype_1to3[residue_name]
  361. residue_chi_angles = residue_constants.chi_angles_atoms[residue_name]
  362. atom_indices = []
  363. for chi_angle in residue_chi_angles:
  364. atom_indices.append([residue_constants.atom_order[atom] for atom in chi_angle])
  365. for _ in range(4 - len(atom_indices)):
  366. atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA.
  367. chi_atom_indices.append(atom_indices)
  368. chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
  369. return np.asarray(chi_atom_indices)
  370. def to_tensor(quaternion, translation):
  371. return mnp.concatenate([quaternion, translation], axis=-1)
  372. def from_tensor(tensor, normalize=False):
  373. quaternion, tx, ty, tz = mnp.split(tensor, [4, 5, 6], axis=-1)
  374. return quat_affine(quaternion, mnp.stack([tx[..., 0], ty[..., 0], tz[..., 0]], axis=-1), normalize=normalize)
  375. # return quat_affine(quaternion, [tx[..., 0], ty[..., 0], tz[..., 0]], normalize=normalize)
  376. def generate_new_affine(sequence_mask):
  377. num_residues, _ = sequence_mask.shape
  378. quaternion = mnp.tile(mnp.reshape(mnp.asarray([1., 0., 0., 0.]), [1, 4]), [num_residues, 1])
  379. translation = mnp.zeros([num_residues, 3])
  380. return quat_affine(quaternion, translation, unstack_inputs=True)
  381. def pre_compose(quaternion, rotation, translation, update):
  382. """Return a new QuatAffine which applies the transformation update first.
  383. Args:
  384. update: Length-6 vector. 3-vector of x, y, and z such that the quaternion
  385. update is (1, x, y, z) and zero for the 3-vector is the identity
  386. quaternion. 3-vector for translation concatenated.
  387. Returns:
  388. New QuatAffine object.
  389. """
  390. vector_quaternion_update, x, y, z = mnp.split(update, [3, 4, 5], axis=-1)
  391. trans_update = [mnp.squeeze(x, axis=-1), mnp.squeeze(y, axis=-1), mnp.squeeze(z, axis=-1)]
  392. new_quaternion = (quaternion + quat_multiply_by_vec(quaternion, vector_quaternion_update))
  393. trans_update = apply_rot_to_vec(rotation, trans_update)
  394. new_translation = [translation[0] + trans_update[0],
  395. translation[1] + trans_update[1],
  396. translation[2] + trans_update[2]]
  397. return quat_affine(new_quaternion, mnp.stack(new_translation, axis=-1))
  398. def scale_translation(quaternion, translation, rotation, position_scale):
  399. """Return a new quat affine with a different scale for translation."""
  400. return quat_affine(quaternion,
  401. mnp.stack([translation[0] * position_scale, translation[1] * position_scale,
  402. translation[2] * position_scale], axis=-1),
  403. rotation=rotation,
  404. normalize=False)
  405. def rigids_from_tensor4x4(m):
  406. """Construct Rigids object from an 4x4 array.
  407. Here the 4x4 is representing the transformation in homogeneous coordinates.
  408. Args:
  409. m: Array representing transformations in homogeneous coordinates.
  410. Returns:
  411. Rigids object corresponding to transformations m
  412. """
  413. return m[..., 0, 0], m[..., 0, 1], m[..., 0, 2], m[..., 1, 0], m[..., 1, 1], m[..., 1, 2], m[..., 2, 0], \
  414. m[..., 2, 1], m[..., 2, 2], m[..., 0, 3], m[..., 1, 3], m[..., 2, 3]
  415. def apply_to_point(rotation, translation, point):
  416. """apply to point func"""
  417. r00 = mnp.expand_dims(rotation[0][0], axis=-1)
  418. r01 = mnp.expand_dims(rotation[0][1], axis=-1)
  419. r02 = mnp.expand_dims(rotation[0][2], axis=-1)
  420. r10 = mnp.expand_dims(rotation[1][0], axis=-1)
  421. r11 = mnp.expand_dims(rotation[1][1], axis=-1)
  422. r12 = mnp.expand_dims(rotation[1][2], axis=-1)
  423. r20 = mnp.expand_dims(rotation[2][0], axis=-1)
  424. r21 = mnp.expand_dims(rotation[2][1], axis=-1)
  425. r22 = mnp.expand_dims(rotation[2][2], axis=-1)
  426. t0 = mnp.expand_dims(translation[0], axis=-1)
  427. t1 = mnp.expand_dims(translation[1], axis=-1)
  428. t2 = mnp.expand_dims(translation[2], axis=-1)
  429. p0 = point[0]
  430. p1 = point[1]
  431. p2 = point[2]
  432. rot_point = [r00 * p0 + r01 * p1 + r02 * p2,
  433. r10 * p0 + r11 * p1 + r12 * p2,
  434. r20 * p0 + r21 * p1 + r22 * p2]
  435. result = [rot_point[0] + t0,
  436. rot_point[1] + t1,
  437. rot_point[2] + t2]
  438. return result
  439. def frames_and_literature_positions_to_atom14_pos(aatype, all_frames_to_global, restype_atom14_to_rigid_group,
  440. restype_atom14_rigid_group_positions, restype_atom14_mask): # (N, 14)
  441. """Put atom literature positions (atom14 encoding) in each rigid group.
  442. Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" line 11
  443. Args:
  444. aatype: aatype for each residue.
  445. all_frames_to_global: All per residue coordinate frames.
  446. Returns:
  447. Positions of all atom coordinates in global frame.
  448. """
  449. # Pick the appropriate transform for every atom.
  450. residx_to_group_idx = P.Gather()(restype_atom14_to_rigid_group, aatype, 0)
  451. group_mask = nn.OneHot(depth=8, axis=-1)(residx_to_group_idx)
  452. # # r3.Rigids with shape (N, 14)
  453. map_atoms_to_global = map_atoms_to_global_func(all_frames_to_global, group_mask)
  454. # Gather the literature atom positions for each residue.
  455. # r3.Vecs with shape (N, 14)
  456. lit_positions = vecs_from_tensor(P.Gather()(restype_atom14_rigid_group_positions, aatype, 0))
  457. # Transform each atom from its local frame to the global frame.
  458. # r3.Vecs with shape (N, 14)
  459. pred_positions = rigids_mul_vecs(map_atoms_to_global, lit_positions)
  460. # Mask out non-existing atoms.
  461. mask = P.Gather()(restype_atom14_mask, aatype, 0)
  462. pred_positions = pred_map_mul(pred_positions, mask)
  463. return pred_positions
  464. def pred_map_mul(pred_positions, mask):
  465. return [pred_positions[0] * mask,
  466. pred_positions[1] * mask,
  467. pred_positions[2] * mask]
  468. def rots_mul_vecs(m, v):
  469. """Apply rotations 'm' to vectors 'v'."""
  470. return [m[0] * v[0] + m[1] * v[1] + m[2] * v[2],
  471. m[3] * v[0] + m[4] * v[1] + m[5] * v[2],
  472. m[6] * v[0] + m[7] * v[1] + m[8] * v[2]]
  473. def rigids_mul_vecs(r, v):
  474. """Apply rigid transforms 'r' to points 'v'."""
  475. rots = rots_mul_vecs(r, v)
  476. vecs_add_r = [rots[0] + r[9],
  477. rots[1] + r[10],
  478. rots[2] + r[11]]
  479. return vecs_add_r
  480. def vecs_from_tensor(x): # shape (...)
  481. """Converts from tensor of shape (3,) to Vecs."""
  482. # num_components = x.shape[-1]
  483. # assert num_components == 3
  484. return x[..., 0], x[..., 1], x[..., 2]
  485. def get_exp_atom_pos(atom_pos):
  486. return [mnp.expand_dims(atom_pos[0], axis=0),
  487. mnp.expand_dims(atom_pos[1], axis=0),
  488. mnp.expand_dims(atom_pos[2], axis=0)
  489. ]
  490. def to_tensor_new(quaternion, translation):
  491. tr_new = [mnp.expand_dims(translation[0], axis=-1),
  492. mnp.expand_dims(translation[1], axis=-1),
  493. mnp.expand_dims(translation[2], axis=-1)]
  494. return mnp.concatenate([quaternion, tr_new[0], tr_new[1], tr_new[2]], axis=-1)
  495. def quat_multiply_by_vec(quat, vec):
  496. """Multiply a quaternion by a pure-vector quaternion."""
  497. return mnp.sum(residue_constants.QUAT_MULTIPLY_BY_VEC * quat[..., :, None, None] * vec[..., None, :, None],
  498. axis=(-3, -2))
  499. def rigids_mul_rots(xx, xy, xz, yx, yy, yz, zx, zy, zz, ones, zeros, cos_angles, sin_angles):
  500. """Compose rigid transformations 'r' with rotations 'm'."""
  501. c00 = xx * ones + xy * zeros + xz * zeros
  502. c01 = yx * ones + yy * zeros + yz * zeros
  503. c02 = zx * ones + zy * zeros + zz * zeros
  504. c10 = xx * zeros + xy * cos_angles + xz * sin_angles
  505. c11 = yx * zeros + yy * cos_angles + yz * sin_angles
  506. c12 = zx * zeros + zy * cos_angles + zz * sin_angles
  507. c20 = xx * zeros + xy * (-sin_angles) + xz * cos_angles
  508. c21 = yx * zeros + yy * (-sin_angles) + yz * cos_angles
  509. c22 = zx * zeros + zy * (-sin_angles) + zz * cos_angles
  510. return c00, c10, c20, c01, c11, c21, c02, c12, c22
  511. def rigids_mul_rigids(a, b):
  512. """Group composition of Rigids 'a' and 'b'."""
  513. c00 = a[0] * b[0] + a[1] * b[3] + a[2] * b[6]
  514. c01 = a[3] * b[0] + a[4] * b[3] + a[5] * b[6]
  515. c02 = a[6] * b[0] + a[7] * b[3] + a[8] * b[6]
  516. c10 = a[0] * b[1] + a[1] * b[4] + a[2] * b[7]
  517. c11 = a[3] * b[1] + a[4] * b[4] + a[5] * b[7]
  518. c12 = a[6] * b[1] + a[7] * b[4] + a[8] * b[7]
  519. c20 = a[0] * b[2] + a[1] * b[5] + a[2] * b[8]
  520. c21 = a[3] * b[2] + a[4] * b[5] + a[5] * b[8]
  521. c22 = a[6] * b[2] + a[7] * b[5] + a[8] * b[8]
  522. tr0 = a[0] * b[9] + a[1] * b[10] + a[2] * b[11]
  523. tr1 = a[3] * b[9] + a[4] * b[10] + a[5] * b[11]
  524. tr2 = a[6] * b[9] + a[7] * b[10] + a[8] * b[11]
  525. new_tr0 = a[9] + tr0
  526. new_tr1 = a[10] + tr1
  527. new_tr2 = a[11] + tr2
  528. return [c00, c10, c20, c01, c11, c21, c02, c12, c22, new_tr0, new_tr1, new_tr2]
  529. def rigits_concate_all(xall, x5, x6, x7):
  530. return [mnp.concatenate([xall[0][:, 0:5], x5[0][:, None], x6[0][:, None], x7[0][:, None]], axis=-1),
  531. mnp.concatenate([xall[1][:, 0:5], x5[1][:, None], x6[1][:, None], x7[1][:, None]], axis=-1),
  532. mnp.concatenate([xall[2][:, 0:5], x5[2][:, None], x6[2][:, None], x7[2][:, None]], axis=-1),
  533. mnp.concatenate([xall[3][:, 0:5], x5[3][:, None], x6[3][:, None], x7[3][:, None]], axis=-1),
  534. mnp.concatenate([xall[4][:, 0:5], x5[4][:, None], x6[4][:, None], x7[4][:, None]], axis=-1),
  535. mnp.concatenate([xall[5][:, 0:5], x5[5][:, None], x6[5][:, None], x7[5][:, None]], axis=-1),
  536. mnp.concatenate([xall[6][:, 0:5], x5[6][:, None], x6[6][:, None], x7[6][:, None]], axis=-1),
  537. mnp.concatenate([xall[7][:, 0:5], x5[7][:, None], x6[7][:, None], x7[7][:, None]], axis=-1),
  538. mnp.concatenate([xall[8][:, 0:5], x5[8][:, None], x6[8][:, None], x7[8][:, None]], axis=-1),
  539. mnp.concatenate([xall[9][:, 0:5], x5[9][:, None], x6[9][:, None], x7[9][:, None]], axis=-1),
  540. mnp.concatenate([xall[10][:, 0:5], x5[10][:, None], x6[10][:, None], x7[10][:, None]], axis=-1),
  541. mnp.concatenate([xall[11][:, 0:5], x5[11][:, None], x6[11][:, None], x7[11][:, None]], axis=-1)
  542. ]
  543. def reshape_back(backb):
  544. return [backb[0][:, None],
  545. backb[1][:, None],
  546. backb[2][:, None],
  547. backb[3][:, None],
  548. backb[4][:, None],
  549. backb[5][:, None],
  550. backb[6][:, None],
  551. backb[7][:, None],
  552. backb[8][:, None],
  553. backb[9][:, None],
  554. backb[10][:, None],
  555. backb[11][:, None]
  556. ]
  557. def l2_normalize(x, axis=-1):
  558. return x / mnp.sqrt(mnp.sum(x ** 2, axis=axis, keepdims=True))
  559. def torsion_angles_to_frames(aatype, backb_to_global, torsion_angles_sin_cos, restype_rigid_group_default_frame):
  560. """Compute rigid group frames from torsion angles."""
  561. # Gather the default frames for all rigid groups.
  562. m = P.Gather()(restype_rigid_group_default_frame, aatype, 0)
  563. xx1, xy1, xz1, yx1, yy1, yz1, zx1, zy1, zz1, x1, y1, z1 = rigids_from_tensor4x4(m)
  564. # Create the rotation matrices according to the given angles (each frame is
  565. # defined such that its rotation is around the x-axis).
  566. sin_angles = torsion_angles_sin_cos[..., 0]
  567. cos_angles = torsion_angles_sin_cos[..., 1]
  568. # insert zero rotation for backbone group.
  569. num_residues, = aatype.shape
  570. sin_angles = mnp.concatenate([mnp.zeros([num_residues, 1]), sin_angles], axis=-1)
  571. cos_angles = mnp.concatenate([mnp.ones([num_residues, 1]), cos_angles], axis=-1)
  572. zeros = mnp.zeros_like(sin_angles)
  573. ones = mnp.ones_like(sin_angles)
  574. # Apply rotations to the frames.
  575. xx2, xy2, xz2, yx2, yy2, yz2, zx2, zy2, zz2 = rigids_mul_rots(xx1, xy1, xz1, yx1, yy1, yz1, zx1, zy1, zz1,
  576. ones, zeros, cos_angles, sin_angles)
  577. all_frames = [xx2, xy2, xz2, yx2, yy2, yz2, zx2, zy2, zz2, x1, y1, z1]
  578. # chi2, chi3, and chi4 frames do not transform to the backbone frame but to
  579. # the previous frame. So chain them up accordingly.
  580. chi2_frame_to_frame = [xx2[:, 5], xy2[:, 5], xz2[:, 5], yx2[:, 5], yy2[:, 5], yz2[:, 5], zx2[:, 5], zy2[:, 5],
  581. zz2[:, 5], x1[:, 5], y1[:, 5], z1[:, 5]]
  582. chi3_frame_to_frame = [xx2[:, 6], xy2[:, 6], xz2[:, 6], yx2[:, 6], yy2[:, 6], yz2[:, 6], zx2[:, 6], zy2[:, 6],
  583. zz2[:, 6], x1[:, 6], y1[:, 6], z1[:, 6]]
  584. chi4_frame_to_frame = [xx2[:, 7], xy2[:, 7], xz2[:, 7], yx2[:, 7], yy2[:, 7], yz2[:, 7], zx2[:, 7], zy2[:, 7],
  585. zz2[:, 7], x1[:, 7], y1[:, 7], z1[:, 7]]
  586. #
  587. chi1_frame_to_backb = [xx2[:, 4], xy2[:, 4], xz2[:, 4], yx2[:, 4], yy2[:, 4], yz2[:, 4], zx2[:, 4], zy2[:, 4],
  588. zz2[:, 4], x1[:, 4], y1[:, 4], z1[:, 4]]
  589. chi2_frame_to_backb = rigids_mul_rigids(chi1_frame_to_backb, chi2_frame_to_frame)
  590. chi3_frame_to_backb = rigids_mul_rigids(chi2_frame_to_backb, chi3_frame_to_frame)
  591. chi4_frame_to_backb = rigids_mul_rigids(chi3_frame_to_backb, chi4_frame_to_frame)
  592. # Recombine them to a r3.Rigids with shape (N, 8).
  593. all_frames_to_backb = rigits_concate_all(all_frames, chi2_frame_to_backb,
  594. chi3_frame_to_backb, chi4_frame_to_backb)
  595. backb_to_global_new = reshape_back(backb_to_global)
  596. # Create the global frames.
  597. # shape (N, 8)
  598. all_frames_to_global = rigids_mul_rigids(backb_to_global_new, all_frames_to_backb)
  599. # all_frames_to_global = rigids_mul_rigids(all_frames_to_backb, backb_to_global)
  600. return all_frames_to_global
  601. def map_atoms_to_global_func(all_frames, group_mask):
  602. return [mnp.sum(all_frames[0][:, None, :] * group_mask, axis=-1),
  603. mnp.sum(all_frames[1][:, None, :] * group_mask, axis=-1),
  604. mnp.sum(all_frames[2][:, None, :] * group_mask, axis=-1),
  605. mnp.sum(all_frames[3][:, None, :] * group_mask, axis=-1),
  606. mnp.sum(all_frames[4][:, None, :] * group_mask, axis=-1),
  607. mnp.sum(all_frames[5][:, None, :] * group_mask, axis=-1),
  608. mnp.sum(all_frames[6][:, None, :] * group_mask, axis=-1),
  609. mnp.sum(all_frames[7][:, None, :] * group_mask, axis=-1),
  610. mnp.sum(all_frames[8][:, None, :] * group_mask, axis=-1),
  611. mnp.sum(all_frames[9][:, None, :] * group_mask, axis=-1),
  612. mnp.sum(all_frames[10][:, None, :] * group_mask, axis=-1),
  613. mnp.sum(all_frames[11][:, None, :] * group_mask, axis=-1)
  614. ]
  615. def get_exp_frames(frames):
  616. return [mnp.expand_dims(frames[0], axis=0),
  617. mnp.expand_dims(frames[1], axis=0),
  618. mnp.expand_dims(frames[2], axis=0),
  619. mnp.expand_dims(frames[3], axis=0),
  620. mnp.expand_dims(frames[4], axis=0),
  621. mnp.expand_dims(frames[5], axis=0),
  622. mnp.expand_dims(frames[6], axis=0),
  623. mnp.expand_dims(frames[7], axis=0),
  624. mnp.expand_dims(frames[8], axis=0),
  625. mnp.expand_dims(frames[9], axis=0),
  626. mnp.expand_dims(frames[10], axis=0),
  627. mnp.expand_dims(frames[11], axis=0)
  628. ]
  629. def vecs_to_tensor(v):
  630. """Converts 'v' to tensor with shape 3, inverse of 'vecs_from_tensor'."""
  631. return mnp.stack([v[0], v[1], v[2]], axis=-1)
  632. def atom14_to_atom37(atom14_data, residx_atom37_to_atom14, atom37_atom_exists, indices0):
  633. """Convert atom14 to atom37 representation."""
  634. seq_length = atom14_data.shape[0]
  635. residx_atom37_to_atom14 = residx_atom37_to_atom14.reshape((seq_length, 37, 1))
  636. new_indices = P.Concat(2)((indices0, residx_atom37_to_atom14))
  637. atom37_data = P.GatherNd()(atom14_data, new_indices)
  638. # atom37_data = P.GatherBatch()(atom14_data, residx_atom37_to_atom14)
  639. if len(atom14_data.shape) == 2:
  640. atom37_data *= atom37_atom_exists
  641. elif len(atom14_data.shape) == 3:
  642. atom37_data *= atom37_atom_exists[:, :, None].astype(atom37_data.dtype)
  643. return atom37_data
  644. def batch_apply_rot_to_vec(rot, vec, unstack=False):
  645. """Multiply rotation matrix by a vector."""
  646. if unstack:
  647. x, y, z = vec[:, :, 0], vec[:, :, 1], vec[:, :, 2]
  648. else:
  649. x, y, z = vec
  650. return [(rot[:, 0, 0, :] * x + rot[:, 0, 1, :] * y + rot[:, 0, 2, :] * z)[:, None, :],
  651. (rot[:, 1, 0, :] * x + rot[:, 1, 1, :] * y + rot[:, 1, 2, :] * z)[:, None, :],
  652. (rot[:, 2, 0, :] * x + rot[:, 2, 1, :] * y + rot[:, 2, 2, :] * z)[:, None, :]]
  653. def _batch_multiply(a, b):
  654. """ batch multiply operation"""
  655. x1 = mnp.concatenate(
  656. [(a[:, 0, 0, :] * b[:, 0, 0, :] + a[:, 0, 1, :] * b[:, 1, 0, :] + a[:, 0, 2, :] * b[:, 2, 0, :])[:, None, :],
  657. (a[:, 0, 0, :] * b[:, 0, 1, :] + a[:, 0, 1, :] * b[:, 1, 1, :] + a[:, 0, 2, :] * b[:, 2, 1, :])[:, None, :],
  658. (a[:, 0, 0, :] * b[:, 0, 2, :] + a[:, 0, 1, :] * b[:, 1, 2, :] + a[:, 0, 2, :] * b[:, 2, 2, :])[:, None, :]],
  659. axis=1)[:, None, :, :]
  660. x2 = mnp.concatenate(
  661. [(a[:, 1, 0, :] * b[:, 0, 0, :] + a[:, 1, 1, :] * b[:, 1, 0, :] + a[:, 1, 2, :] * b[:, 2, 0, :])[:, None, :],
  662. (a[:, 1, 0, :] * b[:, 0, 1, :] + a[:, 1, 1, :] * b[:, 1, 1, :] + a[:, 1, 2, :] * b[:, 2, 1, :])[:, None, :],
  663. (a[:, 1, 0, :] * b[:, 0, 2, :] + a[:, 1, 1, :] * b[:, 1, 2, :] + a[:, 1, 2, :] * b[:, 2, 2, :])[:, None, :]],
  664. axis=1)[:, None, :, :]
  665. x3 = mnp.concatenate(
  666. [(a[:, 2, 0, :] * b[:, 0, 0, :] + a[:, 2, 1, :] * b[:, 1, 0, :] + a[:, 2, 2, :] * b[:, 2, 0, :])[:, None, :],
  667. (a[:, 2, 0, :] * b[:, 0, 1, :] + a[:, 2, 1, :] * b[:, 1, 1, :] + a[:, 2, 2, :] * b[:, 2, 1, :])[:, None, :],
  668. (a[:, 2, 0, :] * b[:, 0, 2, :] + a[:, 2, 1, :] * b[:, 1, 2, :] + a[:, 2, 2, :] * b[:, 2, 2, :])[:, None, :]],
  669. axis=1)[:, None, :, :]
  670. return mnp.concatenate([x1, x2, x3], axis=1)
  671. def batch_make_canonical_transform(n_xyz, ca_xyz, c_xyz):
  672. """Returns translation and rotation matrices to canonicalize residue atoms.
  673. Note that this method does not take care of symmetries. If you provide the
  674. atom positions in the non-standard way, the N atom will end up not at
  675. [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You
  676. need to take care of such cases in your code.
  677. Args:
  678. n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates.
  679. ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates.
  680. c_xyz: An array of shape [batch, 3] of carbon xyz coordinates.
  681. Returns:
  682. A tuple (translation, rotation) where:
  683. translation is an array of shape [batch, 3] defining the translation.
  684. rotation is an array of shape [batch, 3, 3] defining the rotation.
  685. After applying the translation and rotation to all atoms in a residue:
  686. * All atoms will be shifted so that CA is at the origin,
  687. * All atoms will be rotated so that C is at the x-axis,
  688. * All atoms will be shifted so that N is in the xy plane.
  689. """
  690. # Place CA at the origin.
  691. translation = -ca_xyz
  692. n_xyz = n_xyz + translation
  693. c_xyz = c_xyz + translation
  694. # Place C on the x-axis.
  695. c_x, c_y, c_z = c_xyz[:, :, 0], c_xyz[:, :, 1], c_xyz[:, :, 2]
  696. # Rotate by angle c1 in the x-y plane (around the z-axis).
  697. sin_c1 = -c_y / mnp.sqrt(1e-20 + c_x ** 2 + c_y ** 2)
  698. cos_c1 = c_x / mnp.sqrt(1e-20 + c_x ** 2 + c_y ** 2)
  699. zeros = mnp.zeros_like(sin_c1).astype("float32")
  700. ones = mnp.ones_like(sin_c1).astype("float32")
  701. # # pylint: disable=bad-whitespace
  702. c1_rot_matrix = mnp.concatenate(
  703. [mnp.concatenate((cos_c1[:, None, ...], (-sin_c1)[:, None, ...], zeros[:, None, ...]), axis=1)[:, None, :, :],
  704. mnp.concatenate((sin_c1[:, None, ...], cos_c1[:, None, ...], zeros[:, None, ...]), axis=1)[:, None, :, :],
  705. mnp.concatenate((zeros[:, None, ...], zeros[:, None, ...], ones[:, None, ...]), axis=1)[:, None, :, :]],
  706. axis=1)
  707. # # Rotate by angle c2 in the x-z plane (around the y-axis).
  708. sin_c2 = c_z / mnp.sqrt(1e-20 + c_x ** 2 + c_y ** 2 + c_z ** 2)
  709. cos_c2 = mnp.sqrt(c_x ** 2 + c_y ** 2) / mnp.sqrt(1e-20 + c_x ** 2 + c_y ** 2 + c_z ** 2)
  710. c2_rot_matrix = mnp.concatenate(
  711. [mnp.concatenate((cos_c2[:, None, ...], zeros[:, None, ...], sin_c2[:, None, ...]), axis=1)[:, None, :, :],
  712. mnp.concatenate((zeros[:, None, ...], ones[:, None, ...], zeros[:, None, ...]), axis=1)[:, None, :, :],
  713. mnp.concatenate(((-sin_c2)[:, None, ...], zeros[:, None, ...], cos_c2[:, None, ...]), axis=1)[:, None, :, :]],
  714. axis=1)
  715. c_rot_matrix = _batch_multiply(c2_rot_matrix, c1_rot_matrix)
  716. n_xyz = mnp.transpose(mnp.concatenate(batch_apply_rot_to_vec(c_rot_matrix, n_xyz, unstack=True), axis=1), (0, 2, 1))
  717. # # Place N in the x-y plane.
  718. _, n_y, n_z = n_xyz[:, :, 0], n_xyz[:, :, 1], n_xyz[:, :, 2]
  719. # # Rotate by angle alpha in the y-z plane (around the x-axis).
  720. sin_n = -n_z / mnp.sqrt(1e-20 + n_y ** 2 + n_z ** 2)
  721. cos_n = n_y / mnp.sqrt(1e-20 + n_y ** 2 + n_z ** 2)
  722. n_rot_matrix = mnp.concatenate(
  723. [mnp.concatenate([ones[:, None, ...], zeros[:, None, ...], zeros[:, None, ...]], axis=1)[:, None, :, :],
  724. mnp.concatenate([zeros[:, None, ...], cos_n[:, None, ...], (-sin_n)[:, None, ...]], axis=1)[:, None, :, :],
  725. mnp.concatenate([zeros[:, None, ...], sin_n[:, None, ...], cos_n[:, None, ...]], axis=1)[:, None, :, :]],
  726. axis=1)
  727. return translation, mnp.transpose(_batch_multiply(n_rot_matrix, c_rot_matrix), [0, 3, 1, 2])
  728. def batch_make_transform_from_reference(n_xyz, ca_xyz, c_xyz):
  729. """Returns rotation and translation matrices to convert from reference.
  730. Note that this method does not take care of symmetries. If you provide the
  731. atom positions in the non-standard way, the N atom will end up not at
  732. [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You
  733. need to take care of such cases in your code.
  734. Args:
  735. n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates.
  736. ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates.
  737. c_xyz: An array of shape [batch, 3] of carbon xyz coordinates.
  738. Returns:
  739. A tuple (rotation, translation) where:
  740. rotation is an array of shape [batch, 3, 3] defining the rotation.
  741. translation is an array of shape [batch, 3] defining the translation.
  742. After applying the translation and rotation to the reference backbone,
  743. the coordinates will approximately equal to the input coordinates.
  744. The order of translation and rotation differs from make_canonical_transform
  745. because the rotation from this function should be applied before the
  746. translation, unlike make_canonical_transform.
  747. """
  748. translation, rotation = batch_make_canonical_transform(n_xyz, ca_xyz, c_xyz)
  749. return mnp.transpose(rotation, (0, 1, 3, 2)), -translation
  750. def batch_rot_to_quat(rot, unstack_inputs=False):
  751. """Convert rotation matrix to quaternion.
  752. Note that this function calls self_adjoint_eig which is extremely expensive on
  753. the GPU. If at all possible, this function should run on the CPU.
  754. Args:
  755. rot: rotation matrix (see below for format).
  756. unstack_inputs: If true, rotation matrix should be shape (..., 3, 3)
  757. otherwise the rotation matrix should be a list of lists of tensors.
  758. Returns:
  759. Quaternion as (..., 4) tensor.
  760. """
  761. if unstack_inputs:
  762. rot = mnp.transpose(rot, [0, 3, 2, 1])
  763. xx, xy, xz = rot[:, 0, 0, :], rot[:, 0, 1, :], rot[:, 0, 2, :]
  764. yx, yy, yz = rot[:, 1, 0, :], rot[:, 1, 1, :], rot[:, 1, 2, :]
  765. zx, zy, zz = rot[:, 2, 0, :], rot[:, 2, 1, :], rot[:, 2, 2, :]
  766. k = mnp.stack((mnp.stack((xx + yy + zz, zy - yz, xz - zx, yx - xy), axis=-1),
  767. mnp.stack((zy - yz, xx - yy - zz, xy + yx, xz + zx), axis=-1),
  768. mnp.stack((xz - zx, xy + yx, yy - xx - zz, yz + zy), axis=-1),
  769. mnp.stack((yx - xy, xz + zx, yz + zy, zz - xx - yy), axis=-1)), axis=-2)
  770. k = (1. / 3.) * k
  771. k = k[:, :, :, 0]
  772. return k
  773. def batch_quat_affine(quaternion, translation, rotation=None, normalize=True, unstack_inputs=False):
  774. if unstack_inputs:
  775. if rotation is not None:
  776. rotation = mnp.transpose(rotation, [0, 3, 2, 1])
  777. translation = mnp.moveaxis(translation, -1, 1) # Unstack.
  778. if normalize and quaternion is not None:
  779. quaternion = quaternion / mnp.norm(quaternion, axis=-1, keepdims=True)
  780. return quaternion, rotation, translation
  781. def batch_apply_inverse_rot_to_vec(rot, vec):
  782. """Multiply the inverse of a rotation matrix by a vector."""
  783. # Inverse rotation is just transpose
  784. return mnp.concatenate(
  785. ((rot[:, 0, 0, :] * vec[:, 0] + rot[:, 1, 0, :] * vec[:, 1] + rot[:, 2, 0, :] * vec[:, 2])[:, None, ...],
  786. (rot[:, 0, 1, :] * vec[:, 0] + rot[:, 1, 1, :] * vec[:, 1] + rot[:, 2, 1, :] * vec[:, 2])[:, None, ...],
  787. (rot[:, 0, 2, :] * vec[:, 0] + rot[:, 1, 2, :] * vec[:, 1] + rot[:, 2, 2, :] * vec[:, 2])[:, None, ...]),
  788. axis=1)
  789. def batch_invert_point(transformed_point, rotation, translation, extra_dims=0):
  790. """Apply inverse of transformation to a point.
  791. Args:
  792. transformed_point: List of 3 tensors to apply affine
  793. extra_dims: Number of dimensions at the end of the transformed_point
  794. shape that are not present in the rotation and translation. The most
  795. common use is rotation N points at once with extra_dims=1 for use in a
  796. network.
  797. Returns:
  798. Transformed point after applying affine.
  799. """
  800. for _ in range(extra_dims):
  801. rotation = mnp.expand_dims(rotation, axis=-1)
  802. translation = mnp.expand_dims(translation, axis=-1)
  803. rot_point = transformed_point - translation
  804. return batch_apply_inverse_rot_to_vec(rotation, rot_point)
  805. def compute_confidence(predicted_lddt_logits):
  806. """compute confidence"""
  807. num_bins = predicted_lddt_logits.shape[-1]
  808. bin_width = 1 / num_bins
  809. start_n = bin_width / 2
  810. plddt = compute_plddt(predicted_lddt_logits, start_n, bin_width)
  811. confidence = np.mean(plddt)
  812. return confidence
  813. def compute_plddt(logits, start_n, bin_width):
  814. """Computes per-residue pLDDT from logits.
  815. Args:
  816. logits: [num_res, num_bins] output from the PredictedLDDTHead.
  817. Returns:
  818. plddt: [num_res] per-residue pLDDT.
  819. """
  820. bin_centers = np.arange(start=start_n, stop=1.0, step=bin_width)
  821. probs = softmax(logits, axis=-1)
  822. predicted_lddt_ca = np.sum(probs * bin_centers[None, :], axis=-1)
  823. return predicted_lddt_ca * 100