You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

config.py 15 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. """Model config."""
  2. import copy
  3. import ml_collections
  4. NUM_RES = 'num residues placeholder'
  5. NUM_MSA_SEQ = 'msa placeholder'
  6. NUM_EXTRA_SEQ = 'extra msa placeholder'
  7. NUM_TEMPLATES = 'num templates placeholder'
  8. def model_config(name: str) -> ml_collections.ConfigDict:
  9. """Get the ConfigDict of a CASP14 model."""
  10. if name not in CONFIG_DIFFS:
  11. raise ValueError(f'Invalid model name {name}.')
  12. cfg = copy.deepcopy(CONFIG)
  13. cfg.update_from_flattened_dict(CONFIG_DIFFS[name])
  14. return cfg
  15. CONFIG_DIFFS = {
  16. 'model_1': {
  17. # Jumper et al. (2021) Suppl. Table 5, Model 1.1.1
  18. 'data.common.max_extra_msa': 5120,
  19. 'data.common.reduce_msa_clusters_by_max_templates': True,
  20. 'data.common.use_templates': True,
  21. 'model.embeddings_and_evoformer.template.embed_torsion_angles': True,
  22. 'model.embeddings_and_evoformer.template.enabled': True
  23. },
  24. 'model_2': {
  25. # Jumper et al. (2021) Suppl. Table 5, Model 1.1.2
  26. 'data.common.reduce_msa_clusters_by_max_templates': True,
  27. 'data.common.use_templates': True,
  28. 'model.embeddings_and_evoformer.template.embed_torsion_angles': True,
  29. 'model.embeddings_and_evoformer.template.enabled': True
  30. },
  31. 'model_3': {
  32. # Jumper et al. (2021) Suppl. Table 5, Model 1.2.1
  33. 'data.common.max_extra_msa': 5120,
  34. },
  35. 'model_4': {
  36. # Jumper et al. (2021) Suppl. Table 5, Model 1.2.2
  37. 'data.common.max_extra_msa': 5120,
  38. },
  39. 'model_5': {
  40. # Jumper et al. (2021) Suppl. Table 5, Model 1.2.3
  41. },
  42. # The following models are fine-tuned from the corresponding models above
  43. # with an additional predicted_aligned_error head that can produce
  44. # predicted TM-score (pTM) and predicted aligned errors.
  45. 'model_1_ptm': {
  46. 'data.common.max_extra_msa': 5120,
  47. 'data.common.reduce_msa_clusters_by_max_templates': True,
  48. 'data.common.use_templates': True,
  49. 'model.embeddings_and_evoformer.template.embed_torsion_angles': True,
  50. 'model.embeddings_and_evoformer.template.enabled': True,
  51. 'model.heads.predicted_aligned_error.weight': 0.1
  52. },
  53. 'model_2_ptm': {
  54. 'data.common.reduce_msa_clusters_by_max_templates': True,
  55. 'data.common.use_templates': True,
  56. 'model.embeddings_and_evoformer.template.embed_torsion_angles': True,
  57. 'model.embeddings_and_evoformer.template.enabled': True,
  58. 'model.heads.predicted_aligned_error.weight': 0.1
  59. },
  60. 'model_3_ptm': {
  61. 'data.common.max_extra_msa': 5120,
  62. 'model.heads.predicted_aligned_error.weight': 0.1
  63. },
  64. 'model_4_ptm': {
  65. 'data.common.max_extra_msa': 5120,
  66. 'model.heads.predicted_aligned_error.weight': 0.1
  67. },
  68. 'model_5_ptm': {
  69. 'model.heads.predicted_aligned_error.weight': 0.1
  70. }
  71. }
  72. CONFIG = ml_collections.ConfigDict({
  73. 'data': {
  74. 'common': {
  75. 'masked_msa': {
  76. 'profile_prob': 0.1,
  77. 'same_prob': 0.1,
  78. 'uniform_prob': 0.1
  79. },
  80. 'max_extra_msa': 1024,
  81. 'msa_cluster_features': True,
  82. 'num_recycle': 3,
  83. 'reduce_msa_clusters_by_max_templates': False,
  84. 'resample_msa_in_recycling': True,
  85. 'template_features': [
  86. 'template_all_atom_positions', 'template_sum_probs',
  87. 'template_aatype', 'template_all_atom_masks',
  88. 'template_domain_names'
  89. ],
  90. 'unsupervised_features': [
  91. 'aatype', 'residue_index', 'sequence', 'msa', 'domain_name',
  92. 'num_alignments', 'seq_length', 'between_segment_residues',
  93. 'deletion_matrix'
  94. ],
  95. 'use_templates': False,
  96. },
  97. 'eval': {
  98. 'feat': {
  99. 'aatype': [NUM_RES],
  100. 'all_atom_mask': [NUM_RES, None],
  101. 'all_atom_positions': [NUM_RES, None, None],
  102. 'alt_chi_angles': [NUM_RES, None],
  103. 'atom14_alt_gt_exists': [NUM_RES, None],
  104. 'atom14_alt_gt_positions': [NUM_RES, None, None],
  105. 'atom14_atom_exists': [NUM_RES, None],
  106. 'atom14_atom_is_ambiguous': [NUM_RES, None],
  107. 'atom14_gt_exists': [NUM_RES, None],
  108. 'atom14_gt_positions': [NUM_RES, None, None],
  109. 'atom37_atom_exists': [NUM_RES, None],
  110. 'backbone_affine_mask': [NUM_RES],
  111. 'backbone_affine_tensor': [NUM_RES, None],
  112. 'bert_mask': [NUM_MSA_SEQ, NUM_RES],
  113. 'chi_angles': [NUM_RES, None],
  114. 'chi_mask': [NUM_RES, None],
  115. 'extra_deletion_value': [NUM_EXTRA_SEQ, NUM_RES],
  116. 'extra_has_deletion': [NUM_EXTRA_SEQ, NUM_RES],
  117. 'extra_msa': [NUM_EXTRA_SEQ, NUM_RES],
  118. 'extra_msa_mask': [NUM_EXTRA_SEQ, NUM_RES],
  119. 'extra_msa_row_mask': [NUM_EXTRA_SEQ],
  120. 'is_distillation': [],
  121. 'msa_feat': [NUM_MSA_SEQ, NUM_RES, None],
  122. 'msa_mask': [NUM_MSA_SEQ, NUM_RES],
  123. 'msa_row_mask': [NUM_MSA_SEQ],
  124. 'pseudo_beta': [NUM_RES, None],
  125. 'pseudo_beta_mask': [NUM_RES],
  126. 'random_crop_to_size_seed': [None],
  127. 'residue_index': [NUM_RES],
  128. 'residx_atom14_to_atom37': [NUM_RES, None],
  129. 'residx_atom37_to_atom14': [NUM_RES, None],
  130. 'resolution': [],
  131. 'rigidgroups_alt_gt_frames': [NUM_RES, None, None],
  132. 'rigidgroups_group_exists': [NUM_RES, None],
  133. 'rigidgroups_group_is_ambiguous': [NUM_RES, None],
  134. 'rigidgroups_gt_exists': [NUM_RES, None],
  135. 'rigidgroups_gt_frames': [NUM_RES, None, None],
  136. 'seq_length': [],
  137. 'seq_mask': [NUM_RES],
  138. 'target_feat': [NUM_RES, None],
  139. 'template_aatype': [NUM_TEMPLATES, NUM_RES],
  140. 'template_all_atom_masks': [NUM_TEMPLATES, NUM_RES, None],
  141. 'template_all_atom_positions': [
  142. NUM_TEMPLATES, NUM_RES, None, None],
  143. 'template_backbone_affine_mask': [NUM_TEMPLATES, NUM_RES],
  144. 'template_backbone_affine_tensor': [
  145. NUM_TEMPLATES, NUM_RES, None],
  146. 'template_mask': [NUM_TEMPLATES],
  147. 'template_pseudo_beta': [NUM_TEMPLATES, NUM_RES, None],
  148. 'template_pseudo_beta_mask': [NUM_TEMPLATES, NUM_RES],
  149. 'template_sum_probs': [NUM_TEMPLATES, None],
  150. 'true_msa': [NUM_MSA_SEQ, NUM_RES]
  151. },
  152. 'fixed_size': True,
  153. 'subsample_templates': False, # We want top templates.
  154. 'masked_msa_replace_fraction': 0.15,
  155. 'max_msa_clusters': 512,
  156. 'max_templates': 4,
  157. 'num_ensemble': 1,
  158. },
  159. },
  160. 'model': {
  161. 'embeddings_and_evoformer': {
  162. 'evoformer_num_block': 48,
  163. 'evoformer': {
  164. 'msa_row_attention_with_pair_bias': {
  165. 'dropout_rate': 0.15,
  166. 'gating': True,
  167. 'num_head': 8,
  168. 'orientation': 'per_row',
  169. 'shared_dropout': True
  170. },
  171. 'msa_column_attention': {
  172. 'dropout_rate': 0.0,
  173. 'gating': True,
  174. 'num_head': 8,
  175. 'orientation': 'per_column',
  176. 'shared_dropout': True
  177. },
  178. 'msa_transition': {
  179. 'dropout_rate': 0.0,
  180. 'num_intermediate_factor': 4,
  181. 'orientation': 'per_row',
  182. 'shared_dropout': True
  183. },
  184. 'outer_product_mean': {
  185. 'chunk_size': 128,
  186. 'dropout_rate': 0.0,
  187. 'num_outer_channel': 32,
  188. 'orientation': 'per_row',
  189. 'shared_dropout': True
  190. },
  191. 'triangle_attention_starting_node': {
  192. 'dropout_rate': 0.25,
  193. 'gating': True,
  194. 'num_head': 4,
  195. 'orientation': 'per_row',
  196. 'shared_dropout': True
  197. },
  198. 'triangle_attention_ending_node': {
  199. 'dropout_rate': 0.25,
  200. 'gating': True,
  201. 'num_head': 4,
  202. 'orientation': 'per_column',
  203. 'shared_dropout': True
  204. },
  205. 'triangle_multiplication_outgoing': {
  206. 'dropout_rate': 0.25,
  207. 'equation': 'ikc,jkc->ijc',
  208. 'num_intermediate_channel': 128,
  209. 'orientation': 'per_row',
  210. 'shared_dropout': True
  211. },
  212. 'triangle_multiplication_incoming': {
  213. 'dropout_rate': 0.25,
  214. 'equation': 'kjc,kic->ijc',
  215. 'num_intermediate_channel': 128,
  216. 'orientation': 'per_row',
  217. 'shared_dropout': True
  218. },
  219. 'pair_transition': {
  220. 'dropout_rate': 0.0,
  221. 'num_intermediate_factor': 4,
  222. 'orientation': 'per_row',
  223. 'shared_dropout': True
  224. }
  225. },
  226. 'extra_msa_channel': 64,
  227. 'extra_msa_stack_num_block': 4,
  228. 'max_relative_feature': 32,
  229. 'msa_channel': 256,
  230. 'pair_channel': 128,
  231. 'prev_pos': {
  232. 'min_bin': 3.25,
  233. 'max_bin': 20.75,
  234. 'num_bins': 15
  235. },
  236. 'recycle_features': True,
  237. 'recycle_pos': True,
  238. 'seq_channel': 384,
  239. 'template': {
  240. 'attention': {
  241. 'gating': False,
  242. 'key_dim': 64,
  243. 'num_head': 4,
  244. 'value_dim': 64
  245. },
  246. 'dgram_features': {
  247. 'min_bin': 3.25,
  248. 'max_bin': 50.75,
  249. 'num_bins': 39
  250. },
  251. 'embed_torsion_angles': False,
  252. 'enabled': False,
  253. 'template_pair_stack': {
  254. 'num_block': 2,
  255. 'triangle_attention_starting_node': {
  256. 'dropout_rate': 0.25,
  257. 'gating': True,
  258. 'key_dim': 64,
  259. 'num_head': 4,
  260. 'orientation': 'per_row',
  261. 'shared_dropout': True,
  262. 'value_dim': 64
  263. },
  264. 'triangle_attention_ending_node': {
  265. 'dropout_rate': 0.25,
  266. 'gating': True,
  267. 'key_dim': 64,
  268. 'num_head': 4,
  269. 'orientation': 'per_column',
  270. 'shared_dropout': True,
  271. 'value_dim': 64
  272. },
  273. 'triangle_multiplication_outgoing': {
  274. 'dropout_rate': 0.25,
  275. 'equation': 'ikc,jkc->ijc',
  276. 'num_intermediate_channel': 64,
  277. 'orientation': 'per_row',
  278. 'shared_dropout': True
  279. },
  280. 'triangle_multiplication_incoming': {
  281. 'dropout_rate': 0.25,
  282. 'equation': 'kjc,kic->ijc',
  283. 'num_intermediate_channel': 64,
  284. 'orientation': 'per_row',
  285. 'shared_dropout': True
  286. },
  287. 'pair_transition': {
  288. 'dropout_rate': 0.0,
  289. 'num_intermediate_factor': 2,
  290. 'orientation': 'per_row',
  291. 'shared_dropout': True
  292. }
  293. },
  294. 'max_templates': 4,
  295. 'subbatch_size': 128,
  296. 'use_template_unit_vector': False,
  297. }
  298. },
  299. 'heads': {
  300. 'distogram': {
  301. 'first_break': 2.3125,
  302. 'last_break': 21.6875,
  303. 'num_bins': 64,
  304. 'weight': 0.3
  305. },
  306. 'predicted_aligned_error': {
  307. # `num_bins - 1` bins uniformly space the
  308. # [0, max_error_bin A] range.
  309. # The final bin covers [max_error_bin A, +infty]
  310. # 31A gives bins with 0.5A width.
  311. 'max_error_bin': 31.,
  312. 'num_bins': 64,
  313. 'num_channels': 128,
  314. 'filter_by_resolution': True,
  315. 'min_resolution': 0.1,
  316. 'max_resolution': 3.0,
  317. 'weight': 0.0,
  318. },
  319. 'experimentally_resolved': {
  320. 'filter_by_resolution': True,
  321. 'max_resolution': 3.0,
  322. 'min_resolution': 0.1,
  323. 'weight': 0.01
  324. },
  325. 'structure_module': {
  326. 'num_layer': 8,
  327. 'fape': {
  328. 'clamp_distance': 10.0,
  329. 'clamp_type': 'relu',
  330. 'loss_unit_distance': 10.0
  331. },
  332. 'angle_norm_weight': 0.01,
  333. 'chi_weight': 0.5,
  334. 'clash_overlap_tolerance': 1.5,
  335. 'compute_in_graph_metrics': True,
  336. 'dropout': 0.1,
  337. 'num_channel': 384,
  338. 'num_head': 12,
  339. 'num_layer_in_transition': 3,
  340. 'num_point_qk': 4,
  341. 'num_point_v': 8,
  342. 'num_scalar_qk': 16,
  343. 'num_scalar_v': 16,
  344. 'position_scale': 10.0,
  345. 'sidechain': {
  346. 'atom_clamp_distance': 10.0,
  347. 'num_channel': 128,
  348. 'num_residual_block': 2,
  349. 'weight_frac': 0.5,
  350. 'length_scale': 10.,
  351. },
  352. 'structural_violation_loss_weight': 1.0,
  353. 'violation_tolerance_factor': 12.0,
  354. 'weight': 1.0
  355. },
  356. 'predicted_lddt': {
  357. 'filter_by_resolution': True,
  358. 'max_resolution': 3.0,
  359. 'min_resolution': 0.1,
  360. 'num_bins': 50,
  361. 'num_channels': 128,
  362. 'weight': 0.01
  363. },
  364. 'masked_msa': {
  365. 'num_output': 23,
  366. 'weight': 2.0
  367. },
  368. },
  369. 'num_recycle': 3,
  370. 'resample_msa_in_recycling': True
  371. },
  372. })