You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

structure_module.py 22 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """structure module"""
  16. import numpy as np
  17. import mindspore.ops as ops
  18. import mindspore.common.dtype as mstype
  19. import mindspore.numpy as mnp
  20. from mindspore import Parameter, ms_function, Tensor
  21. from mindspore import nn
  22. from commons import residue_constants
  23. from commons.utils import generate_new_affine, to_tensor, from_tensor, vecs_to_tensor, atom14_to_atom37, \
  24. get_exp_atom_pos, get_exp_frames, pre_compose, scale_translation, to_tensor_new, l2_normalize, \
  25. torsion_angles_to_frames, frames_and_literature_positions_to_atom14_pos, apply_to_point, _invert_point
  26. class InvariantPointAttention(nn.Cell):
  27. """Invariant Point attention module."""
  28. def __init__(self, config, global_config, pair_dim):
  29. """Initialize.
  30. Args:
  31. config: Structure Module Config
  32. global_config: Global Config of Model.
  33. pair_dim: pair representation dimension.
  34. """
  35. super().__init__()
  36. self._dist_epsilon = 1e-8
  37. self.config = config
  38. self.num_head = config.num_head
  39. self.num_scalar_qk = config.num_scalar_qk
  40. self.num_scalar_v = config.num_scalar_v
  41. self.num_point_v = config.num_point_v
  42. self.num_point_qk = config.num_point_qk
  43. self.num_channel = config.num_channel
  44. self.projection_num = self.num_head * self.num_scalar_v + self.num_head * self.num_point_v * 4 +\
  45. self.num_head * pair_dim
  46. self.global_config = global_config
  47. self.q_scalar = nn.Dense(config.num_channel, self.num_head*self.num_scalar_qk).to_float(mstype.float16)
  48. self.kv_scalar = nn.Dense(config.num_channel, self.num_head*(self.num_scalar_qk + self.num_scalar_v)
  49. ).to_float(mstype.float16)
  50. self.q_point_local = nn.Dense(config.num_channel, self.num_head * 3 * self.num_point_qk
  51. ).to_float(mstype.float16)
  52. self.kv_point_local = nn.Dense(config.num_channel, self.num_head * 3 * (self.num_point_qk + self.num_point_v)
  53. ).to_float(mstype.float16)
  54. self.soft_max = nn.Softmax()
  55. self.soft_plus = ops.Softplus()
  56. self.trainable_point_weights = Parameter(Tensor(np.ones((12,)), mstype.float32), name="trainable_point_weights")
  57. self.attention_2d = nn.Dense(pair_dim, self.num_head).to_float(mstype.float16)
  58. self.output_projection = nn.Dense(self.projection_num, self.num_channel, weight_init='zeros'
  59. ).to_float(mstype.float16)
  60. self.scalar_weights = np.sqrt(1.0 / (3 * 16))
  61. self.point_weights = np.sqrt(1.0 / (3 * 18))
  62. self.attention_2d_weights = np.sqrt(1.0 / 3)
  63. def construct(self, inputs_1d, inputs_2d, mask, rotation, translation):
  64. """Compute geometry-aware attention.
  65. Args:
  66. inputs_1d: (N, C) 1D input embedding that is the basis for the
  67. scalar queries.
  68. inputs_2d: (N, M, C') 2D input embedding, used for biases and values.
  69. mask: (N, 1) mask to indicate which elements of inputs_1d participate
  70. in the attention.
  71. rotation: describe the orientation of every element in inputs_1d
  72. translation: describe the position of every element in inputs_1d
  73. Returns:
  74. Transformation of the input embedding.
  75. """
  76. num_residues, _ = inputs_1d.shape
  77. # Improve readability by removing a large number of 'self's.
  78. num_head = self.num_head
  79. num_scalar_qk = self.num_scalar_qk
  80. num_point_qk = self.num_point_qk
  81. num_scalar_v = self.num_scalar_v
  82. num_point_v = self.num_point_v
  83. # Construct scalar queries of shape:
  84. q_scalar = self.q_scalar(inputs_1d)
  85. q_scalar = mnp.reshape(q_scalar, [num_residues, num_head, num_scalar_qk])
  86. # Construct scalar keys/values of shape:
  87. # [num_target_residues, num_head, num_points]
  88. kv_scalar = self.kv_scalar(inputs_1d)
  89. kv_scalar = mnp.reshape(kv_scalar, [num_residues, num_head, num_scalar_v + num_scalar_qk])
  90. k_scalar, v_scalar = mnp.split(kv_scalar, [num_scalar_qk], axis=-1)
  91. # Construct query points of shape:
  92. # [num_residues, num_head, num_point_qk]
  93. # First construct query points in local frame.
  94. q_point_local = self.q_point_local(inputs_1d)
  95. q_point_local = mnp.stack(mnp.split(q_point_local, 3, axis=-1), axis=0)
  96. # Project query points into global frame.
  97. q_point_global = apply_to_point(rotation, translation, q_point_local)
  98. # Reshape query point for later use.
  99. q_point0 = mnp.reshape(q_point_global[0], (num_residues, num_head, num_point_qk))
  100. q_point1 = mnp.reshape(q_point_global[1], (num_residues, num_head, num_point_qk))
  101. q_point2 = mnp.reshape(q_point_global[2], (num_residues, num_head, num_point_qk))
  102. # Construct key and value points.
  103. # Key points have shape [num_residues, num_head, num_point_qk]
  104. # Value points have shape [num_residues, num_head, num_point_v]
  105. # Construct key and value points in local frame.
  106. kv_point_local = self.kv_point_local(inputs_1d)
  107. kv_point_local = mnp.split(kv_point_local, 3, axis=-1)
  108. # Project key and value points into global frame.
  109. kv_point_global = apply_to_point(rotation, translation, kv_point_local)
  110. kv_point_global0 = mnp.reshape(kv_point_global[0], (num_residues, num_head, (num_point_qk + num_point_v)))
  111. kv_point_global1 = mnp.reshape(kv_point_global[1], (num_residues, num_head, (num_point_qk + num_point_v)))
  112. kv_point_global2 = mnp.reshape(kv_point_global[2], (num_residues, num_head, (num_point_qk + num_point_v)))
  113. # Split key and value points.
  114. k_point0, v_point0 = mnp.split(kv_point_global0, [num_point_qk,], axis=-1)
  115. k_point1, v_point1 = mnp.split(kv_point_global1, [num_point_qk,], axis=-1)
  116. k_point2, v_point2 = mnp.split(kv_point_global2, [num_point_qk,], axis=-1)
  117. trainable_point_weights = self.soft_plus(self.trainable_point_weights)
  118. point_weights = self.point_weights * mnp.expand_dims(trainable_point_weights, axis=1)
  119. v_point = [mnp.swapaxes(v_point0, -2, -3), mnp.swapaxes(v_point1, -2, -3), mnp.swapaxes(v_point2, -2, -3)]
  120. q_point = [mnp.swapaxes(q_point0, -2, -3), mnp.swapaxes(q_point1, -2, -3), mnp.swapaxes(q_point2, -2, -3)]
  121. k_point = [mnp.swapaxes(k_point0, -2, -3), mnp.swapaxes(k_point1, -2, -3), mnp.swapaxes(k_point2, -2, -3)]
  122. dist2 = mnp.square(q_point[0][:, :, None, :] - k_point[0][:, None, :, :]) + \
  123. mnp.square(q_point[1][:, :, None, :] - k_point[1][:, None, :, :]) + \
  124. mnp.square(q_point[2][:, :, None, :] - k_point[2][:, None, :, :])
  125. attn_qk_point = -0.5 * mnp.sum(
  126. point_weights[:, None, None, :] * dist2, axis=-1)
  127. v = mnp.swapaxes(v_scalar, -2, -3)
  128. q = mnp.swapaxes(self.scalar_weights * q_scalar, -2, -3)
  129. k = mnp.swapaxes(k_scalar, -2, -3)
  130. attn_qk_scalar = ops.matmul(q, mnp.swapaxes(k, -2, -1))
  131. attn_logits = attn_qk_scalar + attn_qk_point
  132. attention_2d = self.attention_2d(inputs_2d)
  133. attention_2d = mnp.transpose(attention_2d, [2, 0, 1])
  134. attention_2d = self.attention_2d_weights * attention_2d
  135. attn_logits += attention_2d
  136. mask_2d = mask * mnp.swapaxes(mask, -1, -2)
  137. attn_logits -= 1e5 * (1. - mask_2d)
  138. # [num_head, num_query_residues, num_target_residues]
  139. attn = self.soft_max(attn_logits)
  140. # [num_head, num_query_residues, num_head * num_scalar_v]
  141. result_scalar = ops.matmul(attn, v)
  142. result_point_global = [mnp.swapaxes(mnp.sum(attn[:, :, :, None] * v_point[0][:, None, :, :], axis=-2), -2, -3),
  143. mnp.swapaxes(mnp.sum(attn[:, :, :, None] * v_point[1][:, None, :, :], axis=-2), -2, -3),
  144. mnp.swapaxes(mnp.sum(attn[:, :, :, None] * v_point[2][:, None, :, :], axis=-2), -2, -3)
  145. ]
  146. result_point_global = [mnp.reshape(result_point_global[0], [num_residues, num_head * num_point_v]),
  147. mnp.reshape(result_point_global[1], [num_residues, num_head * num_point_v]),
  148. mnp.reshape(result_point_global[2], [num_residues, num_head * num_point_v])]
  149. result_scalar = mnp.swapaxes(result_scalar, -2, -3)
  150. result_scalar = mnp.reshape(result_scalar, [num_residues, num_head * num_scalar_v])
  151. result_point_local = _invert_point(result_point_global, rotation, translation)
  152. output_feature1 = result_scalar
  153. output_feature20 = result_point_local[0]
  154. output_feature21 = result_point_local[1]
  155. output_feature22 = result_point_local[2]
  156. output_feature3 = mnp.sqrt(self._dist_epsilon +
  157. mnp.square(result_point_local[0]) +
  158. mnp.square(result_point_local[1]) +
  159. mnp.square(result_point_local[2]))
  160. result_attention_over_2d = ops.matmul(mnp.swapaxes(attn, 0, 1), inputs_2d)
  161. num_out = num_head * result_attention_over_2d.shape[-1]
  162. output_feature4 = mnp.reshape(result_attention_over_2d, [num_residues, num_out])
  163. final_act = mnp.concatenate([output_feature1, output_feature20, output_feature21,
  164. output_feature22, output_feature3, output_feature4], axis=-1)
  165. final_result = self.output_projection(final_act)
  166. return final_result
  167. class MultiRigidSidechain(nn.Cell):
  168. """Class to make side chain atoms."""
  169. def __init__(self, config, global_config, single_repr_dim):
  170. super().__init__()
  171. self.config = config
  172. self.global_config = global_config
  173. self.input_projection = nn.Dense(single_repr_dim, config.num_channel, weight_init='normal'
  174. ).to_float(mstype.float16)
  175. self.input_projection_1 = nn.Dense(single_repr_dim, config.num_channel, weight_init='normal'
  176. ).to_float(mstype.float16)
  177. self.relu = nn.ReLU()
  178. self.resblock1 = nn.Dense(config.num_channel, config.num_channel, weight_init='normal').to_float(mstype.float16)
  179. self.resblock2 = nn.Dense(config.num_channel, config.num_channel, weight_init='zeros').to_float(mstype.float16)
  180. self.resblock1_1 = nn.Dense(config.num_channel, config.num_channel, weight_init='normal'
  181. ).to_float(mstype.float16)
  182. self.resblock2_1 = nn.Dense(config.num_channel, config.num_channel, weight_init='zeros'
  183. ).to_float(mstype.float16)
  184. self.unnormalized_angles = nn.Dense(config.num_channel, 14, weight_init='normal').to_float(mstype.float16)
  185. self.print = ops.Print()
  186. self.restype_atom14_to_rigid_group = Tensor(residue_constants.restype_atom14_to_rigid_group)
  187. self.restype_atom14_rigid_group_positions = Tensor(residue_constants.restype_atom14_rigid_group_positions)
  188. self.restype_atom14_mask = Tensor(residue_constants.restype_atom14_mask)
  189. self.restype_rigid_group_default_frame = Tensor(residue_constants.restype_rigid_group_default_frame)
  190. def construct(self, rotation, translation, act, initial_act, aatype):
  191. """Predict side chains using rotation and translation representations.
  192. Args:
  193. rotation: The rotation matrices.
  194. translation: A translation matrices.
  195. act: updated pair activations from structure module
  196. initial_act: initial act representations (input of structure module)
  197. aatype: Amino acid type representations
  198. Returns:
  199. angles, positions and new frames
  200. """
  201. act1 = self.input_projection(self.relu(act.astype(mstype.float32)))
  202. init_act1 = self.input_projection_1(self.relu(initial_act.astype(mstype.float32)))
  203. # Sum the activation list (equivalent to concat then Linear).
  204. act = act1 + init_act1
  205. # Mapping with some residual blocks.
  206. # for _ in range(self.config.num_residual_block):
  207. # resblock1
  208. old_act = act
  209. act = self.resblock1(self.relu(act.astype(mstype.float32)))
  210. act = self.resblock2(self.relu(act.astype(mstype.float32)))
  211. act += old_act
  212. # resblock2
  213. old_act = act
  214. act = self.resblock1_1(self.relu(act.astype(mstype.float32)))
  215. act = self.resblock2_1(self.relu(act.astype(mstype.float32)))
  216. act += old_act
  217. # Map activations to torsion angles. Shape: (num_res, 14).
  218. num_res = act.shape[0]
  219. unnormalized_angles = self.unnormalized_angles(self.relu(act.astype(mstype.float32)))
  220. unnormalized_angles = mnp.reshape(unnormalized_angles, [num_res, 7, 2])
  221. angles = l2_normalize(unnormalized_angles, axis=-1)
  222. backb_to_global = [rotation[0][0], rotation[0][1], rotation[0][2],
  223. rotation[1][0], rotation[1][1], rotation[1][2],
  224. rotation[2][0], rotation[2][1], rotation[2][2],
  225. translation[0], translation[1], translation[2]]
  226. all_frames_to_global = torsion_angles_to_frames(aatype, backb_to_global, angles,
  227. self.restype_rigid_group_default_frame)
  228. pred_positions = frames_and_literature_positions_to_atom14_pos(aatype, all_frames_to_global,
  229. self.restype_atom14_to_rigid_group,
  230. self.restype_atom14_rigid_group_positions,
  231. self.restype_atom14_mask)
  232. atom_pos = pred_positions
  233. frames = all_frames_to_global
  234. return angles, unnormalized_angles, atom_pos, frames
  235. class FoldIteration(nn.Cell):
  236. """A single iteration of the main structure module loop."""
  237. def __init__(self, config, global_config, pair_dim, single_repr_dim):
  238. super().__init__()
  239. self.config = config
  240. self.global_config = global_config
  241. self.drop_out = nn.Dropout(keep_prob=0.9)
  242. self.attention_layer_norm = nn.LayerNorm([config.num_channel,], epsilon=1e-5)
  243. self.transition_layer_norm = nn.LayerNorm([config.num_channel,], epsilon=1e-5)
  244. self.transition = nn.Dense(config.num_channel, config.num_channel, weight_init='normal'
  245. ).to_float(mstype.float16)
  246. self.transition_1 = nn.Dense(config.num_channel, config.num_channel, weight_init='normal'
  247. ).to_float(mstype.float16)
  248. self.transition_2 = nn.Dense(config.num_channel, config.num_channel, weight_init='normal'
  249. ).to_float(mstype.float16)
  250. self.relu = nn.ReLU()
  251. self.affine_update = nn.Dense(config.num_channel, 6, weight_init='zeros').to_float(mstype.float16)
  252. self.attention_module = InvariantPointAttention(self.config, self.global_config, pair_dim)
  253. self.mu_side_chain = MultiRigidSidechain(config.sidechain, global_config, single_repr_dim)
  254. self.print = ops.Print()
  255. def construct(self, act, static_feat_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype):
  256. '''constuct'''
  257. # Attention
  258. attn = self.attention_module(act, static_feat_2d, sequence_mask, rotation, translation)
  259. act += attn
  260. act = self.drop_out(act)
  261. act = self.attention_layer_norm(act.astype(mstype.float32))
  262. # Transition
  263. input_act = act
  264. act = self.transition(act)
  265. act = self.relu(act.astype(mstype.float32))
  266. act = self.transition_1(act)
  267. act = self.relu(act.astype(mstype.float32))
  268. act = self.transition_2(act)
  269. act += input_act
  270. act = self.drop_out(act)
  271. act = self.transition_layer_norm(act.astype(mstype.float32))
  272. # This block corresponds to
  273. # Jumper et al. (2021) Alg. 23 "Backbone update"
  274. # Affine update
  275. affine_update = self.affine_update(act)
  276. quaternion, rotation, translation = pre_compose(quaternion, rotation, translation, affine_update)
  277. _, rotation1, translation1 = scale_translation(quaternion, translation, rotation, 10.0)
  278. angles_sin_cos, unnormalized_angles_sin_cos, atom_pos, frames =\
  279. self.mu_side_chain(rotation1, translation1, act, initial_act, aatype)
  280. affine_output = to_tensor_new(quaternion, translation)
  281. return act, quaternion, translation, rotation, affine_output, angles_sin_cos, unnormalized_angles_sin_cos, \
  282. atom_pos, frames
  283. class StructureModule(nn.Cell):
  284. """StructureModule as a network head."""
  285. def __init__(self, config, single_repr_dim, pair_dim, global_config=None, compute_loss=True):
  286. super(StructureModule, self).__init__()
  287. self.config = config
  288. self.global_config = global_config
  289. self.compute_loss = compute_loss
  290. self.fold_iteration = FoldIteration(self.config, global_config, pair_dim, single_repr_dim)
  291. self.single_layer_norm = nn.LayerNorm([single_repr_dim,], epsilon=1e-5)
  292. self.initial_projection = nn.Dense(single_repr_dim, self.config.num_channel).to_float(mstype.float16)
  293. self.pair_layer_norm = nn.LayerNorm([pair_dim,], epsilon=1e-5)
  294. self.num_layer = config.num_layer
  295. self.indice0 = Tensor(
  296. np.arange(global_config.seq_length).reshape((-1, 1, 1)).repeat(37, axis=1).astype("int32"))
  297. @ms_function
  298. def construct(self, single, pair, seq_mask, aatype, residx_atom37_to_atom14=None, atom37_atom_exists=None):
  299. '''construct'''
  300. sequence_mask = seq_mask[:, None]
  301. act = self.single_layer_norm(single.astype(mstype.float32))
  302. initial_act = act
  303. act = self.initial_projection(act)
  304. quaternion, rotation, translation = generate_new_affine(sequence_mask)
  305. aff_to_tensor = to_tensor(quaternion, mnp.transpose(translation))
  306. act_2d = self.pair_layer_norm(pair.astype(mstype.float32))
  307. # folder iteration
  308. quaternion, rotation, translation = from_tensor(aff_to_tensor)
  309. act_new, atom_pos, _, _, _, _ =\
  310. self.iteration_operation(act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype)
  311. atom14_pred_positions = vecs_to_tensor(atom_pos)[-1]
  312. atom37_pred_positions = atom14_to_atom37(atom14_pred_positions,
  313. residx_atom37_to_atom14,
  314. atom37_atom_exists,
  315. self.indice0)
  316. final_atom_positions = atom37_pred_positions
  317. final_atom_mask = atom37_atom_exists
  318. rp_structure_module = act_new
  319. return final_atom_positions, final_atom_mask, rp_structure_module
  320. def iteration_operation(self, act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act,
  321. aatype):
  322. '''iteration operation'''
  323. affine_init = ()
  324. angles_sin_cos_init = ()
  325. um_angles_sin_cos_init = ()
  326. atom_pos = ()
  327. frames = ()
  328. for _ in range(self.num_layer):
  329. act, quaternion, translation, rotation, affine_output, angles_sin_cos, unnormalized_angles_sin_cos, \
  330. atom_pos, frames = \
  331. self.fold_iteration(act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype)
  332. affine_init = affine_init + (affine_output[None, ...],)
  333. angles_sin_cos_init = angles_sin_cos_init + (angles_sin_cos[None, ...],)
  334. um_angles_sin_cos_init = um_angles_sin_cos_init + (unnormalized_angles_sin_cos[None, ...],)
  335. atom_pos = get_exp_atom_pos(atom_pos)
  336. frames = get_exp_frames(frames)
  337. affine_output_new = mnp.concatenate(affine_init, axis=0)
  338. angles_sin_cos_new = mnp.concatenate(angles_sin_cos_init, axis=0)
  339. um_angles_sin_cos_new = mnp.concatenate(um_angles_sin_cos_init, axis=0)
  340. return act, atom_pos, affine_output_new, angles_sin_cos_new, um_angles_sin_cos_new, frames
  341. class PredictedLDDTHead(nn.Cell):
  342. """Head to predict the per-residue LDDT to be used as a confidence measure."""
  343. def __init__(self, config, global_config, seq_channel):
  344. super().__init__()
  345. self.config = config
  346. self.global_config = global_config
  347. self.input_layer_norm = nn.LayerNorm([seq_channel,], epsilon=1e-5)
  348. self.act_0 = nn.Dense(seq_channel, self.config.num_channels, weight_init='zeros'
  349. ).to_float(mstype.float16)
  350. self.act_1 = nn.Dense(self.config.num_channels, self.config.num_channels, weight_init='zeros'
  351. ).to_float(mstype.float16)
  352. self.logits = nn.Dense(self.config.num_channels, self.config.num_bins, weight_init='zeros'
  353. ).to_float(mstype.float16)
  354. self.relu = nn.ReLU()
  355. def construct(self, rp_structure_module):
  356. """Builds ExperimentallyResolvedHead module."""
  357. act = rp_structure_module
  358. act = self.input_layer_norm(act.astype(mstype.float32))
  359. act = self.act_0(act)
  360. act = self.relu(act.astype(mstype.float32))
  361. act = self.act_1(act)
  362. act = self.relu(act.astype(mstype.float32))
  363. logits = self.logits(act)
  364. return logits