You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

servable_config.py 4.6 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """serving config for mindspore serving"""
  16. import time
  17. import os
  18. import json
  19. import numpy as np
  20. from mindspore_serving.server import register
  21. import mindspore.context as context
  22. from mindspore.common.tensor import Tensor
  23. from mindspore import load_checkpoint
  24. from data.feature.feature_extraction import process_features
  25. from data.tools.data_process import data_process
  26. from commons.utils import compute_confidence
  27. from commons.generate_pdb import to_pdb, from_prediction
  28. from model import AlphaFold
  29. from config import config, global_config
  30. from fold_service.config import config as serving_config
  31. context.set_context(mode=context.GRAPH_MODE,
  32. device_target="Ascend",
  33. variable_memory_max_size="31GB",
  34. device_id=serving_config.device_id,
  35. save_graphs=False)
  36. model_name = "model_1"
  37. model_config = config.model_config(model_name)
  38. num_recycle = model_config.model.num_recycle
  39. global_config = global_config.global_config(serving_config.seq_length)
  40. extra_msa_length = global_config.extra_msa_length
  41. fold_net = AlphaFold(model_config, global_config)
  42. load_checkpoint(serving_config.ckpt_path, fold_net)
  43. def fold_model(input_fasta_path):
  44. """defining fold model"""
  45. seq_files = os.listdir(input_fasta_path)
  46. for seq_file in seq_files:
  47. print(seq_file)
  48. t1 = time.time()
  49. seq_name = seq_file.split('.')[0]
  50. input_features = data_process(seq_name, serving_config)
  51. tensors, aatype, residue_index, ori_res_length = process_features(
  52. raw_features=input_features, config=model_config, global_config=global_config)
  53. prev_pos = Tensor(np.zeros([global_config.seq_length, 37, 3]).astype(np.float16))
  54. prev_msa_first_row = Tensor(np.zeros([global_config.seq_length, 256]).astype(np.float16))
  55. prev_pair = Tensor(np.zeros([global_config.seq_length, global_config.seq_length, 128]).astype(np.float16))
  56. t2 = time.time()
  57. for i in range(num_recycle+1):
  58. tensors_i = [tensor[i] for tensor in tensors]
  59. input_feats = [Tensor(tensor) for tensor in tensors_i]
  60. final_atom_positions, final_atom_mask, predicted_lddt_logits,\
  61. prev_pos, prev_msa_first_row, prev_pair = fold_net(*input_feats,
  62. prev_pos,
  63. prev_msa_first_row,
  64. prev_pair)
  65. t3 = time.time()
  66. final_atom_positions = final_atom_positions.asnumpy()[:ori_res_length]
  67. final_atom_mask = final_atom_mask.asnumpy()[:ori_res_length]
  68. predicted_lddt_logits = predicted_lddt_logits.asnumpy()[:ori_res_length]
  69. confidence = compute_confidence(predicted_lddt_logits)
  70. unrelaxed_protein = from_prediction(final_atom_mask, aatype[0], final_atom_positions, residue_index[0])
  71. pdb_file = to_pdb(unrelaxed_protein)
  72. seq_length = aatype.shape[-1]
  73. os.makedirs(f'./result/seq_{seq_name}_{seq_length}', exist_ok=True)
  74. with open(os.path.join(f'./result/seq_{seq_name}_{seq_length}/', f'unrelaxed_model_{seq_name}.pdb'), 'w') as f:
  75. f.write(pdb_file)
  76. t4 = time.time()
  77. timings = {"pre_process_time": round(t2 - t1, 2),
  78. "model_time": round(t3 - t2, 2),
  79. "pos_process_time": round(t4 - t3, 2),
  80. "all_time": round(t4 - t1, 2),
  81. "confidence": confidence}
  82. print(timings)
  83. with open(f'./result/seq_{seq_name}_{seq_length}/timings', 'w') as f:
  84. f.write(json.dumps(timings))
  85. return True
  86. @register.register_method(output_names=["res"])
  87. def folding(input_fasta_path):
  88. res = register.add_stage(fold_model, input_fasta_path, outputs_count=1)
  89. return res