You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 5.4 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """run script"""
  16. import time
  17. import os
  18. import json
  19. import argparse
  20. import numpy as np
  21. import mindspore.context as context
  22. from mindspore.common.tensor import Tensor
  23. from mindspore import load_checkpoint
  24. from data.feature.feature_extraction import process_features
  25. from data.tools.data_process import data_process
  26. from commons.generate_pdb import to_pdb, from_prediction
  27. from commons.utils import compute_confidence
  28. from model import AlphaFold
  29. from config import config, global_config
  30. parser = argparse.ArgumentParser(description='Inputs for run.py')
  31. parser.add_argument('--seq_length', help='padding sequence length')
  32. parser.add_argument('--input_fasta_path', help='Path of FASTA files folder directory to be predicted.')
  33. parser.add_argument('--msa_result_path', help='Path to save msa result.')
  34. parser.add_argument('--database_dir', help='Path of data to generate msa.')
  35. parser.add_argument('--database_envdb_dir', help='Path of expandable data to generate msa.')
  36. parser.add_argument('--hhsearch_binary_path', help='Path of hhsearch executable.')
  37. parser.add_argument('--pdb70_database_path', help='Path to pdb70.')
  38. parser.add_argument('--template_mmcif_dir', help='Path of template mmcif.')
  39. parser.add_argument('--max_template_date', help='Maximum template release date.')
  40. parser.add_argument('--kalign_binary_path', help='Path to kalign executable.')
  41. parser.add_argument('--obsolete_pdbs_path', help='Path to obsolete pdbs path.')
  42. parser.add_argument('--checkpoint_path', help='Path of the checkpoint.')
  43. parser.add_argument('--device_id', default=0, type=int, help='Device id to be used.')
  44. args = parser.parse_args()
  45. if __name__ == "__main__":
  46. context.set_context(mode=context.GRAPH_MODE,
  47. device_target="Ascend",
  48. variable_memory_max_size="31GB",
  49. device_id=args.device_id,
  50. save_graphs=False)
  51. model_name = "model_1"
  52. model_config = config.model_config(model_name)
  53. num_recycle = model_config.model.num_recycle
  54. global_config = global_config.global_config(args.seq_length)
  55. extra_msa_length = global_config.extra_msa_length
  56. fold_net = AlphaFold(model_config, global_config)
  57. load_checkpoint(args.checkpoint_path, fold_net)
  58. seq_files = os.listdir(args.input_fasta_path)
  59. for seq_file in seq_files:
  60. t1 = time.time()
  61. seq_name = seq_file.split('.')[0]
  62. input_features = data_process(seq_name, args)
  63. tensors, aatype, residue_index, ori_res_length = process_features(
  64. raw_features=input_features, config=model_config, global_config=global_config)
  65. prev_pos = Tensor(np.zeros([global_config.seq_length, 37, 3]).astype(np.float16))
  66. prev_msa_first_row = Tensor(np.zeros([global_config.seq_length, 256]).astype(np.float16))
  67. prev_pair = Tensor(np.zeros([global_config.seq_length, global_config.seq_length, 128]).astype(np.float16))
  68. """
  69. :param::@sequence_length
  70. """
  71. t2 = time.time()
  72. for i in range(num_recycle+1):
  73. tensors_i = [tensor[i] for tensor in tensors]
  74. input_feats = [Tensor(tensor) for tensor in tensors_i]
  75. final_atom_positions, final_atom_mask, predicted_lddt_logits,\
  76. prev_pos, prev_msa_first_row, prev_pair = fold_net(*input_feats,
  77. prev_pos,
  78. prev_msa_first_row,
  79. prev_pair)
  80. t3 = time.time()
  81. final_atom_positions = final_atom_positions.asnumpy()[:ori_res_length]
  82. final_atom_mask = final_atom_mask.asnumpy()[:ori_res_length]
  83. predicted_lddt_logits = predicted_lddt_logits.asnumpy()[:ori_res_length]
  84. confidence = compute_confidence(predicted_lddt_logits)
  85. unrelaxed_protein = from_prediction(final_atom_mask, aatype[0], final_atom_positions, residue_index[0])
  86. pdb_file = to_pdb(unrelaxed_protein)
  87. seq_length = aatype.shape[-1]
  88. os.makedirs(f'./result/seq_{seq_name}_{seq_length}', exist_ok=True)
  89. with open(os.path.join(f'./result/seq_{seq_name}_{seq_length}/', f'unrelaxed_model_{seq_name}.pdb'), 'w') as f:
  90. f.write(pdb_file)
  91. t4 = time.time()
  92. timings = {"pre_process_time": round(t2 - t1, 2),
  93. "model_time": round(t3 - t2, 2),
  94. "pos_process_time": round(t4 - t3, 2),
  95. "all_time": round(t4 - t1, 2),
  96. "confidence": confidence}
  97. print(timings)
  98. with open(f'./result/seq_{seq_name}_{seq_length}/timings', 'w') as f:
  99. f.write(json.dumps(timings))