Browse Source

'merge'

pull/272/head
Turing's Cat 2 years ago
parent
commit
a2d06b2ef3
100 changed files with 16506 additions and 0 deletions
  1. +21
    -0
      reproduce/AlphaFold2-Chinese/.gitignore
  2. +479
    -0
      reproduce/AlphaFold2-Chinese/AlphaFold2中文版蛋白质预测模型使用指南(基于Deepmind与Mindspore开源框架).ipynb
  3. +479
    -0
      reproduce/AlphaFold2-Chinese/Fold_CN.ipynb
  4. +478
    -0
      reproduce/AlphaFold2-Chinese/Fold_CN.ipynb.txt
  5. +201
    -0
      reproduce/AlphaFold2-Chinese/LICENSE
  6. +221
    -0
      reproduce/AlphaFold2-Chinese/README.md
  7. +118
    -0
      reproduce/AlphaFold2-Chinese/commons/generate_pdb.py
  8. +104
    -0
      reproduce/AlphaFold2-Chinese/commons/r3.py
  9. +842
    -0
      reproduce/AlphaFold2-Chinese/commons/residue_constants.py
  10. +1038
    -0
      reproduce/AlphaFold2-Chinese/commons/utils.py
  11. +382
    -0
      reproduce/AlphaFold2-Chinese/config/config.py
  12. +341
    -0
      reproduce/AlphaFold2-Chinese/config/global_config.py
  13. +517
    -0
      reproduce/AlphaFold2-Chinese/data/feature/data_transforms.py
  14. +294
    -0
      reproduce/AlphaFold2-Chinese/data/feature/feature_extraction.py
  15. +205
    -0
      reproduce/AlphaFold2-Chinese/data/tools/data_process.py
  16. +428
    -0
      reproduce/AlphaFold2-Chinese/data/tools/data_tools.py
  17. +393
    -0
      reproduce/AlphaFold2-Chinese/data/tools/mmcif_parsing.py
  18. +61
    -0
      reproduce/AlphaFold2-Chinese/data/tools/msa_search.sh
  19. +389
    -0
      reproduce/AlphaFold2-Chinese/data/tools/parsers.py
  20. +999
    -0
      reproduce/AlphaFold2-Chinese/data/tools/templates.py
  21. BIN
      reproduce/AlphaFold2-Chinese/docs/MindScience_Architecture.jpg
  22. BIN
      reproduce/AlphaFold2-Chinese/docs/MindScience_Architecture_en.jpg
  23. BIN
      reproduce/AlphaFold2-Chinese/docs/all_experiment_data.jpg
  24. BIN
      reproduce/AlphaFold2-Chinese/docs/seq_21.jpg
  25. BIN
      reproduce/AlphaFold2-Chinese/docs/seq_64.gif
  26. +112
    -0
      reproduce/AlphaFold2-Chinese/main.py
  27. +936
    -0
      reproduce/AlphaFold2-Chinese/module/basic_module.py
  28. +304
    -0
      reproduce/AlphaFold2-Chinese/module/evoformer_module.py
  29. +235
    -0
      reproduce/AlphaFold2-Chinese/module/model.py
  30. +443
    -0
      reproduce/AlphaFold2-Chinese/module/structure_module.py
  31. +5
    -0
      reproduce/AlphaFold2-Chinese/requirements.txt
  32. +34
    -0
      reproduce/AlphaFold2-Chinese/serving/fold_service/config.py
  33. +104
    -0
      reproduce/AlphaFold2-Chinese/serving/fold_service/servable_config.py
  34. +32
    -0
      reproduce/AlphaFold2-Chinese/serving/serving_client.py
  35. +31
    -0
      reproduce/AlphaFold2-Chinese/serving/serving_server.py
  36. +14
    -0
      reproduce/AlphaFold2-Chinese/tests/st/__init__.py
  37. +15
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/__init__.py
  38. +85
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/architecture/test_activation.py
  39. +224
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/architecture/test_block.py
  40. +43
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/architecture/test_mlt.py
  41. +92
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/common/test_lr_scheduler.py
  42. +39
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/common/test_metrics.py
  43. +82
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/config.py
  44. +110
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_boundary.py
  45. +195
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_data_base.py
  46. +112
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_dataset.py
  47. +77
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_equation.py
  48. +77
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_existed_data.py
  49. +84
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_src_td.py
  50. +133
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_1d.py
  51. +247
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_2d.py
  52. +264
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_base.py
  53. +240
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_csg.py
  54. +142
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_nd.py
  55. +293
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_td.py
  56. +39
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/loss/test_constraints.py
  57. +217
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/net_with_loss/test_netwithloss.py
  58. +30
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/src/config.py
  59. +82
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/src/dataset.py
  60. +30
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/src/lr_generator.py
  61. +49
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/src/metric.py
  62. +298
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/src/model.py
  63. +137
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/test_data_compression.py
  64. +124
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_frequency_domain_maxwell/src/callback.py
  65. +44
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_frequency_domain_maxwell/src/config.py
  66. +42
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_frequency_domain_maxwell/src/dataset.py
  67. +55
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_frequency_domain_maxwell/src/model.py
  68. +144
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_frequency_domain_maxwell/test_frequency_domain_maxwell.py
  69. +31
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/src/config.py
  70. +35
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/src/dataset.py
  71. +80
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/src/loss.py
  72. +176
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/src/maxwell_model.py
  73. +37
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/src/sample.py
  74. +152
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/test_full_em.py
  75. +44
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/pretrain.json
  76. +28
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/src/__init__.py
  77. +57
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/src/dataset.py
  78. +73
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/src/lr_scheduler.py
  79. +184
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/src/maxwell.py
  80. +52
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/src/sampling_config.py
  81. +192
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/test_incremental_learning.py
  82. BIN
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/dataset/Butterfly_antenna/data_input.npy
  83. BIN
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/dataset/Butterfly_antenna/data_label.npy
  84. BIN
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/dataset/Phone/data_input.npy
  85. BIN
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/dataset/Phone/data_label.npy
  86. +130
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/src/dataset.py
  87. +98
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/src/loss.py
  88. +47
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/src/maxwell_model.py
  89. +166
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/test_parameterization.py
  90. +26
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/src/config.py
  91. +79
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/src/dataset.py
  92. +29
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/src/lr_generator.py
  93. +108
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/src/metric.py
  94. +89
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/src/model.py
  95. +144
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/test_s_parameter.py
  96. +37
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/config.json
  97. +32
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/src/__init__.py
  98. +127
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/src/callback.py
  99. +95
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/src/dataset.py
  100. +73
    -0
      reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/src/lr_scheduler.py

+ 21
- 0
reproduce/AlphaFold2-Chinese/.gitignore View File

@@ -0,0 +1,21 @@
/tmp/
*.coverage
/.idea
/.vscode
.vscode
__pycache__/
*.pyc
*.so
*.so.*
*.o
*.out
*.gch
build
*.egg-info
dist
version.py
local_script/
output
.ipynb_checkpoints
somas_meta
analyze_fail.dat

+ 479
- 0
reproduce/AlphaFold2-Chinese/AlphaFold2中文版蛋白质预测模型使用指南(基于Deepmind与Mindspore开源框架).ipynb View File

@@ -0,0 +1,479 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "AlphaFold2中文版蛋白质预测模型使用指南(基于Deepmind与Mindspore开源框架).ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/AlphaFold2.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "G4yBrceuFbf3"
},
"source": [
"<img src=\"https://raw.githubusercontent.com/sokrypton/ColabFold/main/.github/ColabFold_Marv_Logo_Small.png\" height=\"200\" align=\"right\" style=\"height:240px\">\n",
"\n",
"##AlphaFold2_CN: AlphaFold2 with MMseqs2\n",
"\n",
"简单的中文版蛋白质结构预测操作指南(中文版),基于[AlphaFold2](https://www.nature.com/articles/s41586-021-03819-2)和[Alphafold2-multimer](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1). 序列比对方式基于[MMseqs2](mmseqs.com)和[HHsearch](https://github.com/soedinglab/hh-suite)."
]
},
{
"cell_type": "code",
"metadata": {
"id": "kOblAo-xetgx",
"cellView": "form"
},
"source": [
"#@title 输入蛋白质序列(默认为视频中的测试序列)\n",
"from google.colab import files\n",
"import os.path\n",
"import re\n",
"import hashlib\n",
"import random\n",
"\n",
"def add_hash(x,y):\n",
" return x+\"_\"+hashlib.sha1(y.encode()).hexdigest()[:5]\n",
"\n",
"query_sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n",
"#@markdown - Use `:` to specify inter-protein chainbreaks for **modeling complexes** (supports homo- and hetro-oligomers). For example **PI...SK:PI...SK** for a homodimer\n",
"\n",
"# remove whitespaces\n",
"query_sequence = \"\".join(query_sequence.split())\n",
"\n",
"jobname = 'test' #@param {type:\"string\"}\n",
"# remove whitespaces\n",
"basejobname = \"\".join(jobname.split())\n",
"basejobname = re.sub(r'\\W+', '', basejobname)\n",
"jobname = add_hash(basejobname, query_sequence)\n",
"while os.path.isfile(f\"{jobname}.csv\"):\n",
" jobname = add_hash(basejobname, ''.join(random.sample(query_sequence,len(query_sequence))))\n",
"\n",
"with open(f\"{jobname}.csv\", \"w\") as text_file:\n",
" text_file.write(f\"id,sequence\\n{jobname},{query_sequence}\")\n",
"\n",
"queries_path=f\"{jobname}.csv\"\n",
"\n",
"# number of models to use\n",
"use_amber = False #@param {type:\"boolean\"}\n",
"template_mode = \"none\" #@param [\"none\", \"pdb70\",\"custom\"]\n",
"#@markdown - \"none\" = no template information is used, \"pdb70\" = detect templates in pdb70, \"custom\" - upload and search own templates (PDB or mmCIF format, see [notes below](#custom_templates))\n",
"\n",
"if template_mode == \"pdb70\":\n",
" use_templates = True\n",
" custom_template_path = None\n",
"elif template_mode == \"custom\":\n",
" custom_template_path = f\"{jobname}_template\"\n",
" os.mkdir(custom_template_path)\n",
" uploaded = files.upload()\n",
" use_templates = True\n",
" for fn in uploaded.keys():\n",
" os.rename(fn, f\"{jobname}_template/{fn}\")\n",
"else:\n",
" custom_template_path = None\n",
" use_templates = False\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@markdown ### MSA选项(custom MSA upload, single sequence, pairing mode)\n",
"msa_mode = \"MMseqs2 (UniRef+Environmental)\" #@param [\"MMseqs2 (UniRef+Environmental)\", \"MMseqs2 (UniRef only)\",\"single_sequence\",\"custom\"]\n",
"pair_mode = \"unpaired+paired\" #@param [\"unpaired+paired\",\"paired\",\"unpaired\"] {type:\"string\"}\n",
"#@markdown - \"unpaired+paired\" = pair sequences from same species + unpaired MSA, \"unpaired\" = seperate MSA for each chain, \"paired\" - only use paired sequences.\n",
"\n",
"# decide which a3m to use\n",
"if msa_mode.startswith(\"MMseqs2\"):\n",
" a3m_file = f\"{jobname}.a3m\"\n",
"elif msa_mode == \"custom\":\n",
" a3m_file = f\"{jobname}.custom.a3m\"\n",
" if not os.path.isfile(a3m_file):\n",
" custom_msa_dict = files.upload()\n",
" custom_msa = list(custom_msa_dict.keys())[0]\n",
" header = 0\n",
" import fileinput\n",
" for line in fileinput.FileInput(custom_msa,inplace=1):\n",
" if line.startswith(\">\"):\n",
" header = header + 1\n",
" if not line.rstrip():\n",
" continue\n",
" if line.startswith(\">\") == False and header == 1:\n",
" query_sequence = line.rstrip()\n",
" print(line, end='')\n",
"\n",
" os.rename(custom_msa, a3m_file)\n",
" queries_path=a3m_file\n",
" print(f\"moving {custom_msa} to {a3m_file}\")\n",
"else:\n",
" a3m_file = f\"{jobname}.single_sequence.a3m\"\n",
" with open(a3m_file, \"w\") as text_file:\n",
" text_file.write(\">1\\n%s\" % query_sequence)"
],
"metadata": {
"cellView": "form",
"id": "C2_sh2uAonJH"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@markdown ### 参数设置\n",
"model_type = \"auto\" #@param [\"auto\", \"AlphaFold2-ptm\", \"AlphaFold2-multimer-v1\", \"AlphaFold2-multimer-v2\"]\n",
"#@markdown - \"auto\" = protein structure prediction using \"AlphaFold2-ptm\" and complex prediction \"AlphaFold-multimer-v2\". For complexes \"AlphaFold-multimer-v[1,2]\" and \"AlphaFold-ptm\" can be used.\n",
"num_recycles = 3 #@param [1,3,6,12,24,48] {type:\"raw\"}\n",
"save_to_google_drive = False #@param {type:\"boolean\"}\n",
"\n",
"#@markdown - if the save_to_google_drive option was selected, the result zip will be uploaded to your Google Drive\n",
"dpi = 200 #@param {type:\"integer\"}\n",
"#@markdown - set dpi for image resolution\n",
"\n",
"#@markdown Don't forget to hit `Runtime` -> `Run all` after updating the form.\n",
"\n",
"\n",
"if save_to_google_drive:\n",
" from pydrive.drive import GoogleDrive\n",
" from pydrive.auth import GoogleAuth\n",
" from google.colab import auth\n",
" from oauth2client.client import GoogleCredentials\n",
" auth.authenticate_user()\n",
" gauth = GoogleAuth()\n",
" gauth.credentials = GoogleCredentials.get_application_default()\n",
" drive = GoogleDrive(gauth)\n",
" print(\"You are logged into Google Drive and are good to go!\")"
],
"metadata": {
"cellView": "form",
"id": "ADDuaolKmjGW"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "iccGdbe_Pmt9",
"pycharm": {
"name": "#%%\n"
},
"cellView": "form"
},
"source": [
"#@title 环境安装\n",
"%%bash -s $use_amber $use_templates\n",
"\n",
"set -e\n",
"\n",
"USE_AMBER=$1\n",
"USE_TEMPLATES=$2\n",
"\n",
"if [ ! -f COLABFOLD_READY ]; then\n",
" # install dependencies\n",
" # We have to use \"--no-warn-conflicts\" because colab already has a lot preinstalled with requirements different to ours\n",
" pip install -q --no-warn-conflicts \"colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold\"\n",
" # high risk high gain\n",
" pip install -q \"jax[cuda11_cudnn805]>=0.3.8,<0.4\" -f https://storage.googleapis.com/jax-releases/jax_releases.html\n",
" touch COLABFOLD_READY\n",
"fi\n",
"\n",
"# setup conda\n",
"if [ ${USE_AMBER} == \"True\" ] || [ ${USE_TEMPLATES} == \"True\" ]; then\n",
" if [ ! -f CONDA_READY ]; then\n",
" wget -qnc https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh\n",
" bash Miniconda3-latest-Linux-x86_64.sh -bfp /usr/local 2>&1 1>/dev/null\n",
" rm Miniconda3-latest-Linux-x86_64.sh\n",
" touch CONDA_READY\n",
" fi\n",
"fi\n",
"# setup template search\n",
"if [ ${USE_TEMPLATES} == \"True\" ] && [ ! -f HH_READY ]; then\n",
" conda install -y -q -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 python=3.7 2>&1 1>/dev/null\n",
" touch HH_READY\n",
"fi\n",
"# setup openmm for amber refinement\n",
"if [ ${USE_AMBER} == \"True\" ] && [ ! -f AMBER_READY ]; then\n",
" conda install -y -q -c conda-forge openmm=7.5.1 python=3.7 pdbfixer 2>&1 1>/dev/null\n",
" touch AMBER_READY\n",
"fi"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "_sztQyz29DIC",
"cellView": "form"
},
"source": [
"#@title 模型预测\n",
"\n",
"import sys\n",
"\n",
"from colabfold.download import download_alphafold_params, default_data_dir\n",
"from colabfold.utils import setup_logging\n",
"from colabfold.batch import get_queries, run, set_model_type\n",
"K80_chk = !nvidia-smi | grep \"Tesla K80\" | wc -l\n",
"if \"1\" in K80_chk:\n",
" print(\"WARNING: found GPU Tesla K80: limited to total length < 1000\")\n",
" if \"TF_FORCE_UNIFIED_MEMORY\" in os.environ:\n",
" del os.environ[\"TF_FORCE_UNIFIED_MEMORY\"]\n",
" if \"XLA_PYTHON_CLIENT_MEM_FRACTION\" in os.environ:\n",
" del os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]\n",
"\n",
"from colabfold.colabfold import plot_protein\n",
"from pathlib import Path\n",
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"# For some reason we need that to get pdbfixer to import\n",
"if use_amber and '/usr/local/lib/python3.7/site-packages/' not in sys.path:\n",
" sys.path.insert(0, '/usr/local/lib/python3.7/site-packages/')\n",
"\n",
"def prediction_callback(unrelaxed_protein, length, prediction_result, input_features, type):\n",
" fig = plot_protein(unrelaxed_protein, Ls=length, dpi=150)\n",
" plt.show()\n",
" plt.close()\n",
"\n",
"result_dir=\".\"\n",
"setup_logging(Path(\".\").joinpath(\"log.txt\"))\n",
"queries, is_complex = get_queries(queries_path)\n",
"model_type = set_model_type(is_complex, model_type)\n",
"download_alphafold_params(model_type, Path(\".\"))\n",
"run(\n",
" queries=queries,\n",
" result_dir=result_dir,\n",
" use_templates=use_templates,\n",
" custom_template_path=custom_template_path,\n",
" use_amber=use_amber,\n",
" msa_mode=msa_mode, \n",
" model_type=model_type,\n",
" num_models=5,\n",
" num_recycles=num_recycles,\n",
" model_order=[1, 2, 3, 4, 5],\n",
" is_complex=is_complex,\n",
" data_dir=Path(\".\"),\n",
" keep_existing_results=False,\n",
" recompile_padding=1.0,\n",
" rank_by=\"auto\",\n",
" pair_mode=pair_mode,\n",
" stop_at_score=float(100),\n",
" prediction_callback=prediction_callback,\n",
" dpi=dpi\n",
")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "KK7X9T44pWb7",
"cellView": "form"
},
"source": [
"#@title 展示3维结构 {run: \"auto\"}\n",
"import py3Dmol\n",
"import glob\n",
"import matplotlib.pyplot as plt\n",
"from colabfold.colabfold import plot_plddt_legend\n",
"rank_num = 1 #@param [\"1\", \"2\", \"3\", \"4\", \"5\"] {type:\"raw\"}\n",
"color = \"lDDT\" #@param [\"chain\", \"lDDT\", \"rainbow\"]\n",
"show_sidechains = False #@param {type:\"boolean\"}\n",
"show_mainchains = False #@param {type:\"boolean\"}\n",
"\n",
"jobname_prefix = \".custom\" if msa_mode == \"custom\" else \"\"\n",
"if use_amber:\n",
" pdb_filename = f\"{jobname}{jobname_prefix}_relaxed_rank_{rank_num}_model_*.pdb\"\n",
"else:\n",
" pdb_filename = f\"{jobname}{jobname_prefix}_unrelaxed_rank_{rank_num}_model_*.pdb\"\n",
"\n",
"pdb_file = glob.glob(pdb_filename)\n",
"\n",
"def show_pdb(rank_num=1, show_sidechains=False, show_mainchains=False, color=\"lDDT\"):\n",
" model_name = f\"rank_{rank_num}\"\n",
" view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)\n",
" view.addModel(open(pdb_file[0],'r').read(),'pdb')\n",
"\n",
" if color == \"lDDT\":\n",
" view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}})\n",
" elif color == \"rainbow\":\n",
" view.setStyle({'cartoon': {'color':'spectrum'}})\n",
" elif color == \"chain\":\n",
" chains = len(queries[0][1]) + 1 if is_complex else 1\n",
" for n,chain,color in zip(range(chains),list(\"ABCDEFGH\"),\n",
" [\"lime\",\"cyan\",\"magenta\",\"yellow\",\"salmon\",\"white\",\"blue\",\"orange\"]):\n",
" view.setStyle({'chain':chain},{'cartoon': {'color':color}})\n",
" if show_sidechains:\n",
" BB = ['C','O','N']\n",
" view.addStyle({'and':[{'resn':[\"GLY\",\"PRO\"],'invert':True},{'atom':BB,'invert':True}]},\n",
" {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
" view.addStyle({'and':[{'resn':\"GLY\"},{'atom':'CA'}]},\n",
" {'sphere':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
" view.addStyle({'and':[{'resn':\"PRO\"},{'atom':['C','O'],'invert':True}]},\n",
" {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}}) \n",
" if show_mainchains:\n",
" BB = ['C','O','N','CA']\n",
" view.addStyle({'atom':BB},{'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
"\n",
" view.zoomTo()\n",
" return view\n",
"\n",
"\n",
"show_pdb(rank_num,show_sidechains, show_mainchains, color).show()\n",
"if color == \"lDDT\":\n",
" plot_plddt_legend().show() "
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "11l8k--10q0C",
"cellView": "form"
},
"source": [
"#@title 图表展示 {run: \"auto\"}\n",
"from IPython.display import display, HTML\n",
"import base64\n",
"from html import escape\n",
"\n",
"# see: https://stackoverflow.com/a/53688522\n",
"def image_to_data_url(filename):\n",
" ext = filename.split('.')[-1]\n",
" prefix = f'data:image/{ext};base64,'\n",
" with open(filename, 'rb') as f:\n",
" img = f.read()\n",
" return prefix + base64.b64encode(img).decode('utf-8')\n",
"\n",
"pae = image_to_data_url(f\"{jobname}{jobname_prefix}_PAE.png\")\n",
"cov = image_to_data_url(f\"{jobname}{jobname_prefix}_coverage.png\")\n",
"plddt = image_to_data_url(f\"{jobname}{jobname_prefix}_plddt.png\")\n",
"display(HTML(f\"\"\"\n",
"<style>\n",
" img {{\n",
" float:left;\n",
" }}\n",
" .full {{\n",
" max-width:100%;\n",
" }}\n",
" .half {{\n",
" max-width:50%;\n",
" }}\n",
" @media (max-width:640px) {{\n",
" .half {{\n",
" max-width:100%;\n",
" }}\n",
" }}\n",
"</style>\n",
"<div style=\"max-width:90%; padding:2em;\">\n",
" <h1>Plots for {escape(jobname)}</h1>\n",
" <img src=\"{pae}\" class=\"full\" />\n",
" <img src=\"{cov}\" class=\"half\" />\n",
" <img src=\"{plddt}\" class=\"half\" />\n",
"</div>\n",
"\"\"\"))\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "33g5IIegij5R",
"cellView": "form"
},
"source": [
"#@title结果下载\n",
"#@markdown If you are having issues downloading the result archive, try disabling your adblocker and run this cell again. If that fails click on the little folder icon to the left, navigate to file: `jobname.result.zip`, right-click and select \\\"Download\\\" (see [screenshot](https://pbs.twimg.com/media/E6wRW2lWUAEOuoe?format=jpg&name=small)).\n",
"\n",
"if msa_mode == \"custom\":\n",
" print(\"Don't forget to cite your custom MSA generation method.\")\n",
"\n",
"!zip -FSr $jobname\".result.zip\" config.json $jobname*\".json\" $jobname*\".a3m\" $jobname*\"relaxed_rank_\"*\".pdb\" \"cite.bibtex\" $jobname*\".png\"\n",
"files.download(f\"{jobname}.result.zip\")\n",
"\n",
"if save_to_google_drive == True and drive:\n",
" uploaded = drive.CreateFile({'title': f\"{jobname}.result.zip\"})\n",
" uploaded.SetContentFile(f\"{jobname}.result.zip\")\n",
" uploaded.Upload()\n",
" print(f\"Uploaded {jobname}.result.zip to Google Drive with ID {uploaded.get('id')}\")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "UGUBLzB3C6WN",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"# 操作指南 <a name=\"Instructions\"></a>\n",
"**Quick start**\n",
"1. 把你要预测的氨基酸序列复制进输入框.\n",
"2. 点击\"Runtime\" -> \"Run all\".\n",
"3. 目前的模型预测包括5个模块,最终会生成一个3维结构图.\n",
"\n",
"**生成文件**\n",
"\n",
"1. PDB格式的模型结构文件.\n",
"2. 模型质量图\n",
"3. 模型MSA覆盖率.\n",
"4. 其他.\n",
"\n",
"**Acknowledgments**\n",
"- We thank the AlphaFold team for developing an excellent model and open sourcing the software. \n",
"\n",
"- [Söding Lab](https://www.mpibpc.mpg.de/soeding) for providing the computational resources for the MMseqs2 server\n",
"\n",
"- Richard Evans for helping to benchmark the ColabFold's Alphafold-multimer support\n",
"\n",
"- [David Koes](https://github.com/dkoes) for his awesome [py3Dmol](https://3dmol.csb.pitt.edu/) plugin, without whom these notebooks would be quite boring!\n",
"\n",
"- Do-Yoon Kim for creating the ColabFold logo.\n",
"\n",
"- A colab by Sergey Ovchinnikov ([@sokrypton](https://twitter.com/sokrypton)), Milot Mirdita ([@milot_mirdita](https://twitter.com/milot_mirdita)) and Martin Steinegger ([@thesteinegger](https://twitter.com/thesteinegger)).\n"
]
}
]
}

+ 479
- 0
reproduce/AlphaFold2-Chinese/Fold_CN.ipynb View File

@@ -0,0 +1,479 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "AlphaFold2中文版蛋白质预测模型使用指南(基于Deepmind与Mindspore开源框架).ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/AlphaFold2.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "G4yBrceuFbf3"
},
"source": [
"<img src=\"https://raw.githubusercontent.com/sokrypton/ColabFold/main/.github/ColabFold_Marv_Logo_Small.png\" height=\"200\" align=\"right\" style=\"height:240px\">\n",
"\n",
"##AlphaFold2_CN: AlphaFold2 with MMseqs2\n",
"\n",
"简单的中文版蛋白质结构预测操作指南(中文版),基于[AlphaFold2](https://www.nature.com/articles/s41586-021-03819-2)和[Alphafold2-multimer](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1). 序列比对方式基于[MMseqs2](mmseqs.com)和[HHsearch](https://github.com/soedinglab/hh-suite)."
]
},
{
"cell_type": "code",
"metadata": {
"id": "kOblAo-xetgx",
"cellView": "form"
},
"source": [
"#@title 输入蛋白质序列(默认为视频中的测试序列)\n",
"from google.colab import files\n",
"import os.path\n",
"import re\n",
"import hashlib\n",
"import random\n",
"\n",
"def add_hash(x,y):\n",
" return x+\"_\"+hashlib.sha1(y.encode()).hexdigest()[:5]\n",
"\n",
"query_sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n",
"#@markdown - Use `:` to specify inter-protein chainbreaks for **modeling complexes** (supports homo- and hetro-oligomers). For example **PI...SK:PI...SK** for a homodimer\n",
"\n",
"# remove whitespaces\n",
"query_sequence = \"\".join(query_sequence.split())\n",
"\n",
"jobname = 'test' #@param {type:\"string\"}\n",
"# remove whitespaces\n",
"basejobname = \"\".join(jobname.split())\n",
"basejobname = re.sub(r'\\W+', '', basejobname)\n",
"jobname = add_hash(basejobname, query_sequence)\n",
"while os.path.isfile(f\"{jobname}.csv\"):\n",
" jobname = add_hash(basejobname, ''.join(random.sample(query_sequence,len(query_sequence))))\n",
"\n",
"with open(f\"{jobname}.csv\", \"w\") as text_file:\n",
" text_file.write(f\"id,sequence\\n{jobname},{query_sequence}\")\n",
"\n",
"queries_path=f\"{jobname}.csv\"\n",
"\n",
"# number of models to use\n",
"use_amber = False #@param {type:\"boolean\"}\n",
"template_mode = \"none\" #@param [\"none\", \"pdb70\",\"custom\"]\n",
"#@markdown - \"none\" = no template information is used, \"pdb70\" = detect templates in pdb70, \"custom\" - upload and search own templates (PDB or mmCIF format, see [notes below](#custom_templates))\n",
"\n",
"if template_mode == \"pdb70\":\n",
" use_templates = True\n",
" custom_template_path = None\n",
"elif template_mode == \"custom\":\n",
" custom_template_path = f\"{jobname}_template\"\n",
" os.mkdir(custom_template_path)\n",
" uploaded = files.upload()\n",
" use_templates = True\n",
" for fn in uploaded.keys():\n",
" os.rename(fn, f\"{jobname}_template/{fn}\")\n",
"else:\n",
" custom_template_path = None\n",
" use_templates = False\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@markdown ### MSA选项(custom MSA upload, single sequence, pairing mode)\n",
"msa_mode = \"MMseqs2 (UniRef+Environmental)\" #@param [\"MMseqs2 (UniRef+Environmental)\", \"MMseqs2 (UniRef only)\",\"single_sequence\",\"custom\"]\n",
"pair_mode = \"unpaired+paired\" #@param [\"unpaired+paired\",\"paired\",\"unpaired\"] {type:\"string\"}\n",
"#@markdown - \"unpaired+paired\" = pair sequences from same species + unpaired MSA, \"unpaired\" = seperate MSA for each chain, \"paired\" - only use paired sequences.\n",
"\n",
"# decide which a3m to use\n",
"if msa_mode.startswith(\"MMseqs2\"):\n",
" a3m_file = f\"{jobname}.a3m\"\n",
"elif msa_mode == \"custom\":\n",
" a3m_file = f\"{jobname}.custom.a3m\"\n",
" if not os.path.isfile(a3m_file):\n",
" custom_msa_dict = files.upload()\n",
" custom_msa = list(custom_msa_dict.keys())[0]\n",
" header = 0\n",
" import fileinput\n",
" for line in fileinput.FileInput(custom_msa,inplace=1):\n",
" if line.startswith(\">\"):\n",
" header = header + 1\n",
" if not line.rstrip():\n",
" continue\n",
" if line.startswith(\">\") == False and header == 1:\n",
" query_sequence = line.rstrip()\n",
" print(line, end='')\n",
"\n",
" os.rename(custom_msa, a3m_file)\n",
" queries_path=a3m_file\n",
" print(f\"moving {custom_msa} to {a3m_file}\")\n",
"else:\n",
" a3m_file = f\"{jobname}.single_sequence.a3m\"\n",
" with open(a3m_file, \"w\") as text_file:\n",
" text_file.write(\">1\\n%s\" % query_sequence)"
],
"metadata": {
"cellView": "form",
"id": "C2_sh2uAonJH"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@markdown ### 参数设置\n",
"model_type = \"auto\" #@param [\"auto\", \"AlphaFold2-ptm\", \"AlphaFold2-multimer-v1\", \"AlphaFold2-multimer-v2\"]\n",
"#@markdown - \"auto\" = protein structure prediction using \"AlphaFold2-ptm\" and complex prediction \"AlphaFold-multimer-v2\". For complexes \"AlphaFold-multimer-v[1,2]\" and \"AlphaFold-ptm\" can be used.\n",
"num_recycles = 3 #@param [1,3,6,12,24,48] {type:\"raw\"}\n",
"save_to_google_drive = False #@param {type:\"boolean\"}\n",
"\n",
"#@markdown - if the save_to_google_drive option was selected, the result zip will be uploaded to your Google Drive\n",
"dpi = 200 #@param {type:\"integer\"}\n",
"#@markdown - set dpi for image resolution\n",
"\n",
"#@markdown Don't forget to hit `Runtime` -> `Run all` after updating the form.\n",
"\n",
"\n",
"if save_to_google_drive:\n",
" from pydrive.drive import GoogleDrive\n",
" from pydrive.auth import GoogleAuth\n",
" from google.colab import auth\n",
" from oauth2client.client import GoogleCredentials\n",
" auth.authenticate_user()\n",
" gauth = GoogleAuth()\n",
" gauth.credentials = GoogleCredentials.get_application_default()\n",
" drive = GoogleDrive(gauth)\n",
" print(\"You are logged into Google Drive and are good to go!\")"
],
"metadata": {
"cellView": "form",
"id": "ADDuaolKmjGW"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "iccGdbe_Pmt9",
"pycharm": {
"name": "#%%\n"
},
"cellView": "form"
},
"source": [
"#@title 环境安装\n",
"%%bash -s $use_amber $use_templates\n",
"\n",
"set -e\n",
"\n",
"USE_AMBER=$1\n",
"USE_TEMPLATES=$2\n",
"\n",
"if [ ! -f COLABFOLD_READY ]; then\n",
" # install dependencies\n",
" # We have to use \"--no-warn-conflicts\" because colab already has a lot preinstalled with requirements different to ours\n",
" pip install -q --no-warn-conflicts \"colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold\"\n",
" # high risk high gain\n",
" pip install -q \"jax[cuda11_cudnn805]>=0.3.8,<0.4\" -f https://storage.googleapis.com/jax-releases/jax_releases.html\n",
" touch COLABFOLD_READY\n",
"fi\n",
"\n",
"# setup conda\n",
"if [ ${USE_AMBER} == \"True\" ] || [ ${USE_TEMPLATES} == \"True\" ]; then\n",
" if [ ! -f CONDA_READY ]; then\n",
" wget -qnc https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh\n",
" bash Miniconda3-latest-Linux-x86_64.sh -bfp /usr/local 2>&1 1>/dev/null\n",
" rm Miniconda3-latest-Linux-x86_64.sh\n",
" touch CONDA_READY\n",
" fi\n",
"fi\n",
"# setup template search\n",
"if [ ${USE_TEMPLATES} == \"True\" ] && [ ! -f HH_READY ]; then\n",
" conda install -y -q -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 python=3.7 2>&1 1>/dev/null\n",
" touch HH_READY\n",
"fi\n",
"# setup openmm for amber refinement\n",
"if [ ${USE_AMBER} == \"True\" ] && [ ! -f AMBER_READY ]; then\n",
" conda install -y -q -c conda-forge openmm=7.5.1 python=3.7 pdbfixer 2>&1 1>/dev/null\n",
" touch AMBER_READY\n",
"fi"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "_sztQyz29DIC",
"cellView": "form"
},
"source": [
"#@title 模型预测\n",
"\n",
"import sys\n",
"\n",
"from colabfold.download import download_alphafold_params, default_data_dir\n",
"from colabfold.utils import setup_logging\n",
"from colabfold.batch import get_queries, run, set_model_type\n",
"K80_chk = !nvidia-smi | grep \"Tesla K80\" | wc -l\n",
"if \"1\" in K80_chk:\n",
" print(\"WARNING: found GPU Tesla K80: limited to total length < 1000\")\n",
" if \"TF_FORCE_UNIFIED_MEMORY\" in os.environ:\n",
" del os.environ[\"TF_FORCE_UNIFIED_MEMORY\"]\n",
" if \"XLA_PYTHON_CLIENT_MEM_FRACTION\" in os.environ:\n",
" del os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]\n",
"\n",
"from colabfold.colabfold import plot_protein\n",
"from pathlib import Path\n",
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"# For some reason we need that to get pdbfixer to import\n",
"if use_amber and '/usr/local/lib/python3.7/site-packages/' not in sys.path:\n",
" sys.path.insert(0, '/usr/local/lib/python3.7/site-packages/')\n",
"\n",
"def prediction_callback(unrelaxed_protein, length, prediction_result, input_features, type):\n",
" fig = plot_protein(unrelaxed_protein, Ls=length, dpi=150)\n",
" plt.show()\n",
" plt.close()\n",
"\n",
"result_dir=\".\"\n",
"setup_logging(Path(\".\").joinpath(\"log.txt\"))\n",
"queries, is_complex = get_queries(queries_path)\n",
"model_type = set_model_type(is_complex, model_type)\n",
"download_alphafold_params(model_type, Path(\".\"))\n",
"run(\n",
" queries=queries,\n",
" result_dir=result_dir,\n",
" use_templates=use_templates,\n",
" custom_template_path=custom_template_path,\n",
" use_amber=use_amber,\n",
" msa_mode=msa_mode, \n",
" model_type=model_type,\n",
" num_models=5,\n",
" num_recycles=num_recycles,\n",
" model_order=[1, 2, 3, 4, 5],\n",
" is_complex=is_complex,\n",
" data_dir=Path(\".\"),\n",
" keep_existing_results=False,\n",
" recompile_padding=1.0,\n",
" rank_by=\"auto\",\n",
" pair_mode=pair_mode,\n",
" stop_at_score=float(100),\n",
" prediction_callback=prediction_callback,\n",
" dpi=dpi\n",
")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "KK7X9T44pWb7",
"cellView": "form"
},
"source": [
"#@title 展示3维结构 {run: \"auto\"}\n",
"import py3Dmol\n",
"import glob\n",
"import matplotlib.pyplot as plt\n",
"from colabfold.colabfold import plot_plddt_legend\n",
"rank_num = 1 #@param [\"1\", \"2\", \"3\", \"4\", \"5\"] {type:\"raw\"}\n",
"color = \"lDDT\" #@param [\"chain\", \"lDDT\", \"rainbow\"]\n",
"show_sidechains = False #@param {type:\"boolean\"}\n",
"show_mainchains = False #@param {type:\"boolean\"}\n",
"\n",
"jobname_prefix = \".custom\" if msa_mode == \"custom\" else \"\"\n",
"if use_amber:\n",
" pdb_filename = f\"{jobname}{jobname_prefix}_relaxed_rank_{rank_num}_model_*.pdb\"\n",
"else:\n",
" pdb_filename = f\"{jobname}{jobname_prefix}_unrelaxed_rank_{rank_num}_model_*.pdb\"\n",
"\n",
"pdb_file = glob.glob(pdb_filename)\n",
"\n",
"def show_pdb(rank_num=1, show_sidechains=False, show_mainchains=False, color=\"lDDT\"):\n",
" model_name = f\"rank_{rank_num}\"\n",
" view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)\n",
" view.addModel(open(pdb_file[0],'r').read(),'pdb')\n",
"\n",
" if color == \"lDDT\":\n",
" view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}})\n",
" elif color == \"rainbow\":\n",
" view.setStyle({'cartoon': {'color':'spectrum'}})\n",
" elif color == \"chain\":\n",
" chains = len(queries[0][1]) + 1 if is_complex else 1\n",
" for n,chain,color in zip(range(chains),list(\"ABCDEFGH\"),\n",
" [\"lime\",\"cyan\",\"magenta\",\"yellow\",\"salmon\",\"white\",\"blue\",\"orange\"]):\n",
" view.setStyle({'chain':chain},{'cartoon': {'color':color}})\n",
" if show_sidechains:\n",
" BB = ['C','O','N']\n",
" view.addStyle({'and':[{'resn':[\"GLY\",\"PRO\"],'invert':True},{'atom':BB,'invert':True}]},\n",
" {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
" view.addStyle({'and':[{'resn':\"GLY\"},{'atom':'CA'}]},\n",
" {'sphere':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
" view.addStyle({'and':[{'resn':\"PRO\"},{'atom':['C','O'],'invert':True}]},\n",
" {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}}) \n",
" if show_mainchains:\n",
" BB = ['C','O','N','CA']\n",
" view.addStyle({'atom':BB},{'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
"\n",
" view.zoomTo()\n",
" return view\n",
"\n",
"\n",
"show_pdb(rank_num,show_sidechains, show_mainchains, color).show()\n",
"if color == \"lDDT\":\n",
" plot_plddt_legend().show() "
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "11l8k--10q0C",
"cellView": "form"
},
"source": [
"#@title 图表展示 {run: \"auto\"}\n",
"from IPython.display import display, HTML\n",
"import base64\n",
"from html import escape\n",
"\n",
"# see: https://stackoverflow.com/a/53688522\n",
"def image_to_data_url(filename):\n",
" ext = filename.split('.')[-1]\n",
" prefix = f'data:image/{ext};base64,'\n",
" with open(filename, 'rb') as f:\n",
" img = f.read()\n",
" return prefix + base64.b64encode(img).decode('utf-8')\n",
"\n",
"pae = image_to_data_url(f\"{jobname}{jobname_prefix}_PAE.png\")\n",
"cov = image_to_data_url(f\"{jobname}{jobname_prefix}_coverage.png\")\n",
"plddt = image_to_data_url(f\"{jobname}{jobname_prefix}_plddt.png\")\n",
"display(HTML(f\"\"\"\n",
"<style>\n",
" img {{\n",
" float:left;\n",
" }}\n",
" .full {{\n",
" max-width:100%;\n",
" }}\n",
" .half {{\n",
" max-width:50%;\n",
" }}\n",
" @media (max-width:640px) {{\n",
" .half {{\n",
" max-width:100%;\n",
" }}\n",
" }}\n",
"</style>\n",
"<div style=\"max-width:90%; padding:2em;\">\n",
" <h1>Plots for {escape(jobname)}</h1>\n",
" <img src=\"{pae}\" class=\"full\" />\n",
" <img src=\"{cov}\" class=\"half\" />\n",
" <img src=\"{plddt}\" class=\"half\" />\n",
"</div>\n",
"\"\"\"))\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "33g5IIegij5R",
"cellView": "form"
},
"source": [
"#@title结果下载\n",
"#@markdown If you are having issues downloading the result archive, try disabling your adblocker and run this cell again. If that fails click on the little folder icon to the left, navigate to file: `jobname.result.zip`, right-click and select \\\"Download\\\" (see [screenshot](https://pbs.twimg.com/media/E6wRW2lWUAEOuoe?format=jpg&name=small)).\n",
"\n",
"if msa_mode == \"custom\":\n",
" print(\"Don't forget to cite your custom MSA generation method.\")\n",
"\n",
"!zip -FSr $jobname\".result.zip\" config.json $jobname*\".json\" $jobname*\".a3m\" $jobname*\"relaxed_rank_\"*\".pdb\" \"cite.bibtex\" $jobname*\".png\"\n",
"files.download(f\"{jobname}.result.zip\")\n",
"\n",
"if save_to_google_drive == True and drive:\n",
" uploaded = drive.CreateFile({'title': f\"{jobname}.result.zip\"})\n",
" uploaded.SetContentFile(f\"{jobname}.result.zip\")\n",
" uploaded.Upload()\n",
" print(f\"Uploaded {jobname}.result.zip to Google Drive with ID {uploaded.get('id')}\")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "UGUBLzB3C6WN",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"# 操作指南 <a name=\"Instructions\"></a>\n",
"**Quick start**\n",
"1. 把你要预测的氨基酸序列复制进输入框.\n",
"2. 点击\"Runtime\" -> \"Run all\".\n",
"3. 目前的模型预测包括5个模块,最终会生成一个3维结构图.\n",
"\n",
"**生成文件**\n",
"\n",
"1. PDB格式的模型结构文件.\n",
"2. 模型质量图\n",
"3. 模型MSA覆盖率.\n",
"4. 其他.\n",
"\n",
"**Acknowledgments**\n",
"- We thank the AlphaFold team for developing an excellent model and open sourcing the software. \n",
"\n",
"- [Söding Lab](https://www.mpibpc.mpg.de/soeding) for providing the computational resources for the MMseqs2 server\n",
"\n",
"- Richard Evans for helping to benchmark the ColabFold's Alphafold-multimer support\n",
"\n",
"- [David Koes](https://github.com/dkoes) for his awesome [py3Dmol](https://3dmol.csb.pitt.edu/) plugin, without whom these notebooks would be quite boring!\n",
"\n",
"- Do-Yoon Kim for creating the ColabFold logo.\n",
"\n",
"- A colab by Sergey Ovchinnikov ([@sokrypton](https://twitter.com/sokrypton)), Milot Mirdita ([@milot_mirdita](https://twitter.com/milot_mirdita)) and Martin Steinegger ([@thesteinegger](https://twitter.com/thesteinegger)).\n"
]
}
]
}

+ 478
- 0
reproduce/AlphaFold2-Chinese/Fold_CN.ipynb.txt View File

@@ -0,0 +1,478 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "AlphaFold2_CN.ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/AlphaFold2.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "G4yBrceuFbf3"
},
"source": [
"<img src=\"https://raw.githubusercontent.com/sokrypton/ColabFold/main/.github/ColabFold_Marv_Logo_Small.png\" height=\"200\" align=\"right\" style=\"height:240px\">\n",
"\n",
"##ColabFold: AlphaFold2 using MMseqs2\n",
"\n",
"简单的中文版蛋白质结构预测操作指南(中文版),基于[AlphaFold2](https://www.nature.com/articles/s41586-021-03819-2)和[Alphafold2-multimer](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1). 序列比对方式基于[MMseqs2](mmseqs.com)和[HHsearch](https://github.com/soedinglab/hh-suite)."
]
},
{
"cell_type": "code",
"metadata": {
"id": "kOblAo-xetgx",
"cellView": "form"
},
"source": [
"#输入蛋白质序列(默认为视频中的测试序列)然后点击上边框的 `Runtime` -> `Run all`\n",
"from google.colab import files\n",
"import os.path\n",
"import re\n",
"import hashlib\n",
"import random\n",
"\n",
"def add_hash(x,y):\n",
" return x+\"_\"+hashlib.sha1(y.encode()).hexdigest()[:5]\n",
"\n",
"query_sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n",
"#@markdown - Use `:` to specify inter-protein chainbreaks for **modeling complexes** (supports homo- and hetro-oligomers). For example **PI...SK:PI...SK** for a homodimer\n",
"\n",
"# remove whitespaces\n",
"query_sequence = \"\".join(query_sequence.split())\n",
"\n",
"jobname = 'test' #@param {type:\"string\"}\n",
"# remove whitespaces\n",
"basejobname = \"\".join(jobname.split())\n",
"basejobname = re.sub(r'\\W+', '', basejobname)\n",
"jobname = add_hash(basejobname, query_sequence)\n",
"while os.path.isfile(f\"{jobname}.csv\"):\n",
" jobname = add_hash(basejobname, ''.join(random.sample(query_sequence,len(query_sequence))))\n",
"\n",
"with open(f\"{jobname}.csv\", \"w\") as text_file:\n",
" text_file.write(f\"id,sequence\\n{jobname},{query_sequence}\")\n",
"\n",
"queries_path=f\"{jobname}.csv\"\n",
"\n",
"# number of models to use\n",
"use_amber = False #@param {type:\"boolean\"}\n",
"template_mode = \"none\" #@param [\"none\", \"pdb70\",\"custom\"]\n",
"#@markdown - \"none\" = no template information is used, \"pdb70\" = detect templates in pdb70, \"custom\" - upload and search own templates (PDB or mmCIF format, see [notes below](#custom_templates))\n",
"\n",
"if template_mode == \"pdb70\":\n",
" use_templates = True\n",
" custom_template_path = None\n",
"elif template_mode == \"custom\":\n",
" custom_template_path = f\"{jobname}_template\"\n",
" os.mkdir(custom_template_path)\n",
" uploaded = files.upload()\n",
" use_templates = True\n",
" for fn in uploaded.keys():\n",
" os.rename(fn, f\"{jobname}_template/{fn}\")\n",
"else:\n",
" custom_template_path = None\n",
" use_templates = False\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@markdown ### MSA options (custom MSA upload, single sequence, pairing mode)\n",
"msa_mode = \"MMseqs2 (UniRef+Environmental)\" #@param [\"MMseqs2 (UniRef+Environmental)\", \"MMseqs2 (UniRef only)\",\"single_sequence\",\"custom\"]\n",
"pair_mode = \"unpaired+paired\" #@param [\"unpaired+paired\",\"paired\",\"unpaired\"] {type:\"string\"}\n",
"#@markdown - \"unpaired+paired\" = pair sequences from same species + unpaired MSA, \"unpaired\" = seperate MSA for each chain, \"paired\" - only use paired sequences.\n",
"\n",
"# decide which a3m to use\n",
"if msa_mode.startswith(\"MMseqs2\"):\n",
" a3m_file = f\"{jobname}.a3m\"\n",
"elif msa_mode == \"custom\":\n",
" a3m_file = f\"{jobname}.custom.a3m\"\n",
" if not os.path.isfile(a3m_file):\n",
" custom_msa_dict = files.upload()\n",
" custom_msa = list(custom_msa_dict.keys())[0]\n",
" header = 0\n",
" import fileinput\n",
" for line in fileinput.FileInput(custom_msa,inplace=1):\n",
" if line.startswith(\">\"):\n",
" header = header + 1\n",
" if not line.rstrip():\n",
" continue\n",
" if line.startswith(\">\") == False and header == 1:\n",
" query_sequence = line.rstrip()\n",
" print(line, end='')\n",
"\n",
" os.rename(custom_msa, a3m_file)\n",
" queries_path=a3m_file\n",
" print(f\"moving {custom_msa} to {a3m_file}\")\n",
"else:\n",
" a3m_file = f\"{jobname}.single_sequence.a3m\"\n",
" with open(a3m_file, \"w\") as text_file:\n",
" text_file.write(\">1\\n%s\" % query_sequence)"
],
"metadata": {
"cellView": "form",
"id": "C2_sh2uAonJH"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@markdown ### Advanced settings\n",
"model_type = \"auto\" #@param [\"auto\", \"AlphaFold2-ptm\", \"AlphaFold2-multimer-v1\", \"AlphaFold2-multimer-v2\"]\n",
"#@markdown - \"auto\" = protein structure prediction using \"AlphaFold2-ptm\" and complex prediction \"AlphaFold-multimer-v2\". For complexes \"AlphaFold-multimer-v[1,2]\" and \"AlphaFold-ptm\" can be used.\n",
"num_recycles = 3 #@param [1,3,6,12,24,48] {type:\"raw\"}\n",
"save_to_google_drive = False #@param {type:\"boolean\"}\n",
"\n",
"#@markdown - if the save_to_google_drive option was selected, the result zip will be uploaded to your Google Drive\n",
"dpi = 200 #@param {type:\"integer\"}\n",
"#@markdown - set dpi for image resolution\n",
"\n",
"#@markdown Don't forget to hit `Runtime` -> `Run all` after updating the form.\n",
"\n",
"\n",
"if save_to_google_drive:\n",
" from pydrive.drive import GoogleDrive\n",
" from pydrive.auth import GoogleAuth\n",
" from google.colab import auth\n",
" from oauth2client.client import GoogleCredentials\n",
" auth.authenticate_user()\n",
" gauth = GoogleAuth()\n",
" gauth.credentials = GoogleCredentials.get_application_default()\n",
" drive = GoogleDrive(gauth)\n",
" print(\"You are logged into Google Drive and are good to go!\")"
],
"metadata": {
"cellView": "form",
"id": "ADDuaolKmjGW"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "iccGdbe_Pmt9",
"pycharm": {
"name": "#%%\n"
},
"cellView": "form"
},
"source": [
"#@title Install dependencies\n",
"%%bash -s $use_amber $use_templates\n",
"\n",
"set -e\n",
"\n",
"USE_AMBER=$1\n",
"USE_TEMPLATES=$2\n",
"\n",
"if [ ! -f COLABFOLD_READY ]; then\n",
" # install dependencies\n",
" # We have to use \"--no-warn-conflicts\" because colab already has a lot preinstalled with requirements different to ours\n",
" pip install -q --no-warn-conflicts \"colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold\"\n",
" # high risk high gain\n",
" pip install -q \"jax[cuda11_cudnn805]>=0.3.8,<0.4\" -f https://storage.googleapis.com/jax-releases/jax_releases.html\n",
" touch COLABFOLD_READY\n",
"fi\n",
"\n",
"# setup conda\n",
"if [ ${USE_AMBER} == \"True\" ] || [ ${USE_TEMPLATES} == \"True\" ]; then\n",
" if [ ! -f CONDA_READY ]; then\n",
" wget -qnc https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh\n",
" bash Miniconda3-latest-Linux-x86_64.sh -bfp /usr/local 2>&1 1>/dev/null\n",
" rm Miniconda3-latest-Linux-x86_64.sh\n",
" touch CONDA_READY\n",
" fi\n",
"fi\n",
"# setup template search\n",
"if [ ${USE_TEMPLATES} == \"True\" ] && [ ! -f HH_READY ]; then\n",
" conda install -y -q -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 python=3.7 2>&1 1>/dev/null\n",
" touch HH_READY\n",
"fi\n",
"# setup openmm for amber refinement\n",
"if [ ${USE_AMBER} == \"True\" ] && [ ! -f AMBER_READY ]; then\n",
" conda install -y -q -c conda-forge openmm=7.5.1 python=3.7 pdbfixer 2>&1 1>/dev/null\n",
" touch AMBER_READY\n",
"fi"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "_sztQyz29DIC",
"cellView": "form"
},
"source": [
"#@title Run Prediction\n",
"\n",
"import sys\n",
"\n",
"from colabfold.download import download_alphafold_params, default_data_dir\n",
"from colabfold.utils import setup_logging\n",
"from colabfold.batch import get_queries, run, set_model_type\n",
"K80_chk = !nvidia-smi | grep \"Tesla K80\" | wc -l\n",
"if \"1\" in K80_chk:\n",
" print(\"WARNING: found GPU Tesla K80: limited to total length < 1000\")\n",
" if \"TF_FORCE_UNIFIED_MEMORY\" in os.environ:\n",
" del os.environ[\"TF_FORCE_UNIFIED_MEMORY\"]\n",
" if \"XLA_PYTHON_CLIENT_MEM_FRACTION\" in os.environ:\n",
" del os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]\n",
"\n",
"from colabfold.colabfold import plot_protein\n",
"from pathlib import Path\n",
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"# For some reason we need that to get pdbfixer to import\n",
"if use_amber and '/usr/local/lib/python3.7/site-packages/' not in sys.path:\n",
" sys.path.insert(0, '/usr/local/lib/python3.7/site-packages/')\n",
"\n",
"def prediction_callback(unrelaxed_protein, length, prediction_result, input_features, type):\n",
" fig = plot_protein(unrelaxed_protein, Ls=length, dpi=150)\n",
" plt.show()\n",
" plt.close()\n",
"\n",
"result_dir=\".\"\n",
"setup_logging(Path(\".\").joinpath(\"log.txt\"))\n",
"queries, is_complex = get_queries(queries_path)\n",
"model_type = set_model_type(is_complex, model_type)\n",
"download_alphafold_params(model_type, Path(\".\"))\n",
"run(\n",
" queries=queries,\n",
" result_dir=result_dir,\n",
" use_templates=use_templates,\n",
" custom_template_path=custom_template_path,\n",
" use_amber=use_amber,\n",
" msa_mode=msa_mode, \n",
" model_type=model_type,\n",
" num_models=5,\n",
" num_recycles=num_recycles,\n",
" model_order=[1, 2, 3, 4, 5],\n",
" is_complex=is_complex,\n",
" data_dir=Path(\".\"),\n",
" keep_existing_results=False,\n",
" recompile_padding=1.0,\n",
" rank_by=\"auto\",\n",
" pair_mode=pair_mode,\n",
" stop_at_score=float(100),\n",
" prediction_callback=prediction_callback,\n",
" dpi=dpi\n",
")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "KK7X9T44pWb7",
"cellView": "form"
},
"source": [
"#@title Display 3D structure {run: \"auto\"}\n",
"import py3Dmol\n",
"import glob\n",
"import matplotlib.pyplot as plt\n",
"from colabfold.colabfold import plot_plddt_legend\n",
"rank_num = 1 #@param [\"1\", \"2\", \"3\", \"4\", \"5\"] {type:\"raw\"}\n",
"color = \"lDDT\" #@param [\"chain\", \"lDDT\", \"rainbow\"]\n",
"show_sidechains = False #@param {type:\"boolean\"}\n",
"show_mainchains = False #@param {type:\"boolean\"}\n",
"\n",
"jobname_prefix = \".custom\" if msa_mode == \"custom\" else \"\"\n",
"if use_amber:\n",
" pdb_filename = f\"{jobname}{jobname_prefix}_relaxed_rank_{rank_num}_model_*.pdb\"\n",
"else:\n",
" pdb_filename = f\"{jobname}{jobname_prefix}_unrelaxed_rank_{rank_num}_model_*.pdb\"\n",
"\n",
"pdb_file = glob.glob(pdb_filename)\n",
"\n",
"def show_pdb(rank_num=1, show_sidechains=False, show_mainchains=False, color=\"lDDT\"):\n",
" model_name = f\"rank_{rank_num}\"\n",
" view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)\n",
" view.addModel(open(pdb_file[0],'r').read(),'pdb')\n",
"\n",
" if color == \"lDDT\":\n",
" view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}})\n",
" elif color == \"rainbow\":\n",
" view.setStyle({'cartoon': {'color':'spectrum'}})\n",
" elif color == \"chain\":\n",
" chains = len(queries[0][1]) + 1 if is_complex else 1\n",
" for n,chain,color in zip(range(chains),list(\"ABCDEFGH\"),\n",
" [\"lime\",\"cyan\",\"magenta\",\"yellow\",\"salmon\",\"white\",\"blue\",\"orange\"]):\n",
" view.setStyle({'chain':chain},{'cartoon': {'color':color}})\n",
" if show_sidechains:\n",
" BB = ['C','O','N']\n",
" view.addStyle({'and':[{'resn':[\"GLY\",\"PRO\"],'invert':True},{'atom':BB,'invert':True}]},\n",
" {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
" view.addStyle({'and':[{'resn':\"GLY\"},{'atom':'CA'}]},\n",
" {'sphere':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
" view.addStyle({'and':[{'resn':\"PRO\"},{'atom':['C','O'],'invert':True}]},\n",
" {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}}) \n",
" if show_mainchains:\n",
" BB = ['C','O','N','CA']\n",
" view.addStyle({'atom':BB},{'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
"\n",
" view.zoomTo()\n",
" return view\n",
"\n",
"\n",
"show_pdb(rank_num,show_sidechains, show_mainchains, color).show()\n",
"if color == \"lDDT\":\n",
" plot_plddt_legend().show() "
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "11l8k--10q0C",
"cellView": "form"
},
"source": [
"#@title Plots {run: \"auto\"}\n",
"from IPython.display import display, HTML\n",
"import base64\n",
"from html import escape\n",
"\n",
"# see: https://stackoverflow.com/a/53688522\n",
"def image_to_data_url(filename):\n",
" ext = filename.split('.')[-1]\n",
" prefix = f'data:image/{ext};base64,'\n",
" with open(filename, 'rb') as f:\n",
" img = f.read()\n",
" return prefix + base64.b64encode(img).decode('utf-8')\n",
"\n",
"pae = image_to_data_url(f\"{jobname}{jobname_prefix}_PAE.png\")\n",
"cov = image_to_data_url(f\"{jobname}{jobname_prefix}_coverage.png\")\n",
"plddt = image_to_data_url(f\"{jobname}{jobname_prefix}_plddt.png\")\n",
"display(HTML(f\"\"\"\n",
"<style>\n",
" img {{\n",
" float:left;\n",
" }}\n",
" .full {{\n",
" max-width:100%;\n",
" }}\n",
" .half {{\n",
" max-width:50%;\n",
" }}\n",
" @media (max-width:640px) {{\n",
" .half {{\n",
" max-width:100%;\n",
" }}\n",
" }}\n",
"</style>\n",
"<div style=\"max-width:90%; padding:2em;\">\n",
" <h1>Plots for {escape(jobname)}</h1>\n",
" <img src=\"{pae}\" class=\"full\" />\n",
" <img src=\"{cov}\" class=\"half\" />\n",
" <img src=\"{plddt}\" class=\"half\" />\n",
"</div>\n",
"\"\"\"))\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "33g5IIegij5R",
"cellView": "form"
},
"source": [
"#@title Package and download results\n",
"#@markdown If you are having issues downloading the result archive, try disabling your adblocker and run this cell again. If that fails click on the little folder icon to the left, navigate to file: `jobname.result.zip`, right-click and select \\\"Download\\\" (see [screenshot](https://pbs.twimg.com/media/E6wRW2lWUAEOuoe?format=jpg&name=small)).\n",
"\n",
"if msa_mode == \"custom\":\n",
" print(\"Don't forget to cite your custom MSA generation method.\")\n",
"\n",
"!zip -FSr $jobname\".result.zip\" config.json $jobname*\".json\" $jobname*\".a3m\" $jobname*\"relaxed_rank_\"*\".pdb\" \"cite.bibtex\" $jobname*\".png\"\n",
"files.download(f\"{jobname}.result.zip\")\n",
"\n",
"if save_to_google_drive == True and drive:\n",
" uploaded = drive.CreateFile({'title': f\"{jobname}.result.zip\"})\n",
" uploaded.SetContentFile(f\"{jobname}.result.zip\")\n",
" uploaded.Upload()\n",
" print(f\"Uploaded {jobname}.result.zip to Google Drive with ID {uploaded.get('id')}\")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "UGUBLzB3C6WN",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"# 操作指南 <a name=\"Instructions\"></a>\n",
"**Quick start**\n",
"1. 把你要预测的氨基酸序列复制进输入框.\n",
"2. 点击\"Runtime\" -> \"Run all\".\n",
"3. 目前的模型预测包括5个模块,最终会生成一个3维结构图.\n",
"\n",
"**生成文件**\n",
"\n",
"1. PDB格式的模型结构文件.\n",
"2. 模型质量图\n",
"3. 模型MSA覆盖率.\n",
"4. 其他.\n",
"**Acknowledgments**\n",
"- We thank the AlphaFold team for developing an excellent model and open sourcing the software. \n",
"\n",
"- [Söding Lab](https://www.mpibpc.mpg.de/soeding) for providing the computational resources for the MMseqs2 server\n",
"\n",
"- Richard Evans for helping to benchmark the ColabFold's Alphafold-multimer support\n",
"\n",
"- [David Koes](https://github.com/dkoes) for his awesome [py3Dmol](https://3dmol.csb.pitt.edu/) plugin, without whom these notebooks would be quite boring!\n",
"\n",
"- Do-Yoon Kim for creating the ColabFold logo.\n",
"\n",
"- A colab by Sergey Ovchinnikov ([@sokrypton](https://twitter.com/sokrypton)), Milot Mirdita ([@milot_mirdita](https://twitter.com/milot_mirdita)) and Martin Steinegger ([@thesteinegger](https://twitter.com/thesteinegger)).\n"
]
}
]
}

+ 201
- 0
reproduce/AlphaFold2-Chinese/LICENSE View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/

TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

1. Definitions.

"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.

"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.

"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.

"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.

"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.

"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.

"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).

"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.

"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."

"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.

2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.

3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.

4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:

(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and

(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and

(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and

(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.

You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.

5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.

6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.

7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.

8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.

9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.

END OF TERMS AND CONDITIONS

APPENDIX: How to apply the Apache License to your work.

To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.

Copyright [yyyy] [name of copyright owner]

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

+ 221
- 0
reproduce/AlphaFold2-Chinese/README.md View File

@@ -0,0 +1,221 @@
# AlphaFold2-Chinese
中文版AlphaFold2开源模型复现-基于DeepMind&ColabFold&Mindspore:
* https://github.com/sokrypton/ColabFold
* https://github.com/deepmind/alphafold
* https://gitee.com/mindspore/mindscience/tree/master/MindSPONGE/mindsponge/fold

<div align=center>
<img src="https://github.com/Tssck/AlphaFold2-Chinese/blob/main/docs/seq_64.gif" alt="T1079" width="400"/>
</div>

# 目录

<!-- TOC -->

- [目录](#目录)
- [模型描述](#模型描述)
- [环境要求](#环境要求)
- [硬件环境与框架](#硬件环境与框架)
- [MMseqs2安装](#mmseqs2安装)
- [MindSpore Serving安装](#mindspore_serving安装)
- [数据准备](#数据准备)
- [MSA所需数据库](#msa所需数据库)
- [Template所需工具和数据](#template所需工具和数据)
- [数据](#数据)
- [工具](#工具)
- [脚本说明](#脚本说明)
- [脚本及样例代码](#脚本及样例代码)
- [推理示例](#推理示例)
- [推理过程](#推理过程)
- [推理结果](#推理结果)
- [推理性能](#推理性能)
- [TMscore对比图](#tmscore对比图)
- [预测结果对比图](#预测结果对比图)
- [引用](#引用)

<!-- /TOC -->

## 模型描述

蛋白质结构预测工具是利用计算机高效计算获取蛋白质空间结构的软件。该计算方法一直存在精度不足的缺陷,直至2020年谷歌DeepMind团队的[AlphaFold2](https://www.nature.com/articles/s41586-021-03819-2)【1】【2】取得CASP14比赛中蛋白质3D结构预测的榜首,才让这一缺陷得以弥补。本次开源的蛋白质结构预测推理工具模型部分与其相同,在多序列比对阶段,采用了[MMseqs2](https://www.biorxiv.org/content/10.1101/2021.08.15.456425v1.full.pdf)【3】进行序列检索,相比于原版算法端到端运算速度有2-3倍提升。

## 环境要求

### 硬件环境与框架

本代码运行基于Ascend处理器硬件环境与[MindSpore](https://www.mindspore.cn/) AI框架,当前版本需基于最新库上master代码(2021-11-08之后的代码)[编译](https://www.mindspore.cn/install/detail?path=install/r1.5/mindspore_ascend_install_source.md&highlight=%E6%BA%90%E7%A0%81%E7%BC%96%E8%AF%91),
MindSpore环境参见[MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html),环境安装后需要运行以下命令配置环境变量:

``` shell
export MS_DEV_ENABLE_FALLBACK=0
```

其余python依赖请参见[requirements.txt](https://gitee.com/mindspore/mindscience/tree/master/MindSPONGE/mindsponge/fold/requirements.txt)。

### MMseqs2安装

MMseqs2用于生成多序列比对(multiple sequence alignments,MSA),MMseqs2安装和使用可以参考[MMseqs2 User Guide](https://mmseqs.com/latest/userguide.pdf),安装完成后需要运行以下命令配置环境变量:

``` shell
export PATH=$(pwd)/mmseqs/bin/:$PATH
```

### MindSpore Serving安装

我们提供以服务模式运行推理,该模式使用MindSpore Serving提供高效推理服务,多条序列推理时避免重复编译,大幅提高推理效率,MindSpore Serving安装和配置可以参考[MindSpore Serving安装页面](https://www.mindspore.cn/serving/docs/zh-CN/r1.5/serving_install.html)。

## 数据准备

### MSA所需数据库

- [uniref30_2103](http://wwwuser.gwdg.de/~compbiol/colabfold/uniref30_2103.tar.gz):375G(下载68G)
- [colabfold_envdb_202108](http://wwwuser.gwdg.de/~compbiol/colabfold/colabfold_envdb_202108.tar.gz):949G(下载110G)

数据处理参考[colabfold](http://colabfold.mmseqs.com)。

### Template所需工具和数据

#### 数据

- [pdb70](http://wwwuser.gwdg.de/~compbiol/data/hhsuite/databases/hhsuite_dbs/old-releases/pdb70_from_mmcif_200401.tar.gz):56G(下载19G)
- [mmcif database](https://ftp.rcsb.org/pub/pdb/data/structures/divided/mmCIF/): 206G(下载48G)
- [obsolete_pdbs](http://ftp.wwpdb.org/pub/pdb/data/status/obsolete.dat):140K

#### 工具

- [HHsearch](https://github.com/soedinglab/hh-suite)
- [kalign](https://msa.sbc.su.se/downloads/kalign/current.tar.gz)

## 脚本说明

### 脚本及样例代码

```bash
├── mindscience
├── MindSPONGE
├── mindsponge
├── fold
├── README_CN.md // fold 相关中文说明
├── run.py // 推理脚本
├── model.py // 主模型
├── requirements.txt // 依赖包
├── serving_server.py // 服务模式服务端脚本
├── serving_cline.py // 服务模式客户端脚本
├── fold_service
├── servable_config.py // 服务模式配置脚本
├── module
├── basic_module.py // 基础模块
├── evoformer_module.py // evoformer模块
├── structure_module.py // 结构模块
├── data
├── feature
├── data_transforms.py //msa和template数据处理
├── feature_extraction.py //msa和template特征提取
├── tools
├── data_process.py // 搜索msa和template
├── data_tools.py // 数据处理脚本
├── mmcif_parsing.py // mmcif解析脚本
├── msa_search.sh // mmseqs2搜索msa的shell脚本
├── parsers.py // 解析文件脚本
├── templates.py // 模板搜索脚本
├── config
├── config.py //参数配置脚本
├── global_config.py //全局参数配置脚本
├── common
├── generate_pdb.py // 生成pdb
├── r3.py // 3D坐标转换
├── residue_constants.py // 氨基酸残基常量
├── utils.py // 功能函数
```

### 推理示例

```bash
用法:run.py [--seq_length PADDING_SEQENCE_LENGTH]
[--input_fasta_path INPUT_PATH][--msa_result_path MSA_RESULT_PATH]
[--database_dir DATABASE_PATH][--database_envdb_dir DATABASE_ENVDB_PATH]
[--hhsearch_binary_path HHSEARCH_PATH][--pdb70_database_path PDB70_PATH]
[--template_mmcif_dir TEMPLATE_PATH][--max_template_date TRMPLATE_DATE]
[--kalign_binary_path KALIGN_PATH][--obsolete_pdbs_path OBSOLETE_PATH]


选项:
--seq_length 补零后序列长度,目前支持256/512/1024/2048
--input_fasta_path FASTA文件,用于预测蛋白质结构的蛋白质序列
--msa_result_path 保存mmseqs2检索得到的msa结果路径
--database_dir 搜索msa时的数据库
--database_envdb_dir 搜索msa时的扩展数据库
--hhsearch_binary_path hhsearch可执行文件路径
--pdb70_database_path 供hhsearch使用的pdb70数据库路径
--template_mmcif_dir 具有mmcif结构模板的路径
--max_template_date 模板最新发布的时间
--kalign_binary_path kalign可执行文件路径
--obsolete_pdbs_path PDB IDs的映射文件路径
```

### 推理过程

加载alphafold checkpoint,下载地址[点击这里](https://download.mindspore.cn/model_zoo/research/hpc/molecular_dynamics/protein_fold_1.ckpt),根据自身需求选择合适蛋白质序列配置,当前提供256/512/1024/2048四个标准配置,推理过程如下:

1. 输入参数需要通过`fold_service/config.py`配置,参数含义参见[推理示例](#推理示例)

2. 参数配置好后,先使用`serving_server.py`启动服务端进程,进程成功启动时log显示如下:

``` log
Serving: Serving gRPC server start success, listening on 127.0.0.1:5500
Serving: Serving RESTful server start success, listening on 127.0.0.1:1500
```

3. 服务端进程成功启动后运行`serving_client.py`即可进行推理,第一次推理需要编译

#### 推理结果

推理结果保存在 `./result` 中,共有两个文件,其中的pdb文件即为蛋白质结构预测结果,timings文件保存了运行过程中的时间信息和confidence信息。

```bash
{"pre_process_time": 418.57, "model_time": 122.86, "pos_process_time": 0.14, "all_time ": 541.56, "confidence ": 94.61789646019058}
```

## 推理性能

| 参数 | Fold(Ascend) |
| ------------------- | --------------------------- |
| 模型版本 | AlphaFold |
| 资源 | Ascend 910 |
| 上传日期 | 2021-11-05 |
| MindSpore版本 | master |
| 数据集 | CASP14 T1079 |
| seq_length | 505 |
| confidence | 94.62 |
| TM-score | 98.01% |
|运行时间|541.56s|

### TMscore对比图

- 34条CASP14结果与alphafold2对比:

<div align=center>
<img src="https://github.com/Tssck/AlphaFold2-Chinese/blob/main/docs/all_experiment_data.jpg" alt="all_data" width="600"/>
</div>

### 预测结果对比图

- T1079(长度505):

<div align=center>
<img src="https://github.com/Tssck/AlphaFold2-Chinese/blob/main/docs/seq_64.gif" alt="T1079" width="400"/>
</div>

- T1044(长度2180):

<div align=center>
<img src="https://github.com/Tssck/AlphaFold2-Chinese/blob/main/docs/seq_21.jpg" alt="T1044" width="400"/>
</div>

## 引用

[1] Jumper J, Evans R, Pritzel A, et al. Applying and improving AlphaFold at CASP14[J]. Proteins: Structure, Function, and Bioinformatics, 2021.

[2] Jumper J, Evans R, Pritzel A, et al. Highly accurate protein structure prediction with AlphaFold[J]. Nature, 2021, 596(7873): 583-589.

[3] Mirdita M, Ovchinnikov S, Steinegger M. ColabFold-Making protein folding accessible to all[J]. BioRxiv, 2021.

+ 118
- 0
reproduce/AlphaFold2-Chinese/commons/generate_pdb.py View File

@@ -0,0 +1,118 @@
"""generate pdb file"""

import dataclasses
import numpy as np

from commons import residue_constants


@dataclasses.dataclass(frozen=True)
class Protein:
"""Protein structure representation."""

# Cartesian coordinates of atoms in angstroms. The atom types correspond to
# residue_constants.atom_types, i.e. the first three are N, CA, CB.
atom_positions: np.ndarray # [num_res, num_atom_type, 3]

# Amino-acid type for each residue represented as an integer between 0 and
# 20, where 20 is 'X'.
aatype: np.ndarray # [num_res]

# Binary float mask to indicate presence of a particular atom. 1.0 if an atom
# is present and 0.0 if not. This should be used for loss masking.
atom_mask: np.ndarray # [num_res, num_atom_type]

# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
residue_index: np.ndarray # [num_res]

# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean
# value.
b_factors: np.ndarray # [num_res, num_atom_type]


def from_prediction(final_atom_mask, aatype, final_atom_positions, residue_index):
"""Assembles a protein from a prediction.

Args:
final_atom_mask: atom mask info from structure module.
aatype: amino acid type info.
final_atom_positions: final atom positions from structure module
residue_index: from processed_features

Returns:
A protein instance.
"""
dist_per_residue = np.zeros_like(final_atom_mask)

return Protein(
aatype=aatype,
atom_positions=final_atom_positions,
atom_mask=final_atom_mask,
residue_index=residue_index + 1,
b_factors=dist_per_residue)


def to_pdb(prot: Protein):
"""Converts a `Protein` instance to a PDB string.

Args:
prot: The protein to convert to PDB.

Returns:
PDB string.
"""
restypes = residue_constants.restypes + ['X']
res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK')
atom_types = residue_constants.atom_types

pdb_lines = []

atom_mask = prot.atom_mask
aatype = prot.aatype
atom_positions = prot.atom_positions
residue_index = prot.residue_index.astype(np.int32)
b_factors = prot.b_factors

if (aatype > residue_constants.restype_num).any():
raise ValueError('Invalid aatypes.')

pdb_lines.append('MODEL 1')
atom_index = 1
chain_id = 'A'
# Add all atom sites.
for i in range(aatype.shape[0]):
res_name_3 = res_1to3(aatype[i])
for atom_name, pos, mask, b_factor in zip(
atom_types, atom_positions[i], atom_mask[i], b_factors[i]):
if mask < 0.5:
continue

record_type = 'ATOM'
name = atom_name if len(atom_name) == 4 else f' {atom_name}'
alt_loc = ''
insertion_code = ''
occupancy = 1.00
element = atom_name[0] # Protein supports only C, N, O, S, this works.
charge = ''
# PDB is a columnar format, every space matters here!
atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}'
f'{res_name_3:>3} {chain_id:>1}'
f'{residue_index[i]:>4}{insertion_code:>1} '
f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}'
f'{occupancy:>6.2f}{b_factor:>6.2f} '
f'{element:>2}{charge:>2}')
pdb_lines.append(atom_line)
atom_index += 1

# Close the chain.
chain_end = 'TER'
chain_termination_line = (
f'{chain_end:<6}{atom_index:>5} {res_1to3(aatype[-1]):>3} '
f'{chain_id:>1}{residue_index[-1]:>4}')
pdb_lines.append(chain_termination_line)
pdb_lines.append('ENDMDL')

pdb_lines.append('END')
pdb_lines.append('')
return '\n'.join(pdb_lines)

+ 104
- 0
reproduce/AlphaFold2-Chinese/commons/r3.py View File

@@ -0,0 +1,104 @@
"""Transformations for 3D coordinates."""

import mindspore.numpy as mnp


def vecs_sub(v1, v2):
"""Computes v1 - v2."""
return v1 - v2


def vecs_robust_norm(v, epsilon=1e-8):
"""Computes norm of vectors 'v'."""

return mnp.sqrt(mnp.sum(mnp.square(v), axis=-1) + epsilon)


def vecs_robust_normalize(v, epsilon=1e-8):
"""Normalizes vectors 'v'."""

norms = vecs_robust_norm(v, epsilon)
return v / norms[..., None]


def vecs_dot_vecs(v1, v2):
"""Dot product of vectors 'v1' and 'v2'."""
return mnp.sum(v1 * v2, axis=-1)


def vecs_cross_vecs(v1, v2):
"""Cross product of vectors 'v1' and 'v2'."""

return mnp.concatenate(((v1[..., 1] * v2[..., 2] - v1[..., 2] * v2[..., 1])[..., None],
(v1[..., 2] * v2[..., 0] - v1[..., 0] * v2[..., 2])[..., None],
(v1[..., 0] * v2[..., 1] - v1[..., 1] * v2[..., 0])[..., None]), axis=-1)


def rots_from_two_vecs(e0_unnormalized, e1_unnormalized):
"""Create rotation matrices from unnormalized vectors for the x and y-axes."""

# Normalize the unit vector for the x-axis, e0.
e0 = vecs_robust_normalize(e0_unnormalized)

# make e1 perpendicular to e0.
c = vecs_dot_vecs(e1_unnormalized, e0)
e1 = e1_unnormalized - c[..., None] * e0
e1 = vecs_robust_normalize(e1)

# Compute e2 as cross product of e0 and e1.
e2 = vecs_cross_vecs(e0, e1)

rots = mnp.concatenate(
(mnp.concatenate([e0[..., 0][None, ...], e1[..., 0][None, ...], e2[..., 0][None, ...]], axis=0)[None, ...],
mnp.concatenate([e0[..., 1][None, ...], e1[..., 1][None, ...], e2[..., 1][None, ...]], axis=0)[None, ...],
mnp.concatenate([e0[..., 2][None, ...], e1[..., 2][None, ...], e2[..., 2][None, ...]], axis=0)[None, ...]),
axis=0)
return rots


def rigids_from_3_points(
point_on_neg_x_axis, # shape (...)
origin, # shape (...)
point_on_xy_plane, # shape (...)
): # shape (...)
"""Create Rigids from 3 points. """

m = rots_from_two_vecs(
e0_unnormalized=vecs_sub(origin, point_on_neg_x_axis),
e1_unnormalized=vecs_sub(point_on_xy_plane, origin))
return m, origin


def invert_rots(m):
"""Computes inverse of rotations 'm'."""

return mnp.transpose(m, (1, 0, 2, 3, 4))


def rots_mul_vecs(m, v):
"""Apply rotations 'm' to vectors 'v'."""

return mnp.concatenate(((m[0][0] * v[..., 0] + m[0][1] * v[..., 1] + m[0][2] * v[..., 2])[..., None],
(m[1][0] * v[..., 0] + m[1][1] * v[..., 1] + m[1][2] * v[..., 2])[..., None],
(m[2][0] * v[..., 0] + m[2][1] * v[..., 1] + m[2][2] * v[..., 2])[..., None]), axis=-1)


def invert_rigids(rot, trans):
"""Computes group inverse of rigid transformations 'r'."""

inv_rots = invert_rots(rot)
t = rots_mul_vecs(inv_rots, trans)
inv_trans = -t
return inv_rots, inv_trans


def vecs_add(v1, v2):
"""Add two vectors 'v1' and 'v2'."""

return v1 + v2


def rigids_mul_vecs(rot, trans, v):
"""Apply rigid transforms 'r' to points 'v'."""

return vecs_add(rots_mul_vecs(rot, v), trans)

+ 842
- 0
reproduce/AlphaFold2-Chinese/commons/residue_constants.py View File

@@ -0,0 +1,842 @@
"""residue_constants."""

import collections
import functools
from typing import Mapping, List, Tuple
import numpy as np

from mindspore.common.tensor import Tensor

QUAT_MULTIPLY = np.zeros((4, 4, 4), dtype=np.float32)
QUAT_MULTIPLY[:, :, 0] = [[1, 0, 0, 0],
[0, -1, 0, 0],
[0, 0, -1, 0],
[0, 0, 0, -1]]

QUAT_MULTIPLY[:, :, 1] = [[0, 1, 0, 0],
[1, 0, 0, 0],
[0, 0, 0, 1],
[0, 0, -1, 0]]

QUAT_MULTIPLY[:, :, 2] = [[0, 0, 1, 0],
[0, 0, 0, -1],
[1, 0, 0, 0],
[0, 1, 0, 0]]

QUAT_MULTIPLY[:, :, 3] = [[0, 0, 0, 1],
[0, 0, 1, 0],
[0, -1, 0, 0],
[1, 0, 0, 0]]

QUAT_MULTIPLY_BY_VEC = Tensor(QUAT_MULTIPLY[:, 1:, :])


# Distance from one CA to next CA [trans configuration: omega = 180].
ca_ca = 3.80209737096

# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in
# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have
# chi angles so their chi angle lists are empty.
chi_angles_atoms = {
'ALA': [],
# Chi5 in arginine is always 0 +- 5 degrees, so ignore it.
'ARG': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
['CB', 'CG', 'CD', 'NE'], ['CG', 'CD', 'NE', 'CZ']],
'ASN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']],
'ASP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']],
'CYS': [['N', 'CA', 'CB', 'SG']],
'GLN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
['CB', 'CG', 'CD', 'OE1']],
'GLU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
['CB', 'CG', 'CD', 'OE1']],
'GLY': [],
'HIS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'ND1']],
'ILE': [['N', 'CA', 'CB', 'CG1'], ['CA', 'CB', 'CG1', 'CD1']],
'LEU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
'LYS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
['CB', 'CG', 'CD', 'CE'], ['CG', 'CD', 'CE', 'NZ']],
'MET': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'SD'],
['CB', 'CG', 'SD', 'CE']],
'PHE': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
'PRO': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD']],
'SER': [['N', 'CA', 'CB', 'OG']],
'THR': [['N', 'CA', 'CB', 'OG1']],
'TRP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
'TYR': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
'VAL': [['N', 'CA', 'CB', 'CG1']],
}

# If chi angles given in fixed-length array, this matrix determines how to mask
# them for each AA type. The order is as per restype_order (see below).
chi_angles_mask = [
[0.0, 0.0, 0.0, 0.0], # ALA
[1.0, 1.0, 1.0, 1.0], # ARG
[1.0, 1.0, 0.0, 0.0], # ASN
[1.0, 1.0, 0.0, 0.0], # ASP
[1.0, 0.0, 0.0, 0.0], # CYS
[1.0, 1.0, 1.0, 0.0], # GLN
[1.0, 1.0, 1.0, 0.0], # GLU
[0.0, 0.0, 0.0, 0.0], # GLY
[1.0, 1.0, 0.0, 0.0], # HIS
[1.0, 1.0, 0.0, 0.0], # ILE
[1.0, 1.0, 0.0, 0.0], # LEU
[1.0, 1.0, 1.0, 1.0], # LYS
[1.0, 1.0, 1.0, 0.0], # MET
[1.0, 1.0, 0.0, 0.0], # PHE
[1.0, 1.0, 0.0, 0.0], # PRO
[1.0, 0.0, 0.0, 0.0], # SER
[1.0, 0.0, 0.0, 0.0], # THR
[1.0, 1.0, 0.0, 0.0], # TRP
[1.0, 1.0, 0.0, 0.0], # TYR
[1.0, 0.0, 0.0, 0.0], # VAL
]

# The following chi angles are pi periodic: they can be rotated by a multiple
# of pi without affecting the structure.
chi_pi_periodic = [
[0.0, 0.0, 0.0, 0.0], # ALA
[0.0, 0.0, 0.0, 0.0], # ARG
[0.0, 0.0, 0.0, 0.0], # ASN
[0.0, 1.0, 0.0, 0.0], # ASP
[0.0, 0.0, 0.0, 0.0], # CYS
[0.0, 0.0, 0.0, 0.0], # GLN
[0.0, 0.0, 1.0, 0.0], # GLU
[0.0, 0.0, 0.0, 0.0], # GLY
[0.0, 0.0, 0.0, 0.0], # HIS
[0.0, 0.0, 0.0, 0.0], # ILE
[0.0, 0.0, 0.0, 0.0], # LEU
[0.0, 0.0, 0.0, 0.0], # LYS
[0.0, 0.0, 0.0, 0.0], # MET
[0.0, 1.0, 0.0, 0.0], # PHE
[0.0, 0.0, 0.0, 0.0], # PRO
[0.0, 0.0, 0.0, 0.0], # SER
[0.0, 0.0, 0.0, 0.0], # THR
[0.0, 0.0, 0.0, 0.0], # TRP
[0.0, 1.0, 0.0, 0.0], # TYR
[0.0, 0.0, 0.0, 0.0], # VAL
[0.0, 0.0, 0.0, 0.0], # UNK
]

# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi,
# psi and chi angles:
# 0: 'backbone group',
# 1: 'pre-omega-group', (empty)
# 2: 'phi-group', (currently empty, because it defines only hydrogens)
# 3: 'psi-group',
# 4,5,6,7: 'chi1,2,3,4-group'
# The atom positions are relative to the axis-end-atom of the corresponding
# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis
# is defined such that the dihedral-angle-definiting atom (the last entry in
# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate).
# format: [atomname, group_idx, rel_position]
rigid_group_atom_positions = {
'ALA': [
['N', 0, (-0.525, 1.363, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.526, -0.000, -0.000)],
['CB', 0, (-0.529, -0.774, -1.205)],
['O', 3, (0.627, 1.062, 0.000)],
],
'ARG': [
['N', 0, (-0.524, 1.362, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.525, -0.000, -0.000)],
['CB', 0, (-0.524, -0.778, -1.209)],
['O', 3, (0.626, 1.062, 0.000)],
['CG', 4, (0.616, 1.390, -0.000)],
['CD', 5, (0.564, 1.414, 0.000)],
['NE', 6, (0.539, 1.357, -0.000)],
['NH1', 7, (0.206, 2.301, 0.000)],
['NH2', 7, (2.078, 0.978, -0.000)],
['CZ', 7, (0.758, 1.093, -0.000)],
],
'ASN': [
['N', 0, (-0.536, 1.357, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.526, -0.000, -0.000)],
['CB', 0, (-0.531, -0.787, -1.200)],
['O', 3, (0.625, 1.062, 0.000)],
['CG', 4, (0.584, 1.399, 0.000)],
['ND2', 5, (0.593, -1.188, 0.001)],
['OD1', 5, (0.633, 1.059, 0.000)],
],
'ASP': [
['N', 0, (-0.525, 1.362, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.527, 0.000, -0.000)],
['CB', 0, (-0.526, -0.778, -1.208)],
['O', 3, (0.626, 1.062, -0.000)],
['CG', 4, (0.593, 1.398, -0.000)],
['OD1', 5, (0.610, 1.091, 0.000)],
['OD2', 5, (0.592, -1.101, -0.003)],
],
'CYS': [
['N', 0, (-0.522, 1.362, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.524, 0.000, 0.000)],
['CB', 0, (-0.519, -0.773, -1.212)],
['O', 3, (0.625, 1.062, -0.000)],
['SG', 4, (0.728, 1.653, 0.000)],
],
'GLN': [
['N', 0, (-0.526, 1.361, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.526, 0.000, 0.000)],
['CB', 0, (-0.525, -0.779, -1.207)],
['O', 3, (0.626, 1.062, -0.000)],
['CG', 4, (0.615, 1.393, 0.000)],
['CD', 5, (0.587, 1.399, -0.000)],
['NE2', 6, (0.593, -1.189, -0.001)],
['OE1', 6, (0.634, 1.060, 0.000)],
],
'GLU': [
['N', 0, (-0.528, 1.361, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.526, -0.000, -0.000)],
['CB', 0, (-0.526, -0.781, -1.207)],
['O', 3, (0.626, 1.062, 0.000)],
['CG', 4, (0.615, 1.392, 0.000)],
['CD', 5, (0.600, 1.397, 0.000)],
['OE1', 6, (0.607, 1.095, -0.000)],
['OE2', 6, (0.589, -1.104, -0.001)],
],
'GLY': [
['N', 0, (-0.572, 1.337, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.517, -0.000, -0.000)],
['O', 3, (0.626, 1.062, -0.000)],
],
'HIS': [
['N', 0, (-0.527, 1.360, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.525, 0.000, 0.000)],
['CB', 0, (-0.525, -0.778, -1.208)],
['O', 3, (0.625, 1.063, 0.000)],
['CG', 4, (0.600, 1.370, -0.000)],
['CD2', 5, (0.889, -1.021, 0.003)],
['ND1', 5, (0.744, 1.160, -0.000)],
['CE1', 5, (2.030, 0.851, 0.002)],
['NE2', 5, (2.145, -0.466, 0.004)],
],
'ILE': [
['N', 0, (-0.493, 1.373, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.527, -0.000, -0.000)],
['CB', 0, (-0.536, -0.793, -1.213)],
['O', 3, (0.627, 1.062, -0.000)],
['CG1', 4, (0.534, 1.437, -0.000)],
['CG2', 4, (0.540, -0.785, -1.199)],
['CD1', 5, (0.619, 1.391, 0.000)],
],
'LEU': [
['N', 0, (-0.520, 1.363, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.525, -0.000, -0.000)],
['CB', 0, (-0.522, -0.773, -1.214)],
['O', 3, (0.625, 1.063, -0.000)],
['CG', 4, (0.678, 1.371, 0.000)],
['CD1', 5, (0.530, 1.430, -0.000)],
['CD2', 5, (0.535, -0.774, 1.200)],
],
'LYS': [
['N', 0, (-0.526, 1.362, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.526, 0.000, 0.000)],
['CB', 0, (-0.524, -0.778, -1.208)],
['O', 3, (0.626, 1.062, -0.000)],
['CG', 4, (0.619, 1.390, 0.000)],
['CD', 5, (0.559, 1.417, 0.000)],
['CE', 6, (0.560, 1.416, 0.000)],
['NZ', 7, (0.554, 1.387, 0.000)],
],
'MET': [
['N', 0, (-0.521, 1.364, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.525, 0.000, 0.000)],
['CB', 0, (-0.523, -0.776, -1.210)],
['O', 3, (0.625, 1.062, -0.000)],
['CG', 4, (0.613, 1.391, -0.000)],
['SD', 5, (0.703, 1.695, 0.000)],
['CE', 6, (0.320, 1.786, -0.000)],
],
'PHE': [
['N', 0, (-0.518, 1.363, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.524, 0.000, -0.000)],
['CB', 0, (-0.525, -0.776, -1.212)],
['O', 3, (0.626, 1.062, -0.000)],
['CG', 4, (0.607, 1.377, 0.000)],
['CD1', 5, (0.709, 1.195, -0.000)],
['CD2', 5, (0.706, -1.196, 0.000)],
['CE1', 5, (2.102, 1.198, -0.000)],
['CE2', 5, (2.098, -1.201, -0.000)],
['CZ', 5, (2.794, -0.003, -0.001)],
],
'PRO': [
['N', 0, (-0.566, 1.351, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.527, -0.000, 0.000)],
['CB', 0, (-0.546, -0.611, -1.293)],
['O', 3, (0.621, 1.066, 0.000)],
['CG', 4, (0.382, 1.445, 0.0)],
# ['CD', 5, (0.427, 1.440, 0.0)],
['CD', 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger
],
'SER': [
['N', 0, (-0.529, 1.360, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.525, -0.000, -0.000)],
['CB', 0, (-0.518, -0.777, -1.211)],
['O', 3, (0.626, 1.062, -0.000)],
['OG', 4, (0.503, 1.325, 0.000)],
],
'THR': [
['N', 0, (-0.517, 1.364, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.526, 0.000, -0.000)],
['CB', 0, (-0.516, -0.793, -1.215)],
['O', 3, (0.626, 1.062, 0.000)],
['CG2', 4, (0.550, -0.718, -1.228)],
['OG1', 4, (0.472, 1.353, 0.000)],
],
'TRP': [
['N', 0, (-0.521, 1.363, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.525, -0.000, 0.000)],
['CB', 0, (-0.523, -0.776, -1.212)],
['O', 3, (0.627, 1.062, 0.000)],
['CG', 4, (0.609, 1.370, -0.000)],
['CD1', 5, (0.824, 1.091, 0.000)],
['CD2', 5, (0.854, -1.148, -0.005)],
['CE2', 5, (2.186, -0.678, -0.007)],
['CE3', 5, (0.622, -2.530, -0.007)],
['NE1', 5, (2.140, 0.690, -0.004)],
['CH2', 5, (3.028, -2.890, -0.013)],
['CZ2', 5, (3.283, -1.543, -0.011)],
['CZ3', 5, (1.715, -3.389, -0.011)],
],
'TYR': [
['N', 0, (-0.522, 1.362, 0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.524, -0.000, -0.000)],
['CB', 0, (-0.522, -0.776, -1.213)],
['O', 3, (0.627, 1.062, -0.000)],
['CG', 4, (0.607, 1.382, -0.000)],
['CD1', 5, (0.716, 1.195, -0.000)],
['CD2', 5, (0.713, -1.194, -0.001)],
['CE1', 5, (2.107, 1.200, -0.002)],
['CE2', 5, (2.104, -1.201, -0.003)],
['OH', 5, (4.168, -0.002, -0.005)],
['CZ', 5, (2.791, -0.001, -0.003)],
],
'VAL': [
['N', 0, (-0.494, 1.373, -0.000)],
['CA', 0, (0.000, 0.000, 0.000)],
['C', 0, (1.527, -0.000, -0.000)],
['CB', 0, (-0.533, -0.795, -1.213)],
['O', 3, (0.627, 1.062, -0.000)],
['CG1', 4, (0.540, 1.429, -0.000)],
['CG2', 4, (0.533, -0.776, 1.203)],
],
}

# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention.
residue_atoms = {
'ALA': ['C', 'CA', 'CB', 'N', 'O'],
'ARG': ['C', 'CA', 'CB', 'CG', 'CD', 'CZ', 'N', 'NE', 'O', 'NH1', 'NH2'],
'ASP': ['C', 'CA', 'CB', 'CG', 'N', 'O', 'OD1', 'OD2'],
'ASN': ['C', 'CA', 'CB', 'CG', 'N', 'ND2', 'O', 'OD1'],
'CYS': ['C', 'CA', 'CB', 'N', 'O', 'SG'],
'GLU': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O', 'OE1', 'OE2'],
'GLN': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'NE2', 'O', 'OE1'],
'GLY': ['C', 'CA', 'N', 'O'],
'HIS': ['C', 'CA', 'CB', 'CG', 'CD2', 'CE1', 'N', 'ND1', 'NE2', 'O'],
'ILE': ['C', 'CA', 'CB', 'CG1', 'CG2', 'CD1', 'N', 'O'],
'LEU': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'N', 'O'],
'LYS': ['C', 'CA', 'CB', 'CG', 'CD', 'CE', 'N', 'NZ', 'O'],
'MET': ['C', 'CA', 'CB', 'CG', 'CE', 'N', 'O', 'SD'],
'PHE': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O'],
'PRO': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O'],
'SER': ['C', 'CA', 'CB', 'N', 'O', 'OG'],
'THR': ['C', 'CA', 'CB', 'CG2', 'N', 'O', 'OG1'],
'TRP': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE2', 'CE3', 'CZ2', 'CZ3',
'CH2', 'N', 'NE1', 'O'],
'TYR': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O',
'OH'],
'VAL': ['C', 'CA', 'CB', 'CG1', 'CG2', 'N', 'O']
}

# Van der Waals radii [Angstroem] of the atoms (from Wikipedia)
van_der_waals_radius = {
'C': 1.7,
'N': 1.55,
'O': 1.52,
'S': 1.8,
}

Bond = collections.namedtuple(
'Bond', ['atom1_name', 'atom2_name', 'length', 'stddev'])
BondAngle = collections.namedtuple(
'BondAngle',
['atom1_name', 'atom2_name', 'atom3name', 'angle_rad', 'stddev'])


@functools.lru_cache(maxsize=None)
def load_stereo_chemical_props() -> Tuple[Mapping[str, List[Bond]],
Mapping[str, List[Bond]],
Mapping[str, List[BondAngle]]]:
"""Load stereo_chemical_props.txt into a nice structure.

Load literature values for bond lengths and bond angles and translate
bond angles into the length of the opposite edge of the triangle
("residue_virtual_bonds").

Returns:
residue_bonds: dict that maps resname --> list of Bond tuples
residue_virtual_bonds: dict that maps resname --> list of Bond tuples
residue_bond_angles: dict that maps resname --> list of BondAngle tuples
"""
stereo_chemical_props_path = (
'alphafold/common/stereo_chemical_props.txt')
with open(stereo_chemical_props_path, 'rt') as f:
stereo_chemical_props = f.read()
lines_iter = iter(stereo_chemical_props.splitlines())
# Load bond lengths.
residue_bonds = {}
next(lines_iter) # Skip header line.
for line in lines_iter:
if line.strip() == '-':
break
bond, resname, length, stddev = line.split()
atom1, atom2 = bond.split('-')
if resname not in residue_bonds:
residue_bonds[resname] = []
residue_bonds[resname].append(
Bond(atom1, atom2, float(length), float(stddev)))
residue_bonds['UNK'] = []

# Load bond angles.
residue_bond_angles = {}
next(lines_iter) # Skip empty line.
next(lines_iter) # Skip header line.
for line in lines_iter:
if line.strip() == '-':
break
bond, resname, angle_degree, stddev_degree = line.split()
atom1, atom2, atom3 = bond.split('-')
if resname not in residue_bond_angles:
residue_bond_angles[resname] = []
residue_bond_angles[resname].append(
BondAngle(atom1, atom2, atom3,
float(angle_degree) / 180. * np.pi,
float(stddev_degree) / 180. * np.pi))
residue_bond_angles['UNK'] = []

def make_bond_key(atom1_name, atom2_name):
"""Unique key to lookup bonds."""
return '-'.join(sorted([atom1_name, atom2_name]))

# Translate bond angles into distances ("virtual bonds").
residue_virtual_bonds = {}
for resname, bond_angles in residue_bond_angles.items():
# Create a fast lookup dict for bond lengths.
bond_cache = {}
for b in residue_bonds[resname]:
bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b
residue_virtual_bonds[resname] = []
for ba in bond_angles:
bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)]
bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)]

# Compute distance between atom1 and atom3 using the law of cosines
# c^2 = a^2 + b^2 - 2ab*cos(gamma).
gamma = ba.angle_rad
length = np.sqrt(bond1.length**2 + bond2.length**2
- 2 * bond1.length * bond2.length * np.cos(gamma))

# Propagation of uncertainty assuming uncorrelated errors.
dl_outer = 0.5 / length
dl_dgamma = (2 * bond1.length * bond2.length *
np.sin(gamma)) * dl_outer
dl_db1 = (2 * bond1.length - 2 * bond2.length *
np.cos(gamma)) * dl_outer
dl_db2 = (2 * bond2.length - 2 * bond1.length *
np.cos(gamma)) * dl_outer
stddev = np.sqrt((dl_dgamma * ba.stddev)**2 +
(dl_db1 * bond1.stddev)**2 +
(dl_db2 * bond2.stddev)**2)
residue_virtual_bonds[resname].append(
Bond(ba.atom1_name, ba.atom3name, length, stddev))

return (residue_bonds,
residue_virtual_bonds,
residue_bond_angles)


# This mapping is used when we need to store atom data in a format that requires
# fixed atom data size for every residue (e.g. a numpy array).
atom_types = [
'N', 'CA', 'C', 'CB', 'O', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD',
'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3',
'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2',
'CZ3', 'NZ', 'OXT'
]
atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
atom_type_num = len(atom_types) # := 37.

# A compact atom encoding with 14 columns
# pylint: disable=line-too-long
# pylint: disable=bad-whitespace
restype_name_to_atom14_names = {
'ALA': ['N', 'CA', 'C', 'O', 'CB', '', '', '', '', '', '', '', '', ''],
'ARG': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2', '', '', ''],
'ASN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2', '', '', '', '', '', ''],
'ASP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2', '', '', '', '', '', ''],
'CYS': ['N', 'CA', 'C', 'O', 'CB', 'SG', '', '', '', '', '', '', '', ''],
'GLN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2', '', '', '', '', ''],
'GLU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2', '', '', '', '', ''],
'GLY': ['N', 'CA', 'C', 'O', '', '', '', '', '', '', '', '', '', ''],
'HIS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2', '', '', '', ''],
'ILE': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '', '', '', '', '', ''],
'LEU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', '', '', '', '', '', ''],
'LYS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ', '', '', '', '', ''],
'MET': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE', '', '', '', '', '', ''],
'PHE': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', '', '', ''],
'PRO': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', '', '', '', '', '', '', ''],
'SER': ['N', 'CA', 'C', 'O', 'CB', 'OG', '', '', '', '', '', '', '', ''],
'THR': ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '', '', '', '', '', '', ''],
'TRP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2'],
'TYR': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH', '', ''],
'VAL': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '', '', '', '', '', '', ''],
'UNK': ['', '', '', '', '', '', '', '', '', '', '', '', '', ''],

}

# This is the standard residue order when coding AA type as a number.
# Reproduce it by taking 3-letter AA codes and sorting them alphabetically.
restypes = [
'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P',
'S', 'T', 'W', 'Y', 'V'
]
restype_order = {restype: i for i, restype in enumerate(restypes)}
restype_num = len(restypes) # := 20.

restypes_with_x = restypes + ['X']
restype_order_with_x = {
restype: i for i,
restype in enumerate(restypes_with_x)}


def sequence_to_onehot(
sequence: str,
mapping: Mapping[str, int],
map_unknown_to_x: bool = False) -> np.ndarray:
"""Maps the given sequence into a one-hot encoded matrix.

Args:
sequence: An amino acid sequence.
mapping: A dictionary mapping amino acids to integers.
map_unknown_to_x: If True, any amino acid that is not in the mapping will be
mapped to the unknown amino acid 'X'. If the mapping doesn't contain
amino acid 'X', an error will be thrown. If False, any amino acid not in
the mapping will throw an error.

Returns:
A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of
the sequence.

Raises:
ValueError: If the mapping doesn't contain values from 0 to
num_unique_aas - 1 without any gaps.
"""
num_entries = max(mapping.values()) + 1

if sorted(set(mapping.values())) != list(range(num_entries)):
raise ValueError(
'The mapping must have values from 0 to num_unique_aas-1 '
'without any gaps. Got: %s' %
sorted(
mapping.values()))

one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32)

for aa_index, aa_type in enumerate(sequence):
if map_unknown_to_x:
if aa_type.isalpha() and aa_type.isupper():
aa_id = mapping.get(aa_type, mapping['X'])
else:
raise ValueError(
f'Invalid character in the sequence: {aa_type}')
else:
aa_id = mapping[aa_type]
one_hot_arr[aa_index, aa_id] = 1

return one_hot_arr


restype_1to3 = {
'A': 'ALA',
'R': 'ARG',
'N': 'ASN',
'D': 'ASP',
'C': 'CYS',
'Q': 'GLN',
'E': 'GLU',
'G': 'GLY',
'H': 'HIS',
'I': 'ILE',
'L': 'LEU',
'K': 'LYS',
'M': 'MET',
'F': 'PHE',
'P': 'PRO',
'S': 'SER',
'T': 'THR',
'W': 'TRP',
'Y': 'TYR',
'V': 'VAL',
}


# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple
# 1-to-1 mapping of 3 letter names to one letter names. The latter contains
# many more, and less common, three letter names as keys and maps many of these
# to the same one letter name (including 'X' and 'U' which we don't use here).
restype_3to1 = {v: k for k, v in restype_1to3.items()}

# Define a restype name for all unknown residues.
unk_restype = 'UNK'

resnames = [restype_1to3[r] for r in restypes] + [unk_restype]
resname_to_idx = {resname: i for i, resname in enumerate(resnames)}


# The mapping here uses hhblits convention, so that B is mapped to D, J and O
# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the
# remaining 20 amino acids are kept in alphabetical order.
# There are 2 non-amino acid codes, X (representing any amino acid) and
# "-" representing a missing amino acid in an alignment. The id for these
# codes is put at the end (20 and 21) so that they can easily be ignored if
# desired.
HHBLITS_AA_TO_ID = {
'A': 0,
'B': 2,
'C': 1,
'D': 2,
'E': 3,
'F': 4,
'G': 5,
'H': 6,
'I': 7,
'J': 20,
'K': 8,
'L': 9,
'M': 10,
'N': 11,
'O': 20,
'P': 12,
'Q': 13,
'R': 14,
'S': 15,
'T': 16,
'U': 1,
'V': 17,
'W': 18,
'X': 20,
'Y': 19,
'Z': 3,
'-': 21,
}

# Partial inversion of HHBLITS_AA_TO_ID.
ID_TO_HHBLITS_AA = {
0: 'A',
1: 'C', # Also U.
2: 'D', # Also B.
3: 'E', # Also Z.
4: 'F',
5: 'G',
6: 'H',
7: 'I',
8: 'K',
9: 'L',
10: 'M',
11: 'N',
12: 'P',
13: 'Q',
14: 'R',
15: 'S',
16: 'T',
17: 'V',
18: 'W',
19: 'Y',
20: 'X', # Includes J and O.
21: '-',
}

restypes_with_x_and_gap = restypes + ['X', '-']
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple(restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i])
for i in range(len(restypes_with_x_and_gap)))


def _make_standard_atom_mask() -> np.ndarray:
"""Returns [num_res_types, num_atom_types] mask array."""
# +1 to account for unknown (all 0s).
mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32)
for restype, restype_letter in enumerate(restypes):
restype_name = restype_1to3[restype_letter]
atom_names = residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = atom_order[atom_name]
mask[restype, atom_type] = 1
return mask


STANDARD_ATOM_MASK = _make_standard_atom_mask()


# A one hot representation for the first and second atoms defining the axis
# of rotation for each chi-angle in each residue.
def chi_angle_atom(atom_index: int) -> np.ndarray:
"""Define chi-angle rigid groups via one-hot representations."""
chi_angles_index = {}
one_hots = []

for k, v in chi_angles_atoms.items():
indices = [atom_types.index(s[atom_index]) for s in v]
indices.extend([-1] * (4 - len(indices)))
chi_angles_index[k] = indices

for r in restypes:
res3 = restype_1to3[r]
one_hot = np.eye(atom_type_num)[chi_angles_index[res3]]
one_hots.append(one_hot)

one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`.
one_hot = np.stack(one_hots, axis=0)
one_hot = np.transpose(one_hot, [0, 2, 1])

return one_hot


chi_atom_1_one_hot = chi_angle_atom(1)
chi_atom_2_one_hot = chi_angle_atom(2)

# An array like chi_angles_atoms but using indices rather than names.
chi_angles_atom_indices = [[], [[0, 1, 3, 5], [1, 3, 5, 11], [3, 5, 11, 23], [5, 11, 23, 32]], [[0, 1, 3, 5], [1, 3, 5, 16]], [[0, 1, 3, 5], [1, 3, 5, 16]], [[0, 1, 3, 10]], [[0, 1, 3, 5], [1, 3, 5, 11], [3, 5, 11, 26]], [[0, 1, 3, 5], [1, 3, 5, 11], [3, 5, 11, 26]], [], [[0, 1, 3, 5], [1, 3, 5, 14]], [[0, 1, 3, 6], [1, 3, 6, 12]], [[0, 1, 3, 5], [1, 3, 5, 12]], [[0, 1, 3, 5], [1, 3, 5, 11], [3, 5, 11, 19], [5, 11, 19, 35]], [[0, 1, 3, 5], [1, 3, 5, 18], [3, 5, 18, 19]], [[0, 1, 3, 5], [1, 3, 5, 12]], [[0, 1, 3, 5], [1, 3, 5, 11]], [[0, 1, 3, 8]], [[0, 1, 3, 9]], [[0, 1, 3, 5], [1, 3, 5, 12]], [[0, 1, 3, 5], [1, 3, 5, 12]], [[0, 1, 3, 6]]]
chi_angles_atom_indices = np.array([
chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms)))
for chi_atoms in chi_angles_atom_indices])

# Mapping from (res_name, atom_name) pairs to the atom's chi group index
# and atom index within that group.
chi_groups_for_atom = collections.defaultdict(list)
for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items():
for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res):
for atom_i, atom in enumerate(chi_group):
chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i))
chi_groups_for_atom = dict(chi_groups_for_atom)


def _make_rigid_transformation_4x4(ex, ey, translation):
"""Create a rigid 4x4 transformation matrix from two axes and transl."""
# Normalize ex.
ex_normalized = ex / np.linalg.norm(ex)

# make ey perpendicular to ex
ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized
ey_normalized /= np.linalg.norm(ey_normalized)

# compute ez as cross product
eznorm = np.cross(ex_normalized, ey_normalized)
m = np.stack([ex_normalized, ey_normalized,
eznorm, translation]).transpose()
m = np.concatenate([m, [[0., 0., 0., 1.]]], axis=0)
return m


# create an array with (restype, atomtype) --> rigid_group_idx
# and an array with (restype, atomtype, coord) for the atom positions
# and compute affine transformation matrices (4,4) from one rigid group to the
# previous group
restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int)
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int)
restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)


def _make_rigid_group_constants():
"""Fill the arrays above."""
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
for atomname, group_idx, atom_position in rigid_group_atom_positions[
resname]:
atomtype = atom_order[atomname]
restype_atom37_to_rigid_group[restype, atomtype] = group_idx
restype_atom37_mask[restype, atomtype] = 1
restype_atom37_rigid_group_positions[restype,
atomtype, :] = atom_position

atom14idx = restype_name_to_atom14_names[resname].index(atomname)
restype_atom14_to_rigid_group[restype, atom14idx] = group_idx
restype_atom14_mask[restype, atom14idx] = 1
restype_atom14_rigid_group_positions[restype,
atom14idx, :] = atom_position

for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
atom_positions = {name: np.array(pos) for name, _, pos
in rigid_group_atom_positions[resname]}

# backbone to backbone is the identity transform
restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4)

# pre-omega-frame to backbone (currently dummy identity matrix)
restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4)

# phi-frame to backbone
mat = _make_rigid_transformation_4x4(
ex=atom_positions['N'] - atom_positions['CA'],
ey=np.array([1., 0., 0.]),
translation=atom_positions['N'])
restype_rigid_group_default_frame[restype, 2, :, :] = mat

# psi-frame to backbone
mat = _make_rigid_transformation_4x4(
ex=atom_positions['C'] - atom_positions['CA'],
ey=atom_positions['CA'] - atom_positions['N'],
translation=atom_positions['C'])
restype_rigid_group_default_frame[restype, 3, :, :] = mat

# chi1-frame to backbone
if chi_angles_mask[restype][0]:
base_atom_names = chi_angles_atoms[resname][0]
base_atom_positions = [atom_positions[name]
for name in base_atom_names]
mat = _make_rigid_transformation_4x4(
ex=base_atom_positions[2] - base_atom_positions[1],
ey=base_atom_positions[0] - base_atom_positions[1],
translation=base_atom_positions[2])
restype_rigid_group_default_frame[restype, 4, :, :] = mat

# chi2-frame to chi1-frame
# chi3-frame to chi2-frame
# chi4-frame to chi3-frame
# luckily all rotation axes for the next frame start at (0,0,0) of the
# previous frame
for chi_idx in range(1, 4):
if chi_angles_mask[restype][chi_idx]:
axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2]
axis_end_atom_position = atom_positions[axis_end_atom_name]
mat = _make_rigid_transformation_4x4(
ex=axis_end_atom_position,
ey=np.array([-1., 0., 0.]),
translation=axis_end_atom_position)
restype_rigid_group_default_frame[restype,
4 + chi_idx, :, :] = mat


_make_rigid_group_constants()

+ 1038
- 0
reproduce/AlphaFold2-Chinese/commons/utils.py
File diff suppressed because it is too large
View File


+ 382
- 0
reproduce/AlphaFold2-Chinese/config/config.py View File

@@ -0,0 +1,382 @@
"""Model config."""

import copy
import ml_collections


NUM_RES = 'num residues placeholder'
NUM_MSA_SEQ = 'msa placeholder'
NUM_EXTRA_SEQ = 'extra msa placeholder'
NUM_TEMPLATES = 'num templates placeholder'


def model_config(name: str) -> ml_collections.ConfigDict:
"""Get the ConfigDict of a CASP14 model."""

if name not in CONFIG_DIFFS:
raise ValueError(f'Invalid model name {name}.')
cfg = copy.deepcopy(CONFIG)
cfg.update_from_flattened_dict(CONFIG_DIFFS[name])
return cfg


CONFIG_DIFFS = {
'model_1': {
# Jumper et al. (2021) Suppl. Table 5, Model 1.1.1
'data.common.max_extra_msa': 5120,
'data.common.reduce_msa_clusters_by_max_templates': True,
'data.common.use_templates': True,
'model.embeddings_and_evoformer.template.embed_torsion_angles': True,
'model.embeddings_and_evoformer.template.enabled': True
},
'model_2': {
# Jumper et al. (2021) Suppl. Table 5, Model 1.1.2
'data.common.reduce_msa_clusters_by_max_templates': True,
'data.common.use_templates': True,
'model.embeddings_and_evoformer.template.embed_torsion_angles': True,
'model.embeddings_and_evoformer.template.enabled': True
},
'model_3': {
# Jumper et al. (2021) Suppl. Table 5, Model 1.2.1
'data.common.max_extra_msa': 5120,
},
'model_4': {
# Jumper et al. (2021) Suppl. Table 5, Model 1.2.2
'data.common.max_extra_msa': 5120,
},
'model_5': {
# Jumper et al. (2021) Suppl. Table 5, Model 1.2.3
},

# The following models are fine-tuned from the corresponding models above
# with an additional predicted_aligned_error head that can produce
# predicted TM-score (pTM) and predicted aligned errors.
'model_1_ptm': {
'data.common.max_extra_msa': 5120,
'data.common.reduce_msa_clusters_by_max_templates': True,
'data.common.use_templates': True,
'model.embeddings_and_evoformer.template.embed_torsion_angles': True,
'model.embeddings_and_evoformer.template.enabled': True,
'model.heads.predicted_aligned_error.weight': 0.1
},
'model_2_ptm': {
'data.common.reduce_msa_clusters_by_max_templates': True,
'data.common.use_templates': True,
'model.embeddings_and_evoformer.template.embed_torsion_angles': True,
'model.embeddings_and_evoformer.template.enabled': True,
'model.heads.predicted_aligned_error.weight': 0.1
},
'model_3_ptm': {
'data.common.max_extra_msa': 5120,
'model.heads.predicted_aligned_error.weight': 0.1
},
'model_4_ptm': {
'data.common.max_extra_msa': 5120,
'model.heads.predicted_aligned_error.weight': 0.1
},
'model_5_ptm': {
'model.heads.predicted_aligned_error.weight': 0.1
}
}

CONFIG = ml_collections.ConfigDict({
'data': {
'common': {
'masked_msa': {
'profile_prob': 0.1,
'same_prob': 0.1,
'uniform_prob': 0.1
},
'max_extra_msa': 1024,
'msa_cluster_features': True,
'num_recycle': 3,
'reduce_msa_clusters_by_max_templates': False,
'resample_msa_in_recycling': True,
'template_features': [
'template_all_atom_positions', 'template_sum_probs',
'template_aatype', 'template_all_atom_masks',
'template_domain_names'
],
'unsupervised_features': [
'aatype', 'residue_index', 'sequence', 'msa', 'domain_name',
'num_alignments', 'seq_length', 'between_segment_residues',
'deletion_matrix'
],
'use_templates': False,
},
'eval': {
'feat': {
'aatype': [NUM_RES],
'all_atom_mask': [NUM_RES, None],
'all_atom_positions': [NUM_RES, None, None],
'alt_chi_angles': [NUM_RES, None],
'atom14_alt_gt_exists': [NUM_RES, None],
'atom14_alt_gt_positions': [NUM_RES, None, None],
'atom14_atom_exists': [NUM_RES, None],
'atom14_atom_is_ambiguous': [NUM_RES, None],
'atom14_gt_exists': [NUM_RES, None],
'atom14_gt_positions': [NUM_RES, None, None],
'atom37_atom_exists': [NUM_RES, None],
'backbone_affine_mask': [NUM_RES],
'backbone_affine_tensor': [NUM_RES, None],
'bert_mask': [NUM_MSA_SEQ, NUM_RES],
'chi_angles': [NUM_RES, None],
'chi_mask': [NUM_RES, None],
'extra_deletion_value': [NUM_EXTRA_SEQ, NUM_RES],
'extra_has_deletion': [NUM_EXTRA_SEQ, NUM_RES],
'extra_msa': [NUM_EXTRA_SEQ, NUM_RES],
'extra_msa_mask': [NUM_EXTRA_SEQ, NUM_RES],
'extra_msa_row_mask': [NUM_EXTRA_SEQ],
'is_distillation': [],
'msa_feat': [NUM_MSA_SEQ, NUM_RES, None],
'msa_mask': [NUM_MSA_SEQ, NUM_RES],
'msa_row_mask': [NUM_MSA_SEQ],
'pseudo_beta': [NUM_RES, None],
'pseudo_beta_mask': [NUM_RES],
'random_crop_to_size_seed': [None],
'residue_index': [NUM_RES],
'residx_atom14_to_atom37': [NUM_RES, None],
'residx_atom37_to_atom14': [NUM_RES, None],
'resolution': [],
'rigidgroups_alt_gt_frames': [NUM_RES, None, None],
'rigidgroups_group_exists': [NUM_RES, None],
'rigidgroups_group_is_ambiguous': [NUM_RES, None],
'rigidgroups_gt_exists': [NUM_RES, None],
'rigidgroups_gt_frames': [NUM_RES, None, None],
'seq_length': [],
'seq_mask': [NUM_RES],
'target_feat': [NUM_RES, None],
'template_aatype': [NUM_TEMPLATES, NUM_RES],
'template_all_atom_masks': [NUM_TEMPLATES, NUM_RES, None],
'template_all_atom_positions': [
NUM_TEMPLATES, NUM_RES, None, None],
'template_backbone_affine_mask': [NUM_TEMPLATES, NUM_RES],
'template_backbone_affine_tensor': [
NUM_TEMPLATES, NUM_RES, None],
'template_mask': [NUM_TEMPLATES],
'template_pseudo_beta': [NUM_TEMPLATES, NUM_RES, None],
'template_pseudo_beta_mask': [NUM_TEMPLATES, NUM_RES],
'template_sum_probs': [NUM_TEMPLATES, None],
'true_msa': [NUM_MSA_SEQ, NUM_RES]
},
'fixed_size': True,
'subsample_templates': False, # We want top templates.
'masked_msa_replace_fraction': 0.15,
'max_msa_clusters': 512,
'max_templates': 4,
'num_ensemble': 1,
},
},
'model': {
'embeddings_and_evoformer': {
'evoformer_num_block': 48,
'evoformer': {
'msa_row_attention_with_pair_bias': {
'dropout_rate': 0.15,
'gating': True,
'num_head': 8,
'orientation': 'per_row',
'shared_dropout': True
},
'msa_column_attention': {
'dropout_rate': 0.0,
'gating': True,
'num_head': 8,
'orientation': 'per_column',
'shared_dropout': True
},
'msa_transition': {
'dropout_rate': 0.0,
'num_intermediate_factor': 4,
'orientation': 'per_row',
'shared_dropout': True
},
'outer_product_mean': {
'chunk_size': 128,
'dropout_rate': 0.0,
'num_outer_channel': 32,
'orientation': 'per_row',
'shared_dropout': True
},
'triangle_attention_starting_node': {
'dropout_rate': 0.25,
'gating': True,
'num_head': 4,
'orientation': 'per_row',
'shared_dropout': True
},
'triangle_attention_ending_node': {
'dropout_rate': 0.25,
'gating': True,
'num_head': 4,
'orientation': 'per_column',
'shared_dropout': True
},
'triangle_multiplication_outgoing': {
'dropout_rate': 0.25,
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True
},
'triangle_multiplication_incoming': {
'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 128,
'orientation': 'per_row',
'shared_dropout': True
},
'pair_transition': {
'dropout_rate': 0.0,
'num_intermediate_factor': 4,
'orientation': 'per_row',
'shared_dropout': True
}
},
'extra_msa_channel': 64,
'extra_msa_stack_num_block': 4,
'max_relative_feature': 32,
'msa_channel': 256,
'pair_channel': 128,
'prev_pos': {
'min_bin': 3.25,
'max_bin': 20.75,
'num_bins': 15
},
'recycle_features': True,
'recycle_pos': True,
'seq_channel': 384,
'template': {
'attention': {
'gating': False,
'key_dim': 64,
'num_head': 4,
'value_dim': 64
},
'dgram_features': {
'min_bin': 3.25,
'max_bin': 50.75,
'num_bins': 39
},
'embed_torsion_angles': False,
'enabled': False,
'template_pair_stack': {
'num_block': 2,
'triangle_attention_starting_node': {
'dropout_rate': 0.25,
'gating': True,
'key_dim': 64,
'num_head': 4,
'orientation': 'per_row',
'shared_dropout': True,
'value_dim': 64
},
'triangle_attention_ending_node': {
'dropout_rate': 0.25,
'gating': True,
'key_dim': 64,
'num_head': 4,
'orientation': 'per_column',
'shared_dropout': True,
'value_dim': 64
},
'triangle_multiplication_outgoing': {
'dropout_rate': 0.25,
'equation': 'ikc,jkc->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True
},
'triangle_multiplication_incoming': {
'dropout_rate': 0.25,
'equation': 'kjc,kic->ijc',
'num_intermediate_channel': 64,
'orientation': 'per_row',
'shared_dropout': True
},
'pair_transition': {
'dropout_rate': 0.0,
'num_intermediate_factor': 2,
'orientation': 'per_row',
'shared_dropout': True
}
},
'max_templates': 4,
'subbatch_size': 128,
'use_template_unit_vector': False,
}
},
'heads': {
'distogram': {
'first_break': 2.3125,
'last_break': 21.6875,
'num_bins': 64,
'weight': 0.3
},
'predicted_aligned_error': {
# `num_bins - 1` bins uniformly space the
# [0, max_error_bin A] range.
# The final bin covers [max_error_bin A, +infty]
# 31A gives bins with 0.5A width.
'max_error_bin': 31.,
'num_bins': 64,
'num_channels': 128,
'filter_by_resolution': True,
'min_resolution': 0.1,
'max_resolution': 3.0,
'weight': 0.0,
},
'experimentally_resolved': {
'filter_by_resolution': True,
'max_resolution': 3.0,
'min_resolution': 0.1,
'weight': 0.01
},
'structure_module': {
'num_layer': 8,
'fape': {
'clamp_distance': 10.0,
'clamp_type': 'relu',
'loss_unit_distance': 10.0
},
'angle_norm_weight': 0.01,
'chi_weight': 0.5,
'clash_overlap_tolerance': 1.5,
'compute_in_graph_metrics': True,
'dropout': 0.1,
'num_channel': 384,
'num_head': 12,
'num_layer_in_transition': 3,
'num_point_qk': 4,
'num_point_v': 8,
'num_scalar_qk': 16,
'num_scalar_v': 16,
'position_scale': 10.0,
'sidechain': {
'atom_clamp_distance': 10.0,
'num_channel': 128,
'num_residual_block': 2,
'weight_frac': 0.5,
'length_scale': 10.,
},
'structural_violation_loss_weight': 1.0,
'violation_tolerance_factor': 12.0,
'weight': 1.0
},
'predicted_lddt': {
'filter_by_resolution': True,
'max_resolution': 3.0,
'min_resolution': 0.1,
'num_bins': 50,
'num_channels': 128,
'weight': 0.01
},
'masked_msa': {
'num_output': 23,
'weight': 2.0
},
},
'num_recycle': 3,
'resample_msa_in_recycling': True
},
})

+ 341
- 0
reproduce/AlphaFold2-Chinese/config/global_config.py View File

@@ -0,0 +1,341 @@
"""Model config."""

import copy
import ml_collections

def global_config(length: int) -> ml_collections.ConfigDict:
"""Get the global config."""
if str(length) not in GLOBAL_CONFIG:
raise ValueError(f'Invalid padding sequence length {length}.')
cfg = copy.deepcopy(GLOBAL_CONFIG[str(length)])
return cfg

GLOBAL_CONFIG = ml_collections.ConfigDict({
"256": {
'zero_init': True,
'seq_length': 256,
'extra_msa_length': 5120,
'template_embedding': {
'slice_num': 0,
},
'template_pair_stack': {
'triangle_attention_starting_node': {
'slice_num': 0,
},
'triangle_attention_ending_node': {
'slice_num': 0,
},
'pair_transition': {
'slice_num': 0,
},
},
'extra_msa_stack': {
'msa_transition': {
'slice_num': 0,
},
'msa_row_attention_with_pair_bias': {
'slice_num': 0,
},
'msa_column_global_attention': {
'slice_num': 0,
},
'outer_product_mean': {
'slice_num': 0,
},
'triangle_attention_starting_node': {
'slice_num': 0,
},
'triangle_attention_ending_node': {
'slice_num': 0,
},
'pair_transition': {
'slice_num': 0,
},
},
'evoformer_iteration': {
'msa_transition': {
'slice_num': 0,
},
'msa_row_attention_with_pair_bias': {
'slice_num': 0,
},
'msa_column_attention': {
'slice_num': 0,
},
'outer_product_mean': {
'slice_num': 0,
},
'triangle_attention_starting_node': {
'slice_num': 0,
},
'triangle_attention_ending_node': {
'slice_num': 0,
},
'pair_transition': {
'slice_num': 0,
},
},
},
"512": {
'zero_init': True,
'seq_length': 512,
'extra_msa_length': 5120,
'template_embedding': {
'slice_num': 0,
},
'template_pair_stack': {
'triangle_attention_starting_node': {
'slice_num': 0,
},
'triangle_attention_ending_node': {
'slice_num': 0,
},
'pair_transition': {
'slice_num': 0,
},
},
'extra_msa_stack': {
'msa_transition': {
'slice_num': 0,
},
'msa_row_attention_with_pair_bias': {
'slice_num': 4,
},
'msa_column_global_attention': {
'slice_num': 0,
},
'outer_product_mean': {
'slice_num': 0,
},
'triangle_attention_starting_node': {
'slice_num': 0,
},
'triangle_attention_ending_node': {
'slice_num': 0,
},
'pair_transition': {
'slice_num': 0,
},
},
'evoformer_iteration': {
'msa_transition': {
'slice_num': 0,
},
'msa_row_attention_with_pair_bias': {
'slice_num': 0,
},
'msa_column_attention': {
'slice_num': 0,
},
'outer_product_mean': {
'slice_num': 0,
},
'triangle_attention_starting_node': {
'slice_num': 0,
},
'triangle_attention_ending_node': {
'slice_num': 0,
},
'pair_transition': {
'slice_num': 0,
},
},
},
"1024": {
'zero_init': True,
'seq_length': 1024,
'extra_msa_length': 5120,
'template_embedding': {
'slice_num': 4,
},
'template_pair_stack': {
'triangle_attention_starting_node': {
'slice_num': 4,
},
'triangle_attention_ending_node': {
'slice_num': 4,
},
'pair_transition': {
'slice_num': 0,
},
},
'extra_msa_stack': {
'msa_transition': {
'slice_num': 0,
},
'msa_row_attention_with_pair_bias': {
'slice_num': 16,
},
'msa_column_global_attention': {
'slice_num': 4,
},
'outer_product_mean': {
'slice_num': 0,
},
'triangle_attention_starting_node': {
'slice_num': 4,
},
'triangle_attention_ending_node': {
'slice_num': 4,
},
'pair_transition': {
'slice_num': 0,
},
},
'evoformer_iteration': {
'msa_transition': {
'slice_num': 0,
},
'msa_row_attention_with_pair_bias': {
'slice_num': 4,
},
'msa_column_attention': {
'slice_num': 4,
},
'outer_product_mean': {
'slice_num': 0,
},
'triangle_attention_starting_node': {
'slice_num': 4,
},
'triangle_attention_ending_node': {
'slice_num': 4,
},
'pair_transition': {
'slice_num': 0,
},
},
},
"2048": {
'zero_init': True,
'seq_length': 2048,
'extra_msa_length': 5120,
'template_embedding': {
'slice_num': 32,
},
'template_pair_stack': {
'triangle_attention_starting_node': {
'slice_num': 32,
},
'triangle_attention_ending_node': {
'slice_num': 32,
},
'pair_transition': {
'slice_num': 16,
},
},

'extra_msa_stack': {
'msa_transition': {
'slice_num': 16,
},
'msa_row_attention_with_pair_bias': {
'slice_num': 128,
},
'msa_column_global_attention': {
'slice_num': 32,
},
'outer_product_mean': {
'slice_num': 16,
},
'triangle_attention_starting_node': {
'slice_num': 32,
},
'triangle_attention_ending_node': {
'slice_num': 32,
},
'pair_transition': {
'slice_num': 16,
},
},
'evoformer_iteration': {
'msa_transition': {
'slice_num': 16,
},
'msa_row_attention_with_pair_bias': {
'slice_num': 32,
},
'msa_column_attention': {
'slice_num': 32,
},
'outer_product_mean': {
'slice_num': 16,
},
'triangle_attention_starting_node': {
'slice_num': 32,
},
'triangle_attention_ending_node': {
'slice_num': 32,
},
'pair_transition': {
'slice_num': 16,
},
},
},
"2304": {
'zero_init': True,
'seq_length': 2304,
'extra_msa_length': 5120,
'template_embedding': {
'slice_num': 64,
},
'template_pair_stack': {
'triangle_attention_starting_node': {
'slice_num': 64,
},
'triangle_attention_ending_node': {
'slice_num': 64,
},
'pair_transition': {
'slice_num': 2,
},
},

'extra_msa_stack': {
'msa_transition': {
'slice_num': 2,
},
'msa_row_attention_with_pair_bias': {
'slice_num': 64,
},
'msa_column_global_attention': {
'slice_num': 64,
},
'outer_product_mean': {
'slice_num': 8,
},
'triangle_attention_starting_node': {
'slice_num': 64,
},
'triangle_attention_ending_node': {
'slice_num': 64,
},
'pair_transition': {
'slice_num': 4,
},
},
'evoformer_iteration': {
'msa_transition': {
'slice_num': 2,
},
'msa_row_attention_with_pair_bias': {
'slice_num': 64,
},
'msa_column_attention': {
'slice_num': 64,
},
'outer_product_mean': {
'slice_num': 8,
},
'triangle_attention_starting_node': {
'slice_num': 64,
},
'triangle_attention_ending_node': {
'slice_num': 64,
},
'pair_transition': {
'slice_num': 4,
},
},
},
})

+ 517
- 0
reproduce/AlphaFold2-Chinese/data/feature/data_transforms.py View File

@@ -0,0 +1,517 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""data transforms"""
import numpy as np

from commons import residue_constants

NUM_RES = 'num residues placeholder'
NUM_MSA_SEQ = 'msa placeholder'
NUM_EXTRA_SEQ = 'extra msa placeholder'
NUM_TEMPLATES = 'num templates placeholder'
MS_MIN32 = -2147483648
MS_MAX32 = 2147483647
_MSA_FEATURE_NAMES = ['msa', 'deletion_matrix', 'msa_mask', 'msa_row_mask', 'bert_mask', 'true_msa']


class SeedMaker:
"""Return unique seeds."""

def __init__(self, initial_seed=0):
self.next_seed = initial_seed

def __call__(self):
i = self.next_seed
self.next_seed += 1
return i


seed_maker = SeedMaker()


def one_hot(depth, indices):
res = np.eye(depth)[indices.reshape(-1)]
return res.reshape(list(indices.shape) + [depth])


def make_random_seed(size, seed_maker_t, low=MS_MIN32, high=MS_MAX32):
np.random.seed(seed_maker_t)
return np.random.uniform(size=size, low=low, high=high)


def curry1(f):
"""Supply all arguments except the first."""

def fc(*args, **kwargs):
return lambda x: f(x, *args, **kwargs)

return fc


@curry1
def compose(x, fs):
for f in fs:
x = f(x)
return x


@curry1
def randomly_replace_msa_with_unknown(protein, replace_proportion):
"""Replace a proportion of the MSA with 'X'."""
msa_mask = np.random.uniform(size=shape_list(protein['msa']), low=0, high=1) < replace_proportion
x_idx = 20
gap_idx = 21
msa_mask = np.logical_and(msa_mask, protein['msa'] != gap_idx)
protein['msa'] = np.where(msa_mask, np.ones_like(protein['msa']) * x_idx, protein['msa'])

aatype_mask = np.random.uniform(size=shape_list(protein['aatype']), low=0, high=1) < replace_proportion
protein['aatype'] = np.where(aatype_mask, np.ones_like(protein['aatype']) * x_idx, protein['aatype'])

return protein


@curry1
def sample_msa(protein, max_seq, keep_extra):
"""Sample MSA randomly, remaining sequences are stored as `extra_*`."""
num_seq = protein['msa'].shape[0]

shuffled = list(range(1, num_seq))
np.random.shuffle(shuffled)
shuffled.insert(0, 0)
index_order = np.array(shuffled, np.int32)
num_sel = min(max_seq, num_seq)

sel_seq = index_order[:num_sel]
not_sel_seq = index_order[num_sel:]
is_sel = num_seq - num_sel

for k in _MSA_FEATURE_NAMES:
if k in protein:
if keep_extra and not is_sel:
new_shape = list(protein[k].shape)
new_shape[0] = 1
protein['extra_' + k] = np.zeros(new_shape)
elif keep_extra and is_sel:
protein['extra_' + k] = protein[k][not_sel_seq]
if k == 'msa':
protein['extra_msa'] = protein['extra_msa'].astype(np.int32)
protein[k] = protein[k][sel_seq]

return protein


def shaped_categorical(probs):
ds = shape_list(probs)
num_classes = ds[-1]
probs = np.reshape(probs, (-1, num_classes))
nums = list(range(num_classes))
counts = []
for prob in probs:
counts.append(np.random.choice(nums, p=prob))
return np.reshape(np.array(counts, np.int32), ds[:-1])


@curry1
def make_masked_msa(protein, config, replace_fraction):
"""Create data for BERT on raw MSA."""
random_aa = np.array([0.05] * 20 + [0., 0.], dtype=np.float32)

categorical_probs = config.uniform_prob * random_aa + config.profile_prob * protein['hhblits_profile'] + \
config.same_prob * one_hot(22, protein['msa'])

pad_shapes = [[0, 0] for _ in range(len(categorical_probs.shape))]
pad_shapes[-1][1] = 1
mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob
assert mask_prob >= 0.
categorical_probs = np.pad(categorical_probs, pad_shapes, constant_values=(mask_prob,))

mask_position = np.random.uniform(size=shape_list(protein['msa']), low=0, high=1) < replace_fraction

bert_msa = shaped_categorical(categorical_probs)
bert_msa = np.where(mask_position, bert_msa, protein['msa'])

protein['bert_mask'] = mask_position.astype(np.int32)
protein['true_msa'] = protein['msa']
protein['msa'] = bert_msa

return protein


@curry1
def nearest_neighbor_clusters(protein, gap_agreement_weight=0.):
"""Assign each extra MSA sequence to its nearest neighbor in sampled MSA."""
weights = np.concatenate([np.ones(21), gap_agreement_weight * np.ones(1), np.zeros(1)], 0)

sample_one_hot = protein['msa_mask'][:, :, None] * one_hot(23, protein['msa'])
num_seq, num_res, _ = shape_list(sample_one_hot)

array_extra_msa_mask = protein['extra_msa_mask']
if array_extra_msa_mask.any():
extra_one_hot = protein['extra_msa_mask'][:, :, None] * one_hot(23, protein['extra_msa'])
extra_num_seq, _, _ = shape_list(extra_one_hot)

agreement = np.matmul(
np.reshape(extra_one_hot, [extra_num_seq, num_res * 23]),
np.reshape(sample_one_hot * weights, [num_seq, num_res * 23]).T)
protein['extra_cluster_assignment'] = np.argmax(agreement, axis=1)
else:
protein['extra_cluster_assignment'] = np.array([])

return protein


@curry1
def summarize_clusters(protein):
"""Produce profile and deletion_matrix_mean within each cluster."""
num_seq = shape_list(protein['msa'])[0]

def csum(x):
result = []
for i in range(num_seq):
result.append(np.sum(x[np.where(protein['extra_cluster_assignment'] == i)], axis=0))
return np.array(result)

mask = protein['extra_msa_mask']
mask_counts = 1e-6 + protein['msa_mask'] + csum(mask) # Include center

msa_sum = csum(mask[:, :, None] * np.zeros(mask.shape + (23,), np.float32))
msa_sum += one_hot(23, protein['msa']) # Original sequence
protein['cluster_profile'] = msa_sum / mask_counts[:, :, None]

del msa_sum

del_sum = csum(mask * protein['extra_deletion_matrix'])
del_sum += protein['deletion_matrix'] # Original sequence
protein['cluster_deletion_mean'] = del_sum / mask_counts
del del_sum

return protein


@curry1
def crop_extra_msa(protein, max_extra_msa):
"""MSA features are cropped so only `max_extra_msa` sequences are kept."""
if protein['extra_msa'].any():
num_seq = protein['extra_msa'].shape[0]
num_sel = np.minimum(max_extra_msa, num_seq)
shuffled = list(range(num_seq))
np.random.shuffle(shuffled)
select_indices = shuffled[:num_sel]
for k in _MSA_FEATURE_NAMES:
if 'extra_' + k in protein:
protein['extra_' + k] = protein['extra_' + k][select_indices]

return protein


def delete_extra_msa(protein):
for k in _MSA_FEATURE_NAMES:
if 'extra_' + k in protein:
del protein['extra_' + k]
return protein


@curry1
def make_msa_feat(protein):
"""Create and concatenate MSA features."""
has_break = np.clip(protein['between_segment_residues'].astype(np.float32), np.array(0), np.array(1))
aatype_1hot = one_hot(21, protein['aatype'])

target_feat = [np.expand_dims(has_break, axis=-1), aatype_1hot]

msa_1hot = one_hot(23, protein['msa'])
has_deletion = np.clip(protein['deletion_matrix'], np.array(0), np.array(1))
deletion_value = np.arctan(protein['deletion_matrix'] / 3.) * (2. / np.pi)

msa_feat = [msa_1hot, np.expand_dims(has_deletion, axis=-1), np.expand_dims(deletion_value, axis=-1)]

if 'cluster_profile' in protein:
deletion_mean_value = (np.arctan(protein['cluster_deletion_mean'] / 3.) * (2. / np.pi))
msa_feat.extend([protein['cluster_profile'], np.expand_dims(deletion_mean_value, axis=-1)])

if 'extra_deletion_matrix' in protein:
protein['extra_has_deletion'] = np.clip(protein['extra_deletion_matrix'], np.array(0), np.array(1))
protein['extra_deletion_value'] = np.arctan(protein['extra_deletion_matrix'] / 3.) * (2. / np.pi)

protein['msa_feat'] = np.concatenate(msa_feat, axis=-1)
protein['target_feat'] = np.concatenate(target_feat, axis=-1)
return protein


@curry1
def select_feat(protein, feature_list):
return {k: v for k, v in protein.items() if k in feature_list}


@curry1
def random_crop_to_size(protein, crop_size, max_templates, shape_schema,
subsample_templates=False):
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
seq_length = protein['seq_length']
seq_length_int = int(seq_length)
if 'template_mask' in protein:
num_templates = np.array(shape_list(protein['template_mask'])[0], np.int32)
else:
num_templates = np.array(0, np.int32)
num_res_crop_size = np.minimum(seq_length, crop_size)
num_res_crop_size_int = int(num_res_crop_size)

if subsample_templates:
templates_crop_start = make_random_seed(size=(), seed_maker_t=seed_maker(), low=0, high=num_templates + 1)
else:
templates_crop_start = 0

num_templates_crop_size = np.minimum(num_templates - templates_crop_start, max_templates)
num_templates_crop_size_int = int(num_templates_crop_size)

num_res_crop_start = int(make_random_seed(size=(), seed_maker_t=seed_maker(), low=0,
high=seq_length_int - num_res_crop_size_int + 1))

templates_select_indices = np.argsort(make_random_seed(size=[num_templates], seed_maker_t=seed_maker()))

for k, v in protein.items():
if k not in shape_schema or ('template' not in k and NUM_RES not in shape_schema[k]):
continue

if k.startswith('template') and subsample_templates:
v = v[templates_select_indices]

crop_sizes = []
crop_starts = []
for i, (dim_size, dim) in enumerate(zip(shape_schema[k], shape_list(v))):
is_num_res = (dim_size == NUM_RES)
if i == 0 and k.startswith('template'):
crop_size = num_templates_crop_size_int
crop_start = templates_crop_start
else:
crop_start = num_res_crop_start if is_num_res else 0
crop_size = (num_res_crop_size_int if is_num_res else (-1 if dim is None else dim))
crop_sizes.append(crop_size)
crop_starts.append(crop_start)
if len(v.shape) == 1:
protein[k] = v[crop_starts[0]:crop_starts[0] + crop_sizes[0]]
elif len(v.shape) == 2:
protein[k] = v[crop_starts[0]:crop_starts[0] + crop_sizes[0], crop_starts[1]:crop_starts[1] + crop_sizes[1]]
elif len(v.shape) == 3:
protein[k] = v[crop_starts[0]:crop_starts[0] + crop_sizes[0], crop_starts[1]:crop_starts[1] + crop_sizes[1],
crop_starts[2]:crop_starts[2] + crop_sizes[2]]
else:
protein[k] = v[crop_starts[0]:crop_starts[0] + crop_sizes[0], crop_starts[1]:crop_starts[1] + crop_sizes[1],
crop_starts[2]:crop_starts[2] + crop_sizes[2], crop_starts[3]:crop_starts[3] + crop_sizes[3]]

protein['seq_length'] = num_res_crop_size
return protein


@curry1
def make_fixed_size(protein, shape_schema, msa_cluster_size, extra_msa_size,
num_res, num_templates=0):
"""Guess at the MSA and sequence dimensions to make fixed size."""

pad_size_map = {
NUM_RES: num_res,
NUM_MSA_SEQ: msa_cluster_size,
NUM_EXTRA_SEQ: extra_msa_size,
NUM_TEMPLATES: num_templates,
}

for k, v in protein.items():
if k == 'extra_cluster_assignment':
continue
shape = list(v.shape)
schema = shape_schema[k]
assert len(shape) == len(schema), f'Rank mismatch between shape and shape schema for {k}: {shape} vs {schema}'
pad_size = [pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)]
padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)]
if padding:
protein[k] = np.pad(v, padding)
protein[k].reshape(pad_size)

return protein


@curry1
def crop_templates(protein, max_templates):
for k, v in protein.items():
if k.startswith('template_'):
protein[k] = v[:max_templates]
return protein


def correct_msa_restypes(protein):
"""Correct MSA restype to have the same order as residue_constants."""
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = np.array(new_order_list, dtype=protein['msa'].dtype)
protein['msa'] = new_order[protein['msa']]

perm_matrix = np.zeros((22, 22), dtype=np.float32)
perm_matrix[range(len(new_order_list)), new_order_list] = 1.
return protein


@curry1
def add_distillation_flag(protein, distillation):
protein['is_distillation'] = np.array(float(distillation), dtype=np.float32)
return protein


def squeeze_features(protein):
"""Remove singleton and repeated dimensions in protein features."""
protein['aatype'] = np.argmax(protein['aatype'], axis=-1)
for k in ['msa', 'num_alignments', 'seq_length', 'sequence', 'superfamily', 'deletion_matrix',
'resolution', 'between_segment_residues', 'residue_index', 'template_all_atom_masks']:
if k in protein:
final_dim = shape_list(protein[k])[-1]
if isinstance(final_dim, int) and final_dim == 1:
protein[k] = np.squeeze(protein[k], axis=-1)

for k in ['seq_length', 'num_alignments']:
if k in protein:
protein[k] = protein[k][0] # Remove fake sequence dimension
return protein


def cast_64bit_ints(protein):
for k, v in protein.items():
if v.dtype == np.int64:
protein[k] = v.astype(np.int32)
return protein


def make_seq_mask(protein):
protein['seq_mask'] = np.ones(shape_list(protein['aatype']), dtype=np.float32)
return protein


def make_msa_mask(protein):
"""Mask features are all ones, but will later be zero-padded."""
protein['msa_mask'] = np.ones(shape_list(protein['msa']), dtype=np.float32)
protein['msa_row_mask'] = np.ones(shape_list(protein['msa'])[0], dtype=np.float32)
return protein


def make_hhblits_profile(protein):
"""Compute the HHblits MSA profile if not already present."""
if 'hhblits_profile' in protein:
return protein

protein['hhblits_profile'] = np.mean(one_hot(22, protein['msa']), axis=0)
return protein


def make_random_crop_to_size_seed(protein):
"""Random seed for cropping residues and templates."""
protein['random_crop_to_size_seed'] = np.array(make_random_seed([2], seed_maker_t=seed_maker()), np.int32)
return protein


def fix_templates_aatype(protein):
"""Fixes aatype encoding of templates."""
protein['template_aatype'] = np.argmax(protein['template_aatype'], axis=-1).astype(np.int32)
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = np.array(new_order_list, np.int32)
protein['template_aatype'] = new_order[protein['template_aatype']]
return protein


def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
"""Create pseudo beta features."""
is_gly = np.equal(aatype, residue_constants.restype_order['G'])
ca_idx = residue_constants.atom_order['CA']
cb_idx = residue_constants.atom_order['CB']
pseudo_beta = np.where(
np.tile(is_gly[..., None].astype("int32"), [1,] * len(is_gly.shape) + [3,]).astype("bool"),
all_atom_positions[..., ca_idx, :],
all_atom_positions[..., cb_idx, :])
if all_atom_masks is not None:
pseudo_beta_mask = np.where(is_gly, all_atom_masks[..., ca_idx], all_atom_masks[..., cb_idx])
pseudo_beta_mask = pseudo_beta_mask.astype(np.float32)
return pseudo_beta, pseudo_beta_mask
return pseudo_beta


@curry1
def make_pseudo_beta(protein, prefix=''):
"""Create pseudo-beta (alpha for glycine) position and mask."""
assert prefix in ['', 'template_']
protein[prefix + 'pseudo_beta'], protein[prefix + 'pseudo_beta_mask'] = (
pseudo_beta_fn(
protein['template_aatype' if prefix else 'all_atom_aatype'],
protein[prefix + 'all_atom_positions'],
protein['template_all_atom_masks' if prefix else 'all_atom_mask']))
return protein


def make_atom14_masks(protein):
"""Construct denser atom positions (14 dimensions instead of 37)."""
restype_atom14_to_atom37 = []
restype_atom37_to_atom14 = []
restype_atom14_mask = []

for rt in residue_constants.restypes:
atom_names = residue_constants.restype_name_to_atom14_names[residue_constants.restype_1to3[rt]]

restype_atom14_to_atom37.append([(residue_constants.atom_order[name] if name else 0) for name in atom_names])

atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append([(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in residue_constants.atom_types])

restype_atom14_mask.append([(1. if name else 0.) for name in atom_names])

restype_atom14_to_atom37.append([0] * 14)
restype_atom37_to_atom14.append([0] * 37)
restype_atom14_mask.append([0.] * 14)

restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, np.int32)
restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, np.int32)
restype_atom14_mask = np.array(restype_atom14_mask, np.float32)

residx_atom14_to_atom37 = restype_atom14_to_atom37[protein['aatype']]
residx_atom14_mask = restype_atom14_mask[protein['aatype']]

protein['atom14_atom_exists'] = residx_atom14_mask
protein['residx_atom14_to_atom37'] = residx_atom14_to_atom37

residx_atom37_to_atom14 = restype_atom37_to_atom14[protein['aatype']]
protein['residx_atom37_to_atom14'] = residx_atom37_to_atom14

restype_atom37_mask = np.zeros([21, 37], np.float32)
for restype, restype_letter in enumerate(residue_constants.restypes):
restype_name = residue_constants.restype_1to3[restype_letter]
atom_names = residue_constants.residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = residue_constants.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1

residx_atom37_mask = restype_atom37_mask[protein['aatype']]
protein['atom37_atom_exists'] = residx_atom37_mask

return protein


def shape_list(x):
"""Return list of dimensions of an array."""
x = np.array(x)

if x.ndim is None:
return x.shape

static = x.shape

ret = []
for _, dim in enumerate(static):
ret.append(dim)
return ret

+ 294
- 0
reproduce/AlphaFold2-Chinese/data/feature/feature_extraction.py View File

@@ -0,0 +1,294 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""feature extraction"""
import copy

import numpy as np

from commons import residue_constants
from data.feature import data_transforms

NUM_RES = "num residues placeholder"
NUM_SEQ = "length msa placeholder"
NUM_TEMPLATES = "num templates placeholder"

FEATURES = {
"aatype": (np.float32, [NUM_RES, 21]),
"between_segment_residues": (np.int64, [NUM_RES, 1]),
"deletion_matrix": (np.float32, [NUM_SEQ, NUM_RES, 1]),
"msa": (np.int64, [NUM_SEQ, NUM_RES, 1]),
"num_alignments": (np.int64, [NUM_RES, 1]),
"residue_index": (np.int64, [NUM_RES, 1]),
"seq_length": (np.int64, [NUM_RES, 1]),
"all_atom_positions": (np.float32, [NUM_RES, residue_constants.atom_type_num, 3]),
"all_atom_mask": (np.int64, [NUM_RES, residue_constants.atom_type_num]),
"resolution": (np.float32, [1]),
"template_domain_names": (str, [NUM_TEMPLATES]),
"template_sum_probs": (np.float32, [NUM_TEMPLATES, 1]),
"template_aatype": (np.float32, [NUM_TEMPLATES, NUM_RES, 22]),
"template_all_atom_positions": (np.float32, [NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 3]),
"template_all_atom_masks": (np.float32, [NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 1]),
}


def nonensembled_map_fns(data_config):
"""Input pipeline functions which are not ensembled."""
common_cfg = data_config.common

map_fns = [
data_transforms.correct_msa_restypes,
data_transforms.add_distillation_flag(False),
data_transforms.cast_64bit_ints,
data_transforms.squeeze_features,
data_transforms.randomly_replace_msa_with_unknown(0.0),
data_transforms.make_seq_mask,
data_transforms.make_msa_mask,
data_transforms.make_hhblits_profile,
data_transforms.make_random_crop_to_size_seed,
]
if common_cfg.use_templates:
map_fns.extend([data_transforms.fix_templates_aatype, data_transforms.make_pseudo_beta('template_')])
map_fns.extend([data_transforms.make_atom14_masks,])

return map_fns


def ensembled_map_fns(data_config):
"""Input pipeline functions that can be ensembled and averaged."""
common_cfg = data_config.common
eval_cfg = data_config.eval

map_fns = []

if common_cfg.reduce_msa_clusters_by_max_templates:
pad_msa_clusters = eval_cfg.max_msa_clusters - eval_cfg.max_templates
else:
pad_msa_clusters = eval_cfg.max_msa_clusters

max_msa_clusters = pad_msa_clusters
max_extra_msa = common_cfg.max_extra_msa

map_fns.append(data_transforms.sample_msa(max_msa_clusters, keep_extra=True))

if 'masked_msa' in common_cfg:
map_fns.append(data_transforms.make_masked_msa(common_cfg.masked_msa, eval_cfg.masked_msa_replace_fraction))

if common_cfg.msa_cluster_features:
map_fns.append(data_transforms.nearest_neighbor_clusters())
map_fns.append(data_transforms.summarize_clusters())

if max_extra_msa:
map_fns.append(data_transforms.crop_extra_msa(max_extra_msa))
else:
map_fns.append(data_transforms.delete_extra_msa)

map_fns.append(data_transforms.make_msa_feat())

crop_feats = dict(eval_cfg.feat)

if eval_cfg.fixed_size:
map_fns.append(data_transforms.select_feat(list(crop_feats)))
map_fns.append(data_transforms.random_crop_to_size(
eval_cfg.crop_size,
eval_cfg.max_templates,
crop_feats,
eval_cfg.subsample_templates))
map_fns.append(data_transforms.make_fixed_size(
crop_feats,
pad_msa_clusters,
common_cfg.max_extra_msa,
eval_cfg.crop_size,
eval_cfg.max_templates))
else:
map_fns.append(data_transforms.crop_templates(eval_cfg.max_templates))

return map_fns


def process_arrays_from_config(arrays, data_config):
"""Apply filters and maps to an existing dataset, based on the config."""

def wrap_ensemble_fn(data, i):
"""Function to be mapped over the ensemble dimension."""
d = data.copy()
fns = ensembled_map_fns(data_config)
fn = data_transforms.compose(fns)
d['ensemble_index'] = i
return fn(d)

eval_cfg = data_config.eval
arrays = data_transforms.compose(nonensembled_map_fns(data_config))(arrays)
arrays_0 = wrap_ensemble_fn(arrays, np.array(0, np.int32))
num_ensemble = eval_cfg.num_ensemble
if data_config.common.resample_msa_in_recycling:
num_ensemble *= data_config.common.num_recycle + 1

result_array = {x: () for x in arrays_0.keys()}
if num_ensemble > 1:
for i in range(num_ensemble):
arrays_t = wrap_ensemble_fn(arrays, np.array(i, np.int32))
for key in arrays_0.keys():
result_array[key] += (arrays_t[key][None],)
for key in arrays_0.keys():
result_array[key] = np.concatenate(result_array[key], axis=0)
else:
result_array = {key: arrays_0[key][None] for key in arrays_0.keys()}
return result_array


def feature_shape(feature_name,
num_residues,
msa_length,
num_templates,
features=None):
"""Get the shape for the given feature name."""
features = features or FEATURES
if feature_name.endswith("_unnormalized"):
feature_name = feature_name[:-13]

unused_dtype, raw_sizes = features[feature_name]
replacements = {NUM_RES: num_residues,
NUM_SEQ: msa_length}

if num_templates is not None:
replacements[NUM_TEMPLATES] = num_templates

sizes = [replacements.get(dimension, dimension) for dimension in raw_sizes]
for dimension in sizes:
if isinstance(dimension, str):
raise ValueError("Could not parse %s (shape: %s) with values: %s" % (
feature_name, raw_sizes, replacements))
size_r = [int(x) for x in sizes]
return size_r


def parse_reshape_logic(parsed_features, features, num_template, key=None):
"""Transforms parsed serial features to the correct shape."""
num_residues = np.reshape(parsed_features['seq_length'].astype(np.int32), (-1,))[0]

if "num_alignments" in parsed_features:
num_msa = np.reshape(parsed_features["num_alignments"].astype(np.int32), (-1,))[0]
else:
num_msa = 0

if key is not None and "key" in features:
parsed_features["key"] = [key] # Expand dims from () to (1,).

for k, v in parsed_features.items():
new_shape = feature_shape(
feature_name=k,
num_residues=num_residues,
msa_length=num_msa,
num_templates=num_template,
features=features)
new_shape_size = 1
for dim in new_shape:
new_shape_size *= dim

if np.size(v) != new_shape_size:
raise ValueError("the size of feature {} ({}) could not be reshaped into {}"
"".format(k, np.size(v), new_shape))

if "template" not in k:
if np.size(v) <= 0:
raise ValueError("The feature {} is not empty.".format(k))
parsed_features[k] = np.reshape(v, new_shape)

return parsed_features


def _make_features_metadata(feature_names):
"""Makes a feature name to type and shape mapping from a list of names."""
required_features = ["sequence", "domain_name", "template_domain_names"]
feature_names = list(set(feature_names) - set(required_features))

features_metadata = {name: FEATURES[name] for name in feature_names}
return features_metadata


def np_to_array_dict(np_example, features):
"""Creates dict of arrays."""
features_metadata = _make_features_metadata(features)
array_dict = {k: v for k, v in np_example.items() if k in features_metadata}
if "template_domain_names" in np_example:
num_template = len(np_example["template_domain_names"])
else:
num_template = 0

array_dict = parse_reshape_logic(array_dict, features_metadata, num_template)
array_dict['template_mask'] = np.ones([num_template], np.float32)
return array_dict


def make_data_config(config, num_res):
"""Makes a data config for the input pipeline."""
cfg = copy.deepcopy(config.data)

feature_names = cfg.common.unsupervised_features
if cfg.common.use_templates:
feature_names += cfg.common.template_features

with cfg.unlocked():
cfg.eval.crop_size = num_res

return cfg, feature_names


def custom_padding(config, arrays, dims):
"""Pad array to fixed size."""
step_size = config.seq_length

res_length = arrays[0].shape[dims[0]]
padding_size = step_size - res_length
for i, arr in enumerate(arrays):
if dims[i] == -1:
continue
extra_array_shape = list(arr.shape)
extra_array_shape[dims[i]] = padding_size
extra_array = np.zeros(extra_array_shape, dtype=arr.dtype)
arrays[i] = np.concatenate((arr, extra_array), axis=dims[i])
return arrays


def process_features(raw_features, config, global_config):
"""Preprocesses NumPy feature dict using pipeline."""
num_res = int(raw_features['seq_length'][0])
cfg, feature_names = make_data_config(config, num_res=num_res)

if 'deletion_matrix_int' in raw_features:
raw_features['deletion_matrix'] = (raw_features.pop('deletion_matrix_int').astype(np.float32))

array_dict = np_to_array_dict(np_example=raw_features, features=feature_names)

features = process_arrays_from_config(array_dict, cfg)
features = {k: v for k, v in features.items() if v.dtype != 'O'}

extra_msa_length = global_config.extra_msa_length
ori_res_length = features["target_feat"].shape[1]
aatype = features["aatype"]
residue_index = features["residue_index"]
for key in ["extra_msa", "extra_has_deletion", "extra_deletion_value", "extra_msa_mask"]:
features[key] = features[key][:, :extra_msa_length]
input_keys = ['target_feat', 'msa_feat', 'msa_mask', 'seq_mask', 'aatype', 'template_aatype',
'template_all_atom_masks', 'template_all_atom_positions', 'template_mask',
'template_pseudo_beta_mask', 'template_pseudo_beta', 'template_sum_probs',
'extra_msa', 'extra_has_deletion', 'extra_deletion_value', 'extra_msa_mask',
'residx_atom37_to_atom14', 'atom37_atom_exists', 'residue_index']
arrays = [features[key] for key in input_keys]
dims = [1, 2, 2, 1, 1, 2, 2, 2, -1, 2, 2, -1, 2, 2, 2, 2, 1, 1, 1]
arrays = custom_padding(global_config, arrays, dims)
arrays = [array.astype(np.float16) if array.dtype == "float64" else array for array in arrays]
arrays = [array.astype(np.float16) if array.dtype == "float32" else array for array in arrays]
return arrays, aatype, residue_index, ori_res_length

+ 205
- 0
reproduce/AlphaFold2-Chinese/data/tools/data_process.py View File

@@ -0,0 +1,205 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''data process'''
import os
import hashlib
import re
import numpy as np
from commons import residue_constants
from data.tools.parsers import parse_fasta, parse_hhr, parse_a3m
from data.tools.templates import TemplateHitFeaturizer
from data.tools.data_tools import HHSearch

def get_hash(x):
return hashlib.sha1(x.encode()).hexdigest()

def run_mmseqs2(x, path, use_env=False):
'''run mmseqs2'''
a3m_files = [f"{path}/uniref.a3m"]
if use_env: a3m_files.append(f"{path}/bfd.mgnify30.metaeuk30.smag30.a3m")

# gather a3m lines
a3m_lines = {}
for a3m_file in a3m_files:
update_m, m = True, None
for line in open(a3m_file, "r"):
if line:
if "\x00" in line:
line = line.replace("\x00", "")
update_m = True
if line.startswith(">") and update_m:

m = int(line[2:6].rstrip())
update_m = False
if m not in a3m_lines: a3m_lines[m] = []
a3m_lines[m].append(line)

# return results
a3m_lines = ["".join(a3m_lines[key]) for key in a3m_lines]

if isinstance(x, str):
return a3m_lines[0]
return a3m_lines


def make_sequence_features(
sequence: str, description: str, num_res: int):
"""Constructs a feature dict of sequence features."""
features = {'aatype': residue_constants.sequence_to_onehot(sequence=sequence,
mapping=residue_constants.restype_order_with_x,
map_unknown_to_x=True),
'between_segment_residues': np.zeros((num_res,), dtype=np.int32),
'domain_name': np.array([description.encode('utf-8')], dtype=np.object_),
'residue_index': np.array(range(num_res), dtype=np.int32),
'seq_length': np.array([num_res] * num_res, dtype=np.int32),
'sequence': np.array([sequence.encode('utf-8')], dtype=np.object_)}
return features


def make_msa_features(
msas,
deletion_matrices):
"""Constructs a feature dict of MSA features."""
if not msas:
raise ValueError('At least one MSA must be provided.')

int_msa = []
deletion_matrix = []
seen_sequences = set()
for msa_index, msa in enumerate(msas):
if not msa:
raise ValueError(f'MSA {msa_index} must contain at least one sequence.')
for sequence_index, sequence in enumerate(msa):
if sequence in seen_sequences:
continue
seen_sequences.add(sequence)
int_msa.append(
[residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence])
deletion_matrix.append(deletion_matrices[msa_index][sequence_index])

num_res = len(msas[0][0])
num_alignments = len(int_msa)
features = {'deletion_matrix_int': np.array(deletion_matrix, dtype=np.int32),
'msa': np.array(int_msa, dtype=np.int32),
'num_alignments': np.array([num_alignments] * num_res, dtype=np.int32)}
return features


class DataPipeline:
"""Runs the alignment tools and assembles the input features."""

def __init__(self,
hhsearch_binary_path: str,
pdb70_database_path: str,
template_featurizer: TemplateHitFeaturizer,
result_path,
use_env=False):
"""Constructs a feature dict for a given FASTA file."""

self.hhsearch_pdb70_runner = HHSearch(
binary_path=hhsearch_binary_path,
databases=[pdb70_database_path])
self.template_featurizer = template_featurizer
self.result_path = result_path
self.use_env = use_env

def process(self, input_fasta_path):
"""Runs alignment tools on the input sequence and creates features."""
with open(input_fasta_path) as f:
input_fasta_str = f.read()
input_seqs, input_descs = parse_fasta(input_fasta_str)
if len(input_seqs) != 1:
raise ValueError(f'More than one input sequence found in {input_fasta_path}.')
input_sequence = input_seqs[0]
input_description = input_descs[0]

num_res = len(input_sequence)

# mmseq2
sequence = input_sequence
sequence = re.sub("[^A-Z:/]", "", sequence.upper())
sequence = re.sub(":+", ":", sequence)
sequence = re.sub("/+", "/", sequence)
sequence = re.sub("^[:/]+", "", sequence)
sequence = re.sub("[:/]+$", "", sequence)
ori_sequence = sequence
seqs = ori_sequence.replace("/", "").split(":")

a3m_lines = run_mmseqs2(seqs, path=self.result_path, use_env=self.use_env)

hhsearch_result = self.hhsearch_pdb70_runner.query(a3m_lines[0])
hhsearch_hits = parse_hhr(hhsearch_result)

msas, deletion_matrices = parse_a3m(a3m_lines[0])
templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence,
query_pdb_code=None,
query_release_date=None,
hhr_hits=hhsearch_hits)
sequence_features = make_sequence_features(
sequence=input_sequence,
description=input_description,
num_res=num_res)
msa_features = make_msa_features(
msas=(msas,),
deletion_matrices=(deletion_matrices,
))
return {**sequence_features, **msa_features, **templates_result.features}


def data_process(seq_name, args):
"""data_process"""

fasta_path = os.path.join(args.input_fasta_path, seq_name + '.fasta')
result_path = os.path.join(args.msa_result_path, "/result_" + str(seq_name))
if args.database_envdb_dir:
use_env = True
command = "sh ./data/tools/msa_search.sh mmseqs " + fasta_path + " " + result_path + " " + \
args.database_dir + " " + "\"\"" + " " + args.database_envdb_dir + " \"1\" \"0\" \"1\""
else:
use_env = False
command = "sh ./data/tools/msa_search.sh mmseqs " + fasta_path + " " + result_path + " " + \
args.database_dir + " " + "\"\"" + " \"\"" + " \"0\" \"0\" \"1\""
print('start mmseqs2 MSA')
print('command: ', command)
os.system(command)
print('mmseqs2 MSA successful')
print('use_env: ', use_env)
hhsearch_binary_path = args.hhsearch_binary_path

pdb70_database_path = args.pdb70_database_path
template_mmcif_dir = args.template_mmcif_dir
max_template_date = args.max_template_date
kalign_binary_path = args.kalign_binary_path
obsolete_pdbs_path = args.obsolete_pdbs_path

template_featurizer = TemplateHitFeaturizer(
mmcif_dir=template_mmcif_dir,
max_template_date=max_template_date,
max_hits=20,
kalign_binary_path=kalign_binary_path,
release_dates_path=None,
obsolete_pdbs_path=obsolete_pdbs_path)

data_pipeline = DataPipeline(

hhsearch_binary_path=hhsearch_binary_path,
pdb70_database_path=pdb70_database_path,
template_featurizer=template_featurizer,
result_path=result_path,
use_env=use_env)

feature_dict = data_pipeline.process(fasta_path)
return feature_dict

+ 428
- 0
reproduce/AlphaFold2-Chinese/data/tools/data_tools.py View File

@@ -0,0 +1,428 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''data tools'''
import glob
import os
import subprocess
import contextlib
import shutil
import tempfile
import time

from typing import Any, Mapping, Optional, Sequence

from absl import logging

_HHBLITS_DEFAULT_P = 20
_HHBLITS_DEFAULT_Z = 500


def _to_a3m(sequences: Sequence[str]) -> str:
"""Converts sequences to an a3m file."""
names = ['sequence %d' % i for i in range(1, len(sequences) + 1)]
a3m = []
for sequence, name in zip(sequences, names):
a3m.append(u'>' + name + u'\n')
a3m.append(sequence + u'\n')
return ''.join(a3m)


class Kalign:
"""Python wrapper of the Kalign binary."""

def __init__(self, *, binary_path: str):
"""Initializes the Python Kalign wrapper.

Args:
binary_path: The path to the Kalign binary.
"""
self.binary_path = binary_path

def align(self, sequences: Sequence[str]) -> str:
"""Aligns the sequences and returns the alignment in A3M string.

Args:
sequences: A list of query sequence strings. The sequences have to be at
least 6 residues long (Kalign requires this). Note that the order in
which you give the sequences might alter the output slightly as
different alignment tree might get constructed.

Returns:
A string with the alignment in a3m format.

Raises:
RuntimeError: If Kalign fails.
ValueError: If any of the sequences is less than 6 residues long.
"""
logging.info('Aligning %d sequences', len(sequences))

for s in sequences:
if len(s) < 6:
raise ValueError('Kalign requires all sequences to be at least 6 '
'residues long. Got %s (%d residues).' % (s, len(s)))

with tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta')
output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m')

with open(input_fasta_path, 'w') as f:
f.write(_to_a3m(sequences))

cmd = [self.binary_path, '-i', input_fasta_path, '-o', output_a3m_path, '-format', 'fasta',]

logging.info('Launching subprocess "%s"', ' '.join(cmd))
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

with timing('Kalign query'):
stdout, stderr = process.communicate()
retcode = process.wait()
logging.info('Kalign stdout:\n%s\n\nstderr:\n%s\n', stdout.decode('utf-8'), stderr.decode('utf-8'))

if retcode:
raise RuntimeError(
'Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n' % (stdout.decode('utf-8'), stderr.decode('utf-8')))

with open(output_a3m_path) as f:
a3m = f.read()

return a3m


@contextlib.contextmanager
def tmpdir_manager(base_dir: Optional[str] = None):
"""Context manager that deletes a temporary directory on exit."""
tmpdir = tempfile.mkdtemp(dir=base_dir)
try:
yield tmpdir
finally:
shutil.rmtree(tmpdir, ignore_errors=True)


@contextlib.contextmanager
def timing(msg: str):
logging.info('Started %s', msg)
tic = time.time()
yield
toc = time.time()
logging.info('Finished %s in %.3f seconds', msg, toc - tic)


class HHBlits:
"""Python wrapper of the HHblits binary."""

def __init__(self,
*,
binary_path: str,
databases: Sequence[str],
n_cpu: int = 4,
n_iter: int = 3,
e_value: float = 0.001,
maxseq: int = 1_000_000,
realign_max: int = 100_000,
maxfilt: int = 100_000,
min_prefilter_hits: int = 1000,
all_seqs: bool = False,
alt: Optional[int] = None,
p: int = _HHBLITS_DEFAULT_P,
z: int = _HHBLITS_DEFAULT_Z):
"""Initializes the Python HHblits wrapper.

Args:
binary_path: The path to the HHblits executable.
databases: A sequence of HHblits database paths. This should be the
common prefix for the database files (i.e. up to but not including
_hhm.ffindex etc.)
n_cpu: The number of CPUs to give HHblits.
n_iter: The number of HHblits iterations.
e_value: The E-value, see HHblits docs for more details.
maxseq: The maximum number of rows in an input alignment. Note that this
parameter is only supported in HHBlits version 3.1 and higher.
realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500.
maxfilt: Max number of hits allowed to pass the 2nd prefilter.
HHblits default: 20000.
min_prefilter_hits: Min number of hits to pass prefilter.
HHblits default: 100.
all_seqs: Return all sequences in the MSA / Do not filter the result MSA.
HHblits default: False.
alt: Show up to this many alternative alignments.
p: Minimum Prob for a hit to be included in the output hhr file.
HHblits default: 20.
z: Hard cap on number of hits reported in the hhr file.
HHblits default: 500. NB: The relevant HHblits flag is -Z not -z.

Raises:
RuntimeError: If HHblits binary not found within the path.
"""
self.binary_path = binary_path
self.databases = databases

for database_path in self.databases:
if not glob.glob(database_path + '_*'):
logging.error('Could not find HHBlits database %s', database_path)
raise ValueError(f'Could not find HHBlits database {database_path}')

self.n_cpu = n_cpu
self.n_iter = n_iter
self.e_value = e_value
self.maxseq = maxseq
self.realign_max = realign_max
self.maxfilt = maxfilt
self.min_prefilter_hits = min_prefilter_hits
self.all_seqs = all_seqs
self.alt = alt
self.p = p
self.z = z

def query(self, input_fasta_path: str) -> Mapping[str, Any]:
"""Queries the database using HHblits."""
with tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
a3m_path = os.path.join(query_tmp_dir, 'output.a3m')

db_cmd = []
for db_path in self.databases:
db_cmd.append('-d')
db_cmd.append(db_path)
cmd = [
self.binary_path,
'-i', input_fasta_path,
'-cpu', str(self.n_cpu),
'-oa3m', a3m_path,
'-o', '/dev/null',
'-n', str(self.n_iter),
'-e', str(self.e_value),
'-maxseq', str(self.maxseq),
'-realign_max', str(self.realign_max),
'-maxfilt', str(self.maxfilt),
'-min_prefilter_hits', str(self.min_prefilter_hits)]
if self.all_seqs:
cmd += ['-all']
if self.alt:
cmd += ['-alt', str(self.alt)]
if self.p != _HHBLITS_DEFAULT_P:
cmd += ['-p', str(self.p)]
if self.z != _HHBLITS_DEFAULT_Z:
cmd += ['-Z', str(self.z)]
cmd += db_cmd

logging.info('Launching subprocess "%s"', ' '.join(cmd))
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

with timing('HHblits query'):
stdout, stderr = process.communicate()
retcode = process.wait()

if retcode:
# Logs have a 15k character limit, so log HHblits error line by
# line.
logging.error('HHblits failed. HHblits stderr begin:')
for error_line in stderr.decode('utf-8').splitlines():
if error_line.strip():
logging.error(error_line.strip())
logging.error('HHblits stderr end')
raise RuntimeError('HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n' % (
stdout.decode('utf-8'), stderr[:500_000].decode('utf-8')))

with open(a3m_path) as f:
a3m = f.read()

raw_output = dict(
a3m=a3m,
output=stdout,
stderr=stderr,
n_iter=self.n_iter,
e_value=self.e_value)
return raw_output


class HHSearch:
"""Python wrapper of the HHsearch binary."""

def __init__(self,
*,
binary_path: str,
databases: Sequence[str],
maxseq: int = 1_000_000):
"""Initializes the Python HHsearch wrapper.

Args:
binary_path: The path to the HHsearch executable.
databases: A sequence of HHsearch database paths. This should be the
common prefix for the database files (i.e. up to but not including
_hhm.ffindex etc.)
maxseq: The maximum number of rows in an input alignment. Note that this
parameter is only supported in HHBlits version 3.1 and higher.

Raises:
RuntimeError: If HHsearch binary not found within the path.
"""
self.binary_path = binary_path
self.databases = databases
self.maxseq = maxseq

for database_path in self.databases:
if not glob.glob(database_path + '_*'):
logging.error(
'Could not find HHsearch database %s',
database_path)
raise ValueError(
f'Could not find HHsearch database {database_path}')

def query(self, a3m: str) -> str:
"""Queries the database using HHsearch using a given a3m."""
with tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
input_path = os.path.join(query_tmp_dir, 'query.a3m')
hhr_path = os.path.join(query_tmp_dir, 'output.hhr')
with open(input_path, 'w') as f:
f.write(a3m)

db_cmd = []
for db_path in self.databases:
db_cmd.append('-d')
db_cmd.append(db_path)
cmd = [self.binary_path,
'-i', input_path,
'-o', hhr_path,
'-maxseq', str(self.maxseq),
'-cpu', '8',] + db_cmd

logging.info('Launching subprocess "%s"', ' '.join(cmd))
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
with timing('HHsearch query'):
stdout, stderr = process.communicate()
retcode = process.wait()
if retcode:
# Stderr is truncated to prevent proto size errors in Beam.
raise RuntimeError('HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % (
stdout.decode('utf-8'), stderr[:100_000].decode('utf-8')))
with open(hhr_path) as f:
hhr = f.read()
return hhr


class Jackhmmer:
"""Python wrapper of the Jackhmmer binary."""

def __init__(self,
*,
binary_path: str,
database_path: str,
n_cpu: int = 8,
n_iter: int = 1,
e_value: float = 0.0001,
z_value: Optional[int] = None,
get_tblout: bool = False,
filter_f1: float = 0.0005,
filter_f2: float = 0.00005,
filter_f3: float = 0.0000005,
incdom_e: Optional[float] = None,
dom_e: Optional[float] = None):
"""Initializes the Python Jackhmmer wrapper.

Args:
binary_path: The path to the jackhmmer executable.
database_path: The path to the jackhmmer database (FASTA format).
n_cpu: The number of CPUs to give Jackhmmer.
n_iter: The number of Jackhmmer iterations.
e_value: The E-value, see Jackhmmer docs for more details.
z_value: The Z-value, see Jackhmmer docs for more details.
get_tblout: Whether to save tblout string.
filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off.
filter_f2: Viterbi pre-filter, set to >1.0 to turn off.
filter_f3: Forward pre-filter, set to >1.0 to turn off.
incdom_e: Domain e-value criteria for inclusion of domains in MSA/next
round.
dom_e: Domain e-value criteria for inclusion in tblout.
"""
self.binary_path = binary_path
self.database_path = database_path

if not os.path.exists(self.database_path):
logging.error(
'Could not find Jackhmmer database %s',
database_path)
raise ValueError(
f'Could not find Jackhmmer database {database_path}')

self.n_cpu = n_cpu
self.n_iter = n_iter
self.e_value = e_value
self.z_value = z_value
self.filter_f1 = filter_f1
self.filter_f2 = filter_f2
self.filter_f3 = filter_f3
self.incdom_e = incdom_e
self.dom_e = dom_e
self.get_tblout = get_tblout

def query(self, input_fasta_path: str) -> Mapping[str, Any]:
"""Queries the database using Jackhmmer."""
with tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
sto_path = os.path.join(query_tmp_dir, 'output.sto')

# The F1/F2/F3 are the expected proportion to pass each of the filtering
# stages (which get progressively more expensive), reducing these
# speeds up the pipeline at the expensive of sensitivity. They are
# currently set very low to make querying Mgnify run in a reasonable
# amount of time.
cmd_flags = [
# Don't pollute stdout with Jackhmmer output.
'-o', '/dev/null',
'-A', sto_path,
'--noali',
'--F1', str(self.filter_f1),
'--F2', str(self.filter_f2),
'--F3', str(self.filter_f3),
'--incE', str(self.e_value),
# Report only sequences with E-values <= x in per-sequence
# output.
'-E', str(self.e_value),
'--cpu', str(self.n_cpu),
'-N', str(self.n_iter)
]
if self.get_tblout:
tblout_path = os.path.join(query_tmp_dir, 'tblout.txt')
cmd_flags.extend(['--tblout', tblout_path])

if self.z_value:
cmd_flags.extend(['-Z', str(self.z_value)])

if self.dom_e is not None:
cmd_flags.extend(['--domE', str(self.dom_e)])

if self.incdom_e is not None:
cmd_flags.extend(['--incdomE', str(self.incdom_e)])

cmd = [self.binary_path] + cmd_flags + [input_fasta_path, self.database_path]

logging.info('Launching subprocess "%s"', ' '.join(cmd))
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
with timing(f'Jackhmmer ({os.path.basename(self.database_path)}) query'):
_, stderr = process.communicate()
retcode = process.wait()

if retcode:
raise RuntimeError('Jackhmmer failed\nstderr:\n%s\n' % stderr.decode('utf-8'))

# Get e-values for each target name
tbl = ''
if self.get_tblout:
with open(tblout_path) as f:
tbl = f.read()

with open(sto_path) as f:
sto = f.read()

raw_output = dict(sto=sto, tbl=tbl, stderr=stderr, n_iter=self.n_iter, e_value=self.e_value)
return raw_output

+ 393
- 0
reproduce/AlphaFold2-Chinese/data/tools/mmcif_parsing.py View File

@@ -0,0 +1,393 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Parses the mmCIF file format."""
import collections
import io
import dataclasses
from typing import Any, Mapping, Optional, Sequence, Tuple

from absl import logging
from Bio import PDB
from Bio.Data import SCOPData


# Type aliases:
ChainId = str
PdbHeader = Mapping[str, Any]
PDBSTRUCTURE = PDB.Structure.Structure
SeqRes = str
MmCIFDict = Mapping[str, Sequence[str]]


@dataclasses.dataclass(frozen=True)
class Monomer:
id: str
num: int


# Note - mmCIF format provides no guarantees on the type of author-assigned
# sequence numbers. They need not be integers.
@dataclasses.dataclass(frozen=True)
class AtomSite:
residue_name: str
author_chain_id: str
mmcif_chain_id: str
author_seq_num: str
mmcif_seq_num: int
insertion_code: str
hetatm_atom: str
model_num: int


# Used to map SEQRES index to a residue in the structure.
@dataclasses.dataclass(frozen=True)
class ResiduePosition:
chain_id: str
residue_number: int
insertion_code: str


@dataclasses.dataclass(frozen=True)
class ResidueAtPosition:
position: Optional[ResiduePosition]
name: str
is_missing: bool
hetflag: str


@dataclasses.dataclass(frozen=True)
class MmcifObject:
"""Representation of a parsed mmCIF file.

Contains:
file_id: A meaningful name, e.g. a pdb_id. Should be unique amongst all
files being processed.
header: Biopython header.
structure: Biopython structure.
chain_to_seqres: Dict mapping chain_id to 1 letter amino acid sequence. E.g.
{'A': 'ABCDEFG'}
seqres_to_structure: Dict; for each chain_id contains a mapping between
SEQRES index and a ResidueAtPosition. e.g. {'A': {0: ResidueAtPosition,
1: ResidueAtPosition,
...}}
raw_string: The raw string used to construct the MmcifObject.
"""
file_id: str
header: PdbHeader
structure: PDBSTRUCTURE
chain_to_seqres: Mapping[ChainId, SeqRes]
seqres_to_structure: Mapping[ChainId, Mapping[int, ResidueAtPosition]]
raw_string: Any


@dataclasses.dataclass(frozen=True)
class ParsingResult:
"""Returned by the parse function.

Contains:
mmcif_object: A MmcifObject, may be None if no chain could be successfully
parsed.
errors: A dict mapping (file_id, chain_id) to any exception generated.
"""
mmcif_object: Optional[MmcifObject]
errors: Mapping[Tuple[str, str], Any]


class ParseError(Exception):
"""An error indicating that an mmCIF file could not be parsed."""


def mmcif_loop_to_list(prefix: str,
parsed_info: MmCIFDict) -> Sequence[Mapping[str, str]]:
"""Extracts loop associated with a prefix from mmCIF data as a list.

Reference for loop_ in mmCIF:
http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html

Args:
prefix: Prefix shared by each of the data items in the loop.
e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
_entity_poly_seq.mon_id. Should include the trailing period.
parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
parser.

Returns:
Returns a list of dicts; each dict represents 1 entry from an mmCIF loop.
"""
cols = []
data = []
for key, value in parsed_info.items():
if key.startswith(prefix):
cols.append(key)
data.append(value)

assert all([len(xs) == len(data[0]) for xs in data]), (
'mmCIF error: Not all loops are the same length: %s' % cols)

return [dict(zip(cols, xs)) for xs in zip(*data)]


def mmcif_loop_to_dict(prefix: str,
index: str,
parsed_info: MmCIFDict,
) -> Mapping[str, Mapping[str, str]]:
"""Extracts loop associated with a prefix from mmCIF data as a dictionary.

Args:
prefix: Prefix shared by each of the data items in the loop.
e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
_entity_poly_seq.mon_id. Should include the trailing period.
index: Which item of loop data should serve as the key.
parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
parser.

Returns:
Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop,
indexed by the index column.
"""
entries = mmcif_loop_to_list(prefix, parsed_info)
return {entry[index]: entry for entry in entries}


def parse(*,
file_id: str,
mmcif_string: str,
catch_all_errors: bool = True) -> ParsingResult:
"""Entry point, parses an mmcif_string.

Args:
file_id: A string identifier for this file. Should be unique within the
collection of files being processed.
mmcif_string: Contents of an mmCIF file.
catch_all_errors: If True, all exceptions are caught and error messages are
returned as part of the ParsingResult. If False exceptions will be allowed
to propagate.

Returns:
A ParsingResult.
"""
errors = {}
try:
parser = PDB.MMCIFParser(QUIET=True)
handle = io.StringIO(mmcif_string)
full_structure = parser.get_structure('', handle)
first_model_structure = _get_first_model(full_structure)
# Extract the _mmcif_dict from the parser, which contains useful fields not
# reflected in the Biopython structure.
parsed_info = parser._mmcif_dict # pylint:disable=protected-access

# Ensure all values are lists, even if singletons.
for key, value in parsed_info.items():
if not isinstance(value, list):
parsed_info[key] = [value]

header = _get_header(parsed_info)

# Determine the protein chains, and their start numbers according to the
# internal mmCIF numbering scheme (likely but not guaranteed to be 1).
valid_chains = _get_protein_chains(parsed_info=parsed_info)
if not valid_chains:
return ParsingResult(
None, {(file_id, ''): 'No protein chains found in this file.'})
seq_start_num = {chain_id: min([monomer.num for monomer in seq])
for chain_id, seq in valid_chains.items()}

# Loop over the atoms for which we have coordinates. Populate two mappings:
# -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used
# the authors / Biopython).
# -seq_to_structure_mappings (maps idx into sequence to ResidueAtPosition).
mmcif_to_author_chain_id = {}
seq_to_structure_mappings = {}
for atom in _get_atom_site_list(parsed_info):
if atom.model_num != '1':
# We only process the first model at the moment.
continue

mmcif_to_author_chain_id[atom.mmcif_chain_id] = atom.author_chain_id

if atom.mmcif_chain_id in valid_chains:
hetflag = ' '
if atom.hetatm_atom == 'HETATM':
# Water atoms are assigned a special hetflag of W in Biopython. We
# need to do the same, so that this hetflag can be used to fetch
# a residue from the Biopython structure by id.
if atom.residue_name in ('HOH', 'WAT'):
hetflag = 'W'
else:
hetflag = 'H_' + atom.residue_name
insertion_code = atom.insertion_code
if not _is_set(atom.insertion_code):
insertion_code = ' '
position = ResiduePosition(
chain_id=atom.author_chain_id, residue_number=int(
atom.author_seq_num), insertion_code=insertion_code)
seq_idx = int(atom.mmcif_seq_num) - \
seq_start_num[atom.mmcif_chain_id]
current = seq_to_structure_mappings.get(
atom.author_chain_id, {})
current[seq_idx] = ResidueAtPosition(position=position,
name=atom.residue_name,
is_missing=False,
hetflag=hetflag)
seq_to_structure_mappings[atom.author_chain_id] = current

# Add missing residue information to seq_to_structure_mappings.
for chain_id, seq_info in valid_chains.items():
author_chain = mmcif_to_author_chain_id[chain_id]
current_mapping = seq_to_structure_mappings[author_chain]
for idx, monomer in enumerate(seq_info):
if idx not in current_mapping:
current_mapping[idx] = ResidueAtPosition(position=None,
name=monomer.id,
is_missing=True,
hetflag=' ')

author_chain_to_sequence = {}
for chain_id, seq_info in valid_chains.items():
author_chain = mmcif_to_author_chain_id[chain_id]
seq = []
for monomer in seq_info:
code = SCOPData.protein_letters_3to1.get(monomer.id, 'X')
seq.append(code if len(code) == 1 else 'X')
seq = ''.join(seq)
author_chain_to_sequence[author_chain] = seq

mmcif_object = MmcifObject(
file_id=file_id,
header=header,
structure=first_model_structure,
chain_to_seqres=author_chain_to_sequence,
seqres_to_structure=seq_to_structure_mappings,
raw_string=parsed_info)

return ParsingResult(mmcif_object=mmcif_object, errors=errors)
except Exception as e: # pylint:disable=broad-except
errors[(file_id, '')] = e
if not catch_all_errors:
raise
return ParsingResult(mmcif_object=None, errors=errors)


def _get_first_model(structure: PDBSTRUCTURE) -> PDBSTRUCTURE:
"""Returns the first model in a Biopython structure."""
return next(structure.get_models())


_MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE = 21


def get_release_date(parsed_info: MmCIFDict) -> str:
"""Returns the oldest revision date."""
revision_dates = parsed_info['_pdbx_audit_revision_history.revision_date']
return min(revision_dates)


def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
"""Returns a basic header containing method, release date and resolution."""
header = {}

experiments = mmcif_loop_to_list('_exptl.', parsed_info)
header['structure_method'] = ','.join([
experiment['_exptl.method'].lower() for experiment in experiments])

# Note: The release_date here corresponds to the oldest revision. We prefer to
# use this for dataset filtering over the deposition_date.
if '_pdbx_audit_revision_history.revision_date' in parsed_info:
header['release_date'] = get_release_date(parsed_info)
else:
logging.warning('Could not determine release_date: %s',
parsed_info['_entry.id'])

header['resolution'] = 0.00
for res_key in ('_refine.ls_d_res_high', '_em_3d_reconstruction.resolution', '_reflns.d_resolution_high'):
if res_key in parsed_info:
try:
raw_resolution = parsed_info[res_key][0]
header['resolution'] = float(raw_resolution)
except ValueError:
logging.warning(
'Invalid resolution format: %s',
parsed_info[res_key])

return header


def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]:
"""Returns list of atom sites; contains data not present in the structure."""
return [AtomSite(*site) for site in zip( # pylint:disable=g-complex-comprehension
parsed_info['_atom_site.label_comp_id'],
parsed_info['_atom_site.auth_asym_id'],
parsed_info['_atom_site.label_asym_id'],
parsed_info['_atom_site.auth_seq_id'],
parsed_info['_atom_site.label_seq_id'],
parsed_info['_atom_site.pdbx_PDB_ins_code'],
parsed_info['_atom_site.group_PDB'],
parsed_info['_atom_site.pdbx_PDB_model_num'],
)]


def _get_protein_chains(*,
parsed_info: Mapping[str,
Any]) -> Mapping[ChainId,
Sequence[Monomer]]:
"""Extracts polymer information for protein chains only.

Args:
parsed_info: _mmcif_dict produced by the Biopython parser.

Returns:
A dict mapping mmcif chain id to a list of Monomers.
"""
# Get polymer information for each entity in the structure.
entity_poly_seqs = mmcif_loop_to_list('_entity_poly_seq.', parsed_info)

polymers = collections.defaultdict(list)
for entity_poly_seq in entity_poly_seqs:
polymers[entity_poly_seq['_entity_poly_seq.entity_id']].append(
Monomer(id=entity_poly_seq['_entity_poly_seq.mon_id'],
num=int(entity_poly_seq['_entity_poly_seq.num'])))

# Get chemical compositions. Will allow us to identify which of these polymers
# are proteins.
chem_comps = mmcif_loop_to_dict(
'_chem_comp.', '_chem_comp.id', parsed_info)

# Get chains information for each entity. Necessary so that we can return a
# dict keyed on chain id rather than entity.
struct_asyms = mmcif_loop_to_list('_struct_asym.', parsed_info)

entity_to_mmcif_chains = collections.defaultdict(list)
for struct_asym in struct_asyms:
chain_id = struct_asym['_struct_asym.id']
entity_id = struct_asym['_struct_asym.entity_id']
entity_to_mmcif_chains[entity_id].append(chain_id)

# Identify and return the valid protein chains.
valid_chains = {}
for entity_id, seq_info in polymers.items():
chain_ids = entity_to_mmcif_chains[entity_id]

# Reject polymers without any peptide-like components, such as DNA/RNA.
if any(['peptide' in chem_comps[monomer.id]['_chem_comp.type']
for monomer in seq_info]):
for chain_id in chain_ids:
valid_chains[chain_id] = seq_info
return valid_chains


def _is_set(data: str) -> bool:
"""Returns False if data is a special mmCIF character indicating 'unset'."""
return data not in ('.', '?')

+ 61
- 0
reproduce/AlphaFold2-Chinese/data/tools/msa_search.sh View File

@@ -0,0 +1,61 @@
#!/bin/bash -e
MMSEQS="$1" #mmseqs
QUERY="$2" #"/path/QUERY.fasta"
BASE="$3" #"./result/"
DB1="$4" #"uniref30_2103_db"
DB2="$5" #""
DB3="$6" #"colabfold_envdb_202108_db"
USE_ENV="$7" #1
USE_TEMPLATES="$8" #0
FILTER="${9}" #1
EXPAND_EVAL=inf
ALIGN_EVAL=10
DIFF=3000
QSC=-20.0
MAX_ACCEPT=1000000
time=$(date )
echo "${time}"
if [ "${FILTER}" = "1" ]; then
# 0.1 was not used in benchmarks due to POSIX shell bug in line above
# EXPAND_EVAL=0.1
ALIGN_EVAL=10
QSC=0.8
MAX_ACCEPT=100000
fi
export MMSEQS_CALL_DEPTH=1
SEARCH_PARAM="--num-iterations 3 --db-load-mode 2 -a -s 8 -e 0.1 --max-seqs 10000"
FILTER_PARAM="--filter-msa ${FILTER} --filter-min-enable 1000 --diff ${DIFF} --qid 0.0,0.2,0.4,0.6,0.8,1.0 --qsc 0 --max-seq-id 0.95"
EXPAND_PARAM="--expansion-mode 0 -e ${EXPAND_EVAL} --expand-filter-clusters ${FILTER} --max-seq-id 0.95"
mkdir -p "${BASE}"
"${MMSEQS}" createdb "${QUERY}" "${BASE}/qdb"
"${MMSEQS}" search "${BASE}/qdb" "${DB1}" "${BASE}/res" "${BASE}/tmp" $SEARCH_PARAM
"${MMSEQS}" expandaln "${BASE}/qdb" "${DB1}.idx" "${BASE}/res" "${DB1}.idx" "${BASE}/res_exp" --db-load-mode 2 ${EXPAND_PARAM}

"${MMSEQS}" mvdb "${BASE}/tmp/latest/profile_1" "${BASE}/prof_res"
"${MMSEQS}" lndb "${BASE}/qdb_h" "${BASE}/prof_res_h"
"${MMSEQS}" align "${BASE}/prof_res" "${DB1}.idx" "${BASE}/res_exp" "${BASE}/res_exp_realign" --db-load-mode 2 -e ${ALIGN_EVAL} --max-accept ${MAX_ACCEPT} --alt-ali 10 -a
"${MMSEQS}" filterresult "${BASE}/qdb" "${DB1}.idx" "${BASE}/res_exp_realign" "${BASE}/res_exp_realign_filter" --db-load-mode 2 --qid 0 --qsc $QSC --diff 0 --max-seq-id 1.0 --filter-min-enable 100
"${MMSEQS}" result2msa "${BASE}/qdb" "${DB1}.idx" "${BASE}/res_exp_realign_filter" "${BASE}/uniref.a3m" --msa-format-mode 6 --db-load-mode 2 ${FILTER_PARAM}
"${MMSEQS}" rmdb "${BASE}/res_exp_realign"
"${MMSEQS}" rmdb "${BASE}/res_exp"
"${MMSEQS}" rmdb "${BASE}/res"
"${MMSEQS}" rmdb "${BASE}/res_exp_realign_filter"
if [ "${USE_TEMPLATES}" = "1" ]; then
"${MMSEQS}" search "${BASE}/prof_res" "${DB2}" "${BASE}/res_pdb" "${BASE}/tmp" --db-load-mode 2 -s 7.5 -a -e 0.1
echo "-----------------------in here"
echo "${BASE}/${DB2}.m8"
"${MMSEQS}" convertalis "${BASE}/prof_res" "${DB2}.idx" "${BASE}/res_pdb" "${BASE}/${DB2}.m8" --format-output query,target,fident,alnlen,mismatch,gapopen,qstart,qend,tstart,tend,evalue,bits,cigar --db-load-mode 2
"${MMSEQS}" rmdb "${BASE}/res_pdb"
fi
if [ "${USE_ENV}" = "1" ]; then
"${MMSEQS}" search "${BASE}/prof_res" "${DB3}" "${BASE}/res_env" "${BASE}/tmp" $SEARCH_PARAM
"${MMSEQS}" expandaln "${BASE}/prof_res" "${DB3}.idx" "${BASE}/res_env" "${DB3}.idx" "${BASE}/res_env_exp" -e ${EXPAND_EVAL} --expansion-mode 0 --db-load-mode 2
"${MMSEQS}" align "${BASE}/tmp/latest/profile_1" "${DB3}.idx" "${BASE}/res_env_exp" "${BASE}/res_env_exp_realign" --db-load-mode 2 -e ${ALIGN_EVAL} --max-accept ${MAX_ACCEPT} --alt-ali 10 -a
"${MMSEQS}" filterresult "${BASE}/qdb" "${DB3}.idx" "${BASE}/res_env_exp_realign" "${BASE}/res_env_exp_realign_filter" --db-load-mode 2 --qid 0 --qsc $QSC --diff 0 --max-seq-id 1.0 --filter-min-enable 100
"${MMSEQS}" result2msa "${BASE}/qdb" "${DB3}.idx" "${BASE}/res_env_exp_realign_filter" "${BASE}/bfd.mgnify30.metaeuk30.smag30.a3m" --msa-format-mode 6 --db-load-mode 2 ${FILTER_PARAM}
"${MMSEQS}" rmdb "${BASE}/res_env_exp_realign_filter"
"${MMSEQS}" rmdb "${BASE}/res_env_exp_realign"
fi

time=$(date )
echo "${time}"

+ 389
- 0
reproduce/AlphaFold2-Chinese/data/tools/parsers.py View File

@@ -0,0 +1,389 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''parsers'''
import collections
import re
import string
from typing import Iterable, List, Optional, Sequence, Tuple

import dataclasses

DeletionMatrix = Sequence[Sequence[int]]


@dataclasses.dataclass(frozen=True)
class HhrHit:
"""Class representing a hit in an hhr file."""
index: int
name: str
prob_true: float
e_value: float
score: float
aligned_cols: int
identity: float
similarity: float
sum_probs: float
neff: float
query: str
hit_sequence: str
hit_dssp: str
column_score_code: str
confidence_scores: str
indices_query: List[int]
indices_hit: List[int]


def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
"""Parses FASTA string and returns list of strings with amino-acid sequences.

Arguments:
fasta_string: The string contents of a FASTA file.

Returns:
A tuple of two lists:
* A list of sequences.
* A list of sequence descriptions taken from the comment lines. In the
same order as the sequences.
"""
sequences = []
descriptions = []
index = -1
for line in fasta_string.splitlines():
line = line.strip()
if line.startswith('>'):
index += 1
descriptions.append(line[1:]) # Remove the '>' at the beginning.
sequences.append('')
continue
elif not line:
continue # Skip blank lines.
sequences[index] += line

return sequences, descriptions


def parse_stockholm(
stockholm_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
"""Parses sequences and deletion matrix from stockholm format alignment.

Args:
stockholm_string: The string contents of a stockholm file. The first
sequence in the file should be the query sequence.

Returns:
A tuple of:
* A list of sequences that have been aligned to the query. These
might contain duplicates.
* The deletion matrix for the alignment as a list of lists. The element
at `deletion_matrix[i][j]` is the number of residues deleted from
the aligned sequence i at residue position j.
"""
name_to_sequence = collections.OrderedDict()
for line in stockholm_string.splitlines():
line = line.strip()
if not line or line.startswith(('#', '//')):
continue
name, sequence = line.split()
if name not in name_to_sequence:
name_to_sequence[name] = ''
name_to_sequence[name] += sequence

msa = []
deletion_matrix = []

query = ''
keep_columns = []
for seq_index, sequence in enumerate(name_to_sequence.values()):
if seq_index == 0:
# Gather the columns with gaps from the query
query = sequence
keep_columns = [i for i, res in enumerate(query) if res != '-']

# Remove the columns with gaps in the query from all sequences.
aligned_sequence = ''.join([sequence[c] for c in keep_columns])

msa.append(aligned_sequence)

# Count the number of deletions w.r.t. query.
deletion_vec = []
deletion_count = 0
for seq_res, query_res in zip(sequence, query):
if seq_res != '-' or query_res != '-':
if query_res == '-':
deletion_count += 1
else:
deletion_vec.append(deletion_count)
deletion_count = 0
deletion_matrix.append(deletion_vec)

return msa, deletion_matrix


def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
"""Parses sequences and deletion matrix from a3m format alignment.

Args:
a3m_string: The string contents of a a3m file. The first sequence in the
file should be the query sequence.

Returns:
A tuple of:
* A list of sequences that have been aligned to the query. These
might contain duplicates.
* The deletion matrix for the alignment as a list of lists. The element
at `deletion_matrix[i][j]` is the number of residues deleted from
the aligned sequence i at residue position j.
"""
sequences, _ = parse_fasta(a3m_string)
deletion_matrix = []
for msa_sequence in sequences:
deletion_vec = []
deletion_count = 0
for j in msa_sequence:
if j.islower():
deletion_count += 1
else:
deletion_vec.append(deletion_count)
deletion_count = 0
deletion_matrix.append(deletion_vec)

# Make the MSA matrix out of aligned (deletion-free) sequences.
deletion_table = str.maketrans('', '', string.ascii_lowercase)
aligned_sequences = [s.translate(deletion_table) for s in sequences]
return aligned_sequences, deletion_matrix


def _convert_sto_seq_to_a3m(
query_non_gaps: Sequence[bool], sto_seq: str) -> Iterable[str]:
for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq):
if is_query_res_non_gap:
yield sequence_res
elif sequence_res != '-':
yield sequence_res.lower()


def convert_stockholm_to_a3m(stockholm_format: str,
max_sequences: Optional[int] = None) -> str:
"""Converts MSA in Stockholm format to the A3M format."""
descriptions = {}
sequences = {}
reached_max_sequences = False

for line in stockholm_format.splitlines():
reached_max_sequences = max_sequences and len(sequences) >= max_sequences

if line.strip() and not line.startswith(('#', '//')):
# Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts.
seqname, aligned_seq = line.split(maxsplit=1)
if seqname not in sequences:
if reached_max_sequences:
continue
sequences[seqname] = ''
sequences[seqname] += aligned_seq

for line in stockholm_format.splitlines():
if line[:4] == '#=GS':
# Description row - example format is:
# #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ...
columns = line.split(maxsplit=3)
seqname, feature = columns[1:3]
value = columns[3] if len(columns) == 4 else ''
if feature != 'DE':
continue
if reached_max_sequences and seqname not in sequences:
continue
descriptions[seqname] = value
if len(descriptions) == len(sequences):
break

# Convert sto format to a3m line by line
a3m_sequences = {}
# query_sequence is assumed to be the first sequence
query_sequence = next(iter(sequences.values()))
query_non_gaps = [res != '-' for res in query_sequence]
for seqname, sto_sequence in sequences.items():
a3m_sequences[seqname] = ''.join(
_convert_sto_seq_to_a3m(query_non_gaps, sto_sequence))

fasta_chunks = (f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}"
for k in a3m_sequences)
return '\n'.join(fasta_chunks) + '\n' # Include terminating newline.


def _get_hhr_line_regex_groups(
regex_pattern: str, line: str) -> Sequence[Optional[str]]:
match = re.match(regex_pattern, line)
if match is None:
raise RuntimeError(f'Could not parse query line {line}')
return match.groups()


def _update_hhr_residue_indices_list(
sequence: str, start_index: int, indices_list: List[int]):
"""Computes the relative indices for each residue with respect to the original sequence."""
counter = start_index
for symbol in sequence:
if symbol == '-':
indices_list.append(-1)
else:
indices_list.append(counter)
counter += 1


def _parse_hhr_hit(detailed_lines: Sequence[str]) -> HhrHit:
"""Parses the detailed HMM HMM comparison section for a single Hit.

This works on .hhr files generated from both HHBlits and HHSearch.

Args:
detailed_lines: A list of lines from a single comparison section between 2
sequences (which each have their own HMM's)

Returns:
A dictionary with the information from that detailed comparison section

Raises:
RuntimeError: If a certain line cannot be processed
"""
# Parse first 2 lines.
number_of_hit = int(detailed_lines[0].split()[-1])
name_hit = detailed_lines[1][1:]

# Parse the summary line.
pattern = (
'Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t'
' ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t '
']*Template_Neff=(.*)')
match = re.match(pattern, detailed_lines[2])
if match is None:
raise RuntimeError(
'Could not parse section: %s. Expected this: \n%s to contain summary.' %
(detailed_lines, detailed_lines[2]))
(prob_true, e_value, score, aligned_cols, identity, similarity, sum_probs,
neff) = [float(x) for x in match.groups()]

# The next section reads the detailed comparisons. These are in a 'human
# readable' format which has a fixed length. The strategy employed is to
# assume that each block starts with the query sequence line, and to parse
# that with a regexp in order to deduce the fixed length used for that
# block.
query = ''
hit_sequence = ''
hit_dssp = ''
column_score_code = ''
confidence_scores = ''
indices_query = []
indices_hit = []
length_block = None

for line in detailed_lines[3:]:
# Parse the query sequence line
if (line.startswith('Q ') and not line.startswith('Q ss_dssp') and not line.startswith('Q ss_pred') \
and not line.startswith('Q Consensus')):
# Thus the first 17 characters must be 'Q <query_name> ', and we can parse
# everything after that.
# start sequence end total_sequence_length
patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)'
groups = _get_hhr_line_regex_groups(patt, line[17:])

# Get the length of the parsed block using the start and finish indices,
# and ensure it is the same as the actual block length.
start = int(groups[0]) - 1 # Make index zero based.
delta_query = groups[1]
end = int(groups[2])
num_insertions = len([x for x in delta_query if x == '-'])
length_block = end - start + num_insertions
assert length_block == len(delta_query)

# Update the query sequence and indices list.
query += delta_query
_update_hhr_residue_indices_list(delta_query, start, indices_query)

elif line.startswith('T '):
# Parse the hit dssp line.
if line.startswith('T ss_dssp'):
# T ss_dssp hit_dssp
patt = r'T ss_dssp[\t ]*([A-Z-]*)'
groups = _get_hhr_line_regex_groups(patt, line)
assert len(groups[0]) == length_block
hit_dssp += groups[0]

# Parse the hit sequence.
elif (not line.startswith('T ss_pred') and
not line.startswith('T Consensus')):
# Thus the first 17 characters must be 'T <hit_name> ', and we can
# parse everything after that.
# start sequence end total_sequence_length
patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)'
groups = _get_hhr_line_regex_groups(patt, line[17:])
start = int(groups[0]) - 1 # Make index zero based.
delta_hit_sequence = groups[1]
assert length_block == len(delta_hit_sequence)

# Update the hit sequence and indices list.
hit_sequence += delta_hit_sequence
_update_hhr_residue_indices_list(
delta_hit_sequence, start, indices_hit)

# Parse the column score line.
elif line.startswith(' ' * 22):
assert length_block
column_score_code += line[22:length_block + 22]

# Update confidence score.
elif line.startswith('Confidence'):
assert length_block
confidence_scores += line[22:length_block + 22]

return HhrHit(
index=number_of_hit,
name=name_hit,
prob_true=prob_true,
e_value=e_value,
score=score,
aligned_cols=int(aligned_cols),
identity=identity,
similarity=similarity,
sum_probs=sum_probs,
neff=neff,
query=query,
hit_sequence=hit_sequence,
hit_dssp=hit_dssp,
column_score_code=column_score_code,
confidence_scores=confidence_scores,
indices_query=indices_query,
indices_hit=indices_hit,
)


def parse_hhr(hhr_string: str) -> Sequence[HhrHit]:
"""Parses the content of an entire HHR file."""
lines = hhr_string.splitlines()

# Each .hhr file starts with a results table, then has a sequence of hit
# "paragraphs", each paragraph starting with a line 'No <hit number>'. We
# iterate through each paragraph to parse each hit.

block_starts = [i for i, line in enumerate(
lines) if line.startswith('No ')]

hits = []
if block_starts:
block_starts.append(len(lines)) # Add the end of the final block.
for i in range(len(block_starts) - 1):
hits.append(_parse_hhr_hit(
lines[block_starts[i]:block_starts[i + 1]]))
return hits

+ 999
- 0
reproduce/AlphaFold2-Chinese/data/tools/templates.py
File diff suppressed because it is too large
View File


BIN
reproduce/AlphaFold2-Chinese/docs/MindScience_Architecture.jpg View File

Before After
Width: 1169  |  Height: 796  |  Size: 212 kB

BIN
reproduce/AlphaFold2-Chinese/docs/MindScience_Architecture_en.jpg View File

Before After
Width: 701  |  Height: 281  |  Size: 14 kB

BIN
reproduce/AlphaFold2-Chinese/docs/all_experiment_data.jpg View File

Before After
Width: 571  |  Height: 530  |  Size: 36 kB

BIN
reproduce/AlphaFold2-Chinese/docs/seq_21.jpg View File

Before After
Width: 741  |  Height: 627  |  Size: 131 kB

BIN
reproduce/AlphaFold2-Chinese/docs/seq_64.gif View File

Before After
Width: 640  |  Height: 248  |  Size: 778 kB

+ 112
- 0
reproduce/AlphaFold2-Chinese/main.py View File

@@ -0,0 +1,112 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""run script"""

import time
import os
import json
import argparse
import numpy as np

import mindspore.context as context
from mindspore.common.tensor import Tensor
from mindspore import load_checkpoint

from data.feature.feature_extraction import process_features
from data.tools.data_process import data_process
from commons.generate_pdb import to_pdb, from_prediction
from commons.utils import compute_confidence
from model import AlphaFold
from config import config, global_config

parser = argparse.ArgumentParser(description='Inputs for run.py')
parser.add_argument('--seq_length', help='padding sequence length')
parser.add_argument('--input_fasta_path', help='Path of FASTA files folder directory to be predicted.')
parser.add_argument('--msa_result_path', help='Path to save msa result.')
parser.add_argument('--database_dir', help='Path of data to generate msa.')
parser.add_argument('--database_envdb_dir', help='Path of expandable data to generate msa.')
parser.add_argument('--hhsearch_binary_path', help='Path of hhsearch executable.')
parser.add_argument('--pdb70_database_path', help='Path to pdb70.')
parser.add_argument('--template_mmcif_dir', help='Path of template mmcif.')
parser.add_argument('--max_template_date', help='Maximum template release date.')
parser.add_argument('--kalign_binary_path', help='Path to kalign executable.')
parser.add_argument('--obsolete_pdbs_path', help='Path to obsolete pdbs path.')
parser.add_argument('--checkpoint_path', help='Path of the checkpoint.')
parser.add_argument('--device_id', default=0, type=int, help='Device id to be used.')
args = parser.parse_args()

if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend",
variable_memory_max_size="31GB",
device_id=args.device_id,
save_graphs=False)
model_name = "model_1"
model_config = config.model_config(model_name)
num_recycle = model_config.model.num_recycle
global_config = global_config.global_config(args.seq_length)
extra_msa_length = global_config.extra_msa_length
fold_net = AlphaFold(model_config, global_config)

load_checkpoint(args.checkpoint_path, fold_net)

seq_files = os.listdir(args.input_fasta_path)

for seq_file in seq_files:
t1 = time.time()
seq_name = seq_file.split('.')[0]
input_features = data_process(seq_name, args)
tensors, aatype, residue_index, ori_res_length = process_features(
raw_features=input_features, config=model_config, global_config=global_config)
prev_pos = Tensor(np.zeros([global_config.seq_length, 37, 3]).astype(np.float16))
prev_msa_first_row = Tensor(np.zeros([global_config.seq_length, 256]).astype(np.float16))
prev_pair = Tensor(np.zeros([global_config.seq_length, global_config.seq_length, 128]).astype(np.float16))
"""
:param::@sequence_length
"""
t2 = time.time()
for i in range(num_recycle+1):
tensors_i = [tensor[i] for tensor in tensors]
input_feats = [Tensor(tensor) for tensor in tensors_i]
final_atom_positions, final_atom_mask, predicted_lddt_logits,\
prev_pos, prev_msa_first_row, prev_pair = fold_net(*input_feats,
prev_pos,
prev_msa_first_row,
prev_pair)

t3 = time.time()

final_atom_positions = final_atom_positions.asnumpy()[:ori_res_length]
final_atom_mask = final_atom_mask.asnumpy()[:ori_res_length]
predicted_lddt_logits = predicted_lddt_logits.asnumpy()[:ori_res_length]

confidence = compute_confidence(predicted_lddt_logits)
unrelaxed_protein = from_prediction(final_atom_mask, aatype[0], final_atom_positions, residue_index[0])
pdb_file = to_pdb(unrelaxed_protein)

seq_length = aatype.shape[-1]
os.makedirs(f'./result/seq_{seq_name}_{seq_length}', exist_ok=True)

with open(os.path.join(f'./result/seq_{seq_name}_{seq_length}/', f'unrelaxed_model_{seq_name}.pdb'), 'w') as f:
f.write(pdb_file)
t4 = time.time()
timings = {"pre_process_time": round(t2 - t1, 2),
"model_time": round(t3 - t2, 2),
"pos_process_time": round(t4 - t3, 2),
"all_time": round(t4 - t1, 2),
"confidence": confidence}
print(timings)
with open(f'./result/seq_{seq_name}_{seq_length}/timings', 'w') as f:
f.write(json.dumps(timings))

+ 936
- 0
reproduce/AlphaFold2-Chinese/module/basic_module.py View File

@@ -0,0 +1,936 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""basic module"""

import mindspore.nn as nn
import mindspore.common.dtype as mstype
import mindspore.numpy as mnp
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.tensor import Tensor
from mindspore import Parameter
import numpy as np
from commons.utils import mask_mean


class Attention(nn.Cell):
'''attention module'''
def __init__(self, config, q_data_dim, m_data_dim, output_dim, batch_size=None):
super(Attention, self).__init__()
self.config = config
self.q_data_dim = q_data_dim
self.m_data_dim = m_data_dim
self.output_dim = output_dim
self.num_head = self.config.num_head
self.gating = self.config.gating
self.key_dim = self.config.get('key_dim', int(q_data_dim))
self.value_dim = self.config.get('value_dim', int(m_data_dim))
self.key_dim = self.key_dim // self.num_head
self.value_dim = self.value_dim // self.num_head
self.batch_size = batch_size
self.matmul = P.MatMul(transpose_b=True)
self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True)
self.softmax = nn.Softmax()
self.sigmoid = nn.Sigmoid()
self.batch_size = batch_size
self._init_parameter()

def _init_parameter(self):
'''init parameter'''
if self.batch_size:
self.linear_q_weights = Parameter(Tensor(np.zeros([self.batch_size, self.num_head * self.key_dim,
self.q_data_dim]), mstype.float32))
self.linear_k_weights = Parameter(Tensor(np.zeros([self.batch_size, self.num_head * self.key_dim,
self.m_data_dim]), mstype.float32))
self.linear_v_weights = Parameter(Tensor(np.zeros([self.batch_size, self.num_head * self.value_dim,
self.m_data_dim]), mstype.float32))
self.linear_output_weights = Parameter(Tensor(np.zeros([self.batch_size, self.output_dim,
self.num_head * self.value_dim]), mstype.float32))
self.o_biases = Parameter(Tensor(np.zeros([self.batch_size, self.output_dim]), mstype.float32))
if self.gating:
self.linear_gating_weights = Parameter(Tensor(np.zeros([self.batch_size, self.num_head * self.value_dim,
self.q_data_dim]), mstype.float32))
self.gating_biases = Parameter(Tensor(np.zeros((self.batch_size, self.num_head, self.value_dim)),
mstype.float32), name="gating_b")
else:
self.linear_q_weights = Parameter(Tensor(np.zeros([self.num_head * self.key_dim, self.q_data_dim]),
mstype.float32))
self.linear_k_weights = Parameter(Tensor(np.zeros([self.num_head * self.key_dim, self.m_data_dim]),
mstype.float32))
self.linear_v_weights = Parameter(Tensor(np.zeros([self.num_head * self.value_dim, self.m_data_dim]),
mstype.float32))
self.linear_output_weights = Parameter(Tensor(np.zeros([self.output_dim, self.num_head * self.value_dim]),
mstype.float32))
self.o_biases = Parameter(Tensor(np.zeros([self.output_dim]), mstype.float32))
if self.gating:
self.linear_gating_weights = Parameter(Tensor(np.zeros([self.num_head * self.value_dim,
self.q_data_dim]), mstype.float32))
self.gating_biases = Parameter(Tensor(np.zeros((self.num_head, self.value_dim)), mstype.float32),
name="gating_b")

def construct(self, q_data, m_data, bias, index=None, nonbatched_bias=None):
'''construct'''
if self.batch_size:
linear_q_weight = P.Gather()(self.linear_q_weights, index, 0)
linear_k_weight = P.Gather()(self.linear_k_weights, index, 0)
linear_v_weight = P.Gather()(self.linear_v_weights, index, 0)
linear_output_weight = P.Gather()(self.linear_output_weights, index, 0)
o_bias = P.Gather()(self.o_biases, index, 0)
linear_gating_weight = 0
gating_bias = 0
if self.gating:
linear_gating_weight = P.Gather()(self.linear_gating_weights, index, 0)
gating_bias = P.Gather()(self.gating_biases, index, 0)
else:
linear_q_weight = self.linear_q_weights
linear_k_weight = self.linear_k_weights
linear_v_weight = self.linear_v_weights
linear_output_weight = self.linear_output_weights
o_bias = self.o_biases
linear_gating_weight = 0
gating_bias = 0
if self.gating:
linear_gating_weight = self.linear_gating_weights
gating_bias = self.gating_biases

q_data, m_data, bias = q_data.astype(mstype.float32), m_data.astype(mstype.float32), bias.astype(mstype.float32)
dim_b, dim_q, dim_a = q_data.shape
_, dim_k, dim_c = m_data.shape
dim_h = self.num_head

q_data = P.Reshape()(q_data, (-1, dim_a))
m_data = P.Reshape()(m_data, (-1, dim_c))

q = self.matmul(q_data.astype(mstype.float16), linear_q_weight.astype(mstype.float16)). \
astype(mstype.float32) * self.key_dim ** (-0.5)
k = self.matmul(m_data.astype(mstype.float16), linear_k_weight.astype(mstype.float16))
v = self.matmul(m_data.astype(mstype.float16), linear_v_weight.astype(mstype.float16))

q = P.Reshape()(q, (dim_b, dim_q, dim_h, -1))
k = P.Reshape()(k, (dim_b, dim_k, dim_h, -1))
v = P.Reshape()(v, (dim_b, dim_k, dim_h, -1))

tmp_q = P.Reshape()(P.Transpose()(q.astype(mstype.float16), (0, 2, 1, 3)), (dim_b * dim_h, dim_q, -1))
tmp_k = P.Reshape()(P.Transpose()(k.astype(mstype.float16), (0, 2, 1, 3)), (dim_b * dim_h, dim_k, -1))
logits = P.Reshape()(self.batch_matmul_trans_b(tmp_q.astype(mstype.float16),
tmp_k.astype(mstype.float16)),
(dim_b, dim_h, dim_q, dim_k)) + bias.astype(mstype.float16)

if nonbatched_bias is not None:
logits += mnp.expand_dims(nonbatched_bias, axis=0)
weights = self.softmax(logits.astype(mstype.float32))
tmp_v = P.Reshape()(P.Transpose()(v, (0, 2, 3, 1)), (dim_b * dim_h, -1, dim_k))
tmp_weights = P.Reshape()(weights, (dim_b * dim_h, dim_q, -1))
weighted_avg = P.Transpose()(P.Reshape()(self.batch_matmul_trans_b(tmp_weights.astype(mstype.float16),
tmp_v.astype(mstype.float16)),
(dim_b, dim_h, dim_q, -1)),
(0, 2, 1, 3)).astype(mstype.float32)

if self.gating:
gate_values = P.Reshape()(
self.matmul(q_data.astype(mstype.float16), linear_gating_weight.astype(mstype.float16)),
(dim_b, dim_q, dim_h, -1)) + gating_bias[None, None, ...].astype(mstype.float16)
gate_values = self.sigmoid(gate_values.astype(mstype.float32))
weighted_avg = P.Reshape()(weighted_avg * gate_values, (dim_b * dim_q, -1))

weighted_avg = P.Reshape()(weighted_avg, (dim_b * dim_q, -1))
output = P.Reshape()(
self.matmul(weighted_avg.astype(mstype.float16), linear_output_weight.astype(mstype.float16)),
(dim_b, dim_q, -1)) + o_bias[None, ...].astype(mstype.float16)
return output


class MSARowAttentionWithPairBias(nn.Cell):
'''MSA row attention'''
def __init__(self, config, msa_act_dim, pair_act_dim, batch_size=None, slice_num=0):
super(MSARowAttentionWithPairBias, self).__init__()
self.config = config
self.num_head = self.config.num_head
self.batch_size = batch_size
self.norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5)
self.matmul = P.MatMul(transpose_b=True)
self.attn_mod = Attention(self.config, msa_act_dim, msa_act_dim, msa_act_dim, batch_size)
self.msa_act_dim = msa_act_dim
self.pair_act_dim = pair_act_dim
self.batch_size = batch_size
self.slice_num = slice_num
self.idx = Tensor(0, mstype.int32)
self._init_parameter()

def _init_parameter(self):
'''init parameter'''
self.query_norm_gammas = Parameter(Tensor(np.zeros([self.batch_size, self.msa_act_dim,]), mstype.float32))
self.query_norm_betas = Parameter(Tensor(np.zeros([self.batch_size, self.msa_act_dim,]), mstype.float32))
self.feat_2d_norm_gammas = Parameter(Tensor(np.zeros([self.batch_size, self.pair_act_dim,]), mstype.float32))
self.feat_2d_norm_betas = Parameter(Tensor(np.zeros([self.batch_size, self.pair_act_dim,]), mstype.float32))
self.feat_2d_weights = Parameter(
Tensor(np.zeros([self.batch_size, self.num_head, self.pair_act_dim]), mstype.float32))

def construct(self, msa_act, msa_mask, pair_act, index):
'''construct'''
query_norm_gamma = P.Gather()(self.query_norm_gammas, index, 0)
query_norm_beta = P.Gather()(self.query_norm_betas, index, 0)
feat_2d_norm_gamma = P.Gather()(self.feat_2d_norm_gammas, index, 0)
feat_2d_norm_beta = P.Gather()(self.feat_2d_norm_betas, index, 0)
feat_2d_weight = P.Gather()(self.feat_2d_weights, index, 0)

q, k, _ = pair_act.shape
bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
msa_act, _, _ = self.norm(msa_act.astype(mstype.float32), query_norm_gamma.astype(mstype.float32),
query_norm_beta.astype(mstype.float32))
pair_act, _, _ = self.norm(pair_act.astype(mstype.float32), feat_2d_norm_gamma.astype(mstype.float32),
feat_2d_norm_beta.astype(mstype.float32))
pair_act = P.Reshape()(pair_act, (-1, pair_act.shape[-1]))
nonbatched_bias = P.Transpose()(
P.Reshape()(self.matmul(pair_act.astype(mstype.float16), feat_2d_weight.astype(mstype.float16)),
(q, k, self.num_head)), (2, 0, 1))

if self.slice_num:
msa_act_ori_shape = P.Shape()(msa_act)
slice_shape = (self.slice_num, -1) + msa_act_ori_shape[1:]
msa_act = P.Reshape()(msa_act, slice_shape).astype(mstype.float16)
bias_shape = P.Shape()(bias)
bias = P.Reshape()(bias, slice_shape[:2] + bias_shape[1:])
slice_idx = 0
slice_idx_tensor = self.idx
msa_act_tuple = ()

msa_act_slice = P.Gather()(msa_act, slice_idx_tensor, 0)
bias_slice = P.Gather()(bias, slice_idx_tensor, 0)
msa_act_slice = self.attn_mod(msa_act_slice, msa_act_slice, bias_slice, index, nonbatched_bias)
msa_act_slice = P.Reshape()(msa_act_slice, ((1,) + P.Shape()(msa_act_slice)))
msa_act_tuple = msa_act_tuple + (msa_act_slice,)
slice_idx += 1
slice_idx_tensor += 1

while slice_idx < self.slice_num:
msa_act_slice = P.Gather()(msa_act, slice_idx_tensor, 0)
msa_act_slice = F.depend(msa_act_slice, msa_act_tuple[-1])
bias_slice = P.Gather()(bias, slice_idx_tensor, 0)
msa_act_slice = self.attn_mod(msa_act_slice, msa_act_slice, bias_slice, index, nonbatched_bias)
msa_act_slice = P.Reshape()(msa_act_slice, ((1,) + P.Shape()(msa_act_slice)))
msa_act_tuple = msa_act_tuple + (msa_act_slice,)
slice_idx += 1
slice_idx_tensor += 1

msa_act = P.Concat()(msa_act_tuple)
msa_act = P.Reshape()(msa_act, msa_act_ori_shape)
return msa_act

msa_act = self.attn_mod(msa_act, msa_act, bias, index, nonbatched_bias)
return msa_act


class MSAColumnAttention(nn.Cell):
'''MSA column attention'''
def __init__(self, config, msa_act_dim, batch_size=None, slice_num=0):
super(MSAColumnAttention, self).__init__()
self.config = config
self.query_norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5)
self.attn_mod = Attention(self.config, msa_act_dim, msa_act_dim, msa_act_dim, batch_size)
self.batch_size = batch_size
self.slice_num = slice_num
self.msa_act_dim = msa_act_dim
self.idx = Tensor(0, mstype.int32)
self._init_parameter()

def _init_parameter(self):
self.query_norm_gammas = Parameter(Tensor(np.zeros([self.batch_size, self.msa_act_dim]), mstype.float32))
self.query_norm_betas = Parameter(Tensor(np.zeros([self.batch_size, self.msa_act_dim]), mstype.float32))

def construct(self, msa_act, msa_mask, index):
'''construct'''
query_norm_gamma = P.Gather()(self.query_norm_gammas, index, 0)
query_norm_beta = P.Gather()(self.query_norm_betas, index, 0)
msa_act = mnp.swapaxes(msa_act, -2, -3)
msa_mask = mnp.swapaxes(msa_mask, -1, -2)
bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
msa_act, _, _ = self.query_norm(msa_act.astype(mstype.float32), query_norm_gamma.astype(mstype.float32),
query_norm_beta.astype(mstype.float32))
if self.slice_num:
msa_act_ori_shape = P.Shape()(msa_act)
slice_shape = (self.slice_num, -1) + msa_act_ori_shape[1:]
msa_act = P.Reshape()(msa_act, slice_shape).astype(mstype.float16)
bias_shape = P.Shape()(bias)
bias = P.Reshape()(bias, slice_shape[:2] + bias_shape[1:])

slice_idx = 0
slice_idx_tensor = self.idx
msa_act_tuple = ()

msa_act_slice = P.Gather()(msa_act, slice_idx_tensor, 0)
bias_slice = P.Gather()(bias, slice_idx_tensor, 0)
msa_act_slice = self.attn_mod(msa_act_slice, msa_act_slice, bias_slice, index)
msa_act_slice = P.Reshape()(msa_act_slice, ((1,) + P.Shape()(msa_act_slice)))
msa_act_tuple = msa_act_tuple + (msa_act_slice,)
slice_idx += 1
slice_idx_tensor += 1

while slice_idx < self.slice_num:
msa_act_slice = P.Gather()(msa_act, slice_idx_tensor, 0)
msa_act_slice = F.depend(msa_act_slice, msa_act_tuple[-1])
bias_slice = P.Gather()(bias, slice_idx_tensor, 0)
msa_act_slice = self.attn_mod(msa_act_slice, msa_act_slice, bias_slice, index)
msa_act_slice = P.Reshape()(msa_act_slice, ((1,) + P.Shape()(msa_act_slice)))
msa_act_tuple = msa_act_tuple + (msa_act_slice,)
slice_idx += 1
slice_idx_tensor += 1

msa_act = P.Concat()(msa_act_tuple)
msa_act = P.Reshape()(msa_act, msa_act_ori_shape)
msa_act = mnp.swapaxes(msa_act, -2, -3)
return msa_act

msa_act = self.attn_mod(msa_act, msa_act, bias, index)
msa_act = mnp.swapaxes(msa_act, -2, -3)
return msa_act


class GlobalAttention(nn.Cell):
'''global attention'''
def __init__(self, config, key_dim, value_dim, output_dim, batch_size=None):
super(GlobalAttention, self).__init__()
self.config = config
self.key_dim = key_dim
self.ori_key_dim = key_dim
self.value_dim = value_dim
self.ori_value_dim = value_dim
self.num_head = self.config.num_head
self.key_dim = self.key_dim // self.num_head
self.value_dim = self.value_dim // self.num_head
self.output_dim = output_dim
self.matmul_trans_b = P.MatMul(transpose_b=True)
self.batch_matmul = P.BatchMatMul()
self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True)
self.matmul = P.MatMul()
self.softmax = nn.Softmax()
self.sigmoid = nn.Sigmoid()
self.gating = self.config.gating

self.batch_size = batch_size
self._init_parameter()

def _init_parameter(self):
'''init parameter'''
self.linear_q_weights = Parameter(
Tensor(np.zeros((self.batch_size, self.ori_key_dim, self.num_head, self.key_dim)), mstype.float32))
self.linear_k_weights = Parameter(
Tensor(np.zeros((self.batch_size, self.ori_value_dim, self.key_dim)), mstype.float32))
self.linear_v_weights = Parameter(
Tensor(np.zeros((self.batch_size, self.ori_value_dim, self.value_dim)), mstype.float32))
self.linear_output_weights = Parameter(
Tensor(np.zeros((self.batch_size, self.output_dim, self.num_head * self.value_dim)), mstype.float32))
self.o_biases = Parameter(Tensor(np.zeros((self.batch_size, self.output_dim)), mstype.float32))
if self.gating:
self.linear_gating_weights = Parameter(
Tensor(np.zeros((self.batch_size, self.num_head * self.value_dim, self.ori_key_dim)), mstype.float32))
self.gating_biases = Parameter(Tensor(np.zeros((self.batch_size, self.ori_key_dim)), mstype.float32))

def construct(self, q_data, m_data, q_mask, bias, index):
'''construct'''
q_weights = P.Gather()(self.linear_q_weights, index, 0)
k_weights = P.Gather()(self.linear_k_weights, index, 0)
v_weights = P.Gather()(self.linear_v_weights, index, 0)
output_weights = P.Gather()(self.linear_output_weights, index, 0)
output_bias = P.Gather()(self.o_biases, index, 0)
gating_weights = 0
gating_bias = 0
if self.gating:
gating_weights = P.Gather()(self.linear_gating_weights, index, 0)
gating_bias = P.Gather()(self.gating_biases, index, 0)
b, _, _ = m_data.shape
v_weights = v_weights[None, ...]
v_weights = mnp.broadcast_to(v_weights, (b, self.value_dim * self.num_head, self.value_dim))
v = self.batch_matmul(m_data.astype(mstype.float16), v_weights.astype(mstype.float16))
q_avg = mask_mean(q_mask, q_data, axis=1)
q_weights = P.Reshape()(q_weights, (-1, self.num_head * self.key_dim))
q = P.Reshape()(self.matmul(q_avg.astype(mstype.float16), q_weights.astype(mstype.float16)),
(-1, self.num_head, self.key_dim)) * (self.key_dim ** (-0.5))

k_weights = k_weights[None, ...]
k_weights = mnp.broadcast_to(k_weights, (b, self.value_dim * self.num_head, self.key_dim))

k = self.batch_matmul(m_data.astype(mstype.float16), k_weights.astype(mstype.float16))

bias = (1e9 * (q_mask[:, None, :, 0] - 1.))

logits = self.batch_matmul_trans_b(q.astype(mstype.float16), k.astype(mstype.float16)) + bias.astype(
mstype.float16)
weights = self.softmax(logits.astype(mstype.float32))
weighted_avg = self.batch_matmul(weights.astype(mstype.float16), v.astype(mstype.float16))

if self.gating:
# gate_values = self.linear_gating(q_data).astype(mstype.float32)
q_data_shape = P.Shape()(q_data)
if len(q_data_shape) != 2:
q_data = P.Reshape()(q_data, (-1, q_data_shape[-1]))
out_shape = q_data_shape[:-1] + (-1,)
gate_values = P.Reshape()(self.matmul_trans_b(q_data.astype(mstype.float16),
gating_weights.astype(mstype.float16)) +
gating_bias.astype(mstype.float16), out_shape)

gate_values = P.Reshape()(self.sigmoid(gate_values.astype(mstype.float32)),
(b, -1, self.num_head, self.value_dim))
weighted_avg = P.Reshape()(weighted_avg[:, None] * gate_values, (-1, self.num_head * self.value_dim))
weighted_avg_shape = P.Shape()(weighted_avg)
if len(weighted_avg_shape) != 2:
weighted_avg = P.Reshape()(weighted_avg, (-1, weighted_avg_shape[-1]))
output = P.Reshape()(self.matmul_trans_b(weighted_avg.astype(mstype.float16),
output_weights.astype(mstype.float16))
+ output_bias.astype(mstype.float16), (b, -1, self.output_dim))

else:
weighted_avg = P.Reshape()(weighted_avg, (-1, self.num_head * self.value_dim))
# output = self.linear_gating(weighted_avg)
weighted_avg_shape = P.Shape()(weighted_avg)
if len(weighted_avg_shape) != 2:
weighted_avg = P.Reshape()(weighted_avg, (-1, weighted_avg_shape[-1]))
out_shape = weighted_avg_shape[:-1] + (-1,)
output = P.Reshape()(self.matmul_trans_b(weighted_avg.astype(mstype.float16),
output_weights.astype(mstype.float16)) +
output_bias.astype(mstype.float16), out_shape)
output = output[:, None]
return output


class MSAColumnGlobalAttention(nn.Cell):
'''MSA column global attention'''
def __init__(self, config, msa_act_dim, batch_size=None, slice_num=0):
super(MSAColumnGlobalAttention, self).__init__()
self.config = config
self.attn_mod = GlobalAttention(self.config, msa_act_dim, msa_act_dim, msa_act_dim, batch_size)
self.query_norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5)
self.batch_size = batch_size
self.slice_num = slice_num
self.msa_act_dim = msa_act_dim
self.idx = Tensor(0, mstype.int32)
self._init_parameter()

def _init_parameter(self):
'''init parameter'''
self.query_norm_gammas = Parameter(Tensor(np.zeros((self.batch_size, self.msa_act_dim)), mstype.float32))
self.query_norm_betas = Parameter(Tensor(np.zeros((self.batch_size, self.msa_act_dim)), mstype.float32))

def construct(self, msa_act, msa_mask, index):
'''construct'''
query_norm_gamma = P.Gather()(self.query_norm_gammas, index, 0)
query_norm_beta = P.Gather()(self.query_norm_betas, index, 0)
msa_act = mnp.swapaxes(msa_act, -2, -3)
msa_mask = mnp.swapaxes(msa_mask, -1, -2)
bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
msa_act, _, _ = self.query_norm(msa_act.astype(mstype.float32),
query_norm_gamma.astype(mstype.float32),
query_norm_beta.astype(mstype.float32))
msa_mask = mnp.expand_dims(msa_mask, axis=-1)

if self.slice_num:
msa_act_ori_shape = P.Shape()(msa_act)
slice_shape = (self.slice_num, -1) + msa_act_ori_shape[1:]
msa_act = P.Reshape()(msa_act, slice_shape).astype(mstype.float16)
bias_shape = P.Shape()(bias)
bias = P.Reshape()(bias, slice_shape[:2] + bias_shape[1:])
msa_mask_shape = P.Shape()(msa_mask)
msa_mask = P.Reshape()(msa_mask, slice_shape[:2] + msa_mask_shape[1:])

slice_idx = 0
slice_idx_tensor = self.idx
msa_act_tuple = ()

msa_act_slice = P.Gather()(msa_act, slice_idx_tensor, 0)
msa_mask_slice = P.Gather()(msa_mask, slice_idx_tensor, 0)
bias_slice = P.Gather()(bias, slice_idx_tensor, 0)
msa_act_slice = self.attn_mod(msa_act_slice, msa_act_slice, msa_mask_slice, bias_slice, index)
msa_act_slice = P.Reshape()(msa_act_slice, ((1,) + P.Shape()(msa_act_slice)))
msa_act_tuple = msa_act_tuple + (msa_act_slice,)
slice_idx += 1
slice_idx_tensor += 1

while slice_idx < self.slice_num:
msa_act_slice = P.Gather()(msa_act, slice_idx_tensor, 0)
msa_act_slice = F.depend(msa_act_slice, msa_act_tuple[-1])
msa_mask_slice = P.Gather()(msa_mask, slice_idx_tensor, 0)
bias_slice = P.Gather()(bias, slice_idx_tensor, 0)

msa_act_slice = self.attn_mod(msa_act_slice, msa_act_slice, msa_mask_slice, bias_slice, index)
msa_act_slice = P.Reshape()(msa_act_slice, ((1,) + P.Shape()(msa_act_slice)))
msa_act_tuple = msa_act_tuple + (msa_act_slice,)
slice_idx += 1
slice_idx_tensor += 1

msa_act = P.Concat()(msa_act_tuple)
msa_act = P.Reshape()(msa_act, msa_act_ori_shape)
msa_act = mnp.swapaxes(msa_act, -2, -3)
return msa_act

msa_act = self.attn_mod(msa_act, msa_act, msa_mask, bias, index)
msa_act = mnp.swapaxes(msa_act, -2, -3)
return msa_act


class Transition(nn.Cell):
'''transition'''
def __init__(self, config, layer_norm_dim, batch_size=None, slice_num=0):
super(Transition, self).__init__()
self.config = config
self.input_layer_norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5)
self.matmul = P.MatMul(transpose_b=True)
self.layer_norm_dim = layer_norm_dim
self.num_intermediate = int(layer_norm_dim * self.config.num_intermediate_factor)
self.batch_size = batch_size
self.slice_num = slice_num
self.relu = nn.ReLU()
self.idx = Tensor(0, mstype.int32)
self._init_parameter()

def _init_parameter(self):
'''init parameter'''
self.input_layer_norm_gammas = Parameter(
Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32))
self.input_layer_norm_betas = Parameter(
Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32))
self.transition1_weights = Parameter(
Tensor(np.zeros((self.batch_size, self.num_intermediate, self.layer_norm_dim)), mstype.float32))
self.transition1_biases = Parameter(Tensor(np.zeros((self.batch_size, self.num_intermediate)), mstype.float32))
self.transition2_weights = Parameter(
Tensor(np.zeros((self.batch_size, self.layer_norm_dim, self.num_intermediate)), mstype.float32))
self.transition2_biases = Parameter(Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32))

def construct(self, act, index):
'''construct'''
input_layer_norm_gamma = P.Gather()(self.input_layer_norm_gammas, index, 0)
input_layer_norm_beta = P.Gather()(self.input_layer_norm_betas, index, 0)
transition1_weight = P.Gather()(self.transition1_weights, index, 0)
transition1_bias = P.Gather()(self.transition1_biases, index, 0)
transition2_weight = P.Gather()(self.transition2_weights, index, 0)
transition2_bias = P.Gather()(self.transition2_biases, index, 0)
act, _, _ = self.input_layer_norm(act.astype(mstype.float32), input_layer_norm_gamma.astype(mstype.float32),
input_layer_norm_beta.astype(mstype.float32))
if self.slice_num:
act_ori_shape = P.Shape()(act)
slice_shape = (self.slice_num, -1) + act_ori_shape[1:]
act = P.Reshape()(act, slice_shape).astype(mstype.float16)

slice_idx = 0
slice_idx_tensor = self.idx
act_tuple = ()

act_slice = P.Gather()(act, slice_idx_tensor, 0)
act_shape = P.Shape()(act_slice)
if len(act_shape) != 2:
act_slice = P.Reshape()(act_slice, (-1, act_shape[-1]))
act_slice = self.relu(
P.BiasAdd()(self.matmul(act_slice.astype(mstype.float16), transition1_weight.astype(mstype.float16)),
transition1_bias.astype(mstype.float16)).astype(mstype.float32))
act_slice = P.BiasAdd()(
self.matmul(act_slice.astype(mstype.float16), transition2_weight.astype(mstype.float16)),
transition2_bias.astype(mstype.float16))
act_slice = P.Reshape()(act_slice, act_shape)
act_slice = P.Reshape()(act_slice, ((1,) + P.Shape()(act_slice)))
act_tuple = act_tuple + (act_slice,)
slice_idx += 1
slice_idx_tensor += 1

while slice_idx < self.slice_num:
act_slice = P.Gather()(act, slice_idx_tensor, 0)
act_slice = F.depend(act_slice, act_tuple[-1])
act_shape = P.Shape()(act_slice)
if len(act_shape) != 2:
act_slice = P.Reshape()(act_slice, (-1, act_shape[-1]))
act_slice = self.relu(P.BiasAdd()(
self.matmul(act_slice.astype(mstype.float16), transition1_weight.astype(mstype.float16)),
transition1_bias.astype(mstype.float16)).astype(mstype.float32))
act_slice = P.BiasAdd()(
self.matmul(act_slice.astype(mstype.float16), transition2_weight.astype(mstype.float16)),
transition2_bias.astype(mstype.float16))
act_slice = P.Reshape()(act_slice, act_shape)
act_slice = P.Reshape()(act_slice, ((1,) + P.Shape()(act_slice)))
act_tuple = act_tuple + (act_slice,)
slice_idx += 1
slice_idx_tensor += 1

act = P.Concat()(act_tuple)
act = P.Reshape()(act, act_ori_shape)
return act

act_shape = P.Shape()(act)
if len(act_shape) != 2:
act = P.Reshape()(act, (-1, act_shape[-1]))
act = self.relu(
P.BiasAdd()(self.matmul(act.astype(mstype.float16), transition1_weight.astype(mstype.float16)),
transition1_bias.astype(mstype.float16)).astype(mstype.float32))
act = P.BiasAdd()(self.matmul(act.astype(mstype.float16), transition2_weight.astype(mstype.float16)),
transition2_bias.astype(mstype.float16))
act = P.Reshape()(act, act_shape)
return act


class OuterProductMean(nn.Cell):
'''outerproduct mean'''
def __init__(self, config, act_dim, num_output_channel, batch_size=None, slice_num=0):
super(OuterProductMean, self).__init__()
self.num_output_channel = num_output_channel
self.config = config
self.layer_norm_input = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5)
self.matmul_trans_b = P.MatMul(transpose_b=True)
self.matmul = P.MatMul()
self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True)
self.act_dim = act_dim
self.batch_size = batch_size
self.slice_num = slice_num
self.idx = Tensor(0, mstype.int32)
self._init_parameter()

def _init_parameter(self):
'''init parameter'''
self.layer_norm_input_gammas = Parameter(Tensor(np.zeros((self.batch_size, self.act_dim)), mstype.float32))
self.layer_norm_input_betas = Parameter(Tensor(np.zeros((self.batch_size, self.act_dim)), mstype.float32))
self.left_projection_weights = Parameter(
Tensor(np.zeros((self.batch_size, self.config.num_outer_channel, self.act_dim)), mstype.float32))
self.left_projection_biases = Parameter(
Tensor(np.zeros((self.batch_size, self.config.num_outer_channel)), mstype.float32))
self.right_projection_weights = Parameter(
Tensor(np.zeros((self.batch_size, self.config.num_outer_channel, self.act_dim)), mstype.float32))
self.right_projection_biases = Parameter(
Tensor(np.zeros((self.batch_size, self.config.num_outer_channel)), mstype.float32))
self.linear_output_weights = Parameter(Tensor(np.zeros(
(self.batch_size, self.num_output_channel, self.config.num_outer_channel *
self.config.num_outer_channel)), mstype.float32))
self.o_biases = Parameter(Tensor(np.zeros((self.batch_size, self.num_output_channel)), mstype.float32))

def construct(self, act, mask, index):
'''construct'''
layer_norm_input_gamma = P.Gather()(self.layer_norm_input_gammas, index, 0)
layer_norm_input_beta = P.Gather()(self.layer_norm_input_betas, index, 0)
left_projection_weight = P.Gather()(self.left_projection_weights, index, 0)
left_projection_bias = P.Gather()(self.left_projection_biases, index, 0)
right_projection_weight = P.Gather()(self.right_projection_weights, index, 0)
right_projection_bias = P.Gather()(self.right_projection_biases, index, 0)
linear_output_weight = P.Gather()(self.linear_output_weights, index, 0)
linear_output_bias = P.Gather()(self.o_biases, index, 0)
mask = mask[..., None]
act, _, _ = self.layer_norm_input(act.astype(mstype.float32),
layer_norm_input_gamma.astype(mstype.float32),
layer_norm_input_beta.astype(mstype.float32))
act_shape = P.Shape()(act)
if len(act_shape) != 2:
act = P.Reshape()(act, (-1, act_shape[-1]))
out_shape = act_shape[:-1] + (-1,)
left_act = mask * P.Reshape()(P.BiasAdd()(self.matmul_trans_b(act.astype(mstype.float16),
left_projection_weight.astype(mstype.float16)),
left_projection_bias.astype(mstype.float16)), out_shape)
right_act = mask * P.Reshape()(P.BiasAdd()(self.matmul_trans_b(act.astype(mstype.float16),
right_projection_weight.astype(mstype.float16)),
right_projection_bias.astype(mstype.float16)), out_shape)
_, d, e = right_act.shape
if self.slice_num:
left_act_shape = P.Shape()(left_act)
slice_shape = (left_act_shape[0],) + (self.slice_num, -1) + (left_act_shape[-1],)
left_act = P.Reshape()(left_act, slice_shape)
slice_idx = 0
slice_idx_tensor = self.idx
act_tuple = ()
left_act_slice = P.Gather()(left_act, slice_idx_tensor, 1)
a, b, c = left_act_slice.shape
left_act_slice = P.Reshape()(mnp.transpose(left_act_slice.astype(mstype.float16), [2, 1, 0]), (-1, a))
right_act = P.Reshape()(right_act, (a, -1))
act_slice = P.Reshape()(P.Transpose()(P.Reshape()(self.matmul(left_act_slice.astype(mstype.float16),
right_act.astype(mstype.float16)),
(c, b, d, e)), (2, 1, 0, 3)), (d, b, c * e))
act_slice_shape = P.Shape()(act_slice)
if len(act_shape) != 2:
act_slice = P.Reshape()(act_slice, (-1, act_slice_shape[-1]))
act_slice = P.Reshape()(P.BiasAdd()(self.matmul_trans_b(act_slice.astype(mstype.float16),
linear_output_weight.astype(mstype.float16)),
linear_output_bias.astype(mstype.float16)), (d, b, -1))
act_slice = mnp.transpose(act_slice.astype(mstype.float16), [1, 0, 2])
act_slice = P.Reshape()(act_slice, ((1,) + P.Shape()(act_slice)))
act_tuple = act_tuple + (act_slice,)
slice_idx += 1
slice_idx_tensor += 1
while slice_idx < self.slice_num:
left_act_slice = P.Gather()(left_act, slice_idx_tensor, 1)
left_act_slice = F.depend(left_act_slice, act_tuple[-1])
a, b, c = left_act_slice.shape
left_act_slice = P.Reshape()(mnp.transpose(left_act_slice.astype(mstype.float16), [2, 1, 0]), (-1, a))
right_act = P.Reshape()(right_act, (a, -1))
act_slice = P.Reshape()(P.Transpose()(P.Reshape()(self.matmul(left_act_slice.astype(mstype.float16),
right_act.astype(mstype.float16)),
(c, b, d, e)), (2, 1, 0, 3)), (d, b, c * e))
act_slice_shape = P.Shape()(act_slice)
if len(act_shape) != 2:
act_slice = P.Reshape()(act_slice, (-1, act_slice_shape[-1]))
act_slice = P.Reshape()(P.BiasAdd()(self.matmul_trans_b(act_slice.astype(mstype.float16),
linear_output_weight.astype(mstype.float16)),
linear_output_bias.astype(mstype.float16)), (d, b, -1))
act_slice = mnp.transpose(act_slice.astype(mstype.float16), [1, 0, 2])
act_slice = P.Reshape()(act_slice, ((1,) + P.Shape()(act_slice)))
act_tuple = act_tuple + (act_slice,)
slice_idx += 1
slice_idx_tensor += 1
act = P.Concat()(act_tuple)
act_shape = P.Shape()(act)
act = P.Reshape()(act, (-1, act_shape[-2], act_shape[-1]))
epsilon = 1e-3
tmp_mask = P.Transpose()(mask.astype(mstype.float16), (2, 1, 0))
norm = P.Transpose()(self.batch_matmul_trans_b(tmp_mask.astype(mstype.float16),
tmp_mask.astype(mstype.float16)),
(1, 2, 0)).astype(mstype.float32)
act /= epsilon + norm
return act

a, b, c = left_act.shape
left_act = P.Reshape()(mnp.transpose(left_act.astype(mstype.float16), [2, 1, 0]), (-1, a))
right_act = P.Reshape()(right_act, (a, -1))
act = P.Reshape()(P.Transpose()(P.Reshape()(self.matmul(left_act.astype(mstype.float16),
right_act.astype(mstype.float16)),
(c, b, d, e)), (2, 1, 0, 3)), (d, b, c * e))
act_shape = P.Shape()(act)
if len(act_shape) != 2:
act = P.Reshape()(act, (-1, act_shape[-1]))
act = P.Reshape()(P.BiasAdd()(self.matmul_trans_b(act.astype(mstype.float16),
linear_output_weight.astype(mstype.float16)),
linear_output_bias.astype(mstype.float16)), (d, b, -1))
act = mnp.transpose(act.astype(mstype.float16), [1, 0, 2])
epsilon = 1e-3
tmp_mask = P.Transpose()(mask.astype(mstype.float16), (2, 1, 0))
norm = P.Transpose()(self.batch_matmul_trans_b(tmp_mask.astype(mstype.float16),
tmp_mask.astype(mstype.float16)),
(1, 2, 0)).astype(mstype.float32)
act /= epsilon + norm
return act

class TriangleMultiplication(nn.Cell):
'''triangle multiplication'''
def __init__(self, config, layer_norm_dim, batch_size):
super(TriangleMultiplication, self).__init__()
self.config = config
self.layer_norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5)
self.matmul = P.MatMul(transpose_b=True)
self.sigmoid = nn.Sigmoid()
self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True)
equation = ["ikc,jkc->ijc", "kjc,kic->ijc"]
if self.config.equation not in equation:
print("TriangleMultiplication Not Suppl")
if self.config.equation == "ikc,jkc->ijc":
self.equation = True
elif self.config.equation == "kjc,kic->ijc":
self.equation = False
else:
self.equation = None
self.batch_size = batch_size
self.layer_norm_dim = layer_norm_dim
self._init_parameter()

def _init_parameter(self):
'''init parameter'''
self.layer_norm_input_gammas = Parameter(
Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32))
self.layer_norm_input_betas = Parameter(
Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32))
self.left_projection_weights = Parameter(
Tensor(np.zeros((self.batch_size, self.config.num_intermediate_channel, self.layer_norm_dim)),
mstype.float32))
self.left_projection_biases = Parameter(
Tensor(np.zeros((self.batch_size, self.config.num_intermediate_channel)), mstype.float32))
self.right_projection_weights = Parameter(
Tensor(np.zeros((self.batch_size, self.config.num_intermediate_channel, self.layer_norm_dim)),
mstype.float32))
self.right_projection_biases = Parameter(
Tensor(np.zeros((self.batch_size, self.config.num_intermediate_channel)), mstype.float32))
self.left_gate_weights = Parameter(
Tensor(np.zeros((self.batch_size, self.config.num_intermediate_channel, self.layer_norm_dim)),
mstype.float32))
self.left_gate_biases = Parameter(
Tensor(np.zeros((self.batch_size, self.config.num_intermediate_channel)), mstype.float32))
self.right_gate_weights = Parameter(
Tensor(np.zeros((self.batch_size, self.config.num_intermediate_channel, self.layer_norm_dim)),
mstype.float32))
self.right_gate_biases = Parameter(
Tensor(np.zeros((self.batch_size, self.config.num_intermediate_channel)), mstype.float32))
self.center_layer_norm_gammas = Parameter(
Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32))
self.center_layer_norm_betas = Parameter(
Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32))
self.output_projection_weights = Parameter(
Tensor(np.zeros((self.batch_size, self.layer_norm_dim, self.layer_norm_dim)), mstype.float32))
self.output_projection_biases = Parameter(
Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32))
self.gating_linear_weights = Parameter(
Tensor(np.zeros((self.batch_size, self.layer_norm_dim, self.layer_norm_dim)), mstype.float32))
self.gating_linear_biases = Parameter(Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32))

def construct(self, act, mask, index):
'''construct'''
layer_norm_input_gamma = P.Gather()(self.layer_norm_input_gammas, index, 0)
layer_norm_input_beta = P.Gather()(self.layer_norm_input_betas, index, 0)
left_projection_weight = P.Gather()(self.left_projection_weights, index, 0)
left_projection_bias = P.Gather()(self.left_projection_biases, index, 0)
right_projection_weight = P.Gather()(self.right_projection_weights, index, 0)
right_projection_bias = P.Gather()(self.right_projection_biases, index, 0)
left_gate_weight = P.Gather()(self.left_gate_weights, index, 0)
left_gate_bias = P.Gather()(self.left_gate_biases, index, 0)
right_gate_weight = P.Gather()(self.right_gate_weights, index, 0)
right_gate_bias = P.Gather()(self.right_gate_biases, index, 0)
center_layer_norm_gamma = P.Gather()(self.center_layer_norm_gammas, index, 0)
center_layer_norm_beta = P.Gather()(self.center_layer_norm_betas, index, 0)
output_projection_weight = P.Gather()(self.output_projection_weights, index, 0)
output_projection_bias = P.Gather()(self.output_projection_biases, index, 0)
gating_linear_weight = P.Gather()(self.gating_linear_weights, index, 0)
gating_linear_bias = P.Gather()(self.gating_linear_biases, index, 0)

mask = mask[..., None]
act, _, _ = self.layer_norm(act.astype(mstype.float32),
layer_norm_input_gamma.astype(mstype.float32),
layer_norm_input_beta.astype(mstype.float32))
input_act = act
act_shape = P.Shape()(act)
if len(act_shape) != 2:
act = P.Reshape()(act, (-1, act_shape[-1]))
out_shape = act_shape[:-1] + (-1,)
left_projection = mask * P.Reshape()(
P.BiasAdd()(self.matmul(act.astype(mstype.float16), left_projection_weight.astype(mstype.float16)),
left_projection_bias.astype(mstype.float16)), out_shape)
left_projection = left_projection.astype(mstype.float16)
act = F.depend(act, left_projection)

left_gate_values = self.sigmoid(P.Reshape()(
P.BiasAdd()(self.matmul(act.astype(mstype.float16), left_gate_weight.astype(mstype.float16)),
left_gate_bias.astype(mstype.float16)), out_shape).astype(mstype.float32))
left_proj_act = left_projection * left_gate_values
act = F.depend(act, left_proj_act)

right_projection = mask * P.Reshape()(
P.BiasAdd()(self.matmul(act.astype(mstype.float16), right_projection_weight.astype(mstype.float16)),
right_projection_bias.astype(mstype.float16)), out_shape)
right_projection = right_projection.astype(mstype.float16)
act = F.depend(act, right_projection)

right_gate_values = self.sigmoid(P.Reshape()(
P.BiasAdd()(self.matmul(act.astype(mstype.float16), right_gate_weight.astype(mstype.float16)),
right_gate_bias.astype(mstype.float16)), out_shape).astype(mstype.float32))
right_proj_act = right_projection * right_gate_values
left_proj_act = F.depend(left_proj_act, right_proj_act)

if self.equation is not None:
if self.equation:
left_proj_act_tmp = P.Transpose()(left_proj_act.astype(mstype.float16), (2, 0, 1))
right_proj_act_tmp = P.Transpose()(right_proj_act.astype(mstype.float16), (2, 0, 1))
act = self.batch_matmul_trans_b(left_proj_act_tmp, right_proj_act_tmp)
act = P.Transpose()(act, (1, 2, 0)).astype(mstype.float32)
else:
left_proj_act_tmp = P.Transpose()(left_proj_act.astype(mstype.float16), (2, 1, 0))
right_proj_act_tmp = P.Transpose()(right_proj_act.astype(mstype.float16), (2, 1, 0))
act = self.batch_matmul_trans_b(left_proj_act_tmp, right_proj_act_tmp)
act = P.Transpose()(act, (2, 1, 0)).astype(mstype.float32)
act, _, _ = self.layer_norm(act.astype(mstype.float32),
center_layer_norm_gamma.astype(mstype.float32),
center_layer_norm_beta.astype(mstype.float32))
act_shape = P.Shape()(act)
if len(act_shape) != 2:
act = P.Reshape()(act, (-1, act_shape[-1]))
out_shape = act_shape[:-1] + (-1,)
act = P.Reshape()(
P.BiasAdd()(self.matmul(act.astype(mstype.float16), output_projection_weight.astype(mstype.float16)),
output_projection_bias.astype(mstype.float16)), out_shape)
input_act_shape = P.Shape()(input_act)
if len(input_act_shape) != 2:
input_act = P.Reshape()(input_act, (-1, input_act_shape[-1]))
out_shape = input_act_shape[:-1] + (-1,)
gate_values = self.sigmoid(P.Reshape()(
P.BiasAdd()(self.matmul(input_act.astype(mstype.float16), gating_linear_weight.astype(mstype.float16)),
gating_linear_bias.astype(mstype.float16)), out_shape).astype(mstype.float32))
act = act * gate_values
return act


class TriangleAttention(nn.Cell):
'''triangle attention'''
def __init__(self, config, layer_norm_dim, batch_size=None, slice_num=0):
super(TriangleAttention, self).__init__()
self.config = config
self.orientation_is_per_column = (self.config.orientation == 'per_column')
self.init_factor = Tensor(1. / np.sqrt(layer_norm_dim), mstype.float32)
self.query_norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5)
self.matmul = P.MatMul(transpose_b=True)
self.attn_mod = Attention(self.config, layer_norm_dim, layer_norm_dim, layer_norm_dim, batch_size)
self.batch_size = batch_size
self.slice_num = slice_num
self.layer_norm_dim = layer_norm_dim
self.idx = Tensor(0, mstype.int32)
self._init_parameter()

def _init_parameter(self):
'''init parameter'''
self.query_norm_gammas = Parameter(Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32))
self.query_norm_betas = Parameter(Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32))
self.feat_2d_weights = Parameter(
Tensor(np.zeros((self.batch_size, self.config.num_head, self.layer_norm_dim)), mstype.float32))

def construct(self, pair_act, pair_mask, index):
'''construct'''
query_norm_gamma = P.Gather()(self.query_norm_gammas, index, 0)
query_norm_beta = P.Gather()(self.query_norm_betas, index, 0)
feat_2d_weight = P.Gather()(self.feat_2d_weights, index, 0)
if self.orientation_is_per_column:
pair_act = mnp.swapaxes(pair_act, -2, -3)
pair_mask = mnp.swapaxes(pair_mask, -1, -2)
bias = (1e9 * (pair_mask - 1.))[:, None, None, :]
pair_act, _, _ = self.query_norm(pair_act.astype(mstype.float32),
query_norm_gamma.astype(mstype.float32),
query_norm_beta.astype(mstype.float32))
q, k, _ = pair_act.shape
nonbatched_bias = self.matmul(P.Reshape()(pair_act.astype(mstype.float16), (-1, pair_act.shape[-1])),
feat_2d_weight.astype(mstype.float16))
nonbatched_bias = P.Transpose()(P.Reshape()(nonbatched_bias, (q, k, -1)), (2, 0, 1))
if self.slice_num:
pair_act_ori_shape = P.Shape()(pair_act)
slice_shape = (self.slice_num, -1) + pair_act_ori_shape[1:]
pair_act = P.Reshape()(pair_act, slice_shape).astype(mstype.float16)
bias_shape = P.Shape()(bias)
bias = P.Reshape()(bias, slice_shape[:2] + bias_shape[1:])

slice_idx = 0
slice_idx_tensor = self.idx
pair_act_tuple = ()

pair_act_slice = P.Gather()(pair_act, slice_idx_tensor, 0)
bias_slice = P.Gather()(bias, slice_idx_tensor, 0)
pair_act_slice = self.attn_mod(pair_act_slice, pair_act_slice, bias_slice, index, nonbatched_bias)
pair_act_slice = P.Reshape()(pair_act_slice, ((1,) + P.Shape()(pair_act_slice)))
pair_act_tuple = pair_act_tuple + (pair_act_slice,)
slice_idx += 1
slice_idx_tensor += 1

while slice_idx < self.slice_num:
pair_act_slice = P.Gather()(pair_act, slice_idx_tensor, 0)
pair_act_slice = F.depend(pair_act_slice, pair_act_tuple[-1])
bias_slice = P.Gather()(bias, slice_idx_tensor, 0)
pair_act_slice = self.attn_mod(pair_act_slice, pair_act_slice, bias_slice, index, nonbatched_bias)
pair_act_slice = P.Reshape()(pair_act_slice, ((1,) + P.Shape()(pair_act_slice)))
pair_act_tuple = pair_act_tuple + (pair_act_slice,)
slice_idx += 1
slice_idx_tensor += 1
pair_act = P.Concat()(pair_act_tuple)
pair_act = P.Reshape()(pair_act, pair_act_ori_shape)

if self.orientation_is_per_column:
pair_act = mnp.swapaxes(pair_act, -2, -3)
return pair_act

pair_act = self.attn_mod(pair_act, pair_act, bias, index, nonbatched_bias)
if self.orientation_is_per_column:
pair_act = mnp.swapaxes(pair_act, -2, -3)
return pair_act

+ 304
- 0
reproduce/AlphaFold2-Chinese/module/evoformer_module.py View File

@@ -0,0 +1,304 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Evoformer module"""
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
import mindspore.numpy as mnp
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.tensor import Tensor
from mindspore import Parameter

from commons import residue_constants
from commons.utils import dgram_from_positions, batch_make_transform_from_reference, batch_quat_affine, \
batch_invert_point, batch_rot_to_quat
from module.basic_module import Attention, MSARowAttentionWithPairBias, MSAColumnAttention, MSAColumnGlobalAttention, \
Transition, OuterProductMean, TriangleMultiplication, TriangleAttention
class EvoformerIteration(nn.Cell):
'''evoformer iteration'''
def __init__(self, config, msa_act_dim, pair_act_dim, is_extra_msa, batch_size, global_config):
super(EvoformerIteration, self).__init__()
self.config = config
self.is_extra_msa = is_extra_msa
if not self.is_extra_msa:
self.msa_row_attention_with_pair_bias = MSARowAttentionWithPairBias(
self.config.msa_row_attention_with_pair_bias, msa_act_dim, pair_act_dim, batch_size,
global_config.evoformer_iteration.msa_row_attention_with_pair_bias.slice_num)
self.attn_mod = MSAColumnAttention(self.config.msa_column_attention, msa_act_dim, batch_size,
global_config.evoformer_iteration.msa_column_attention.slice_num)
self.msa_transition = Transition(self.config.msa_transition, msa_act_dim, batch_size,
global_config.evoformer_iteration.msa_transition.slice_num)
self.outer_product_mean = OuterProductMean(self.config.outer_product_mean, msa_act_dim, pair_act_dim,
batch_size,
global_config.evoformer_iteration.outer_product_mean.slice_num)
self.triangle_attention_starting_node = \
TriangleAttention(self.config.triangle_attention_starting_node,
pair_act_dim, batch_size,
global_config.evoformer_iteration.triangle_attention_starting_node.slice_num)
self.triangle_attention_ending_node = \
TriangleAttention(self.config.triangle_attention_ending_node,
pair_act_dim, batch_size,
global_config.evoformer_iteration.triangle_attention_ending_node.slice_num)
self.pair_transition = Transition(self.config.pair_transition, pair_act_dim, batch_size,
global_config.evoformer_iteration.pair_transition.slice_num)
else:
self.msa_row_attention_with_pair_bias = MSARowAttentionWithPairBias(
self.config.msa_row_attention_with_pair_bias, msa_act_dim, pair_act_dim, batch_size,
global_config.extra_msa_stack.msa_row_attention_with_pair_bias.slice_num)
self.attn_mod = \
MSAColumnGlobalAttention(self.config.msa_column_attention, msa_act_dim, batch_size,
global_config.extra_msa_stack.msa_column_global_attention.slice_num)
self.msa_transition = Transition(self.config.msa_transition, msa_act_dim, batch_size,
global_config.extra_msa_stack.msa_transition.slice_num)
self.outer_product_mean = OuterProductMean(self.config.outer_product_mean, msa_act_dim, pair_act_dim,
batch_size,
global_config.extra_msa_stack.outer_product_mean.slice_num)
self.triangle_attention_starting_node = \
TriangleAttention(self.config.triangle_attention_starting_node,
pair_act_dim, batch_size,
global_config.extra_msa_stack.triangle_attention_starting_node.slice_num)
self.triangle_attention_ending_node = \
TriangleAttention(self.config.triangle_attention_ending_node,
pair_act_dim, batch_size,
global_config.extra_msa_stack.triangle_attention_ending_node.slice_num)
self.pair_transition = Transition(self.config.pair_transition, pair_act_dim, batch_size,
global_config.extra_msa_stack.pair_transition.slice_num)

self.triangle_multiplication_outgoing = TriangleMultiplication(self.config.triangle_multiplication_outgoing,
pair_act_dim, batch_size)
self.triangle_multiplication_incoming = TriangleMultiplication(self.config.triangle_multiplication_incoming,
pair_act_dim, batch_size)

def construct(self, msa_act, pair_act, msa_mask, pair_mask, index):
'''construct'''
msa_act = msa_act + self.msa_row_attention_with_pair_bias(msa_act, msa_mask, pair_act, index)
msa_act = msa_act + self.attn_mod(msa_act, msa_mask, index)
msa_act = msa_act + self.msa_transition(msa_act, index)
pair_act = pair_act + self.outer_product_mean(msa_act, msa_mask, index)
pair_act = pair_act + self.triangle_multiplication_outgoing(pair_act, pair_mask, index)
pair_act = pair_act + self.triangle_multiplication_incoming(pair_act, pair_mask, index)
pair_act = pair_act + self.triangle_attention_starting_node(pair_act, pair_mask, index)
pair_act = pair_act + self.triangle_attention_ending_node(pair_act, pair_mask, index)
pair_act = pair_act + self.pair_transition(pair_act, index)
return msa_act.astype(mstype.float16), pair_act.astype(mstype.float16)


class TemplatePairStack(nn.Cell):
'''template pair stack'''
def __init__(self, config, global_config=None):
super(TemplatePairStack, self).__init__()
self.config = config
self.global_config = global_config
self.num_block = self.config.num_block
# self.seq_length = global_config.step_size

self.triangle_attention_starting_node = \
TriangleAttention(self.config.triangle_attention_starting_node,
64, self.num_block,
global_config.template_pair_stack.triangle_attention_starting_node.slice_num)
self.triangle_attention_ending_node = \
TriangleAttention(self.config.triangle_attention_ending_node,
64, self.num_block,
global_config.template_pair_stack.triangle_attention_ending_node.slice_num)
# Hard Code
self.pair_transition = Transition(self.config.pair_transition, 64, self.num_block,
global_config.template_pair_stack.pair_transition.slice_num)
self.triangle_multiplication_outgoing = TriangleMultiplication(self.config.triangle_multiplication_outgoing,
layer_norm_dim=64, batch_size=self.num_block)
self.triangle_multiplication_incoming = TriangleMultiplication(self.config.triangle_multiplication_incoming,
layer_norm_dim=64, batch_size=self.num_block)

def construct(self, pair_act, pair_mask, index):
if not self.num_block:
return pair_act

pair_act = pair_act + self.triangle_attention_starting_node(pair_act, pair_mask, index)
pair_act = pair_act + self.triangle_attention_ending_node(pair_act, pair_mask, index)
pair_act = pair_act + self.triangle_multiplication_outgoing(pair_act, pair_mask, index)
pair_act = pair_act + self.triangle_multiplication_incoming(pair_act, pair_mask, index)
pair_act = pair_act + self.pair_transition(pair_act, index)
return pair_act.astype(mstype.float16)


class SingleTemplateEmbedding(nn.Cell):
'''single template embedding'''
def __init__(self, config, global_config=None):
super(SingleTemplateEmbedding, self).__init__()
self.config = config
# self.seq_length = global_config.step_size
self.num_channels = (self.config.template_pair_stack.triangle_attention_ending_node.value_dim)
self.embedding2d = nn.Dense(88, self.num_channels).to_float(mstype.float16)
self.template_pair_stack = TemplatePairStack(self.config.template_pair_stack, global_config)
self.num_bins = self.config.dgram_features.num_bins
self.min_bin = self.config.dgram_features.min_bin
self.max_bin = self.config.dgram_features.max_bin

self.one_hot = nn.OneHot(depth=22, axis=-1)
self.n, self.ca, self.c = [residue_constants.atom_order[a] for a in ('N', 'CA', 'C')]

self.use_template_unit_vector = self.config.use_template_unit_vector
layer_norm_dim = 64
self.output_layer_norm = nn.LayerNorm([layer_norm_dim,], epsilon=1e-5)

self.idx_num_block = Parameter(Tensor(0, mstype.int32), requires_grad=False)
self.idx_batch_loop = Parameter(Tensor(0, mstype.int32), requires_grad=False)
# self.num_block = Tensor(self.template_pair_stack.num_block, mstype.int32)
# self.batch_block = Tensor(4, mstype.int32)
self.num_block = self.template_pair_stack.num_block
self.batch_block = 4
self._act = Parameter(
Tensor(np.zeros([global_config.seq_length, global_config.seq_length, 64]).astype(np.float16)),
requires_grad=False)

def construct(self, query_embedding, mask_2d, template_aatype, template_all_atom_masks, template_all_atom_positions,
template_pseudo_beta_mask, template_pseudo_beta):
'''construct'''
num_res = template_aatype[0, ...].shape[0]
template_mask_2d_temp = P.Cast()(template_pseudo_beta_mask[:, :, None] * template_pseudo_beta_mask[:, None, :],
query_embedding.dtype)
template_dgram_temp = dgram_from_positions(template_pseudo_beta, self.num_bins, self.min_bin, self.max_bin)
template_dgram_temp = P.Cast()(template_dgram_temp, query_embedding.dtype)

to_concat_temp = (template_dgram_temp, template_mask_2d_temp[:, :, :, None])
aatype_temp = self.one_hot(template_aatype)
to_concat_temp = to_concat_temp + (mnp.tile(aatype_temp[:, None, :, :], (1, num_res, 1, 1)),
mnp.tile(aatype_temp[:, :, None, :], (1, 1, num_res, 1)))
rot_temp, trans_temp = batch_make_transform_from_reference(template_all_atom_positions[:, :, self.n],
template_all_atom_positions[:, :, self.ca],
template_all_atom_positions[:, :, self.c])

_, rotation_tmp, translation_tmp = batch_quat_affine(
batch_rot_to_quat(rot_temp, unstack_inputs=True), translation=trans_temp, rotation=rot_temp,
unstack_inputs=True)
points_tmp = mnp.expand_dims(translation_tmp, axis=-2)
affine_vec_tmp = batch_invert_point(points_tmp, rotation_tmp, translation_tmp, extra_dims=1)
inv_distance_scalar_tmp = P.Rsqrt()(1e-6 + mnp.sum(mnp.square(affine_vec_tmp), axis=1))
template_mask_tmp = (template_all_atom_masks[:, :, self.n] *
template_all_atom_masks[:, :, self.ca] *
template_all_atom_masks[:, :, self.c])
template_mask_2d_tmp = template_mask_tmp[:, :, None] * template_mask_tmp[:, None, :]

inv_distance_scalar_tmp = inv_distance_scalar_tmp * template_mask_2d_tmp.astype(inv_distance_scalar_tmp.dtype)
unit_vector_tmp = P.Transpose()((affine_vec_tmp * inv_distance_scalar_tmp[:, None, ...]), (0, 2, 3, 1))
template_mask_2d_tmp = P.Cast()(template_mask_2d_tmp, query_embedding.dtype)
if not self.use_template_unit_vector:
unit_vector_tmp = mnp.zeros_like(unit_vector_tmp)
to_concat_temp = to_concat_temp + (unit_vector_tmp, template_mask_2d_tmp[..., None],)
act_tmp = mnp.concatenate(to_concat_temp, axis=-1)
act_tmp = act_tmp * template_mask_2d_tmp[..., None]
act_tmp = self.embedding2d(act_tmp)

idx_batch_loop = self.idx_batch_loop
output = []
idx_batch_loop_int = 0
self.idx_num_block = 0

while idx_batch_loop_int < self.batch_block:
self.idx_num_block = 0
idx_num_block_int = 0
self._act = P.Gather()(act_tmp, idx_batch_loop, 0)
while idx_num_block_int < self.num_block:
self._act = self.template_pair_stack(self._act, mask_2d, self.idx_num_block)
self.idx_num_block += 1
idx_num_block_int += 1
temp_act = P.Reshape()(self._act, ((1,) + P.Shape()(self._act)))
output.append(temp_act)
idx_batch_loop += 1
idx_batch_loop_int += 1

act_tmp_loop = P.Concat()(output)
act_tmp = self.output_layer_norm(act_tmp_loop.astype(mstype.float32))
return act_tmp


class TemplateEmbedding(nn.Cell):
'''template embedding'''
def __init__(self, config, slice_num, global_config=None):
super(TemplateEmbedding, self).__init__()
self.config = config
self.global_config = global_config
self.num_channels = (self.config.template_pair_stack.triangle_attention_ending_node.value_dim)
self.template_embedder = SingleTemplateEmbedding(self.config, self.global_config)
self.template_pointwise_attention = Attention(self.config.attention, q_data_dim=128, m_data_dim=64,
output_dim=128)
self.slice_num = slice_num
if slice_num == 0:
slice_num = 1
self._flat_query_slice = Parameter(
Tensor(np.zeros((int(global_config.seq_length * global_config.seq_length / slice_num), 1, 128)),
dtype=mstype.float16), requires_grad=False)
self._flat_templates_slice = Parameter(
Tensor(np.zeros((int(global_config.seq_length * global_config.seq_length / slice_num), 4, 64)),
dtype=mstype.float16), requires_grad=False)

def construct(self, query_embedding, template_aatype, template_all_atom_masks, template_all_atom_positions,
template_mask, template_pseudo_beta_mask, template_pseudo_beta, mask_2d):
'''construct'''
num_templates = template_mask.shape[0]
num_channels = self.num_channels
num_res = query_embedding.shape[0]
query_num_channels = query_embedding.shape[-1]
template_mask = P.Cast()(template_mask, query_embedding.dtype)

mask_2d = F.depend(mask_2d, query_embedding)
template_pair_representation = self.template_embedder(query_embedding, mask_2d, template_aatype,
template_all_atom_masks, template_all_atom_positions,
template_pseudo_beta_mask,
template_pseudo_beta)

template_pair_representation = template_pair_representation.astype(mstype.float32)
flat_query = mnp.reshape(query_embedding, [num_res * num_res, 1, query_num_channels])
flat_templates = mnp.reshape(
mnp.transpose(template_pair_representation.astype(mstype.float16), [1, 2, 0, 3]),
[num_res * num_res, num_templates, num_channels]).astype(mstype.float32)
bias = (1e9 * (template_mask[None, None, None, :] - 1.))
flat_query, flat_templates, bias = flat_query.astype(mstype.float32), flat_templates.astype(
mstype.float32), bias.astype(mstype.float32)

if self.slice_num:
slice_shape = (self.slice_num, -1)
flat_query_shape = P.Shape()(flat_query)
flat_query = P.Reshape()(flat_query, slice_shape + flat_query_shape[1:]).astype(mstype.float16)
flat_templates_shape = P.Shape()(flat_templates)
flat_templates = P.Reshape()(flat_templates, slice_shape + flat_templates_shape[1:]).astype(mstype.float16)
slice_idx = 0
embedding_tuple = ()
while slice_idx < self.slice_num:
self._flat_query_slice = flat_query[slice_idx]
self._flat_templates_slice = flat_templates[slice_idx]
embedding_slice = self.template_pointwise_attention(self._flat_query_slice, self._flat_templates_slice,
bias, index=None, nonbatched_bias=None)
embedding_slice = P.Reshape()(embedding_slice, ((1,) + P.Shape()(embedding_slice)))
embedding_tuple = embedding_tuple + (embedding_slice,)
slice_idx += 1
embedding = P.Concat()(embedding_tuple)

embedding = embedding.astype(mstype.float32)
embedding = mnp.reshape(embedding, [num_res, num_res, query_num_channels])
# No gradients if no templates.
template_mask = template_mask.astype(embedding.dtype)
embedding = embedding * (mnp.sum(template_mask) > 0.).astype(embedding.dtype)
return embedding

embedding = self.template_pointwise_attention(flat_query, flat_templates, bias, index=None,
nonbatched_bias=None)
embedding = embedding.astype(mstype.float32)
embedding = mnp.reshape(embedding, [num_res, num_res, query_num_channels])
# No gradients if no templates.
template_mask = template_mask.astype(embedding.dtype)
embedding = embedding * (mnp.sum(template_mask) > 0.).astype(embedding.dtype)
return embedding

+ 235
- 0
reproduce/AlphaFold2-Chinese/module/model.py View File

@@ -0,0 +1,235 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""AlphaFold Model"""

import numpy as np

import mindspore.nn as nn
import mindspore.common.dtype as mstype
import mindspore.numpy as mnp
from mindspore.common.tensor import Tensor
from mindspore import Parameter
from mindspore.ops import functional as F

from commons import residue_constants
from commons.utils import get_chi_atom_indices, pseudo_beta_fn, dgram_from_positions, atom37_to_torsion_angles
from module.evoformer_module import TemplateEmbedding, EvoformerIteration
from module.structure_module import StructureModule, PredictedLDDTHead

class AlphaFold(nn.Cell):
"""AlphaFold Model"""
def __init__(self, config, global_config):
super(AlphaFold, self).__init__()
self.config = config.model.embeddings_and_evoformer
self.preprocess_1d = nn.Dense(22, self.config.msa_channel).to_float(mstype.float16)
self.preprocess_msa = nn.Dense(49, self.config.msa_channel).to_float(mstype.float16)
self.left_single = nn.Dense(22, self.config.pair_channel).to_float(mstype.float16)
self.right_single = nn.Dense(22, self.config.pair_channel).to_float(mstype.float16)
self.prev_pos_linear = nn.Dense(15, self.config.pair_channel).to_float(mstype.float16)
self.pair_activations = nn.Dense(65, self.config.pair_channel).to_float(mstype.float16)
self.prev_msa_first_row_norm = nn.LayerNorm([256,], epsilon=1e-5)
self.prev_pair_norm = nn.LayerNorm([128,], epsilon=1e-5)
self.one_hot = nn.OneHot(depth=self.config.max_relative_feature * 2 + 1, axis=-1)
self.extra_msa_activations = nn.Dense(25, self.config.extra_msa_channel).to_float(mstype.float16)
self.template_single_embedding = nn.Dense(57, self.config.msa_channel).to_float(mstype.float16)
self.template_projection = nn.Dense(self.config.msa_channel, self.config.msa_channel).to_float(mstype.float16)
self.single_activations = nn.Dense(self.config.msa_channel, self.config.seq_channel).to_float(mstype.float16)
self.relu = nn.ReLU()
self.recycle_pos = self.config.recycle_pos
self.recycle_features = self.config.recycle_features
self.template_enable = self.config.template.enabled
self.max_relative_feature = self.config.max_relative_feature
self.template_enabled = self.config.template.enabled
self.template_embed_torsion_angles = self.config.template.embed_torsion_angles
self.num_bins = self.config.prev_pos.num_bins
self.min_bin = self.config.prev_pos.min_bin
self.max_bin = self.config.prev_pos.max_bin
self.extra_msa_one_hot = nn.OneHot(depth=23, axis=-1)
self.template_aatype_one_hot = nn.OneHot(depth=22, axis=-1)
self.template_embedding = TemplateEmbedding(self.config.template,
global_config.template_embedding.slice_num,
global_config=global_config)
self.extra_msa_stack_iteration = EvoformerIteration(self.config.evoformer,
msa_act_dim=64,
pair_act_dim=128,
is_extra_msa=True,
batch_size=self.config.extra_msa_stack_num_block,
global_config=global_config)

self.evoformer_iteration = EvoformerIteration(self.config.evoformer,
msa_act_dim=256,
pair_act_dim=128,
is_extra_msa=False,
batch_size=self.config.evoformer_num_block,
global_config=global_config)

self.structure_module = StructureModule(config.model.heads.structure_module,
self.config.seq_channel,
self.config.pair_channel,
global_config=global_config)

self.module_lddt = PredictedLDDTHead(config.model.heads.predicted_lddt,
global_config,
self.config.seq_channel)
self._init_tensor(global_config)

def _init_tensor(self, global_config):
"initialization of tensors and parameters"
self.chi_atom_indices = Tensor(get_chi_atom_indices(), mstype.int32)
chi_angles_mask = list(residue_constants.chi_angles_mask)
chi_angles_mask.append([0.0, 0.0, 0.0, 0.0])
self.chi_angles_mask = Tensor(chi_angles_mask, mstype.float32)
self.mirror_psi_mask = Tensor(np.asarray([1., 1., -1., 1., 1., 1., 1.])[None, None, :, None], mstype.float32)
self.chi_pi_periodic = Tensor(residue_constants.chi_pi_periodic, mstype.float32)

indices0 = np.arange(4).reshape((-1, 1, 1, 1, 1)).astype("int64") # 4 batch
indices0 = indices0.repeat(global_config.seq_length, axis=1) # seq_length sequence length
indices0 = indices0.repeat(4, axis=2) # 4 chis
self.indices0 = Tensor(indices0.repeat(4, axis=3)) # 4 atoms

indices1 = np.arange(global_config.seq_length).reshape((1, -1, 1, 1, 1)).astype("int64")
indices1 = indices1.repeat(4, axis=0)
indices1 = indices1.repeat(4, axis=2)
self.indices1 = Tensor(indices1.repeat(4, axis=3))

self.idx_extra_msa_stack = Parameter(Tensor(0, mstype.int32), requires_grad=False)
self.extra_msa_stack_num_block = self.config.extra_msa_stack_num_block

self.idx_evoformer_block = Parameter(Tensor(0, mstype.int32), requires_grad=False)
self.evoformer_num_block = Tensor(self.config.evoformer_num_block, mstype.int32)

def construct(self, target_feat, msa_feat, msa_mask, seq_mask, aatype,
template_aatype, template_all_atom_masks, template_all_atom_positions,
template_mask, template_pseudo_beta_mask, template_pseudo_beta,
_, extra_msa, extra_has_deletion,
extra_deletion_value, extra_msa_mask,
atom14_atom_exists, atom37_atom_exists, residue_index,
prev_pos, prev_msa_first_row, prev_pair):
"""construct"""

preprocess_1d = self.preprocess_1d(target_feat)
preprocess_msa = self.preprocess_msa(msa_feat)
msa_activations1 = mnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa

left_single = self.left_single(target_feat)
right_single = self.right_single(target_feat)

pair_activations = left_single[:, None] + right_single[None]
mask_2d = seq_mask[:, None] * seq_mask[None, :]

if self.recycle_pos:
prev_pseudo_beta = pseudo_beta_fn(aatype, prev_pos, None)
dgram = dgram_from_positions(prev_pseudo_beta, self.num_bins, self.min_bin, self.max_bin)
pair_activations += self.prev_pos_linear(dgram)
# return pair_activations, msa_activations1
prev_msa_first_row = F.depend(prev_msa_first_row, pair_activations)
if self.recycle_features:
prev_msa_first_row = self.prev_msa_first_row_norm(prev_msa_first_row)
msa_activations1 = mnp.concatenate(
(mnp.expand_dims(prev_msa_first_row + msa_activations1[0, ...], 0),
msa_activations1[1:, ...]), 0)
pair_activations += self.prev_pair_norm(prev_pair.astype(mstype.float32))

if self.max_relative_feature:
offset = residue_index[:, None] - residue_index[None, :]
rel_pos = self.one_hot(mnp.clip(offset + self.max_relative_feature, 0, 2 * self.max_relative_feature))
pair_activations += self.pair_activations(rel_pos)

template_pair_representation = 0
if self.template_enable:
template_pair_representation = self.template_embedding(pair_activations, template_aatype,
template_all_atom_masks, template_all_atom_positions,
template_mask, template_pseudo_beta_mask,
template_pseudo_beta, mask_2d)
pair_activations += template_pair_representation

msa_1hot = self.extra_msa_one_hot(extra_msa)
extra_msa_feat = mnp.concatenate((msa_1hot, extra_has_deletion[..., None], extra_deletion_value[..., None]),
axis=-1)
extra_msa_activations = self.extra_msa_activations(extra_msa_feat)
msa_act = extra_msa_activations
pair_act = pair_activations

msa_act = msa_act.astype(mstype.float32)
pair_act = pair_act.astype(mstype.float32)
extra_msa_mask = extra_msa_mask.astype(mstype.float32)
mask_2d = mask_2d.astype(mstype.float32)

self.idx_extra_msa_stack = 0
idx_extra_msa_stack_int = 0
while idx_extra_msa_stack_int < self.extra_msa_stack_num_block:
msa_act, pair_act = \
self.extra_msa_stack_iteration(msa_act, pair_act, extra_msa_mask, mask_2d, self.idx_extra_msa_stack)
self.idx_extra_msa_stack += 1
idx_extra_msa_stack_int += 1
msa_act = F.depend(msa_act, self.idx_extra_msa_stack)
pair_act = F.depend(pair_act, self.idx_extra_msa_stack)

msa_activations2 = None
if self.template_enabled and self.template_embed_torsion_angles:
num_templ, num_res = template_aatype.shape
aatype_one_hot = self.template_aatype_one_hot(template_aatype)
torsion_angles_sin_cos, alt_torsion_angles_sin_cos, torsion_angles_mask = atom37_to_torsion_angles(
template_aatype, template_all_atom_positions, template_all_atom_masks, self.chi_atom_indices,
self.chi_angles_mask, self.mirror_psi_mask, self.chi_pi_periodic, self.indices0, self.indices1)
template_features = mnp.concatenate([aatype_one_hot,
mnp.reshape(torsion_angles_sin_cos, [num_templ, num_res, 14]),
mnp.reshape(alt_torsion_angles_sin_cos, [num_templ, num_res, 14]),
torsion_angles_mask], axis=-1)
template_activations = self.template_single_embedding(template_features)
template_activations = self.relu(template_activations.astype(mstype.float32))
template_activations = self.template_projection(template_activations)
msa_activations2 = mnp.concatenate([msa_activations1, template_activations], axis=0)
torsion_angle_mask = torsion_angles_mask[:, :, 2]
torsion_angle_mask = torsion_angle_mask.astype(msa_mask.dtype)
msa_mask = mnp.concatenate([msa_mask, torsion_angle_mask], axis=0)

msa_activations2 = msa_activations2.astype(mstype.float16)
pair_activations = pair_act.astype(mstype.float16)
msa_mask = msa_mask.astype(mstype.float16)
mask_2d = mask_2d.astype(mstype.float16)
# return msa_activations2, pair_activations, msa_mask, mask_2d
self.idx_evoformer_block = self.idx_evoformer_block * 0
while self.idx_evoformer_block < self.evoformer_num_block:
msa_activations2, pair_activations = \
self.evoformer_iteration(msa_activations2,
pair_activations,
msa_mask,
mask_2d,
self.idx_evoformer_block)
self.idx_evoformer_block += 1

single_activations = self.single_activations(msa_activations2[0])
msa_first_row = msa_activations2[0]

# return single_activations, msa, msa_first_row
final_atom_positions, final_atom_mask, rp_structure_module = \
self.structure_module(single_activations,
pair_activations,
seq_mask,
aatype,
atom14_atom_exists,
atom37_atom_exists)

predicted_lddt_logits = self.module_lddt(rp_structure_module)

prev_pos = final_atom_positions.astype(mstype.float16)
prev_msa_first_row = msa_first_row.astype(mstype.float16)
prev_pair = pair_activations.astype(mstype.float16)

final_atom_positions = final_atom_positions.astype(mstype.float16)
final_atom_mask = final_atom_mask.astype(mstype.float16)

return final_atom_positions, final_atom_mask, predicted_lddt_logits, prev_pos, prev_msa_first_row, prev_pair

+ 443
- 0
reproduce/AlphaFold2-Chinese/module/structure_module.py View File

@@ -0,0 +1,443 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""structure module"""
import numpy as np
import mindspore.ops as ops
import mindspore.common.dtype as mstype
import mindspore.numpy as mnp
from mindspore import Parameter, ms_function, Tensor
from mindspore import nn
from commons import residue_constants
from commons.utils import generate_new_affine, to_tensor, from_tensor, vecs_to_tensor, atom14_to_atom37, \
get_exp_atom_pos, get_exp_frames, pre_compose, scale_translation, to_tensor_new, l2_normalize, \
torsion_angles_to_frames, frames_and_literature_positions_to_atom14_pos, apply_to_point, _invert_point

class InvariantPointAttention(nn.Cell):
"""Invariant Point attention module."""

def __init__(self, config, global_config, pair_dim):
"""Initialize.

Args:
config: Structure Module Config
global_config: Global Config of Model.
pair_dim: pair representation dimension.
"""

super().__init__()

self._dist_epsilon = 1e-8
self.config = config
self.num_head = config.num_head
self.num_scalar_qk = config.num_scalar_qk
self.num_scalar_v = config.num_scalar_v
self.num_point_v = config.num_point_v
self.num_point_qk = config.num_point_qk
self.num_channel = config.num_channel
self.projection_num = self.num_head * self.num_scalar_v + self.num_head * self.num_point_v * 4 +\
self.num_head * pair_dim

self.global_config = global_config
self.q_scalar = nn.Dense(config.num_channel, self.num_head*self.num_scalar_qk).to_float(mstype.float16)
self.kv_scalar = nn.Dense(config.num_channel, self.num_head*(self.num_scalar_qk + self.num_scalar_v)
).to_float(mstype.float16)
self.q_point_local = nn.Dense(config.num_channel, self.num_head * 3 * self.num_point_qk
).to_float(mstype.float16)
self.kv_point_local = nn.Dense(config.num_channel, self.num_head * 3 * (self.num_point_qk + self.num_point_v)
).to_float(mstype.float16)
self.soft_max = nn.Softmax()
self.soft_plus = ops.Softplus()
self.trainable_point_weights = Parameter(Tensor(np.ones((12,)), mstype.float32), name="trainable_point_weights")
self.attention_2d = nn.Dense(pair_dim, self.num_head).to_float(mstype.float16)
self.output_projection = nn.Dense(self.projection_num, self.num_channel, weight_init='zeros'
).to_float(mstype.float16)
self.scalar_weights = np.sqrt(1.0 / (3 * 16))
self.point_weights = np.sqrt(1.0 / (3 * 18))
self.attention_2d_weights = np.sqrt(1.0 / 3)

def construct(self, inputs_1d, inputs_2d, mask, rotation, translation):
"""Compute geometry-aware attention.

Args:
inputs_1d: (N, C) 1D input embedding that is the basis for the
scalar queries.
inputs_2d: (N, M, C') 2D input embedding, used for biases and values.
mask: (N, 1) mask to indicate which elements of inputs_1d participate
in the attention.
rotation: describe the orientation of every element in inputs_1d
translation: describe the position of every element in inputs_1d

Returns:
Transformation of the input embedding.
"""

num_residues, _ = inputs_1d.shape

# Improve readability by removing a large number of 'self's.
num_head = self.num_head
num_scalar_qk = self.num_scalar_qk
num_point_qk = self.num_point_qk
num_scalar_v = self.num_scalar_v
num_point_v = self.num_point_v

# Construct scalar queries of shape:
q_scalar = self.q_scalar(inputs_1d)
q_scalar = mnp.reshape(q_scalar, [num_residues, num_head, num_scalar_qk])

# Construct scalar keys/values of shape:
# [num_target_residues, num_head, num_points]
kv_scalar = self.kv_scalar(inputs_1d)
kv_scalar = mnp.reshape(kv_scalar, [num_residues, num_head, num_scalar_v + num_scalar_qk])
k_scalar, v_scalar = mnp.split(kv_scalar, [num_scalar_qk], axis=-1)

# Construct query points of shape:
# [num_residues, num_head, num_point_qk]
# First construct query points in local frame.
q_point_local = self.q_point_local(inputs_1d)
q_point_local = mnp.stack(mnp.split(q_point_local, 3, axis=-1), axis=0)

# Project query points into global frame.
q_point_global = apply_to_point(rotation, translation, q_point_local)

# Reshape query point for later use.
q_point0 = mnp.reshape(q_point_global[0], (num_residues, num_head, num_point_qk))
q_point1 = mnp.reshape(q_point_global[1], (num_residues, num_head, num_point_qk))
q_point2 = mnp.reshape(q_point_global[2], (num_residues, num_head, num_point_qk))

# Construct key and value points.
# Key points have shape [num_residues, num_head, num_point_qk]
# Value points have shape [num_residues, num_head, num_point_v]

# Construct key and value points in local frame.
kv_point_local = self.kv_point_local(inputs_1d)

kv_point_local = mnp.split(kv_point_local, 3, axis=-1)
# Project key and value points into global frame.
kv_point_global = apply_to_point(rotation, translation, kv_point_local)

kv_point_global0 = mnp.reshape(kv_point_global[0], (num_residues, num_head, (num_point_qk + num_point_v)))
kv_point_global1 = mnp.reshape(kv_point_global[1], (num_residues, num_head, (num_point_qk + num_point_v)))
kv_point_global2 = mnp.reshape(kv_point_global[2], (num_residues, num_head, (num_point_qk + num_point_v)))

# Split key and value points.
k_point0, v_point0 = mnp.split(kv_point_global0, [num_point_qk,], axis=-1)
k_point1, v_point1 = mnp.split(kv_point_global1, [num_point_qk,], axis=-1)
k_point2, v_point2 = mnp.split(kv_point_global2, [num_point_qk,], axis=-1)

trainable_point_weights = self.soft_plus(self.trainable_point_weights)
point_weights = self.point_weights * mnp.expand_dims(trainable_point_weights, axis=1)

v_point = [mnp.swapaxes(v_point0, -2, -3), mnp.swapaxes(v_point1, -2, -3), mnp.swapaxes(v_point2, -2, -3)]
q_point = [mnp.swapaxes(q_point0, -2, -3), mnp.swapaxes(q_point1, -2, -3), mnp.swapaxes(q_point2, -2, -3)]
k_point = [mnp.swapaxes(k_point0, -2, -3), mnp.swapaxes(k_point1, -2, -3), mnp.swapaxes(k_point2, -2, -3)]

dist2 = mnp.square(q_point[0][:, :, None, :] - k_point[0][:, None, :, :]) + \
mnp.square(q_point[1][:, :, None, :] - k_point[1][:, None, :, :]) + \
mnp.square(q_point[2][:, :, None, :] - k_point[2][:, None, :, :])

attn_qk_point = -0.5 * mnp.sum(
point_weights[:, None, None, :] * dist2, axis=-1)

v = mnp.swapaxes(v_scalar, -2, -3)
q = mnp.swapaxes(self.scalar_weights * q_scalar, -2, -3)
k = mnp.swapaxes(k_scalar, -2, -3)
attn_qk_scalar = ops.matmul(q, mnp.swapaxes(k, -2, -1))
attn_logits = attn_qk_scalar + attn_qk_point

attention_2d = self.attention_2d(inputs_2d)
attention_2d = mnp.transpose(attention_2d, [2, 0, 1])
attention_2d = self.attention_2d_weights * attention_2d

attn_logits += attention_2d

mask_2d = mask * mnp.swapaxes(mask, -1, -2)
attn_logits -= 1e5 * (1. - mask_2d)

# [num_head, num_query_residues, num_target_residues]
attn = self.soft_max(attn_logits)

# [num_head, num_query_residues, num_head * num_scalar_v]
result_scalar = ops.matmul(attn, v)

result_point_global = [mnp.swapaxes(mnp.sum(attn[:, :, :, None] * v_point[0][:, None, :, :], axis=-2), -2, -3),
mnp.swapaxes(mnp.sum(attn[:, :, :, None] * v_point[1][:, None, :, :], axis=-2), -2, -3),
mnp.swapaxes(mnp.sum(attn[:, :, :, None] * v_point[2][:, None, :, :], axis=-2), -2, -3)
]

result_point_global = [mnp.reshape(result_point_global[0], [num_residues, num_head * num_point_v]),
mnp.reshape(result_point_global[1], [num_residues, num_head * num_point_v]),
mnp.reshape(result_point_global[2], [num_residues, num_head * num_point_v])]
result_scalar = mnp.swapaxes(result_scalar, -2, -3)

result_scalar = mnp.reshape(result_scalar, [num_residues, num_head * num_scalar_v])

result_point_local = _invert_point(result_point_global, rotation, translation)

output_feature1 = result_scalar
output_feature20 = result_point_local[0]
output_feature21 = result_point_local[1]
output_feature22 = result_point_local[2]

output_feature3 = mnp.sqrt(self._dist_epsilon +
mnp.square(result_point_local[0]) +
mnp.square(result_point_local[1]) +
mnp.square(result_point_local[2]))

result_attention_over_2d = ops.matmul(mnp.swapaxes(attn, 0, 1), inputs_2d)
num_out = num_head * result_attention_over_2d.shape[-1]
output_feature4 = mnp.reshape(result_attention_over_2d, [num_residues, num_out])

final_act = mnp.concatenate([output_feature1, output_feature20, output_feature21,
output_feature22, output_feature3, output_feature4], axis=-1)
final_result = self.output_projection(final_act)
return final_result


class MultiRigidSidechain(nn.Cell):
"""Class to make side chain atoms."""

def __init__(self, config, global_config, single_repr_dim):
super().__init__()
self.config = config
self.global_config = global_config
self.input_projection = nn.Dense(single_repr_dim, config.num_channel, weight_init='normal'
).to_float(mstype.float16)
self.input_projection_1 = nn.Dense(single_repr_dim, config.num_channel, weight_init='normal'
).to_float(mstype.float16)
self.relu = nn.ReLU()
self.resblock1 = nn.Dense(config.num_channel, config.num_channel, weight_init='normal').to_float(mstype.float16)
self.resblock2 = nn.Dense(config.num_channel, config.num_channel, weight_init='zeros').to_float(mstype.float16)
self.resblock1_1 = nn.Dense(config.num_channel, config.num_channel, weight_init='normal'
).to_float(mstype.float16)
self.resblock2_1 = nn.Dense(config.num_channel, config.num_channel, weight_init='zeros'
).to_float(mstype.float16)
self.unnormalized_angles = nn.Dense(config.num_channel, 14, weight_init='normal').to_float(mstype.float16)
self.print = ops.Print()
self.restype_atom14_to_rigid_group = Tensor(residue_constants.restype_atom14_to_rigid_group)
self.restype_atom14_rigid_group_positions = Tensor(residue_constants.restype_atom14_rigid_group_positions)
self.restype_atom14_mask = Tensor(residue_constants.restype_atom14_mask)
self.restype_rigid_group_default_frame = Tensor(residue_constants.restype_rigid_group_default_frame)

def construct(self, rotation, translation, act, initial_act, aatype):
"""Predict side chains using rotation and translation representations.

Args:
rotation: The rotation matrices.
translation: A translation matrices.
act: updated pair activations from structure module
initial_act: initial act representations (input of structure module)
aatype: Amino acid type representations

Returns:
angles, positions and new frames
"""

act1 = self.input_projection(self.relu(act.astype(mstype.float32)))
init_act1 = self.input_projection_1(self.relu(initial_act.astype(mstype.float32)))
# Sum the activation list (equivalent to concat then Linear).
act = act1 + init_act1

# Mapping with some residual blocks.
# for _ in range(self.config.num_residual_block):
# resblock1
old_act = act
act = self.resblock1(self.relu(act.astype(mstype.float32)))
act = self.resblock2(self.relu(act.astype(mstype.float32)))
act += old_act
# resblock2
old_act = act
act = self.resblock1_1(self.relu(act.astype(mstype.float32)))
act = self.resblock2_1(self.relu(act.astype(mstype.float32)))
act += old_act

# Map activations to torsion angles. Shape: (num_res, 14).
num_res = act.shape[0]
unnormalized_angles = self.unnormalized_angles(self.relu(act.astype(mstype.float32)))

unnormalized_angles = mnp.reshape(unnormalized_angles, [num_res, 7, 2])

angles = l2_normalize(unnormalized_angles, axis=-1)

backb_to_global = [rotation[0][0], rotation[0][1], rotation[0][2],
rotation[1][0], rotation[1][1], rotation[1][2],
rotation[2][0], rotation[2][1], rotation[2][2],
translation[0], translation[1], translation[2]]

all_frames_to_global = torsion_angles_to_frames(aatype, backb_to_global, angles,
self.restype_rigid_group_default_frame)

pred_positions = frames_and_literature_positions_to_atom14_pos(aatype, all_frames_to_global,
self.restype_atom14_to_rigid_group,
self.restype_atom14_rigid_group_positions,
self.restype_atom14_mask)

atom_pos = pred_positions
frames = all_frames_to_global

return angles, unnormalized_angles, atom_pos, frames


class FoldIteration(nn.Cell):
"""A single iteration of the main structure module loop."""

def __init__(self, config, global_config, pair_dim, single_repr_dim):
super().__init__()
self.config = config
self.global_config = global_config
self.drop_out = nn.Dropout(keep_prob=0.9)
self.attention_layer_norm = nn.LayerNorm([config.num_channel,], epsilon=1e-5)
self.transition_layer_norm = nn.LayerNorm([config.num_channel,], epsilon=1e-5)
self.transition = nn.Dense(config.num_channel, config.num_channel, weight_init='normal'
).to_float(mstype.float16)
self.transition_1 = nn.Dense(config.num_channel, config.num_channel, weight_init='normal'
).to_float(mstype.float16)
self.transition_2 = nn.Dense(config.num_channel, config.num_channel, weight_init='normal'
).to_float(mstype.float16)
self.relu = nn.ReLU()
self.affine_update = nn.Dense(config.num_channel, 6, weight_init='zeros').to_float(mstype.float16)
self.attention_module = InvariantPointAttention(self.config, self.global_config, pair_dim)
self.mu_side_chain = MultiRigidSidechain(config.sidechain, global_config, single_repr_dim)
self.print = ops.Print()

def construct(self, act, static_feat_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype):
'''constuct'''
# Attention
attn = self.attention_module(act, static_feat_2d, sequence_mask, rotation, translation)
act += attn
act = self.drop_out(act)
act = self.attention_layer_norm(act.astype(mstype.float32))
# Transition
input_act = act
act = self.transition(act)
act = self.relu(act.astype(mstype.float32))
act = self.transition_1(act)
act = self.relu(act.astype(mstype.float32))
act = self.transition_2(act)

act += input_act
act = self.drop_out(act)
act = self.transition_layer_norm(act.astype(mstype.float32))

# This block corresponds to
# Jumper et al. (2021) Alg. 23 "Backbone update"
# Affine update
affine_update = self.affine_update(act)

quaternion, rotation, translation = pre_compose(quaternion, rotation, translation, affine_update)
_, rotation1, translation1 = scale_translation(quaternion, translation, rotation, 10.0)

angles_sin_cos, unnormalized_angles_sin_cos, atom_pos, frames =\
self.mu_side_chain(rotation1, translation1, act, initial_act, aatype)

affine_output = to_tensor_new(quaternion, translation)

return act, quaternion, translation, rotation, affine_output, angles_sin_cos, unnormalized_angles_sin_cos, \
atom_pos, frames


class StructureModule(nn.Cell):
"""StructureModule as a network head."""

def __init__(self, config, single_repr_dim, pair_dim, global_config=None, compute_loss=True):
super(StructureModule, self).__init__()
self.config = config
self.global_config = global_config
self.compute_loss = compute_loss
self.fold_iteration = FoldIteration(self.config, global_config, pair_dim, single_repr_dim)
self.single_layer_norm = nn.LayerNorm([single_repr_dim,], epsilon=1e-5)
self.initial_projection = nn.Dense(single_repr_dim, self.config.num_channel).to_float(mstype.float16)
self.pair_layer_norm = nn.LayerNorm([pair_dim,], epsilon=1e-5)
self.num_layer = config.num_layer
self.indice0 = Tensor(
np.arange(global_config.seq_length).reshape((-1, 1, 1)).repeat(37, axis=1).astype("int32"))

@ms_function
def construct(self, single, pair, seq_mask, aatype, residx_atom37_to_atom14=None, atom37_atom_exists=None):
'''construct'''
sequence_mask = seq_mask[:, None]
act = self.single_layer_norm(single.astype(mstype.float32))
initial_act = act
act = self.initial_projection(act)
quaternion, rotation, translation = generate_new_affine(sequence_mask)
aff_to_tensor = to_tensor(quaternion, mnp.transpose(translation))
act_2d = self.pair_layer_norm(pair.astype(mstype.float32))
# folder iteration
quaternion, rotation, translation = from_tensor(aff_to_tensor)

act_new, atom_pos, _, _, _, _ =\
self.iteration_operation(act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype)
atom14_pred_positions = vecs_to_tensor(atom_pos)[-1]

atom37_pred_positions = atom14_to_atom37(atom14_pred_positions,
residx_atom37_to_atom14,
atom37_atom_exists,
self.indice0)

final_atom_positions = atom37_pred_positions
final_atom_mask = atom37_atom_exists
rp_structure_module = act_new
return final_atom_positions, final_atom_mask, rp_structure_module

def iteration_operation(self, act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act,
aatype):
'''iteration operation'''
affine_init = ()
angles_sin_cos_init = ()
um_angles_sin_cos_init = ()
atom_pos = ()
frames = ()

for _ in range(self.num_layer):
act, quaternion, translation, rotation, affine_output, angles_sin_cos, unnormalized_angles_sin_cos, \
atom_pos, frames = \
self.fold_iteration(act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype)
affine_init = affine_init + (affine_output[None, ...],)
angles_sin_cos_init = angles_sin_cos_init + (angles_sin_cos[None, ...],)
um_angles_sin_cos_init = um_angles_sin_cos_init + (unnormalized_angles_sin_cos[None, ...],)
atom_pos = get_exp_atom_pos(atom_pos)
frames = get_exp_frames(frames)
affine_output_new = mnp.concatenate(affine_init, axis=0)
angles_sin_cos_new = mnp.concatenate(angles_sin_cos_init, axis=0)
um_angles_sin_cos_new = mnp.concatenate(um_angles_sin_cos_init, axis=0)

return act, atom_pos, affine_output_new, angles_sin_cos_new, um_angles_sin_cos_new, frames


class PredictedLDDTHead(nn.Cell):
"""Head to predict the per-residue LDDT to be used as a confidence measure."""
def __init__(self, config, global_config, seq_channel):
super().__init__()
self.config = config
self.global_config = global_config
self.input_layer_norm = nn.LayerNorm([seq_channel,], epsilon=1e-5)
self.act_0 = nn.Dense(seq_channel, self.config.num_channels, weight_init='zeros'
).to_float(mstype.float16)
self.act_1 = nn.Dense(self.config.num_channels, self.config.num_channels, weight_init='zeros'
).to_float(mstype.float16)
self.logits = nn.Dense(self.config.num_channels, self.config.num_bins, weight_init='zeros'
).to_float(mstype.float16)
self.relu = nn.ReLU()

def construct(self, rp_structure_module):
"""Builds ExperimentallyResolvedHead module."""
act = rp_structure_module
act = self.input_layer_norm(act.astype(mstype.float32))
act = self.act_0(act)
act = self.relu(act.astype(mstype.float32))
act = self.act_1(act)
act = self.relu(act.astype(mstype.float32))
logits = self.logits(act)
return logits

+ 5
- 0
reproduce/AlphaFold2-Chinese/requirements.txt View File

@@ -0,0 +1,5 @@
absl-py==0.13.0
biopython=1.79
ml-collections==0.1.0
scipy==1.7.0
numpy==1.19.5

+ 34
- 0
reproduce/AlphaFold2-Chinese/serving/fold_service/config.py View File

@@ -0,0 +1,34 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""config for serving mode"""

import ml_collections

config = ml_collections.ConfigDict({
"seq_length": 256,
"device_id": 0,
"port": 5500,
"ckpt_path": "/CHECKPOINT_PATH",
"input_fasta_path": "INPUT_FASTA_PATH",
"msa_result_path": "MSA_RESULT_PATH",
"database_dir": "DATABASE_DIR",
"database_envdb_dir": "DATABASE_ENVDB_DIR",
"hhsearch_binary_path": "HHSEARCH_BINARY_PATH",
"pdb70_database_path": 'PDB&)_DATABASE_PATH',
"template_mmcif_dir": 'TEMPLATE_MMCIF_DIR',
"max_template_date": "MAX_TEMPLATE_DATE",
"kalign_binary_path": 'KALIGN_BINARY_PATH',
"obsolete_pdbs_path": 'OBSOLETE_PDBS_PATH',
})

+ 104
- 0
reproduce/AlphaFold2-Chinese/serving/fold_service/servable_config.py View File

@@ -0,0 +1,104 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""serving config for mindspore serving"""

import time
import os
import json
import numpy as np

from mindspore_serving.server import register
import mindspore.context as context
from mindspore.common.tensor import Tensor
from mindspore import load_checkpoint

from data.feature.feature_extraction import process_features
from data.tools.data_process import data_process
from commons.utils import compute_confidence
from commons.generate_pdb import to_pdb, from_prediction
from model import AlphaFold
from config import config, global_config
from fold_service.config import config as serving_config

context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend",
variable_memory_max_size="31GB",
device_id=serving_config.device_id,
save_graphs=False)
model_name = "model_1"
model_config = config.model_config(model_name)
num_recycle = model_config.model.num_recycle
global_config = global_config.global_config(serving_config.seq_length)
extra_msa_length = global_config.extra_msa_length

fold_net = AlphaFold(model_config, global_config)
load_checkpoint(serving_config.ckpt_path, fold_net)

def fold_model(input_fasta_path):
"""defining fold model"""

seq_files = os.listdir(input_fasta_path)
for seq_file in seq_files:
print(seq_file)
t1 = time.time()
seq_name = seq_file.split('.')[0]

input_features = data_process(seq_name, serving_config)
tensors, aatype, residue_index, ori_res_length = process_features(
raw_features=input_features, config=model_config, global_config=global_config)
prev_pos = Tensor(np.zeros([global_config.seq_length, 37, 3]).astype(np.float16))
prev_msa_first_row = Tensor(np.zeros([global_config.seq_length, 256]).astype(np.float16))
prev_pair = Tensor(np.zeros([global_config.seq_length, global_config.seq_length, 128]).astype(np.float16))

t2 = time.time()
for i in range(num_recycle+1):
tensors_i = [tensor[i] for tensor in tensors]
input_feats = [Tensor(tensor) for tensor in tensors_i]
final_atom_positions, final_atom_mask, predicted_lddt_logits,\
prev_pos, prev_msa_first_row, prev_pair = fold_net(*input_feats,
prev_pos,
prev_msa_first_row,
prev_pair)

t3 = time.time()

final_atom_positions = final_atom_positions.asnumpy()[:ori_res_length]
final_atom_mask = final_atom_mask.asnumpy()[:ori_res_length]
predicted_lddt_logits = predicted_lddt_logits.asnumpy()[:ori_res_length]

confidence = compute_confidence(predicted_lddt_logits)
unrelaxed_protein = from_prediction(final_atom_mask, aatype[0], final_atom_positions, residue_index[0])
pdb_file = to_pdb(unrelaxed_protein)

seq_length = aatype.shape[-1]
os.makedirs(f'./result/seq_{seq_name}_{seq_length}', exist_ok=True)

with open(os.path.join(f'./result/seq_{seq_name}_{seq_length}/', f'unrelaxed_model_{seq_name}.pdb'), 'w') as f:
f.write(pdb_file)
t4 = time.time()
timings = {"pre_process_time": round(t2 - t1, 2),
"model_time": round(t3 - t2, 2),
"pos_process_time": round(t4 - t3, 2),
"all_time": round(t4 - t1, 2),
"confidence": confidence}
print(timings)
with open(f'./result/seq_{seq_name}_{seq_length}/timings', 'w') as f:
f.write(json.dumps(timings))
return True

@register.register_method(output_names=["res"])
def folding(input_fasta_path):
res = register.add_stage(fold_model, input_fasta_path, outputs_count=1)
return res

+ 32
- 0
reproduce/AlphaFold2-Chinese/serving/serving_client.py View File

@@ -0,0 +1,32 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""client script for serving mode"""

import time

from mindspore_serving.client import Client
from fold_service.config import config as serving_config

if __name__ == "__main__":

client = Client("127.0.0.1:" + str(serving_config.port), "fold_service", "folding")
instances = [{"input_fasta_path": serving_config.input_fasta_path}]

print("inferring...")
t1 = time.time()
result = client.infer(instances)
t2 = time.time()
print("finish inferring! Time costed:", t2 - t1)
print(result)

+ 31
- 0
reproduce/AlphaFold2-Chinese/serving/serving_server.py View File

@@ -0,0 +1,31 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""server script for serving mode"""

import os
from mindspore_serving import server
from fold_service.config import config as serving_config

def start():
servable_dir = os.path.dirname(os.path.realpath(__file__))

servable_config = server.ServableStartConfig(servable_directory=servable_dir, servable_name="fold_service",
device_ids=serving_config.device_id)
server.start_servables(servable_configs=servable_config)

server.start_grpc_server(address="127.0.0.1:" + str(serving_config.port))

if __name__ == "__main__":
start()

+ 14
- 0
reproduce/AlphaFold2-Chinese/tests/st/__init__.py View File

@@ -0,0 +1,14 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

+ 15
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/__init__.py View File

@@ -0,0 +1,15 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""init"""

+ 85
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/architecture/test_activation.py View File

@@ -0,0 +1,85 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test Activations """
import pytest
import numpy as np

from mindspore import nn
from mindspore import Tensor
from mindspore import context
from mindelec.architecture import get_activation

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")


class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.srelu = get_activation("srelu")

def construct(self, x):
return self.srelu(x)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_srelu():
"""test srelu activation"""
net = Net()
input_tensor = Tensor(np.array([[1.2, 0.1], [0.2, 3.2]], dtype=np.float32))
output = net(input_tensor)
print(input_tensor.asnumpy())
print(output.asnumpy())


class Net1(nn.Cell):
"""net"""
def __init__(self):
super(Net1, self).__init__()
self.sin = get_activation("sin")

def construct(self, x):
return self.sin(x)

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_sin():
"""test sin activation"""
net = Net1()
input_tensor = Tensor(np.array([[1.2, 0.1], [0.2, 3.2]], dtype=np.float32))
output = net(input_tensor)
print(input_tensor.asnumpy())
print(output.asnumpy())

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_activation_type_error():
with pytest.raises(TypeError):
get_activation(1)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_get_activation():
activation = get_activation("softshrink")
assert isinstance(activation, nn.Cell)

+ 224
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/architecture/test_block.py View File

@@ -0,0 +1,224 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test block """
import pytest
import numpy as np

from mindspore import nn, context
from mindspore import Tensor, Parameter

from mindelec.architecture import LinearBlock, ResBlock
from mindelec.architecture import InputScaleNet, FCSequential, MultiScaleFCCell

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")


class Net(nn.Cell):
""" Net definition """
def __init__(self,
input_channels,
output_channels,
weight='normal',
bias='zeros',
has_bias=True):
super(Net, self).__init__()
self.fc = LinearBlock(input_channels, output_channels, weight, bias, has_bias)

def construct(self, input_x):
return self.fc(input_x)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_linear():
"""test linear block"""
weight = Tensor(np.random.randint(0, 255, [8, 64]).astype(np.float32))
bias = Tensor(np.random.randint(0, 255, [8]).astype(np.float32))
net = Net(64, 8, weight=weight, bias=bias)
input_data = Tensor(np.random.randint(0, 255, [128, 64]).astype(np.float32))
output = net(input_data)
print(output.asnumpy())


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_linear_nobias():
"""test linear block with no bias"""
weight = Tensor(np.random.randint(0, 255, [8, 64]).astype(np.float32))
net = Net(64, 8, weight=weight, has_bias=False)
input_data = Tensor(np.random.randint(0, 255, [128, 64]).astype(np.float32))
output = net(input_data)
print(output.asnumpy())


class Net1(nn.Cell):
""" Net definition """
def __init__(self,
input_channels,
output_channels,
weight='normal',
bias='zeros',
has_bias=True,
activation=None):
super(Net1, self).__init__()
self.fc = ResBlock(input_channels, output_channels, weight, bias, has_bias, activation)

def construct(self, input_x):
return self.fc(input_x)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_res():
"""test res block"""
weight = Tensor(np.random.randint(0, 255, [8, 8]).astype(np.float32))
bias = Tensor(np.random.randint(0, 255, [8]).astype(np.float32))
net = Net1(8, 8, weight=weight, bias=bias)
input_data = Tensor(np.random.randint(0, 255, [128, 8]).astype(np.float32))
output = net(input_data)
print(output.asnumpy())


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_res_nobias():
"""test res block with no bias"""
weight = Tensor(np.random.randint(0, 255, [8, 8]).astype(np.float32))
net = Net1(8, 8, weight=weight, has_bias=False)
input_data = Tensor(np.random.randint(0, 255, [128, 8]).astype(np.float32))
output = net(input_data)
print(output.asnumpy())


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_res_activation():
"""test res block with activation"""
weight = Tensor(np.random.randint(0, 255, [8, 8]).astype(np.float32))
bias = Tensor(np.random.randint(0, 255, [8]).astype(np.float32))
net = Net1(8, 8, weight=weight, bias=bias, activation='sin')
input_data = Tensor(np.random.randint(0, 255, [128, 8]).astype(np.float32))
output = net(input_data)
print(output.asnumpy())


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_res_channel_error():
with pytest.raises(ValueError):
ResBlock(3, 6)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_input_scale():
"""test input scale cell"""
inputs = np.random.uniform(size=(16, 3)) + 3.0
inputs = Tensor(inputs.astype(np.float32))
input_scale = [1.0, 2.0, 4.0]
input_center = [3.5, 3.5, 3.5]
net = InputScaleNet(input_scale, input_center)
output = net(inputs).asnumpy()

assert np.all(output[:, 0] <= 0.5) and np.all(output[:, 0] >= -0.5)
assert np.all(output[:, 1] <= 1.0) and np.all(output[:, 0] >= -1.0)
assert np.all(output[:, 2] <= 2.0) and np.all(output[:, 0] >= -2.0)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fc_sequential():
"""test fc sequential cell"""
inputs = np.ones((16, 3))
inputs = Tensor(inputs.astype(np.float32))
net = FCSequential(3, 3, 5, 32, weight_init="ones", bias_init="zeros")
output = net(inputs).asnumpy()
target = np.ones((16, 3)) * -31.998459
assert np.allclose(output, target, rtol=5e-2)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_mulscale_without_latent():
"""test multi-scale net without latent vector"""
inputs = np.ones((16, 3)) + 3.0
inputs = Tensor(inputs.astype(np.float32))
input_scale = [1.0, 2.0, 4.0]
input_center = [3.5, 3.5, 3.5]
net = MultiScaleFCCell(3, 3, 5, 32,
weight_init="ones", bias_init="zeros",
input_scale=input_scale, input_center=input_center)
output = net(inputs).asnumpy()
target = np.ones((16, 3)) * -61.669254
assert np.allclose(output, target, rtol=5e-2)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_mulscale_with_latent():
"""test multi-scale net with latent vector and input scale"""
inputs = np.ones((64, 3)) + 3.0
inputs = Tensor(inputs.astype(np.float32))
num_scenarios = 4
latent_size = 16
latent_init = np.ones((num_scenarios, latent_size)).astype(np.float32)
latent_vector = Parameter(Tensor(latent_init), requires_grad=True)
input_scale = [1.0, 2.0, 4.0]
input_center = [3.5, 3.5, 3.5]
net = MultiScaleFCCell(3, 3, 5, 32,
weight_init="ones", bias_init="zeros",
input_scale=input_scale, input_center=input_center, latent_vector=latent_vector)
output = net(inputs).asnumpy()
target = np.ones((64, 3)) * -57.8849
assert np.allclose(output, target, rtol=5e-2)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_mulscale_with_latent_noscale():
"""test multi-scale net with latent vector"""
inputs = np.ones((64, 3))
inputs = Tensor(inputs.astype(np.float32))
num_scenarios = 4
latent_size = 16
latent_init = np.ones((num_scenarios, latent_size)).astype(np.float32)
latent_vector = Parameter(Tensor(latent_init), requires_grad=True)
net = MultiScaleFCCell(3, 3, 5, 32,
weight_init="ones", bias_init="zeros", latent_vector=latent_vector)
output = net(inputs).asnumpy()
target = np.ones((64, 3)) * -105.62799
assert np.allclose(output, target, rtol=5e-2)

+ 43
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/architecture/test_mlt.py View File

@@ -0,0 +1,43 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test block """
import pytest
import numpy as np

from mindspore import context
from mindspore import Tensor
from mindelec.architecture import MTLWeightedLossCell

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_linear():
net = MTLWeightedLossCell(num_losses=2)
input_data = Tensor(np.array([1.0, 1.0]).astype(np.float32))
output = net(input_data)
print(output.asnumpy())


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_mlt_num_losses_error():
with pytest.raises(TypeError):
MTLWeightedLossCell('a')

+ 92
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/common/test_lr_scheduler.py View File

@@ -0,0 +1,92 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test metrics """
import pytest
from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.common import dtype as mstype
from mindelec.common import LearningRate, get_poly_lr

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_learning_rate():
"""test LearningRate"""
learning_rate = LearningRate(0.1, 0.001, 0, 10, 0.5)
res = learning_rate(Tensor(10000, mstype.int32))
assert res == 0.001


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_learning_rate_power_value_error():
with pytest.raises(ValueError):
LearningRate(0.1, 0.001, 0, 10, -0.5)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_learning_rate_warmup_steps_type_error():
"""test TypeError cases"""
with pytest.raises(TypeError):
LearningRate(0.1, 0.001, 0.1, 10, 0.5)
with pytest.raises(ValueError):
LearningRate(0.0, 0.001, 0, 10, 0.5)
with pytest.raises(ValueError):
LearningRate(0.1, -0.001, 0, 10, 0.5)
with pytest.raises(ValueError):
LearningRate(0.1, 0.001, 0, 0, 0.5)
with pytest.raises(ValueError):
LearningRate(0.1, 0.001, -1, 10, 0.5)
with pytest.raises(ValueError):
LearningRate(0.1, 0.001, 0, -10, 0.5)
with pytest.raises(ValueError):
LearningRate(0.1, 0.001, 1, -10, 0.5)

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_get_poly_lr():
"""test get_poly_lr"""
res = get_poly_lr(100, 0.001, 0.1, 0.0001, 1000, 10000, 0.5)
assert res.shape == (9900,)
with pytest.raises(ValueError):
get_poly_lr(-1, 0.001, 0.1, 0.0001, 1000, 10000, 0.5)
with pytest.raises(ValueError):
get_poly_lr(100, 0.0, 0.1, 0.0001, 1000, 10000, 0.5)
with pytest.raises(ValueError):
get_poly_lr(100, 0.001, 0.1, 0.0, 1000, 10000, 0.5)
with pytest.raises(ValueError):
get_poly_lr(100, 0.001, 0.1, 0.0001, 1000, 0, 0.5)
with pytest.raises(ValueError):
get_poly_lr(100, 0.001, 0.1, 0.0001, 1000, 10000, -0.5)

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_get_poly_lr1():
"""test get_poly_lr"""
res = get_poly_lr(100, 0.001, 0.1, 0.0001, 0, 10000, 0.5)
assert res.shape == (9900,)

+ 39
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/common/test_metrics.py View File

@@ -0,0 +1,39 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test metrics """
import pytest
import numpy as np

import mindspore
from mindspore import Tensor
from mindspore import context
from mindelec.common import L2

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_l2():
"""test l2"""
x = Tensor(np.array([0.1, 0.2, 0.6, 0.9]), mindspore.float32)
y = Tensor(np.array([0.1, 0.25, 0.7, 0.9]), mindspore.float32)
metric = L2()
metric.clear()
metric.update(x, y)
result = metric.eval()
assert result == 0.09543302997807275

+ 82
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/config.py View File

@@ -0,0 +1,82 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
sampling and dataset settings
"""

from easydict import EasyDict as edict

ds_config = edict({
'train': edict({
'batch_size': 100,
'shuffle': True,
'drop_remainder': True,
}),
'eval': edict({
'batch_size': 100,
'shuffle': False,
'drop_remainder': False,
}),
})

src_sampling_config = edict({
'domain': edict({
'random_sampling': True,
'size': 400,
'sampler': 'uniform'
}),
'IC': edict({
'random_sampling': True,
'size': 100,
'sampler': 'uniform',
}),
'time': edict({
'random_sampling': True,
'size': 10,
'sampler': 'uniform',
}),
})

no_src_sampling_config = edict({
'domain': edict({
'random_sampling': True,
'size': 100,
'sampler': 'uniform'
}),
'IC': edict({
'random_sampling': True,
'size': 100,
'sampler': 'uniform',
}),
'time': edict({
'random_sampling': True,
'size': 10,
'sampler': 'uniform',
}),
})

bc_sampling_config = edict({
'BC': edict({
'random_sampling': True,
'size': 100,
'sampler': 'uniform',
'with_normal': True
}),
'time': edict({
'random_sampling': True,
'size': 10,
'sampler': 'uniform',
}),
})

+ 110
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_boundary.py View File

@@ -0,0 +1,110 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test geometry module: geometry with time cases"""

import copy
import pytest
from easydict import EasyDict as edict

from mindelec.geometry import create_config_from_edict
from mindelec.geometry import Rectangle, GeometryWithTime, TimeDomain
from mindelec.data import Boundary, BoundaryBC, BoundaryIC

reset_geom_time_config = edict({
'domain': edict({
'random_sampling': False,
'size': [10, 20],
}),
'BC': edict({
'random_sampling': True,
'size': 10,
'with_normal': True,
}),
'IC': edict({
'random_sampling': True,
'size': 10,
}),
'time': edict({
'random_sampling': False,
'size': 10,
})
})

reset_geom_time_config2 = edict({
'domain': edict({
'random_sampling': False,
'size': [10, 20],
}),
'BC': edict({
'random_sampling': False,
'size': 10,
'with_normal': False,
}),
'IC': edict({
'random_sampling': False,
'size': 10,
}),
'time': edict({
'random_sampling': False,
'size': 10,
})
})

def check_rect_with_time_set_config(config):
"""check_rect_with_time_set_config"""
rect = Rectangle("rect", [-1.0, -0.5], [1.0, 0.5])
time = TimeDomain("time", 0.0, 1.0)
rect_with_time = GeometryWithTime(rect, time)

with pytest.raises(TypeError):
Boundary(0)
with pytest.raises(ValueError):
Boundary(rect_with_time)

config1 = copy.deepcopy(config)
config1.pop('BC')
rect_with_time.set_sampling_config(create_config_from_edict(config1))
with pytest.raises(ValueError):
BoundaryBC(rect_with_time)

config2 = copy.deepcopy(config)
config2.pop('IC')
rect_with_time.set_sampling_config(create_config_from_edict(config2))
with pytest.raises(ValueError):
BoundaryIC(rect_with_time)

rect_with_time.set_sampling_config(create_config_from_edict(config))

bc = BoundaryBC(rect_with_time)
for i in range(20):
print(bc[i])

config3 = copy.deepcopy(config)
if not config3.IC.random_sampling:
config3.IC.size = [4, 4]
rect_with_time.set_sampling_config(create_config_from_edict(config3))
ic = BoundaryIC(rect_with_time)
for i in range(20):
print(ic[i])


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_rect_with_time_set_config():
"""test_check_rect_with_time_set_config"""
check_rect_with_time_set_config(reset_geom_time_config)
check_rect_with_time_set_config(reset_geom_time_config2)

+ 195
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_data_base.py View File

@@ -0,0 +1,195 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test data_base."""
import pytest
import numpy as np
from mindelec.data import ExistedDataConfig
from mindelec.data.data_base import Data


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_existed_data_config_name_error():
with pytest.raises(TypeError):
ExistedDataConfig(1, "/home/a.npy", "data")


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_existed_data_config_data_dir_error():
with pytest.raises(TypeError):
ExistedDataConfig("a", 1, "data")


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_existed_data_config_column_list_error():
with pytest.raises(TypeError):
input_path = "./input_data.npy"
input_data = np.random.randn(10, 3)
np.save(input_path, input_data)
ExistedDataConfig("a", input_path, 1)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_existed_data_config_constraint_type_error():
with pytest.raises(TypeError):
input_path = "./input_data.npy"
input_data = np.random.randn(10, 3)
np.save(input_path, input_data)
ExistedDataConfig("a", input_path, "data", contraint_type="a")


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_existed_data_config_data_format_typeerror():
with pytest.raises(TypeError):
input_path = "./input_data.npy"
input_data = np.random.randn(10, 3)
np.save(input_path, input_data)
ExistedDataConfig("a", input_path, "data", 1)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_existed_data_config_data_format_valueerror():
with pytest.raises(ValueError):
input_path = "./input_data.npy"
input_data = np.random.randn(10, 3)
np.save(input_path, input_data)
ExistedDataConfig("a", input_path, "data", "csv")


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_existed_data_config_data_constraint_type_valueerror():
with pytest.raises(TypeError):
input_path = "./input_data.npy"
input_data = np.random.randn(10, 3)
np.save(input_path, input_data)
ExistedDataConfig("a", input_path, "data", constraint_type='1')

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_existed_data_config_data_constraint_type_valueerror1():
with pytest.raises(TypeError):
input_path = "./input_data.npy"
input_data = np.random.randn(10, 3)
np.save(input_path, input_data)
ExistedDataConfig("a", input_path, "data", constraint_type='test')

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_existed_data_config_random_merge_error():
with pytest.raises(TypeError):
input_path = "./input_data.npy"
input_data = np.random.randn(10, 3)
np.save(input_path, input_data)
ExistedDataConfig("a", input_path, "data", random_merge=1)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_data_name_type_error():
with pytest.raises(TypeError):
Data(1)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_data_columns_list_type_error():
with pytest.raises(TypeError):
Data(columns_list=1)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_data_constraint_type_type_error():
with pytest.raises(TypeError):
Data(constraint_type=1)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_data_constraint_type_type_error1():
with pytest.raises(TypeError):
Data(constraint_type="labe")


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_data_set_constraint_type_type_error():
with pytest.raises(TypeError):
data = Data()
data.set_constraint_type("test")


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_data_create_dataset_nie_error():
with pytest.raises(NotImplementedError):
data = Data()
data.create_dataset()


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_data_get_item_nie_error():
with pytest.raises(NotImplementedError):
data = Data()
print(data[0])


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_data_len_nie_error():
with pytest.raises(NotImplementedError):
data = Data()
len(data)

+ 112
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_dataset.py View File

@@ -0,0 +1,112 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test dataset module"""
import pytest
from easydict import EasyDict as edict
from mindelec.data import Dataset
from mindelec.geometry import create_config_from_edict
from mindelec.geometry import Disk, Rectangle, TimeDomain, GeometryWithTime
from mindelec.data import BoundaryBC, BoundaryIC
from config import ds_config, src_sampling_config, no_src_sampling_config, bc_sampling_config

ic_bc_config = edict({
'domain': edict({
'random_sampling': False,
'size': [10, 20],
}),
'BC': edict({
'random_sampling': True,
'size': 10,
'with_normal': True,
}),
'IC': edict({
'random_sampling': True,
'size': 10,
}),
'time': edict({
'random_sampling': False,
'size': 10,
})
})

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dataset_allnone():
with pytest.raises(ValueError):
Dataset()

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dataset():
"""test dataset"""
disk = Disk("src", (0.0, 0.0), 0.2)
rectangle = Rectangle("rect", (-1, -1), (1, 1))
diff = rectangle - disk
time = TimeDomain("time", 0.0, 1.0)

# check datalist
rect_with_time = GeometryWithTime(rectangle, time)
rect_with_time.set_sampling_config(create_config_from_edict(ic_bc_config))
bc = BoundaryBC(rect_with_time)
ic = BoundaryIC(rect_with_time)
dataset = Dataset(dataset_list=bc)

dataset.set_constraint_type("Equation")

c_type1 = {bc: "Equation", ic: "Equation"}
with pytest.raises(ValueError):
dataset.set_constraint_type(c_type1)

no_src_region = GeometryWithTime(diff, time)
no_src_region.set_name("no_src")
no_src_region.set_sampling_config(create_config_from_edict(no_src_sampling_config))
src_region = GeometryWithTime(disk, time)
src_region.set_name("src")
src_region.set_sampling_config(create_config_from_edict(src_sampling_config))
boundary = GeometryWithTime(rectangle, time)
boundary.set_name("bc")
boundary.set_sampling_config(create_config_from_edict(bc_sampling_config))

geom_dict = ['1', '2']
with pytest.raises(TypeError):
Dataset(geom_dict)

geom_dict = {src_region: ["test"]}
with pytest.raises(KeyError):
Dataset(geom_dict)

geom_dict = {src_region: ["domain", "IC"],
no_src_region: ["domain", "IC"],
boundary: ["BC"]}
dataset = Dataset(geom_dict)

with pytest.raises(ValueError):
print(dataset[0])

with pytest.raises(ValueError):
len(dataset)

with pytest.raises(ValueError):
dataset.get_columns_list()

with pytest.raises(ValueError):
dataset.create_dataset(batch_size=ds_config.train.batch_size,
shuffle=ds_config.train.shuffle,
prebatched_data=True,
drop_remainder=False)

+ 77
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_equation.py View File

@@ -0,0 +1,77 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test geometry module: geometry with time cases"""

import copy
import pytest
from easydict import EasyDict as edict

from mindelec.geometry import create_config_from_edict
from mindelec.geometry import Rectangle, GeometryWithTime, TimeDomain
from mindelec.data import Equation

reset_geom_time_config = edict({
'domain': edict({
'random_sampling': False,
'size': [5, 2],
}),
'time': edict({
'random_sampling': False,
'size': 10,
})
})

reset_geom_time_config2 = edict({
'domain': edict({
'random_sampling': True,
'size': 10,
}),
'time': edict({
'random_sampling': False,
'size': 10,
})
})

def check_rect_with_time_set_config(config):
"""check_rect_with_time_set_config"""
rect = Rectangle("rect", [-1.0, -0.5], [1.0, 0.5])
time = TimeDomain("time", 0.0, 1.0)
rect_with_time = GeometryWithTime(rect, time)

with pytest.raises(TypeError):
Equation(0)
with pytest.raises(ValueError):
Equation(rect_with_time)

config1 = copy.deepcopy(config)
config1.pop('domain')
rect_with_time.set_sampling_config(create_config_from_edict(config1))
with pytest.raises(KeyError):
Equation(rect_with_time)

rect_with_time.set_sampling_config(create_config_from_edict(config))
eq = Equation(rect_with_time)
for i in range(20):
print(eq[i])


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_rect_with_time_set_config():
"""test_check_rect_with_time_set_config"""
check_rect_with_time_set_config(reset_geom_time_config)
check_rect_with_time_set_config(reset_geom_time_config2)

+ 77
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_existed_data.py View File

@@ -0,0 +1,77 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test existed_data."""
import pytest
import numpy as np
from mindelec.data import ExistedDataset, ExistedDataConfig


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_existed_data_value_error():
with pytest.raises(ValueError):
ExistedDataset()

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_existed_data_cnfig_type_error():
with pytest.raises(TypeError):
ExistedDataset(data_config=1)

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_existed_config():
"""test various errors"""
with pytest.raises(ValueError):
input_path = "./input_data.npy"
label_path = "./label.npy"
input_data = np.random.randn(10, 3)
output = np.random.randn(10, 3)
np.save(input_path, input_data)
np.save(label_path, output)

config = ExistedDataConfig(name="existed_data",
data_dir=[input_path, label_path],
columns_list=["inputs", "label"],
constraint_type="Equation",
data_format="npz")
dataset = ExistedDataset(data_config=config)

input_path = "./input_data.npy"
input_data = np.random.randn(10, 3)
np.save(input_path, input_data)

dataset = ExistedDataset(name="existed_data",
data_dir=[input_path],
columns_list=["inputs"],
constraint_type="Equation",
data_format="npy")
for i in range(20):
print(dataset[i])

dataset = ExistedDataset(name="existed_data",
data_dir=[input_path],
columns_list=["inputs"],
constraint_type="Equation",
data_format="npy",
random_merge=False)
for i in range(20):
print(dataset[i])

+ 84
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_src_td.py View File

@@ -0,0 +1,84 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test dataset module, call CSG ant GeometryWithTime"""
import pytest
import numpy as np
from easydict import EasyDict as edict

from mindelec.data import Dataset
from mindelec.geometry import Rectangle, TimeDomain, GeometryWithTime, create_config_from_edict


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_dataset():
"""check dataset"""
sampling_config = edict({
'domain': edict({
'random_sampling': False,
'size': [10, 10],
}),
'BC': edict({
'random_sampling': False,
'size': 100,
'with_normal': True,
}),
'IC': edict({
'random_sampling': False,
'size': [10, 10],
}),
'time': edict({
'random_sampling': False,
'size': 40,
}),
})

rect_space = Rectangle("rectangle", coord_min=[0, 0], coord_max=[10, 10])
time = TimeDomain("time", 0.0, 20)
grid = GeometryWithTime(rect_space, time)
grid.set_sampling_config(create_config_from_edict(sampling_config))
grid.set_name("grid")
geom_dict = {grid: ["domain", "BC", "IC"]}
dataset = Dataset(geom_dict)

def preprocess_fn(*data):
bc_data = data[1]
bc_normal = np.ones(bc_data.shape)
return data[0], data[1], bc_normal, data[2], data[3]

colunms_map = {"grid_domain_points": ["grid_domain_points"],
"grid_BC_points": ["grid_BC_points", "grid_BC_tangent"],
"grid_BC_normal": "grid_BC_normal",
"grid_IC_points": "grid_IC_points"}
train_data = dataset.create_dataset(batch_size=8192, shuffle=False, drop_remainder=False,
preprocess_fn=preprocess_fn,
input_output_columns_map=colunms_map)
dataset.set_constraint_type({dataset.all_datasets[0]: "Equation",
dataset.all_datasets[1]: "BC",
dataset.all_datasets[2]: "IC"})
print("get merged data: {}".format(dataset[5]))
for sub_data in dataset.all_datasets:
print("get data: {}".format(sub_data[5]))

dataset_iter = train_data.create_dict_iterator(num_epochs=1)
np.set_printoptions(threshold=np.inf)
for _ in range(1):
for data in dataset_iter:
for k, v in data.items():
print("key: ", k)
print(v)
break

+ 133
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_1d.py View File

@@ -0,0 +1,133 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test geometry module: 1d cases"""
import pytest
from easydict import EasyDict as edict

from mindelec.geometry import create_config_from_edict
from mindelec.geometry import Interval, TimeDomain

line_config_out = edict({
'domain': edict({
'random_sampling': True,
'size': 100,
'sampler': 'uniform'
}),
'BC': edict({
'random_sampling': True,
'size': 10,
'sampler': 'uniform',
}),
})


line_config_out2 = edict({
'domain': edict({
'random_sampling': True,
'size': 100,
'sampler': 'uniform'
}),
'BC': edict({
'random_sampling': False,
'size': 10,
'sampler': 'uniform',
}),
})

def check_line_interval_case1(line_config):
"""check_line_interval_case1"""
try:
Interval("line", 'test', 1.0, sampling_config=create_config_from_edict(line_config))
except ValueError:
return
line = Interval("line", -1.0, 1.0, sampling_config=create_config_from_edict(line_config))

domain = line.sampling(geom_type="domain")
bc = line.sampling(geom_type="BC")
try:
line.sampling(geom_type="other")
except ValueError:
print("get ValueError when sampling other data")

# uniform sampling
if "BC" in line_config.keys():
line_config.BC = None
if "domain" in line_config.keys():
line_config.domain.random_sampling = False
line.set_sampling_config(create_config_from_edict(line_config))
domain = line.sampling(geom_type="domain")
try:
line.sampling(geom_type="BC")
except KeyError:
print("get ValueError when sampling BC data")

# lhs, halton, sobol
for samplers in ["lhs", "halton", "sobol"]:
if "domain" in line_config.keys():
line_config.domain.random_sampling = True
line_config.domain.sampler = samplers
line.set_sampling_config(create_config_from_edict(line_config))
domain = line.sampling(geom_type="domain")
print(domain, bc)

time_config = edict({
'domain': edict({
'random_sampling': True,
'size': 100,
'sampler': 'lhs'
})
})


def check_time_interval(line_config):
"""check_time_interval"""
try:
create_config_from_edict({"test": "test"})
except ValueError:
return

line = TimeDomain("time", 0.0, 1.0, sampling_config=create_config_from_edict(line_config))
domain = line.sampling(geom_type="domain")
try:
line.sampling(geom_type="BC")
except KeyError:
print("get ValueError when sampling BC data")
print(domain)

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_time_interval():
"""test_check_time_interval"""
check_time_interval(time_config)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_line_interval_case1():
"""test_check_time_interval"""
check_line_interval_case1(line_config_out)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_line_interval_case2():
"""test_check_time_interval"""
check_line_interval_case1(line_config_out2)

+ 247
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_2d.py View File

@@ -0,0 +1,247 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test geometry module: 2d cases"""
import pytest
from easydict import EasyDict as edict

from mindelec.geometry import create_config_from_edict
from mindelec.geometry import Rectangle, Disk

disk_random = edict({
'domain': edict({
'random_sampling': True,
'size': 1000,
'sampler': 'uniform'
}),
'BC': edict({
'random_sampling': True,
'size': 200,
'sampler': 'uniform',
'with_normal': False,
}),
})

disk_mesh = edict({
'domain': edict({
'random_sampling': False,
'size': [100, 180],
}),
'BC': edict({
'random_sampling': False,
'size': [20, 10],
'with_normal': False,
}),
})

disk_mesh_wrong_meshsize = edict({
'domain': edict({
'random_sampling': False,
'size': [100, 180, 200],
}),
'BC': edict({
'random_sampling': False,
'size': 200,
'with_normal': False,
}),
})

disk_mesh_nodomain = edict({
'BC': edict({
'random_sampling': False,
'size': 200,
'with_normal': False,
}),
})

disk_mesh_nobc = edict({
'domain': edict({
'random_sampling': True,
'size': 1000,
'sampler': 'uniform'
}),
})


def check_disk_random(disk_config):
"""check_disk_random"""
with pytest.raises(ValueError):
disk = Disk("disk", (-1.0, 0), -2.0)
with pytest.raises(ValueError):
disk = Disk("disk", (-1.0, 0, 3), 2.0)

disk = Disk("disk", (-1.0, 0), 2.0)
for samplers in ["uniform", "lhs", "halton", "sobol"]:
print("check random sampler: {}".format(samplers))
if "domain" in disk_config.keys():
disk_config.domain.sampler = samplers
if "BC" in disk_config.keys():
disk_config.BC.sampler = samplers

try:
disk.sampling(geom_type="domain")
except ValueError:
return

disk.set_sampling_config(create_config_from_edict(disk_config))
with pytest.raises(ValueError):
disk.sampling(geom_type="test")

domain = disk.sampling(geom_type="domain")

bc = disk.sampling(geom_type="BC")
disk.sampling_config.bc.with_normal = True
bc, bc_normal = disk.sampling(geom_type="BC")
print(bc, bc_normal, domain)


def check_disk_mesh(disk_config):
"""check_disk_mesh"""
disk = Disk("disk", (-1.0, 0), 2.0, sampling_config=create_config_from_edict(disk_config))
domain = disk.sampling(geom_type="domain")

bc = disk.sampling(geom_type="BC")
disk.sampling_config.bc.with_normal = True
bc, bc_normal = disk.sampling(geom_type="BC")
print(bc, bc_normal, domain)


rectangle_random = edict({
'domain': edict({
'random_sampling': True,
'size': 1000,
'sampler': 'uniform'
}),
'BC': edict({
'random_sampling': True,
'size': 200,
'sampler': 'uniform',
'with_normal': True,
}),
})

rectangle_mesh = edict({
'domain': edict({
'random_sampling': False,
'size': [50, 25],
}),
'BC': edict({
'random_sampling': False,
'size': 300,
'with_normal': True,
}),
})


def check_rectangle_random(config):
"""check_rectangle_random"""
rectangle = Rectangle("rectangle", (-3.0, 1), (1, 2))
for samplers in ["uniform", "lhs", "halton", "sobol"]:
print("check random sampler: {}".format(samplers))
if "domain" in config.keys():
config.domain.sampler = samplers
if "BC" in config.keys():
config.BC.sampler = samplers
config.BC.with_normal = True
rectangle.set_sampling_config(create_config_from_edict(config))
domain = rectangle.sampling(geom_type="domain")
bc, bc_normal = rectangle.sampling(geom_type="BC")

if "BC" in config.keys():
config.BC.with_normal = False
rectangle.set_sampling_config(create_config_from_edict(config))
bc = rectangle.sampling(geom_type="BC")
print(bc, bc_normal, domain)


def check_rectangle_mesh(config):
"""check_rectangle_mesh"""
rectangle = Rectangle("rectangle", (-3.0, 1), (1, 2))
if "BC" in config.keys():
config.BC.with_normal = True
rectangle.set_sampling_config(create_config_from_edict(config))
domain = rectangle.sampling(geom_type="domain")
bc, bc_normal = rectangle.sampling(geom_type="BC")

if "BC" in config.keys():
config.BC.with_normal = False
rectangle.set_sampling_config(create_config_from_edict(config))
bc = rectangle.sampling(geom_type="BC")
print(bc, bc_normal, domain)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_disk_random():
"""test_check_disk_random"""
check_disk_random(disk_random)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_disk_mesh():
"""test_check_disk_mesh"""
check_disk_mesh(disk_mesh)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_disk_mesh_wrong_meshsize_error():
"""test_check_disk_mesh"""
with pytest.raises(ValueError):
check_disk_mesh(disk_mesh_wrong_meshsize)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_disk_mesh_nodomain_error():
"""test_check_disk_mesh"""
with pytest.raises(KeyError):
check_disk_mesh(disk_mesh_nodomain)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_disk_mesh_nobc_error():
"""test_check_disk_mesh"""
with pytest.raises(KeyError):
check_disk_mesh(disk_mesh_nobc)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_rectangle_random():
"""test_check_rectangle_random"""
check_rectangle_random(rectangle_random)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_rectangle_mesh():
"""test_check_rectangle_mesh"""
check_rectangle_mesh(rectangle_mesh)

+ 264
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_base.py View File

@@ -0,0 +1,264 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test geometry module: base classes"""
import pytest
import numpy as np
from easydict import EasyDict as edict

from mindelec.geometry import Geometry, PartSamplingConfig, SamplingConfig, create_config_from_edict

def check_create_config_from_edict():
try:
sampling_config = ["geom", "IC", "BC"]
config = create_config_from_edict(sampling_config)
print("check config: {}".format(config))
except TypeError:
print("get sampling config type error")


def check_part_sampling_config(size, random_sampling, sampler, random_merge, with_normal):
try:
config = PartSamplingConfig(size, random_sampling, sampler, random_merge, with_normal)
print("check config: {}".format(config.__dict__))
except TypeError:
print("get TypeError")


temp_config = edict({
'domain': edict({
'random_sampling': True,
'size': 100,
}),
'BC': edict({
'random_sampling': True,
'size': 100,
'sampler': 'uniform',
'random_merge': True,
}),
})


def check_sampling_config_case1():
"""check_sampling_config_case1"""
with pytest.raises(TypeError):
SamplingConfig("test")
with pytest.raises(KeyError):
SamplingConfig({"test": "test"})
with pytest.raises(TypeError):
SamplingConfig({"domain": "test"})
with pytest.raises(TypeError):
part_sampling_config_dict = {"domain": PartSamplingConfig("test", False, True)}
SamplingConfig(part_sampling_config_dict)

part_sampling_config_dict = {"domain": PartSamplingConfig([100, 100], False, "uniform", True, True),
"BC": PartSamplingConfig(100, True, "uniform", True, True)}
sampling_config_tmp = SamplingConfig(part_sampling_config_dict)
for attr, config in sampling_config_tmp.__dict__.items():
if config is not None:
print("check sampling config: {}: {}".format(attr, config.__dict__))


def check_sampling_config_case2(config_in):
"""check_sampling_config_case2"""
sampling_config_tmp = create_config_from_edict(config_in)
for attr, config in sampling_config_tmp.__dict__.items():
if config is not None:
print("check sampling config: {}: {}".format(attr, config.__dict__))


def check_sampling_config_case3(config_in):
"""check_sampling_config_case3"""
config_in.ttime = config_in.domain
sampling_config_tmp = create_config_from_edict(config_in)
for attr, config in sampling_config_tmp.__dict__.items():
if config is not None:
print("check sampling config: {}: {}".format(attr, config.__dict__))


def check_geometry_case1():
"""check_geometry_case1"""
with pytest.raises(ValueError):
Geometry("geom", 2, 0.0, 1.0)
with pytest.raises(TypeError):
Geometry("geom", 2, [0, 0], [1, 1], sampling_config="test")

geom = Geometry("geom", 1, 0.0, 1.0)
with pytest.raises(TypeError):
geom.set_sampling_config('test')

with pytest.raises(NotImplementedError):
geom.sampling()

for attr, config in geom.__dict__.items():
if config is not None:
print("check sampling config: {}: {}".format(attr, config))

try:
geom = Geometry("geom", 1, 1.0, 0.0)
for attr, config in geom.__dict__.items():
if config is not None:
print("check sampling config: {}: {}".format(attr, config))
except ValueError:
print("get ValueError")


sampling_config2 = create_config_from_edict(temp_config)


def check_geometry_case2(sampling_config_tmp):
"""check_geometry_case2"""
geom = Geometry("geom", 1, 0.0, 1.0, sampling_config=sampling_config_tmp)
geom.set_name("geom_name")
for attr, config in geom.__dict__.items():
if config is not None:
print("check sampling config: {}: {}".format(attr, config))


def check_geometry_case3(config_in):
"""check_geometry_case3"""
geom = Geometry("geom", 1, 0.0, 1.0)
for attr, configs in geom.__dict__.items():
if configs is not None:
print("check sampling config: {}: {}".format(attr, configs))

geom.set_name("geom_name")
geom.set_sampling_config(create_config_from_edict(config_in))
for attr, config in geom.__dict__.items():
if attr == "sampling_config" and config is not None:
print("check sampling config after set: {}: {}".format(attr, config.__dict__))
for attrs, configs in config.__dict__.items():
if configs is not None:
print("check sampling config: {}: {}".format(attrs, configs.__dict__))
else:
if config is not None:
print("check sampling config after set: {}: {}".format(attr, config))


def check_geometry_case4():
"""check_geometry_case4"""
try:
geom = Geometry(10, 1, 1.0, 0.0)
for attr, config in geom.__dict__.items():
if config is not None:
print("check sampling config: {}: {}".format(attr, config))
except TypeError:
print("get geom name type error")

geom = Geometry("geom", 1, 0.0, 1.0)
try:
geom.set_name("geom_name")
except TypeError:
print("get set geom name type error")
geom.set_name("geom_name")

try:
geom = Geometry("geom", 1.0, 1.0, 2.0)
for attr, config in geom.__dict__.items():
if config is not None:
print("check sampling config: {}: {}".format(attr, config))
except TypeError:
print("get geom dim type error")

try:
geom = Geometry("geom", 1, 1.0, 0.0)
except ValueError:
print("get geom coord value error")

try:
geom = Geometry("geom", 1, {"min": 0.0}, {"max": 1.0})
except TypeError:
print("get geom coord type error")

try:
geom = Geometry("geom", 1, 0.0, 1.0, dtype=np.uint32)
except TypeError:
print("get geom data type error")

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_part_sampling_config():
"""test_check_part_sampling_config"""
check_part_sampling_config(100, True, "uniform", True, True)
check_part_sampling_config(100, False, "sobol", True, True)
check_part_sampling_config([100, 100], False, "sobol", True, True)
check_part_sampling_config([100, 100], True, "uniform", True, True)
check_part_sampling_config(100, False, "lhs", True, True)
check_part_sampling_config(100, False, "halton", True, True)
check_create_config_from_edict()


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_sampling_config_case1():
"""test_check_sampling_config_case1"""
check_sampling_config_case1()


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_sampling_config_case2():
"""test_check_sampling_config_case2"""
check_sampling_config_case2(temp_config)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_sampling_config_case3():
"""test_check_sampling_config_case3"""
check_sampling_config_case3(temp_config)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_geometry_case1():
"""test_check_geometry_case1"""
check_geometry_case1()


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_geometry_case2():
"""test_check_geometry_case2"""
check_geometry_case2(sampling_config2)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_geometry_case3():
"""test_check_geometry_case3"""
check_geometry_case3(temp_config)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_geometry_case4():
"""test_check_geometry_case4"""
check_geometry_case4()

+ 240
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_csg.py View File

@@ -0,0 +1,240 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test geometry module: CSG classes"""
import pytest
from easydict import EasyDict as edict
from mindelec.geometry import create_config_from_edict
from mindelec.geometry import Rectangle, Disk, Interval
from mindelec.geometry import CSGIntersection, CSGDifference, CSGUnion, CSGXOR, CSG

sampling_config_csg = edict({
'domain': edict({
'random_sampling': True,
'size': 1000,
'sampler': 'uniform'
}),
'BC': edict({
'random_sampling': True,
'size': 200,
'sampler': 'uniform',
'with_normal': True,
}),
})

sampling_config_csg2 = edict({
'domain': edict({
'random_sampling': True,
'size': 1000,
'sampler': 'uniform'
}),
'BC': edict({
'random_sampling': True,
'size': 200,
'sampler': 'uniform',
'with_normal': False,
}),
})

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_csg_union():
"""test check union"""
disk = Disk("disk", (1.2, 0.5), 0.8)
rect = Rectangle("rect", (-1.0, 0), (1, 1))

union = CSGUnion(rect, disk, create_config_from_edict(sampling_config_csg))
union = rect | disk
for samplers in ["uniform", "lhs", "halton", "sobol"]:
print("check random sampler: {}".format(samplers))
if "domain" in sampling_config_csg.keys():
sampling_config_csg.domain.sampler = samplers
if "BC" in sampling_config_csg.keys():
sampling_config_csg.BC.sampler = samplers

union.set_sampling_config(create_config_from_edict(sampling_config_csg2))
bc = union.sampling(geom_type="BC")

union.set_sampling_config(create_config_from_edict(sampling_config_csg))
domain = union.sampling(geom_type="domain")
bc, bc_normal = union.sampling(geom_type="BC")
print(bc, bc_normal, domain)

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_csg_difference():
"""test_check_csg_difference"""
disk = Disk("disk", (1.2, 0.5), 0.8)
rect = Rectangle("rect", (-1.0, 0), (1, 1))

difference = CSGDifference(rect, disk, create_config_from_edict(sampling_config_csg))
difference = rect - disk
for samplers in ["uniform", "lhs", "halton", "sobol"]:
print("check random sampler: {}".format(samplers))
if "domain" in sampling_config_csg.keys():
sampling_config_csg.domain.sampler = samplers
if "BC" in sampling_config_csg.keys():
sampling_config_csg.BC.sampler = samplers

difference.set_sampling_config(create_config_from_edict(sampling_config_csg2))
bc = difference.sampling(geom_type="BC")

difference.set_sampling_config(create_config_from_edict(sampling_config_csg))
domain = difference.sampling(geom_type="domain")
bc, bc_normal = difference.sampling(geom_type="BC")
print(bc, bc_normal, domain)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_csg_intersection():
"""test_check_csg_intersection"""
disk = Disk("disk", (1.2, 0.5), 0.8)
rect = Rectangle("rect", (-1.0, 0), (1, 1))

intersec = CSGIntersection(rect, disk, create_config_from_edict(sampling_config_csg))
intersec = rect & disk
for samplers in ["uniform", "lhs", "halton", "sobol"]:
print("check random sampler: {}".format(samplers))
if "domain" in sampling_config_csg.keys():
sampling_config_csg.domain.sampler = samplers
if "BC" in sampling_config_csg.keys():
sampling_config_csg.BC.sampler = samplers

intersec.set_sampling_config(create_config_from_edict(sampling_config_csg2))
bc = intersec.sampling(geom_type="BC")

intersec.set_sampling_config(create_config_from_edict(sampling_config_csg))
domain = intersec.sampling(geom_type="domain")
bc, bc_normal = intersec.sampling(geom_type="BC")
print(bc, bc_normal, domain)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_csg_xor():
"""test_check_csg_xor"""
disk = Disk("disk", (1.2, 0.5), 0.8)
rect = Rectangle("rect", (-1.0, 0), (1, 1))

xor = CSGXOR(rect, disk, create_config_from_edict(sampling_config_csg))
xor = rect ^ disk
for samplers in ["uniform", "lhs", "halton", "sobol"]:
print("check random sampler: {}".format(samplers))
if "domain" in sampling_config_csg.keys():
sampling_config_csg.domain.sampler = samplers
if "BC" in sampling_config_csg.keys():
sampling_config_csg.BC.sampler = samplers

xor.set_sampling_config(create_config_from_edict(sampling_config_csg2))
bc = xor.sampling(geom_type="BC")

xor.set_sampling_config(create_config_from_edict(sampling_config_csg))
domain = xor.sampling(geom_type="domain")
bc, bc_normal = xor.sampling(geom_type="BC")
print(bc, bc_normal, domain)


no_src_sampling_config = edict({
'domain': edict({
'random_sampling': True,
'size': 200,
'sampler': 'uniform'
}),
})

no_src_sampling_config1 = edict({
'BC': edict({
'random_sampling': True,
'size': 200,
'sampler': 'uniform',
'with_normal': False
}),
})

no_src_sampling_config2 = edict({
'domain': edict({
'random_sampling': False,
'size': [10, 10]
}),
})

no_src_sampling_config3 = edict({
'BC': edict({
'random_sampling': False,
'size': [10, 10]
}),
})

no_src_sampling_config4 = edict({
'IC': edict({
'random_sampling': False,
'size': [10, 10]
}),
})

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_point_src_csg():
"""test_check_point_src_csg"""
src_region = Disk("src", (0.0, 0.0), 0.2)
rectangle = Rectangle("rect", (-1, -1), (1, 1))
line = Interval("line", -1, 1)

with pytest.raises(TypeError):
no_src_region = CSG("test", rectangle, 2, [0, 0], [1, 1])
with pytest.raises(TypeError):
no_src_region = CSG("test", 2, rectangle, [0, 0], [1, 1])
with pytest.raises(ValueError):
no_src_region = CSG("test", rectangle, line, [0, 0], [1, 1])

no_src_region = rectangle - src_region
no_src_region.set_name("no_src")

with pytest.raises(ValueError):
no_src_region.set_sampling_config(None)

with pytest.raises(TypeError):
no_src_region.set_sampling_config("test")

with pytest.raises(ValueError):
no_src_region.set_sampling_config(create_config_from_edict(no_src_sampling_config2))

with pytest.raises(ValueError):
no_src_region.set_sampling_config(create_config_from_edict(no_src_sampling_config3))

with pytest.raises(ValueError):
no_src_region.set_sampling_config(create_config_from_edict(no_src_sampling_config4))

no_src_region.set_sampling_config(create_config_from_edict(no_src_sampling_config))
with pytest.raises(KeyError):
no_src_region.sampling(geom_type="BC")

no_src_region.set_sampling_config(create_config_from_edict(no_src_sampling_config1))
with pytest.raises(KeyError):
no_src_region.sampling(geom_type="domain")
with pytest.raises(ValueError):
no_src_region.sampling(geom_type="test")

no_src_region.sampling(geom_type="BC")

+ 142
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_nd.py View File

@@ -0,0 +1,142 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test geometry module: nd cases"""
import pytest
from easydict import EasyDict as edict

from mindelec.geometry import create_config_from_edict
from mindelec.geometry import Cuboid

cuboid_random = edict({
'domain': edict({
'random_sampling': True,
'size': 1000,
'sampler': 'uniform'
}),
'BC': edict({
'random_sampling': True,
'size': 200,
'sampler': 'uniform',
'with_normal': False,
}),
})

cuboid_random2 = edict({
'BC': edict({
'random_sampling': True,
'size': 200,
'sampler': 'uniform',
'with_normal': False,
}),
})

cuboid_mesh = edict({
'domain': edict({
'random_sampling': False,
'size': [20, 30, 10],
}),
'BC': edict({
'random_sampling': False,
'size': 900,
'with_normal': True,
}),
})


cuboid_mesh2 = edict({
'domain': edict({
'random_sampling': False,
'size': [20, 30],
}),
'BC': edict({
'random_sampling': False,
'size': 900,
'with_normal': True,
}),
})

def check_cuboid_random(cuboid_config):
"""check_cuboid_random"""

cuboid = Cuboid("Cuboid", (-3, -1, 0), [-1, 2, 1])
for samplers in ["uniform", "lhs", "halton", "sobol"]:
print("check random sampler: {}".format(samplers))
if "domain" in cuboid_config.keys():
cuboid_config.domain.sampler = samplers
if "BC" in cuboid_config.keys():
cuboid_config.BC.sampler = samplers

try:
cuboid.set_sampling_config("test")
except TypeError:
print("set_sampling_config TypeError")

try:
cuboid.sampling(geom_type="domain")
except ValueError:
print("sampling ValueError")

cuboid.set_sampling_config(create_config_from_edict(cuboid_config))
domain = cuboid.sampling(geom_type="domain")
bc = cuboid.sampling(geom_type="BC")
assert domain.shape == (1000, 3)
assert bc.shape == (199, 3)


def check_cuboid_mesh(cuboid_config):
"""check_cuboid_mesh"""
cuboid = Cuboid("Cuboid", (-3, -1, 0), [-1, 2, 1], sampling_config=create_config_from_edict(cuboid_config))
domain = cuboid.sampling(geom_type="domain")
bc, bc_normal = cuboid.sampling(geom_type="BC")
assert domain.shape == (6000, 3)
assert bc.shape == (556, 3)
assert bc_normal.shape == (556, 3)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_cuboid_random():
"""test_check_cuboid_random"""
check_cuboid_random(cuboid_random)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_cuboid_random_nodomain_error():
"""test_check_cuboid_random"""
with pytest.raises(KeyError):
check_cuboid_random(cuboid_random2)

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_cuboid_mesh():
"""test_check_cuboid_mesh"""
check_cuboid_mesh(cuboid_mesh)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_cuboid_mesh_meshsize_error():
"""test_check_cuboid_mesh"""
with pytest.raises(ValueError):
check_cuboid_mesh(cuboid_mesh2)

+ 293
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_td.py View File

@@ -0,0 +1,293 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test geometry module: geometry with time cases"""

import copy
import pytest
from easydict import EasyDict as edict

from mindelec.geometry import create_config_from_edict
from mindelec.geometry import Rectangle, GeometryWithTime, TimeDomain

rectangle_random = edict({
'domain': edict({
'random_sampling': True,
'size': 1000,
'sampler': 'uniform'
}),
'BC': edict({
'random_sampling': True,
'size': 200,
'sampler': 'uniform',
'with_normal': False,
}),
})

rectangle_mesh = edict({
'domain': edict({
'random_sampling': False,
'size': [10, 20],
}),
'BC': edict({
'random_sampling': False,
'size': 50,
'with_normal': False,
}),
})

time_random = edict({
'domain': edict({
'random_sampling': True,
'size': 10,
'sampler': 'lhs'
})
})

time_mesh = edict({
'domain': edict({
'random_sampling': False,
'size': 10,
})
})

time_mesh2 = edict({
'domain': edict({
'random_sampling': False,
'size': 10000,
})
})

time_mesh3 = edict({
'IC': edict({
'random_sampling': False,
'size': 10000,
})
})

reset_geom_time_config = edict({
'domain': edict({
'random_sampling': False,
'size': [10, 20],
}),
'BC': edict({
'random_sampling': False,
'size': 50,
'with_normal': True,
}),
'IC': edict({
'random_sampling': False,
'size': [10, 10],
}),
'time': edict({
'random_sampling': False,
'size': 10,
})
})

reset_geom_time_config2 = edict({
'domain': edict({
'random_sampling': False,
'size': [10, 20],
}),
'BC': edict({
'random_sampling': False,
'size': 50,
'with_normal': True,
}),
'IC': edict({
'random_sampling': False,
'size': [10, 10],
}),
'time': edict({
'random_sampling': True,
'size': 10,
})
})

reset_geom_time_config3 = edict({
'domain': edict({
'random_sampling': True,
'size': 200,
}),
'BC': edict({
'random_sampling': True,
'size': 50,
'with_normal': True,
}),
'IC': edict({
'random_sampling': False,
'size': [10, 10],
}),
'time': edict({
'random_sampling': True,
'size': 10,
})
})

reset_geom_time_config4 = edict({
'domain': edict({
'random_sampling': True,
'size': 200,
}),
'BC': edict({
'random_sampling': True,
'size': 50,
'with_normal': True,
}),
'IC': edict({
'random_sampling': True,
'size': 100,
}),
'time': edict({
'random_sampling': False,
'size': 10,
})
})

reset_geom_time_config5 = edict({
'time': edict({
'random_sampling': False,
'size': 10,
})
})

def check_rect_with_time_init_config(rect_config, time_config):
"""check_rect_with_time_init_config"""
rect = Rectangle("rect", [-1.0, -0.5], [1.0, 0.5], sampling_config=create_config_from_edict(rect_config))
time = TimeDomain("time", 0.0, 1.0, sampling_config=create_config_from_edict(time_config))
rect_with_time = GeometryWithTime(rect, time)

# check info
print("check rect_with_time initial config: {}".format(rect_with_time.__dict__))
if rect_with_time.sampling_config is not None:
for key, value in rect_with_time.sampling_config.__dict__.items():
if value is not None:
print(" get attr: {}, value: {}".format(key, value.__dict__))

# sampling
config = rect_with_time.sampling_config
if config is None:
raise ValueError
if config.domain is not None:
domain = rect_with_time.sampling(geom_type="domain")
print("check domain points: {}".format(domain.shape))
if config.bc is not None:
if config.bc.with_normal:
bc, bc_normal = rect_with_time.sampling(geom_type="BC")
print("check bc points: {}, bc_normal: {}".format(bc.shape, bc_normal.shape))
else:
bc = rect_with_time.sampling(geom_type="BC")
print("check bc points: {}".format(bc.shape))
if config.ic is not None:
ic = rect_with_time.sampling(geom_type="IC")
print("check ic points: {}".format(ic.shape))


def check_rect_with_time_set_config(config):
"""check_rect_with_time_set_config"""
rect = Rectangle("rect", [-1.0, -0.5], [1.0, 0.5])
time = TimeDomain("time", 0.0, 1.0)
try:
GeometryWithTime(rect, time, create_config_from_edict(config))
except ValueError:
print("create_config_from_edict ValueError")

rect_with_time = GeometryWithTime(rect, time)
try:
rect_with_time.sampling(geom_type="domain")
except ValueError:
print("sampling ValueError")

rect_with_time.set_sampling_config(create_config_from_edict(config))

try:
rect_with_time.sampling(geom_type="test")
except ValueError:
print("sampling ValueError")
# sampling
config = rect_with_time.sampling_config
if config.domain is not None:
domain = rect_with_time.sampling(geom_type="domain")
print("check domain points: {}".format(domain.shape))
else:
try:
rect_with_time.sampling(geom_type="domain")
except ValueError:
print("sampling KeyError")
if config.bc is not None:
if config.bc.with_normal:
bc, bc_normal = rect_with_time.sampling(geom_type="BC")
print("check bc points: {}, bc_normal: {}".format(bc.shape, bc_normal.shape))
normal = copy.deepcopy(bc)
normal[:, :2] = bc[:, :2] + bc_normal[:, :]
else:
bc = rect_with_time.sampling(geom_type="BC")
print("check bc points: {}".format(bc.shape))
else:
try:
rect_with_time.sampling(geom_type="BC")
except ValueError:
print("sampling KeyError")
if config.ic is not None:
ic = rect_with_time.sampling(geom_type="IC")
print("check ic points: {}".format(ic.shape))
else:
try:
rect_with_time.sampling(geom_type="IC")
except ValueError:
print("sampling KeyError")


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_rect_with_time_init_config():
"""test_check_rect_with_time_init_config"""
check_rect_with_time_init_config(rectangle_random, time_random)
check_rect_with_time_init_config(rectangle_mesh, time_random)
check_rect_with_time_init_config(rectangle_mesh, time_mesh2)
check_rect_with_time_init_config(rectangle_random, time_mesh)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_rect_with_time_init_config_error():
"""test_check_rect_with_time_init_config"""
with pytest.raises(ValueError):
check_rect_with_time_init_config(rectangle_mesh, time_mesh3)

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_rect_with_time_set_config():
"""test_check_rect_with_time_set_config"""
check_rect_with_time_set_config(reset_geom_time_config)
check_rect_with_time_set_config(reset_geom_time_config2)
check_rect_with_time_set_config(reset_geom_time_config3)
check_rect_with_time_set_config(reset_geom_time_config4)
check_rect_with_time_set_config(reset_geom_time_config4)

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_check_rect_with_time_set_config2():
"""test_check_rect_with_time_set_config"""
with pytest.raises(ValueError):
check_rect_with_time_set_config(rectangle_random)

+ 39
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/loss/test_constraints.py View File

@@ -0,0 +1,39 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test constraints"""
import pytest
from mindelec.data import Dataset
from mindelec.loss import Constraints
from mindelec.geometry import Rectangle


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_constraints_dataset_type_error():
with pytest.raises(TypeError):
Constraints(1, 1)

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_existed_data_config_type_error():
rectangle = Rectangle("rect", (-1, -1), (1, 1))
geom_dict = {rectangle: ["domain"]}
dataset = Dataset(geom_dict)
with pytest.raises(TypeError):
Constraints(dataset, 1)

+ 217
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/net_with_loss/test_netwithloss.py View File

@@ -0,0 +1,217 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
test net_with_loss
"""
import os
from easydict import EasyDict as edict
import numpy as np
import pytest

import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
from mindspore import context, ms_function
from mindspore.common import set_seed
from mindelec.geometry import Rectangle, create_config_from_edict
from mindelec.data import Dataset, ExistedDataConfig
from mindelec.loss import Constraints, NetWithLoss, NetWithEval
from mindelec.architecture import ResBlock, LinearBlock
from mindelec.solver import Problem
from mindelec.common import L2
from mindelec.operators import Grad

set_seed(1)
np.random.seed(1)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")

path = os.getcwd()
data_config = edict({
'data_dir': [path+'/inputs.npy', path+'/label.npy'], # absolute dir
'columns_list': ['input_data', 'label'],
'data_format': 'npy',
'constraint_type': 'Equation',
'name': 'exist'
})

if not os.path.exists(data_config['data_dir'][0]):
data_in = np.ones((32, 2), dtype=np.float32)
np.save(path+"/inputs.npy", data_in)

if not os.path.exists(data_config['data_dir'][1]):
data_label = np.ones((32, 1), dtype=np.float32)
np.save(path+"/label.npy", data_label)


rectangle_config = edict({
'domain': edict({
'random_sampling': True,
'size': 2500,
'sampler': 'uniform'
}),
'BC': edict({
'random_sampling': True,
'size': 200,
'sampler': 'uniform'
})
})

rectangle_config1 = edict({
'domain': edict({
'random_sampling': True,
'size': 2500,
'sampler': 'uniform'
}),
})


class Net(nn.Cell):
"""net definition"""
def __init__(self, input_dim, output_dim, hidden_layer=128, activation="sin"):
super(Net, self).__init__()
self.resblock = ResBlock(hidden_layer, hidden_layer, activation=activation)
self.fc1 = LinearBlock(input_dim, hidden_layer)
self.fc2 = LinearBlock(hidden_layer, output_dim)

def construct(self, *inputs):
x = inputs[0]
out = self.fc1(x)
out = self.resblock(out)
out = self.fc2(out)
return out


class RectPde(Problem):
"""rectangle pde problem"""
def __init__(self, domain_name=None, bc_name=None, label_name=None, net=None):
super(RectPde, self).__init__()
self.domain_name = domain_name
self.bc_name = bc_name
self.label_name = label_name
self.type = "Equation"
self.jacobian = Grad(net)

@ms_function
def governing_equation(self, *output, **kwargs):
u = output[0]
data = kwargs[self.domain_name]
u_x = self.jacobian(data, 0, 0, u)
return u_x

@ms_function
def boundary_condition(self, *output, **kwargs):
u = output[0]
x = kwargs[self.bc_name][:, 0]
y = kwargs[self.bc_name][:, 1]
return u - ops.sin(x) * ops.cos(y)

@ms_function
def constraint_function(self, *output, **kwargs):
u = output[0]
label = kwargs[self.label_name]
return u - label


class RectPde1(Problem):
"""rectangle pde problem with no boundary condition"""
def __init__(self, domain_name, net):
super(RectPde1, self).__init__()
self.domain_name = domain_name
self.type = "Equation"
self.jacobian = Grad(net)

@ms_function
def governing_equation(self, *output, **kwargs):
u = output[0]
data = kwargs[self.domain_name]
u_x = self.jacobian(data, 0, 0, u)
return u_x


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_netwithloss():
"""test netwithloss function"""
model = Net(2, 1, 128, "sin")
rect_space = Rectangle("rectangle", coord_min=[-1.0, -1.0], coord_max=[1.0, 1.0],
sampling_config=create_config_from_edict(rectangle_config))

geom_dict = {rect_space: ["domain", "BC"]}
dataset = Dataset(geom_dict)
dataset.create_dataset(batch_size=4, shuffle=True)
prob_dict = {rect_space.name: RectPde(domain_name="rectangle_domain_points", bc_name="rectangle_BC_points",
net=model)}
train_constraints = Constraints(dataset, prob_dict)
metrics = {'l2': L2(), 'distance': nn.MAE()}
train_input_map = {'rectangle_domain': ['rectangle_domain_points'], 'rectangle_BC': ['rectangle_BC_points']}
loss_network = NetWithLoss(model, train_constraints, metrics, train_input_map)

domain_points = Tensor(np.ones([32, 2]).astype(np.float32))
bc_points = Tensor(np.ones([32, 2]).astype(np.float32))
bc_normal = Tensor(np.ones([32, 2]).astype(np.float32))
out = loss_network(domain_points, bc_points, bc_normal)
print(out)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_netwithloss1():
"""test netwithloss function1"""
model = Net(2, 1, 128, "sin")
rect_space = Rectangle("rectangle", coord_min=[-1.0, -1.0], coord_max=[1.0, 1.0],
sampling_config=create_config_from_edict(rectangle_config1))
geom_dict = {rect_space: ["domain"]}
dataset = Dataset(geom_dict)
dataset.create_dataset(batch_size=4, shuffle=True)
prob_dict = {rect_space.name: RectPde1(domain_name="rectangle_domain_points", net=model)}
train_constraints = Constraints(dataset, prob_dict)
metrics = {'l2': L2(), 'distance': nn.MAE()}
train_input_map = {'rectangle_domain': ['rectangle_domain_points']}
loss_network = NetWithLoss(model, train_constraints, metrics, train_input_map)

domain_points = Tensor(np.ones([32, 2]).astype(np.float32))
out = loss_network(domain_points)
print(out)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_netwitheval():
"""test netwitheval function"""
model = Net(2, 1, 128, "sin")
src_domain = ExistedDataConfig(name="src_domain",
data_dir=[path+"/inputs.npy", path+"/label.npy"],
columns_list=["inputs", "label"],
data_format="npy",
constraint_type="Label",
random_merge=False)
test_prob_dict = {src_domain.name: RectPde(domain_name=src_domain.name+"_inputs",
label_name=src_domain.name+"_label", net=model)}
test_dataset = Dataset(existed_data_list=[src_domain])
test_dataset.create_dataset(batch_size=4, shuffle=False)
test_constraints = Constraints(test_dataset, test_prob_dict)
metrics = {'l2': L2(), 'distance': nn.MAE()}
loss_network = NetWithEval(model, test_constraints, metrics)

data = Tensor(np.ones([32, 2]).astype(np.float32))
label = Tensor(np.ones([32, 1]).astype(np.float32))
out = loss_network(data, label)
print(out)

+ 30
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/src/config.py View File

@@ -0,0 +1,30 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""

# config
config = {
'base_channels': 8,
'input_channels': 4,
'epochs': 2000,
'batch_size': 8,
'save_epoch': 100,
'lr': 0.01,
'lr_decay_milestones': 5,
'eval_interval': 20,
'patch_shape': [25, 50, 25],
}

+ 82
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/src/dataset.py View File

@@ -0,0 +1,82 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""dataset generation and loading"""

import numpy as np
from mindspore.common import set_seed

from mindelec.data import Dataset, ExistedDataConfig

np.random.seed(0)
set_seed(0)

PATCH_DIM = [25, 50, 25]
NUM_SAMPLE = 10000
INPUT_PATH = ""
DATA_CONFIG_PATH = "./data_config.npy"
SAVE_DATA_PATH = "./"


def generate_data(input_path):
"""generate training data and data configuration"""
space_temp = np.load(input_path)

print("data load finish")
print("random cropping...")
space_data = np.ones((NUM_SAMPLE,
PATCH_DIM[0],
PATCH_DIM[1],
PATCH_DIM[2],
space_temp.shape[-1])).astype(np.float32)
rand_pos = np.random.randint(low=0, high=(space_temp.shape[0] - PATCH_DIM[0])*\
(space_temp.shape[1] - PATCH_DIM[1])*\
(space_temp.shape[2] - PATCH_DIM[2]), size=NUM_SAMPLE)
for i, pos in enumerate(rand_pos):
z = pos % (space_temp.shape[2] - PATCH_DIM[2])
y = (pos // (space_temp.shape[2] - PATCH_DIM[2])) % (space_temp.shape[1] - PATCH_DIM[1])
x = (pos // (space_temp.shape[2] - PATCH_DIM[2])) // (space_temp.shape[1] - PATCH_DIM[1])
space_data[i] = space_temp[x : x+PATCH_DIM[0], y : y+PATCH_DIM[1], z : z+PATCH_DIM[2], :]
print("random crop finished")

space_data[:, :, :, :, 2] = np.log10(space_data[:, :, :, :, 2] + 1.0)
data_config = np.ones(4)
for i in range(4):
data_config[i] = np.max(np.abs(space_data[:, :, :, :, i]))
space_data[:, :, :, :, i] = space_data[:, :, :, :, i] / data_config[i]

length = space_data.shape[0] // 10
test_data = space_data[:length]
train_data = space_data[length:]

space_data = space_data.transpose((0, 4, 1, 2, 3))

np.save(DATA_CONFIG_PATH, data_config)
np.save(SAVE_DATA_PATH+"/train_data.npy", train_data)
np.save(SAVE_DATA_PATH+"/test_data.npy", test_data)
print("AE train and test data and data config is saved")


def create_dataset(input_path, label_path, batch_size=8, shuffle=True):
electromagnetic = ExistedDataConfig(name="electromagnetic",
data_dir=[input_path, label_path],
columns_list=["inputs", "label"],
data_format=input_path.split('.')[-1])
dataset = Dataset(existed_data_list=[electromagnetic])
data_loader = dataset.create_dataset(batch_size=batch_size, shuffle=shuffle)

return data_loader

if __name__ == "__main__":
generate_data(INPUT_PATH)

+ 30
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/src/lr_generator.py View File

@@ -0,0 +1,30 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning rate generator"""


def step_lr_generator(step_size, epochs, lr, lr_decay_milestones):
"""generate step decayed learnig rate"""

total_steps = epochs * step_size

milestones = [int(total_steps * i / lr_decay_milestones) for i in range(1, lr_decay_milestones)]
milestones.append(total_steps)
learning_rates = [lr*0.5**i for i in range(0, lr_decay_milestones - 1)]
learning_rates.append(lr*0.5**(lr_decay_milestones - 1))

print("total_steps: %s, milestones: %s, learning_rates: %s " %(total_steps, milestones, learning_rates))

return milestones, learning_rates

+ 49
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/src/metric.py View File

@@ -0,0 +1,49 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""metrics"""

import numpy as np
import mindspore.nn as nn
from mindspore.ops import functional as F

class MyMSELoss(nn.LossBase):
"""mse loss function"""
def construct(self, base, target):
bs, _, _, _, _ = F.shape(target)
x = F.square(base - target)
return 2*bs*self.get_loss(x)


class EvalMetric(nn.Metric):
"""eval metric"""

def __init__(self, length):
super(EvalMetric, self).__init__()
self.clear()
self.length = length

def clear(self):
self.error_sum_l2_error = 0
self.error_mean_l1_error = 0

def update(self, *inputs):
test_predict = self._convert_data(inputs[0])
test_label = self._convert_data(inputs[1])

for i in range(len(test_label)):
self.error_mean_l1_error += np.mean(np.abs(test_label[i] - test_predict[i]))

def eval(self):
return {'mean_l1_error': self.error_mean_l1_error / self.length}

+ 298
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/src/model.py View File

@@ -0,0 +1,298 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""EncoderDecoder for MindElec."""

import mindspore.nn as nn
import mindspore.ops as ops


class EncoderDecoder(nn.Cell):
"""
EncoderDecoder architecture for MindElec.

Args:
input_dim (int): input channel.
target_shape (list or tuple): Output DWH shape.
base_channels (int): base channel, all intermediate layers' channels are multiple of this value.
decoding (bool): Enable Decoder, True if the reconstructed input is need,
for example, the training. Default: False

Returns:
Tensor, output tensor, compressed encodings (encoding=False) or reconstructed input (encoding=True).

Examples:
>>> training encoder: Encoder_Decoder(input_dim=4, target_shape=[25, 50, 25], base_channels=8, decoding=True)
>>> applying encoder(data compression): Encoder_Decoder(input_dim=4,
... target_shape=[25, 50, 25],
... base_channels=8,
... decoding=False)
"""

def __init__(self, input_dim, target_shape, base_channels=8, decoding=False):
super(EncoderDecoder, self).__init__()
self.decoding = decoding
self.encoder = Encoder(input_dim, base_channels)
if self.decoding:
self.decoder = Decoder(input_dim, target_shape, base_channels)

def construct(self, x):
encoding = self.encoder(x)
if self.decoding:
output = self.decoder(encoding)
else:
output = encoding
return output


class Encoder(nn.Cell):
"""
Encoder architecture.

Args:
input_dim (int): input channel.
base_channels (int): base channel, all intermediate layers' channels are multiple of this value.

Returns:
Tensor, output tensor.

Examples:
>>> Encoder(input_dim=4, base_channels=8)
"""

def __init__(self, input_dim, base_channels):
super(Encoder, self).__init__()
print("BASE_CHANNELS: %d" %base_channels)
self.input_dim = input_dim
self.channels = base_channels

self.conv0 = nn.Conv3d(self.input_dim, self.channels*2, kernel_size=3, pad_mode='pad', padding=1)
self.conv0_1 = nn.Conv3d(self.channels*2, self.channels*2, kernel_size=3, pad_mode='pad', padding=1)
self.conv1 = nn.Conv3d(self.channels*2, self.channels*4, kernel_size=3, pad_mode='pad', padding=1)
self.conv1_1 = nn.Conv3d(self.channels*4, self.channels*4, kernel_size=3, pad_mode='pad', padding=1)
self.conv2 = nn.Conv3d(self.channels*4, self.channels*8, kernel_size=3, pad_mode='pad', padding=1)
self.conv2_1 = nn.Conv3d(self.channels*8, self.channels*8, kernel_size=3, pad_mode='pad', padding=1)
self.conv3 = nn.Conv3d(self.channels*8, self.channels*16, kernel_size=(2, 3, 2))
self.conv4 = nn.Conv2d(self.channels*16, self.channels*32, kernel_size=(1, 3), pad_mode='pad', padding=0)

self.bn0 = nn.BatchNorm3d(self.channels*2)
self.bn1 = nn.BatchNorm3d(self.channels*4)
self.bn2 = nn.BatchNorm3d(self.channels*8)
self.bn3 = nn.BatchNorm3d(self.channels*16)
self.bn4 = nn.BatchNorm2d(self.channels*32)

self.down1 = ops.MaxPool3D(kernel_size=(2, 2, 2), strides=(2, 2, 2))
self.down2 = ops.MaxPool3D(kernel_size=(2, 2, 2), strides=(2, 2, 2))
self.down3 = ops.MaxPool3D(kernel_size=(2, 2, 2), strides=(2, 2, 2))
self.down4 = nn.MaxPool2d(kernel_size=(2, 6), stride=(2, 6))

self.down_1_1 = ops.MaxPool3D(kernel_size=(4, 5, 4), strides=(4, 5, 4))
self.down_1 = nn.MaxPool2d(kernel_size=(3, 5*3))

self.down_2_1 = ops.MaxPool3D(kernel_size=(3, 4, 3), strides=(3, 4, 3))
self.down_2 = nn.MaxPool2d(kernel_size=(2, 3*2))

self.down_3 = nn.MaxPool2d(kernel_size=(3, 18))

self.act = nn.Sigmoid()

self.concat = ops.Concat(axis=1)
self.expand_dims = ops.ExpandDims()

def construct(self, x):
"""forward"""
bs = x.shape[0]

x = self.conv0(x)
x = self.conv0_1(x)
x = self.bn0(x)
x = self.act(x)
x = self.down1(x)
x_1 = self.down_1_1(x)
x_1 = self.down_1(x_1.view(bs, x_1.shape[1], x_1.shape[2], -1))

x = self.conv1(x)
x = self.conv1_1(x)
x = self.bn1(x)
x = self.act(x)
x = self.down2(x)
x_2 = self.down_2_1(x)
x_2 = self.down_2(x_2.view(bs, x_2.shape[1], x_2.shape[2], -1))

x = self.conv2(x)
x = self.conv2_1(x)
x = self.bn2(x)
x = self.act(x)
x = self.down3(x)
x_3 = self.down_3(x.view(bs, x.shape[1], x.shape[2], -1))

x = self.act(self.bn3(self.conv3(x)))
x = x.view((bs, x.shape[1], x.shape[2], -1))
x = self.down4(x)

x = self.act(self.bn4(self.conv4(x)))
x = self.concat((x, x_1, x_2, x_3))
x = self.expand_dims(x, 3)

return x


class Decoder(nn.Cell):
"""
Decoder architecture.

Args:
output_dim (int): output channel.
target_shape (list or tuple): Output DWH shape.
base_channels (int): base channel, all intermediate layers' channels are multiple of this value.

Returns:
Tensor, output tensor.

Examples:
>>> Decoder(output_dim=4, target_shape=[25, 50, 25], base_channels=8)
"""

def __init__(self, output_dim, target_shape, base_channels):
super(Decoder, self).__init__()

self.output_dim = output_dim
self.base_channels = base_channels
self.up0 = Up((32 + 8 + 4 + 2) * self.base_channels,
self.base_channels * 32,
[1, 1, 1],
[x // 8 for x in target_shape],
pad=True)
self.up1 = Up(self.base_channels * 32,
self.base_channels * 16,
[x // 8 for x in target_shape],
[x // 4 for x in target_shape],
pad=False)
self.up2 = Up(self.base_channels * 16,
self.base_channels* 4,
[x // 4 for x in target_shape],
[x // 2 for x in target_shape],
pad=False)
self.up3 = Up(self.base_channels * 4,
self.output_dim,
[x // 2 for x in target_shape],
target_shape,
pad=False)

def construct(self, x):
x = self.up0(x)
x = self.up1(x)
x = self.up2(x)
x = self.up3(x)

return x


class DoubleConvTranspose(nn.Cell):
"""
DoubleConvTranspose architecture

Args:
input_dim (int): Input channel.
out_channel (int): Output channel.
mid_channels (int): Mid channels. Default: None.

Returns:
Tensor, output tensor.

Examples:
>>> DoubleConvTranspose(input_dim=4, out_channels=8)
"""

def __init__(self, input_dim, out_channels, mid_channels=None):
super(DoubleConvTranspose, self).__init__()
if not mid_channels:
mid_channels = out_channels
self.conv = nn.Conv3dTranspose(input_dim, out_channels, kernel_size=3)
self.conv1 = nn.Conv3dTranspose(out_channels, out_channels, kernel_size=3)
self.bn = nn.BatchNorm3d(out_channels)
self.relu = nn.ReLU()
self.act = nn.Sigmoid()

def construct(self, x):
x = self.conv(x)
x = self.conv1(x)
x = self.bn(x)
x = self.act(x)
return x


def calculate_difference(input_shape, target_shape):
"""calculate difference of the target shape and the output shape of previous Conv3dTranspose
and the according padding sizes"""

target_shape = target_shape

# calculating output shape of Conv3dTranspose
input_shape = [(dim - 1)*2 - 2*0 + 1*(2 - 1) + 0 + 1 for dim in input_shape]

diff = [x - y for x, y in zip(target_shape, input_shape)]
paddings = [(0, 0), (0, 0)]
for i in range(3):
paddings.append((diff[i], 0))
return tuple(diff), tuple(paddings)


class Up(nn.Cell):
"""
Upscaling then apply Conv3dTranspose to decode

Args:
in_channel (int): Input channel.
out_channel (int): Output channel.
input_shape (list or tuple): Input DWH shape. Default: None.
target_shape (list or tuple): Output DWH shape. Default: None.
pad (bool): Enable manual padding, only needed in the first layer of the Decoder. Default: True.

Returns:
Tensor, output tensor.

Examples:
>>> Up(8,4, input_shape=[10, 10, 10], target_shape=[20, 20, 20], pad=False)

"""

def __init__(self, input_dim, out_channels, input_shape=None, target_shape=None, pad=True):
super(Up, self).__init__()
self.pad = pad

diff, paddings = calculate_difference(input_shape, target_shape)
if self.pad:
self.pad = ops.Pad(paddings)
self.up = nn.Conv3dTranspose(input_dim, input_dim // 2, kernel_size=2, stride=2, pad_mode='pad', padding=0)

else:
self.up = nn.Conv3dTranspose(input_dim,
input_dim // 2,
kernel_size=2,
stride=2,
pad_mode='pad',
padding=0,
output_padding=diff)

self.conv = DoubleConvTranspose(input_dim // 2, out_channels)


def construct(self, x):
x = self.up(x)
if self.pad:
x = self.pad(x)
x = self.conv(x)

return x

+ 137
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/test_data_compression.py View File

@@ -0,0 +1,137 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ae_train"""

import os
import time
import pytest

import mindspore.nn as nn
import mindspore.common.initializer as weight_init
from mindspore import context
from mindspore.common import set_seed
from mindspore.train.callback import LossMonitor, Callback

from mindelec.solver import Solver

from src.dataset import create_dataset
from src.model import EncoderDecoder
from src.lr_generator import step_lr_generator
from src.metric import MyMSELoss, EvalMetric
from src.config import config

train_data_path = "/home/workspace/mindspore_dataset/mindelec_data/ae_data/train_data.npy"
test_data_path = "/home/workspace/mindspore_dataset/mindelec_data/ae_data/test_data.npy"
set_seed(0)

print("pid:", os.getpid())
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")

class TimeMonitor(Callback):
"""
Monitor the time in training.
"""

def __init__(self, data_size=None):
super(TimeMonitor, self).__init__()
self.data_size = data_size
self.epoch_time = time.time()
self.per_step_time = 0
self._tmp = None

def epoch_begin(self, run_context):
"""
Record time at the begin of epoch.
"""
self.epoch_time = time.time()
self._tmp = run_context

def epoch_end(self, run_context):
"""
Print process cost time at the end of epoch.
"""
epoch_seconds = (time.time() - self.epoch_time) * 1000
step_size = self.data_size
cb_params = run_context.original_args()
if hasattr(cb_params, "batch_num"):
batch_num = cb_params.batch_num
if isinstance(batch_num, int) and batch_num > 0:
step_size = cb_params.batch_num

self.per_step_time = epoch_seconds / step_size
print("epoch time: {:5.3f} ms, per step time: {:5.3f} ms".format(epoch_seconds, self.per_step_time), flush=True)

def get_step_time(self,):
return self.per_step_time


def init_weight(net):
for _, cell in net.cells_and_names():
if isinstance(cell, (nn.Conv3d, nn.Dense)):
cell.weight.set_data(weight_init.initializer(weight_init.HeNormal(),
cell.weight.shape,
cell.weight.dtype))

@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_auto_encoder():
"""training"""

model_net = EncoderDecoder(config["input_channels"], config["patch_shape"], config["base_channels"], decoding=True)
init_weight(net=model_net)

train_dataset = create_dataset(input_path=train_data_path,
label_path=train_data_path,
batch_size=config["batch_size"],
shuffle=True)

eval_dataset = create_dataset(input_path=test_data_path,
label_path=test_data_path,
batch_size=config["batch_size"],
shuffle=False)


step_size = train_dataset.get_dataset_size()
milestones, learning_rates = step_lr_generator(step_size,
config["epochs"],
config["lr"],
config["lr_decay_milestones"])

optimizer = nn.Adam(model_net.trainable_params(),
learning_rate=nn.piecewise_constant_lr(milestones, learning_rates))

loss_net = MyMSELoss()
eval_step_size = eval_dataset.get_dataset_size() * config["batch_size"]
evl_error_mrc = EvalMetric(eval_step_size)

solver = Solver(model_net,
train_input_map={'train': ['train_input_data']},
test_input_map={'test': ['test_input_data']},
optimizer=optimizer,
metrics={'evl_mrc': evl_error_mrc,},
amp_level="O2",
loss_fn=loss_net)

time_cb = TimeMonitor()
solver.model.train(20, train_dataset, callbacks=[LossMonitor(), time_cb], dataset_sink_mode=False)
res = solver.model.eval(eval_dataset, dataset_sink_mode=False)
per_step_time = time_cb.get_step_time()
l1_error = res['evl_mrc']['mean_l1_error']
print('test_res:', f'l1_error: {l1_error:.10f} ')
print(f'per step time: {per_step_time:.10f} ')
assert l1_error <= 0.03
assert per_step_time <= 30

+ 124
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_frequency_domain_maxwell/src/callback.py View File

@@ -0,0 +1,124 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""
call back functions
"""
import time

import numpy as np

from mindspore.train.callback import Callback
from mindspore import Tensor
import mindspore.common.dtype as mstype


class PredictCallback(Callback):
"""
Evaluate the model during training.

Args:
model (Cell): A testing network.
predict_ds (Dataset): Dataset to predictuate the model.
predict_interval (int): Specifies how many epochs to train before prediction.
input_data (Array): Input test dataset.
label (Array): Label data.
batch_size (int): batch size for prediction
"""

def __init__(self, model, predict_interval, input_data, label, batch_size=8192):
super(PredictCallback, self).__init__()
self.model = model
self.input_data = input_data
self.label = label
self.predict_interval = predict_interval
self.batch_size = min(batch_size, len(input_data))
self.l2_error = 1.0
print("check test dataset shape: {}, {}".format(self.input_data.shape, self.label.shape))

def epoch_end(self, run_context):
"""
Evaluate the model at the end of epoch.

Args:
run_context (RunContext): Context of the train running.
"""
cb_params = run_context.original_args()
if cb_params.cur_epoch_num % self.predict_interval == 0:
print("================================Start Evaluation================================")

test_input_data = self.input_data.reshape(-1, 2)
label = self.label.reshape(-1, 1)
index = 0
prediction = np.zeros(label.shape)
time_beg = time.time()
while index < len(test_input_data):
index_end = min(index + self.batch_size, len(test_input_data))
test_batch = Tensor(test_input_data[index: index_end, :], dtype=mstype.float32)
predict = self.model(test_batch)
predict = predict.asnumpy()
prediction[index: index_end, :] = predict[:, :]
index = index_end
print("Total prediction time: {} s".format(time.time() - time_beg))
error = label - prediction
l2_error = np.sqrt(np.sum(np.square(error[:, 0]))) / np.sqrt(np.sum(np.square(label[:, 0])))
print("l2_error: ", l2_error)
print("=================================End Evaluation=================================")
self.l2_error = l2_error

def get_l2_error(self):
return self.l2_error


class TimeMonitor(Callback):
"""
Monitor the time in training.

Args:
data_size (int): Iteration steps to run one epoch of the whole dataset.
"""
def __init__(self, data_size=None):
super(TimeMonitor, self).__init__()
self.data_size = data_size
self.epoch_time = time.time()
self.per_step_time = 0

def epoch_begin(self, run_context):
"""
Set begin time at the beginning of epoch.

Args:
run_context (RunContext): Context of the train running.
"""
run_context.original_args()
self.epoch_time = time.time()

def epoch_end(self, run_context):
"""
Print process cost time at the end of epoch.
"""
epoch_seconds = (time.time() - self.epoch_time) * 1000
step_size = self.data_size
cb_params = run_context.original_args()
if hasattr(cb_params, "batch_num"):
batch_num = cb_params.batch_num
if isinstance(batch_num, int) and batch_num > 0:
step_size = cb_params.batch_num

self.per_step_time = epoch_seconds / step_size
print("epoch time: {:5.3f} ms, per step time: {:5.3f} ms".format(epoch_seconds, self.per_step_time), flush=True)

def get_step_time(self):
return self.per_step_time

+ 44
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_frequency_domain_maxwell/src/config.py View File

@@ -0,0 +1,44 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed

# config
rectangle_sampling_config = ed({
'domain': ed({
'random_sampling': False,
'size': [100, 100],
}),
'BC': ed({
'random_sampling': True,
'size': 128,
'with_normal': False,
})
})

# config
helmholtz_2d_config = ed({
"name": "Helmholtz2D",
"columns_list": ["input", "label"],
"epochs": 10,
"batch_size": 128,
"lr": 0.001,
"coord_min": [0.0, 0.0],
"coord_max": [1.0, 1.0],
"axis_size": 101,
"wave_number": 2
})

+ 42
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_frequency_domain_maxwell/src/dataset.py View File

@@ -0,0 +1,42 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
dataset
"""
import numpy as np

# prepare test input and label
def test_data_prepare(config):
"""create test dataset"""
coord_min = config["coord_min"]
coord_max = config["coord_max"]
axis_size = config["axis_size"]
wave_number = config.get("wave_number", 2.0)

# input
axis_x = np.linspace(coord_min[0], coord_max[0], num=axis_size, endpoint=True)
axis_y = np.linspace(coord_min[1], coord_max[1], num=axis_size, endpoint=True)
mesh_x, mesh_y = np.meshgrid(axis_y, axis_x)
input_data = np.hstack((mesh_x.flatten()[:, None], mesh_y.flatten()[:, None])).astype(np.float32)

# label
label = np.zeros((axis_size, axis_size, 1))
for i in range(axis_size):
for j in range(axis_size):
label[i, j, 0] = np.sin(wave_number * axis_x[j])

label = label.reshape(-1, 1).astype(np.float32)

return input_data, label

+ 55
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_frequency_domain_maxwell/src/model.py View File

@@ -0,0 +1,55 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
feedforward neural network
"""

import mindspore.nn as nn
from mindelec.architecture import get_activation, LinearBlock


class FFNN(nn.Cell):
"""
Full-connect networks.

Args:
input_dim (int): the input dimensions.
output_dim (int): the output dimensions.
hidden_layer (int): number of hidden layers.
activation (str or Cell): activation functions.
"""

def __init__(self, input_dim, output_dim, hidden_layer=64, activation="sin"):
super(FFNN, self).__init__()
self.activation = get_activation(activation)
self.fc1 = LinearBlock(input_dim, hidden_layer)
self.fc2 = LinearBlock(hidden_layer, hidden_layer)
self.fc3 = LinearBlock(hidden_layer, hidden_layer)
self.fc4 = LinearBlock(hidden_layer, hidden_layer)
self.fc5 = LinearBlock(hidden_layer, output_dim)

def construct(self, *inputs):
"""fc network"""
x = inputs[0]
out = self.fc1(x)
out = self.activation(out)
out = self.fc2(out)
out = self.activation(out)
out = self.fc3(out)
out = self.activation(out)
out = self.fc4(out)
out = self.activation(out)
out = self.fc5(out)
return out

+ 144
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_frequency_domain_maxwell/test_frequency_domain_maxwell.py View File

@@ -0,0 +1,144 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
train
"""
import os
import pytest
import numpy as np

import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import context, ms_function
from mindspore.common import set_seed
from mindspore.train.callback import LossMonitor
from mindspore.train.loss_scale_manager import DynamicLossScaleManager

from mindelec.solver import Solver, Problem
from mindelec.geometry import Rectangle, create_config_from_edict
from mindelec.common import L2
from mindelec.data import Dataset
from mindelec.operators import SecondOrderGrad as Hessian
from mindelec.loss import Constraints

from src.config import rectangle_sampling_config, helmholtz_2d_config
from src.model import FFNN
from src.dataset import test_data_prepare
from src.callback import PredictCallback, TimeMonitor

set_seed(0)
np.random.seed(0)

print("pid:", os.getpid())
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend")


# define problem
class Helmholtz2D(Problem):
"""2D Helmholtz equation"""
def __init__(self, domain_name, bc_name, net, wavenumber=2):
super(Helmholtz2D, self).__init__()
self.domain_name = domain_name
self.bc_name = bc_name
self.type = "Equation"
self.wave_number = wavenumber
self.grad_xx = Hessian(net, input_idx1=0, input_idx2=0, output_idx=0)
self.grad_yy = Hessian(net, input_idx1=1, input_idx2=1, output_idx=0)
self.reshape = ops.Reshape()

@ms_function
def governing_equation(self, *output, **kwargs):
"""governing equation"""
u = output[0]
x = kwargs[self.domain_name][:, 0]
y = kwargs[self.domain_name][:, 1]
x = self.reshape(x, (-1, 1))
y = self.reshape(y, (-1, 1))

u_xx = self.grad_xx(kwargs[self.domain_name])
u_yy = self.grad_yy(kwargs[self.domain_name])

return u_xx + u_yy + self.wave_number**2 * u

@ms_function
def boundary_condition(self, *output, **kwargs):
"""boundary condition"""
u = output[0]
x = kwargs[self.bc_name][:, 0]
y = kwargs[self.bc_name][:, 1]
x = self.reshape(x, (-1, 1))
y = self.reshape(y, (-1, 1))

test_label = ops.sin(self.wave_number * x)
return 100 * (u - test_label)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_frequency_domain_maxwell():
"""train process"""
net = FFNN(input_dim=2, output_dim=1, hidden_layer=64)

# define geometry
geom_name = "rectangle"
rect_space = Rectangle(geom_name,
coord_min=helmholtz_2d_config["coord_min"],
coord_max=helmholtz_2d_config["coord_max"],
sampling_config=create_config_from_edict(rectangle_sampling_config))
geom_dict = {rect_space: ["domain", "BC"]}

# create dataset for train and test
train_dataset = Dataset(geom_dict)
train_data = train_dataset.create_dataset(batch_size=helmholtz_2d_config.get("batch_size", 128),
shuffle=True, drop_remainder=False)
test_input, test_label = test_data_prepare(helmholtz_2d_config)

# define problem and constraints
train_prob_dict = {geom_name: Helmholtz2D(domain_name=geom_name + "_domain_points",
bc_name=geom_name + "_BC_points",
net=net,
wavenumber=helmholtz_2d_config.get("wavenumber", 2)),
}
train_constraints = Constraints(train_dataset, train_prob_dict)

# optimizer
optim = nn.Adam(net.trainable_params(), learning_rate=helmholtz_2d_config.get("lr", 1e-4))

# solver
solver = Solver(net,
optimizer=optim,
mode="PINNs",
train_constraints=train_constraints,
test_constraints=None,
amp_level="O2",
metrics={'l2': L2(), 'distance': nn.MAE()},
loss_scale_manager=DynamicLossScaleManager()
)

# train
time_cb = TimeMonitor()
loss_cb = PredictCallback(model=net, predict_interval=10, input_data=test_input, label=test_label)
solver.train(epoch=helmholtz_2d_config.get("epochs", 10),
train_dataset=train_data,
callbacks=[time_cb, LossMonitor(), loss_cb])
per_step_time = time_cb.get_step_time()
l2_error = loss_cb.get_l2_error()

print(f'l2 error: {l2_error:.10f}')
print(f'per step time: {per_step_time:.10f}')
assert l2_error <= 0.05
assert per_step_time <= 10.0

+ 31
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/src/config.py View File

@@ -0,0 +1,31 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed

# config
config = ed({
"epochs": 500,
"batch_size": 8,
"lr": 0.0001,
"t_solution": 162,
"x_solution": 50,
"y_solution": 50,
"z_solution": 8,
"save_checkpoint_epochs": 5,
"keep_checkpoint_max": 20
})

+ 35
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/src/dataset.py View File

@@ -0,0 +1,35 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
dataset
"""
import numpy as np
from mindelec.data import Dataset, ExistedDataConfig


def create_dataset(data_path, batch_size=8, shuffle=True, drop_remainder=True, is_train=True):
"""create dataset"""
input_path = data_path + "inputs.npy"
label_path = data_path + "label.npy"
electromagnetic = ExistedDataConfig(name="electromagnetic",
data_dir=[input_path, label_path],
columns_list=["inputs", "label"],
data_format="npy")
dataset = Dataset(existed_data_list=[electromagnetic])
data_loader = dataset.create_dataset(batch_size=batch_size, shuffle=shuffle, drop_remainder=drop_remainder)
scale = None
if not is_train:
scale = np.load(data_path+"scale.npy")
return data_loader, scale

+ 80
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/src/loss.py View File

@@ -0,0 +1,80 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
loss
"""

import numpy as np
import mindspore.nn as nn
from mindspore.ops import functional as F
from src.config import config


class MyMSELoss(nn.LossBase):
"""mse loss function"""
def construct(self, base, target):
bs, _, _, _, _ = F.shape(target)
x = F.square(base - target)
return 2*bs*self.get_loss(x)


class EvaLMetric(nn.Metric):
"""eval metric"""
def __init__(self, length, scale, batch_size):
super(EvaLMetric, self).__init__()
self.clear()
self.length = length
self.batch_size = batch_size
self.t = config.t_solution
self.x = config.x_solution
self.y = config.y_solution
self.z = config.z_solution
self.predict_real = np.zeros((self.length*self.t, self.x, self.y, self.z, 6), dtype=np.float32)
self.label_real = np.zeros((self.length*self.t, self.x, self.y, self.z, 6), dtype=np.float32)
self.scale = scale
self.iter_idx = 0

def clear(self):
"""clear"""
self.iter_idx = 0

def update(self, *inputs):
"""update"""
y_pred = self._convert_data(inputs[0])
y = self._convert_data(inputs[1])

predict, label = y_pred, y
self.predict_real[self.iter_idx*self.batch_size: self.iter_idx*self.batch_size + label.shape[0]] = predict
self.label_real[self.iter_idx*self.batch_size: self.iter_idx*self.batch_size + label.shape[0]] = label
self.iter_idx += 1

def eval(self):
"""eval"""
predict_real = np.reshape(self.predict_real, (self.length, self.t, self.x, self.y, self.z, 6))
label_real = np.reshape(self.label_real, (self.length, self.t, self.x, self.y, self.z, 6))
l2_time = 0.0
for i in range(self.length):
predict_real_temp = predict_real[i:i+1]
label_real_temp = label_real[i:i+1]
for j in range(self.t):
predict_real_temp[0, j, :, :, :, 0] = predict_real_temp[0, j, :, :, :, 0] * self.scale[0][j]
predict_real_temp[0, j, :, :, :, 1] = predict_real_temp[0, j, :, :, :, 1] * self.scale[1][j]
predict_real_temp[0, j, :, :, :, 2] = predict_real_temp[0, j, :, :, :, 2] * self.scale[2][j]
predict_real_temp[0, j, :, :, :, 3] = predict_real_temp[0, j, :, :, :, 3] * self.scale[3][j]
predict_real_temp[0, j, :, :, :, 4] = predict_real_temp[0, j, :, :, :, 4] * self.scale[4][j]
predict_real_temp[0, j, :, :, :, 5] = predict_real_temp[0, j, :, :, :, 5] * self.scale[5][j]
l2_time += (np.sqrt(np.sum(np.square(label_real_temp - predict_real_temp))) /
np.sqrt(np.sum(np.square(label_real_temp))))
return {'l2_error': l2_time / self.length}

+ 176
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/src/maxwell_model.py View File

@@ -0,0 +1,176 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
maxwell 3d model
"""

import mindspore.nn as nn
from mindspore.ops import operations as P
import mindspore.ops.functional as F
from mindelec.architecture import get_activation


class Maxwell3D(nn.Cell):
"""maxwell3d"""
def __init__(self, output_dim):
super(Maxwell3D, self).__init__()

self.output_dim = output_dim
width = 64
self.net0 = ModelHead(4, width)
self.net1 = ModelHead(4, width)
self.net2 = ModelHead(4, width)
self.net3 = ModelHead(4, width)
self.net4 = ModelHead(4, width)

self.fc0 = nn.Dense(width+33, 128)
self.net = ModelOut(128, output_dim, (2, 2, 1), (2, 2, 1))
self.cat = P.Concat(axis=-1)

def construct(self, x):
"""forward"""
x_location = x[..., :4]
x_media = x[..., 4:]
out1 = self.net0(x_location)
out2 = self.net1(2*x_location)
out3 = self.net2(4*x_location)
out4 = self.net3(8*x_location)
out5 = self.net4(16.0*x_location)
out = out1 + out2 + out3 + out4 + out5
out = self.cat((out, x_media))
out = self.fc0(out)
out = self.net(out)
return out


class ModelHead(nn.Cell):
"""model_head"""
def __init__(self, input_dim, output_dim):
super(ModelHead, self).__init__()
self.output_dim = output_dim
self.fc0 = nn.Dense(input_dim, output_dim)
self.fc1 = nn.Dense(output_dim, output_dim)
self.act0 = get_activation('srelu')
self.act1 = get_activation('srelu')

def construct(self, x):
"""forward"""
x = self.fc0(x)
x = self.act0(x)
x = self.fc1(x)
x = self.act1(x)

return x


class ModelOut(nn.Cell):
"""model out"""
def __init__(self, input_dim, output_dim, kernel_size=2, strides=2):
super(ModelOut, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.base_channels = 64
self.inc = DoubleConv(self.input_dim, self.base_channels)
self.down1 = Down(self.base_channels, self.base_channels * 2, kernel_size, strides)
self.down2 = Down(self.base_channels * 2, self.base_channels * 4, kernel_size, strides)
self.down3 = Down(self.base_channels * 4, self.base_channels * 8, kernel_size, strides)
self.down4 = Down(self.base_channels * 8, self.base_channels * 16, kernel_size, strides)
self.up1 = Up(self.base_channels * 16, self.base_channels * 8, kernel_size, strides)
self.up2 = Up(self.base_channels * 8, self.base_channels * 4, kernel_size, strides)
self.up3 = Up(self.base_channels * 4, self.base_channels * 2, kernel_size, strides)
self.up4 = Up(self.base_channels * 2, self.base_channels, kernel_size, strides)

self.fc1 = nn.Dense(self.base_channels+128, 64)
self.fc2 = nn.Dense(64, output_dim)
self.relu = nn.ReLU()
self.transpose = P.Transpose()
self.cat = P.Concat(axis=1)

def construct(self, x):
"""forward"""
x0 = self.transpose(x, (0, 4, 1, 2, 3))
x1 = self.inc(x0)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.cat((x, x0))
x = self.transpose(x, (0, 2, 3, 4, 1))
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)

return x


class DoubleConv(nn.Cell):
"""double conv"""
def __init__(self, input_dim, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.SequentialCell(
nn.Conv3d(input_dim, mid_channels, kernel_size=3),
nn.BatchNorm3d(mid_channels),
nn.ReLU(),
nn.Conv3d(mid_channels, out_channels, kernel_size=3),
nn.BatchNorm3d(out_channels),
nn.ReLU()
)

def construct(self, x):
"""forward"""
return self.double_conv(x)


class Down(nn.Cell):
"""down"""
def __init__(self, input_dim, out_channels, kernel_size=2, strides=2):
super().__init__()
self.conv = DoubleConv(input_dim, out_channels)
self.maxpool = P.MaxPool3D(kernel_size=kernel_size, strides=strides)

def construct(self, x):
"""forward"""
x = self.maxpool(x)
return self.conv(x)


class Up(nn.Cell):
"""up"""
def __init__(self, input_dim, out_channels, kernel_size=2, strides=2):
super().__init__()
self.up = nn.Conv3dTranspose(input_dim, input_dim // 2, kernel_size=kernel_size, stride=strides)
self.conv = DoubleConv(input_dim, out_channels)
self.cat = P.Concat(axis=1)

def construct(self, x1, x2):
"""forward"""
x1 = self.up(x1)

_, _, h1, w1, c1 = F.shape(x1)
_, _, h2, w2, c2 = F.shape(x2)
diff_z = c2 - c1
diff_y = w2 - w1
diff_x = h2 - h1

x1 = P.Pad(((0, 0), (0, 0), (diff_x // 2, diff_x - diff_x // 2), (diff_y // 2, diff_y - diff_y // 2),
(diff_z // 2, diff_z - diff_z // 2)))(x1)
x = self.cat((x2, x1))
return self.conv(x)

+ 37
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/src/sample.py View File

@@ -0,0 +1,37 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
sample fake data for train and test
"""
import os
import numpy as np


def generate_data(train_path, test_path):
"""generate fake data"""
if not os.path.exists(train_path):
os.makedirs(train_path)
if not os.path.exists(test_path):
os.makedirs(test_path)

inputs = np.ones((162, 50, 50, 8, 37), dtype=np.float32)
label = np.ones((162, 50, 50, 8, 6), dtype=np.float32)
np.save(train_path + "inputs.npy", inputs)
np.save(train_path + "label.npy", label)

scale = np.ones((6, 162), dtype=np.float32)
np.save(test_path + "inputs.npy", inputs)
np.save(test_path + "label.npy", label)
np.save(test_path + "scale.npy", scale)

+ 152
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/test_full_em.py View File

@@ -0,0 +1,152 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
train
"""
import os
import time
import numpy as np
import pytest

import mindspore.nn as nn
import mindspore.common.initializer as weight_init
from mindspore.common import set_seed
from mindspore import Tensor
from mindspore import context
from mindspore.train.callback import Callback, LossMonitor
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from mindelec.solver import Solver
from src.dataset import create_dataset
from src.loss import MyMSELoss, EvaLMetric
from src.maxwell_model import Maxwell3D
from src.config import config
from src.sample import generate_data

set_seed(0)
np.random.seed(0)

train_data_path = "./train_data_em/"
test_data_path = "./test_data_em/"

print("pid:", os.getpid())
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")


class TimeMonitor(Callback):
"""
Monitor the time in training.
"""

def __init__(self, data_size=None):
super(TimeMonitor, self).__init__()
self.data_size = data_size
self.epoch_time = time.time()
self.per_step_time = 0

def epoch_begin(self, run_context):
"""
Record time at the begin of epoch.
"""
run_context.original_args()
self.epoch_time = time.time()

def epoch_end(self, run_context):
"""
Print process cost time at the end of epoch.
"""
epoch_seconds = (time.time() - self.epoch_time) * 1000
step_size = self.data_size
cb_params = run_context.original_args()
if hasattr(cb_params, "batch_num"):
batch_num = cb_params.batch_num
if isinstance(batch_num, int) and batch_num > 0:
step_size = cb_params.batch_num

self.per_step_time = epoch_seconds / step_size
print("epoch time: {:5.3f} ms, per step time: {:5.3f} ms".format(epoch_seconds, self.per_step_time), flush=True)

def get_step_time(self,):
return self.per_step_time


def get_lr(lr_init, steps_per_epoch, total_epochs):
"""get lr"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
for i in range(total_steps):
epoch = i // steps_per_epoch
lr_local = lr_init
if epoch <= 15:
lr_local = lr_init
elif epoch <= 45:
lr_local = lr_init * 0.5
elif epoch <= 300:
lr_local = lr_init * 0.25
elif epoch <= 600:
lr_local = lr_init * 0.125
lr_each_step.append(lr_local)
learning_rate = np.array(lr_each_step).astype(np.float32)
print(learning_rate)
return learning_rate


def init_weight(net):
"""init weight"""
for _, cell in net.cells_and_names():
if isinstance(cell, (nn.Conv3d, nn.Dense)):
cell.weight.set_data(weight_init.initializer(weight_init.HeNormal(),
cell.weight.shape,
cell.weight.dtype))


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_full_em():
"""train"""
generate_data(train_data_path, test_data_path)
train_dataset, _ = create_dataset(train_data_path, batch_size=config.batch_size, shuffle=True)
test_dataset, config_scale = create_dataset(test_data_path, batch_size=config.batch_size,
shuffle=False, drop_remainder=False, is_train=False)
model_net = Maxwell3D(6)
init_weight(net=model_net)
train_step_size = train_dataset.get_dataset_size()
lr = get_lr(config.lr, train_step_size, config.epochs)
optimizer = nn.Adam(model_net.trainable_params(), learning_rate=Tensor(lr))
loss_net = MyMSELoss()
loss_scale = DynamicLossScaleManager()

tets_step_size = test_dataset.get_dataset_size()
test_batch_size = test_dataset.get_batch_size()
data_length = tets_step_size * test_batch_size // config.t_solution
evl_error_mrc = EvaLMetric(data_length, config_scale, test_batch_size)
solver = Solver(model_net,
optimizer=optimizer,
loss_scale_manager=loss_scale,
amp_level="O2",
keep_batchnorm_fp32=False,
loss_fn=loss_net,
metrics={"evl_mrc": evl_error_mrc})

time_cb = TimeMonitor()
solver.model.train(5, train_dataset, callbacks=[LossMonitor(), time_cb], dataset_sink_mode=False)
res = solver.model.eval(test_dataset, dataset_sink_mode=False)
per_step_time = time_cb.get_step_time()
l2_s11 = res['evl_mrc']['l2_error']
print('test_res:', f'l2_error: {l2_s11:.10f} ')
print(f'per step time: {per_step_time:.10f} ')
assert l2_s11 <= 0.05
assert per_step_time <= 150

+ 44
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/pretrain.json View File

@@ -0,0 +1,44 @@
{
"Description" : [ "PINNs for solve Maxwell's equations" ],

"Case" : "2D_Mur_Src_Gauss_Mscale_MTL_PIAD",
"coord_min" : [0, 0],
"coord_max" : [1, 1],
"src_pos" : [0.4975, 0.4975],
"SrcFrq": 1e+9,
"range_t" : 4e-9,
"input_center": [0.5, 0.5, 2.0e-9],
"input_scale": [2.0, 2.0, 5.0e+8],
"output_scale": [37.67303, 37.67303, 0.1],
"src_radius": 0.01,
"input_size" : 3,
"output_size" : 3,
"residual" : true,
"num_scales" : 4,
"layers" : 7,
"neurons" : 64,
"amp_factor" : 2,
"scale_factor" : 2,
"save_ckpt" : true,
"load_ckpt" : false,
"save_ckpt_path" : "./ckpt",
"load_ckpt_path" : "",
"train_with_eval": true,
"test_data_path" : "./benchmark/",
"lr" : 0.001,
"milestones" : [2000],
"lr_gamma" : 0.25,
"train_epoch" : 50,
"train_batch_size" : 1024,
"test_batch_size" : 8192,
"predict_interval" : 500,
"vision_path" : "./vision",
"summary_path" : "./summary",

"EPS_candidates": [1, 3, 5],
"MU_candidates": [1, 3, 5],
"num_scenarios": 9,
"latent_vector_size": 16,
"latent_reg": 1.0,
"latent_init_std": 1.0
}

+ 28
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/src/__init__.py View File

@@ -0,0 +1,28 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
init
"""
from .dataset import get_test_data, create_random_dataset
from .maxwell import Maxwell2DMur
from .lr_scheduler import MultiStepLR


__all__ = [
"create_random_dataset",
"get_test_data",
"Maxwell2DMur",
"MultiStepLR",
]

+ 57
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/src/dataset.py View File

@@ -0,0 +1,57 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
create dataset
"""
import numpy as np
from mindelec.data import Dataset
from mindelec.geometry import Disk, Rectangle, TimeDomain, GeometryWithTime
from mindelec.geometry import create_config_from_edict

from .sampling_config import no_src_sampling_config, src_sampling_config, bc_sampling_config

def get_test_data(test_data_path):
"""load label_dataed data for evaluation"""
# check data
paths = [test_data_path + '/input.npy', test_data_path + '/output.npy']
input_data = np.load(paths[0])
label_data = np.load(paths[1])
return input_data, label_data

def create_random_dataset(config):
"""create training dataset by online sampling"""
radius = config["src_radius"]
origin = config["src_pos"]

disk = Disk("src", origin, radius)
rect = Rectangle("rect", config["coord_min"], config["coord_max"])
diff = rect - disk
interval = TimeDomain("time", 0.0, config["range_t"])
no_src = GeometryWithTime(diff, interval)
no_src.set_name("no_src")
no_src.set_sampling_config(create_config_from_edict(no_src_sampling_config))
src = GeometryWithTime(disk, interval)
src.set_name("src")
src.set_sampling_config(create_config_from_edict(src_sampling_config))
bc = GeometryWithTime(rect, interval)
bc.set_name("bc")
bc.set_sampling_config(create_config_from_edict(bc_sampling_config))

geom_dict = {src: ["domain", "IC"],
no_src: ["domain", "IC"],
bc: ["BC"]}

dataset = Dataset(geom_dict)
return dataset

+ 73
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/src/lr_scheduler.py View File

@@ -0,0 +1,73 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""Learning rate scheduler."""
from collections import Counter
import numpy as np

class _LRScheduler:
"""
Basic class for learning rate scheduler
"""

def __init__(self, lr, max_epoch, steps_per_epoch):
self.base_lr = lr
self.steps_per_epoch = steps_per_epoch
self.total_steps = int(max_epoch * steps_per_epoch)

def get_lr(self):
# Compute learning rate using chainable form of the scheduler
raise NotImplementedError


class MultiStepLR(_LRScheduler):
"""
Multi-step learning rate scheduler

Decays the learning rate by gamma once the number of epoch reaches one of the milestones.

Args:
lr (float): Initial learning rate which is the lower boundary in the cycle.
milestones (list): List of epoch indices. Must be increasing.
gamma (float): Multiplicative factor of learning rate decay.
steps_per_epoch (int): The number of steps per epoch to train for.
max_epoch (int): The number of epochs to train for.

Outputs:
numpy.ndarray, shape=(1, steps_per_epoch*max_epoch)

Example:
>>> # Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.05 if epoch < 30
>>> # lr = 0.005 if 30 <= epoch < 80
>>> # lr = 0.0005 if epoch >= 80
>>> scheduler = MultiStepLR(lr=0.1, milestones=[30,80], gamma=0.1, steps_per_epoch=5000, max_epoch=90)
>>> lr = scheduler.get_lr()
"""

def __init__(self, lr, milestones, gamma, steps_per_epoch, max_epoch):
self.milestones = Counter(milestones)
self.gamma = gamma
super(MultiStepLR, self).__init__(lr, max_epoch, steps_per_epoch)

def get_lr(self):
lr_each_step = []
current_lr = self.base_lr
for i in range(self.total_steps):
cur_ep = i // self.steps_per_epoch
if i % self.steps_per_epoch == 0 and cur_ep in self.milestones:
current_lr = current_lr * self.gamma
lr = current_lr
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)

+ 184
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/src/maxwell.py View File

@@ -0,0 +1,184 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
#pylint: disable=W0613
"""
2D maxwell problem with Mur bc
"""
import numpy as np

import mindspore.numpy as ms_np
from mindspore import ms_function
from mindspore import ops
from mindspore import Tensor
import mindspore.common.dtype as ms_type

from mindelec.solver import Problem
from mindelec.common import MU, EPS, PI
from mindelec.operators import Grad


class Maxwell2DMur(Problem):
r"""
The 2D Maxwell's equations with 2nd-order Mur absorbed boundary condition.

Args:
network (Cell): The solving network.
config (dict): Setting information.
domain_column (str): The corresponding column name of data which governed by maxwell's equation.
bc_column (str): The corresponding column name of data which governed by boundary condition.
bc_normal (str): The column name of normal direction vector corresponding to specified boundary.
ic_column (str): The corresponding column name of data which governed by initial condition.
"""
def __init__(self, network, config, domain_column=None, bc_column=None, ic_column=None):
super(Maxwell2DMur, self).__init__()
self.domain_column = domain_column
self.bc_column = bc_column
self.ic_column = ic_column
self.network = network

# operations
self.gradient = Grad(self.network)
self.reshape = ops.Reshape()
self.cast = ops.Cast()
self.mul = ops.Mul()
self.cast = ops.Cast()
self.split = ops.Split(1, 3)
self.concat = ops.Concat(1)
self.sqrt = ops.Sqrt()

# gauss-type pulse source
self.pi = Tensor(PI, ms_type.float32)
self.src_frq = config.get("src_frq", 1e+9)
self.tau = Tensor((2.3 ** 0.5) / (PI * self.src_frq), ms_type.float32)
self.amp = Tensor(1.0, ms_type.float32)
self.t0 = Tensor(3.65 * self.tau, ms_type.float32)

# src space
self.src_x0 = Tensor(config["src_pos"][0], ms_type.float32)
self.src_y0 = Tensor(config["src_pos"][1], ms_type.float32)
self.src_sigma = Tensor(config["src_radius"] / 4.0, ms_type.float32)
self.src_coord_min = config["coord_min"]
self.src_coord_max = config["coord_max"]

input_scales = config.get("input_scales", [1.0, 1.0, 2.5e+8])
output_scales = config.get("output_scales", [37.67303, 37.67303, 0.1])
self.s_x = Tensor(input_scales[0], ms_type.float32)
self.s_y = Tensor(input_scales[1], ms_type.float32)
self.s_t = Tensor(input_scales[2], ms_type.float32)
self.s_ex = Tensor(output_scales[0], ms_type.float32)
self.s_ey = Tensor(output_scales[1], ms_type.float32)
self.s_hz = Tensor(output_scales[2], ms_type.float32)

# set up eps, mu candidates
eps_candidates = np.array(config["eps_list"], dtype=np.float32) * EPS
mu_candidates = np.array(config["mu_list"], dtype=np.float32) * MU
self.epsilon_x = Tensor(eps_candidates, ms_type.float32).view((-1, 1))
self.epsilon_y = Tensor(eps_candidates, ms_type.float32).view((-1, 1))
self.mu_z = Tensor(mu_candidates, ms_type.float32).view((-1, 1))
self.light_speed = 1.0 / ops.Sqrt()(ops.Mul()(self.epsilon_x, self.mu_z))

def smooth_src(self, x, y, t):
source = self.amp * ops.exp(- ((t - self.t0) / self.tau)**2)
gauss = 1 / (2 * self.pi * self.src_sigma**2) * \
ops.exp(- ((x - self.src_x0)**2 + (y - self.src_y0)**2) / (2 * (self.src_sigma**2)))
return self.mul(source, gauss)

@ms_function
def governing_equation(self, *output, **kwargs):
"""maxwell equation of TE mode wave"""
out = output[0]
data = kwargs[self.domain_column]
x = self.reshape(data[:, 0], (-1, 1))
y = self.reshape(data[:, 1], (-1, 1))
t = self.reshape(data[:, 2], (-1, 1))

dex_dxyt = self.gradient(data, None, 0, out)
_, dex_dy, dex_dt = self.split(dex_dxyt)
dey_dxyt = self.gradient(data, None, 1, out)
dey_dx, _, dey_dt = self.split(dey_dxyt)
dhz_dxyt = self.gradient(data, None, 2, out)
dhz_dx, dhz_dy, dhz_dt = self.split(dhz_dxyt)

dex_dy = self.cast(dex_dy, ms_type.float32)
dex_dt = self.cast(dex_dt, ms_type.float32)
dey_dx = self.cast(dey_dx, ms_type.float32)
dey_dt = self.cast(dey_dt, ms_type.float32)
dhz_dx = self.cast(dhz_dx, ms_type.float32)
dhz_dy = self.cast(dhz_dy, ms_type.float32)
dhz_dt = self.cast(dhz_dt, ms_type.float32)

loss_a1 = (self.s_hz * dhz_dy) / (self.s_ex * self.s_t * self.epsilon_x)
loss_a2 = dex_dt / self.s_t

loss_b1 = -(self.s_hz * dhz_dx) / (self.s_ey * self.s_t * self.epsilon_y)
loss_b2 = dey_dt / self.s_t

loss_c1 = (self.s_ey * dey_dx - self.s_ex * dex_dy) / (self.s_hz * self.s_t * self.mu_z)
loss_c2 = - dhz_dt / self.s_t

source = self.smooth_src(x, y, t) / (self.s_hz * self.s_t * self.mu_z)

pde_res1 = loss_a1 - loss_a2
pde_res2 = loss_b1 - loss_b2
pde_res3 = loss_c1 - loss_c2 - source
pde_r = ops.Concat(1)((pde_res1, pde_res2, pde_res3))
return pde_r

@ms_function
def boundary_condition(self, *output, **kwargs):
"""2nd-order mur boundary condition"""
u = output[0]
data = kwargs[self.bc_column]

coord_min = self.src_coord_min
coord_max = self.src_coord_max
batch_size, _ = data.shape
bc_attr = ms_np.zeros(shape=(batch_size, 4))
bc_attr[:, 0] = ms_np.where(ms_np.isclose(data[:, 0], coord_min[0]), 1.0, 0.0)
bc_attr[:, 1] = ms_np.where(ms_np.isclose(data[:, 0], coord_max[0]), 1.0, 0.0)
bc_attr[:, 2] = ms_np.where(ms_np.isclose(data[:, 1], coord_min[1]), 1.0, 0.0)
bc_attr[:, 3] = ms_np.where(ms_np.isclose(data[:, 1], coord_max[1]), 1.0, 0.0)

dex_dxyt = self.gradient(data, None, 0, u)
_, dex_dy, _ = self.split(dex_dxyt)
dey_dxyt = self.gradient(data, None, 1, u)
dey_dx, _, _ = self.split(dey_dxyt)
dhz_dxyt = self.gradient(data, None, 2, u)
dhz_dx, dhz_dy, dhz_dt = self.split(dhz_dxyt)

dex_dy = self.cast(dex_dy, ms_type.float32)
dey_dx = self.cast(dey_dx, ms_type.float32)
dhz_dx = self.cast(dhz_dx, ms_type.float32)
dhz_dy = self.cast(dhz_dy, ms_type.float32)
dhz_dt = self.cast(dhz_dt, ms_type.float32)

bc_r1 = dhz_dx / self.s_x - dhz_dt / (self.light_speed * self.s_x) + \
self.s_ex * self.light_speed * self.epsilon_x / (2 * self.s_hz * self.s_x) * dex_dy # x=0
bc_r2 = dhz_dx / self.s_x + dhz_dt / (self.light_speed * self.s_x) - \
self.s_ex * self.light_speed * self.epsilon_x / (2 * self.s_hz * self.s_x) * dex_dy # x=L
bc_r3 = dhz_dy / self.s_y - dhz_dt / (self.light_speed * self.s_y) - \
self.s_ey * self.light_speed * self.epsilon_y / (2 * self.s_hz * self.s_y) * dey_dx # y=0
bc_r4 = dhz_dy / self.s_y + dhz_dt / (self.light_speed * self.s_y) + \
self.s_ey * self.light_speed * self.epsilon_y / (2 * self.s_hz * self.s_y) * dey_dx # y=L

bc_r_all = self.concat((bc_r1, bc_r2, bc_r3, bc_r4))
bc_r = self.mul(bc_r_all, bc_attr)
return bc_r

@ms_function
def initial_condition(self, *output, **kwargs):
"""initial condition: u = 0"""
net_out = output[0]
return net_out

+ 52
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/src/sampling_config.py View File

@@ -0,0 +1,52 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""sampling information"""
import copy
from easydict import EasyDict as edict


src_sampling_config = edict({
'domain': edict({
'random_sampling': True,
'size': 262144,
'sampler': 'uniform'
}),
'IC': edict({
'random_sampling': True,
'size': 262144,
'sampler': 'uniform',
}),
'time': edict({
'random_sampling': True,
'size': 262144,
'sampler': 'uniform',
}),
})

no_src_sampling_config = copy.deepcopy(src_sampling_config)

bc_sampling_config = edict({
'BC': edict({
'random_sampling': True,
'size': 262144,
'sampler': 'uniform',
'with_normal': False
}),
'time': edict({
'random_sampling': True,
'size': 262144,
'sampler': 'uniform',
}),
})

+ 192
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/test_incremental_learning.py View File

@@ -0,0 +1,192 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""reconstruct process."""
import os
import json
import math
import pytest
import numpy as np

from mindspore.common import set_seed
from mindspore import context, Tensor, nn, Parameter
from mindspore.train import DynamicLossScaleManager
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.common.dtype as ms_type
from mindspore.common.initializer import HeUniform


from mindelec.loss import Constraints
from mindelec.solver import Solver, LossAndTimeMonitor
from mindelec.common import L2
from mindelec.architecture import MultiScaleFCCell, MTLWeightedLossCell

from src.dataset import create_random_dataset
from src.lr_scheduler import MultiStepLR
from src.maxwell import Maxwell2DMur

set_seed(123456)
np.random.seed(123456)

context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend", save_graphs_path="./solver")


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_incremental_learning():
"""pretraining process"""
print("pid:", os.getpid())
mode = "pretrain"
config = json.load(open("./pretrain.json"))
preprocess_config(config)
elec_train_dataset = create_random_dataset(config)
train_dataset = elec_train_dataset.create_dataset(batch_size=config["batch_size"],
shuffle=True,
prebatched_data=True,
drop_remainder=True)
epoch_steps = len(elec_train_dataset)
print("check train dataset size: ", len(elec_train_dataset))

# load ckpt
if config.get("load_ckpt", False):
param_dict = load_checkpoint(config["load_ckpt_path"])
if mode == "pretrain":
loaded_ckpt_dict = param_dict
else:
loaded_ckpt_dict = {}
latent_vector_ckpt = 0
for name in param_dict:
if name == "model.latent_vector":
latent_vector_ckpt = param_dict[name].data.asnumpy()
elif "network" in name and "moment" not in name:
loaded_ckpt_dict[name] = param_dict[name]

# initialize latent vector
num_scenarios = config["num_scenarios"]
latent_size = config["latent_vector_size"]
if mode == "pretrain":
latent_init = np.random.randn(num_scenarios, latent_size) / np.sqrt(latent_size)
else:
latent_norm = np.mean(np.linalg.norm(latent_vector_ckpt, axis=1))
print("check mean latent vector norm: ", latent_norm)
latent_init = np.zeros((num_scenarios, latent_size))
latent_vector = Parameter(Tensor(latent_init, ms_type.float32), requires_grad=True)

network = MultiScaleFCCell(config["input_size"],
config["output_size"],
layers=config["layers"],
neurons=config["neurons"],
residual=config["residual"],
weight_init=HeUniform(negative_slope=math.sqrt(5)),
act="sin",
num_scales=config["num_scales"],
amp_factor=config["amp_factor"],
scale_factor=config["scale_factor"],
input_scale=config["input_scale"],
input_center=config["input_center"],
latent_vector=latent_vector
)

network = network.to_float(ms_type.float16)
network.input_scale.to_float(ms_type.float32)

if config.get("enable_mtl", True):
mtl_cell = MTLWeightedLossCell(num_losses=elec_train_dataset.num_dataset)
else:
mtl_cell = None

# define problem
train_prob = {}
for dataset in elec_train_dataset.all_datasets:
train_prob[dataset.name] = Maxwell2DMur(network=network, config=config,
domain_column=dataset.name + "_points",
ic_column=dataset.name + "_points",
bc_column=dataset.name + "_points")
print("check problem: ", train_prob)
train_constraints = Constraints(elec_train_dataset, train_prob)

# optimizer
if mode == "pretrain":
params = network.trainable_params() + mtl_cell.trainable_params()
if config.get("load_ckpt", False):
load_param_into_net(network, loaded_ckpt_dict)
load_param_into_net(mtl_cell, loaded_ckpt_dict)
else:
if config.get("finetune_model"):
model_params = network.trainable_params()
else:
model_params = [param for param in network.trainable_params()
if ("bias" not in param.name and "weight" not in param.name)]
params = model_params + mtl_cell.trainable_params() if mtl_cell else model_params
load_param_into_net(network, loaded_ckpt_dict)

lr_scheduler = MultiStepLR(config["lr"], config["milestones"], config["lr_gamma"],
epoch_steps, config["train_epoch"])
optimizer = nn.Adam(params, learning_rate=Tensor(lr_scheduler.get_lr()))

# problem solver
solver = Solver(network,
optimizer=optimizer,
mode="PINNs",
train_constraints=train_constraints,
test_constraints=None,
metrics={'l2': L2(), 'distance': nn.MAE()},
loss_fn='smooth_l1_loss',
loss_scale_manager=DynamicLossScaleManager(),
mtl_weighted_cell=mtl_cell,
latent_vector=latent_vector,
latent_reg=config["latent_reg"]
)

loss_time_callback = LossAndTimeMonitor(epoch_steps)
callbacks = [loss_time_callback]
if config["save_ckpt"]:
config_ck = CheckpointConfig(save_checkpoint_steps=10, keep_checkpoint_max=2)
prefix = 'pretrain_maxwell_frq1e9' if mode == "pretrain" else 'reconstruct_maxwell_frq1e9'
ckpoint_cb = ModelCheckpoint(prefix=prefix, directory=config["save_ckpt_path"], config=config_ck)
callbacks += [ckpoint_cb]

solver.train(config["train_epoch"], train_dataset, callbacks=callbacks, dataset_sink_mode=True)
assert loss_time_callback.get_loss() <= 2.0
assert loss_time_callback.get_step_time() <= 115.0


def preprocess_config(config):
"""preprocess to get the coefficients of electromagnetic field for each scenario"""
eps_candidates = config["EPS_candidates"]
mu_candidates = config["MU_candidates"]
config["num_scenarios"] = len(eps_candidates) * len(mu_candidates)
batch_size_single_scenario = config["train_batch_size"]
config["batch_size"] = batch_size_single_scenario * config["num_scenarios"]
eps_list = []
for eps in eps_candidates:
eps_list.extend([eps] * (batch_size_single_scenario * len(mu_candidates)))
mu_list = []
for mu in mu_candidates:
mu_list.extend([mu] * batch_size_single_scenario)
mu_list = mu_list * (len(eps_candidates))

exp_name = "_" + config["Case"] + '_num_scenarios_' + str(config["num_scenarios"]) \
+ "_latent_reg_" + str(config["latent_reg"])
if config["save_ckpt"]:
config["save_ckpt_path"] += exp_name

config["vision_path"] += exp_name
config["summary_path"] += exp_name
print("check config: {}".format(config))
config["eps_list"] = eps_list
config["mu_list"] = mu_list

BIN
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/dataset/Butterfly_antenna/data_input.npy View File


BIN
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/dataset/Butterfly_antenna/data_label.npy View File


BIN
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/dataset/Phone/data_input.npy View File


BIN
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/dataset/Phone/data_label.npy View File


+ 130
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/src/dataset.py View File

@@ -0,0 +1,130 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
dataset
"""
import os
import shutil
import numpy as np
from mindelec.data import Dataset, ExistedDataConfig


def custom_normalize(data):
"""
get normalize data
"""
print("Custom normalization is called")
ori_shape = data.shape
data = data.reshape(ori_shape[0], -1)
data = np.transpose(data)
mean = np.mean(data, axis=1)
data = data - mean[:, None]
std = np.std(data, axis=1)
std += (np.abs(std) < 0.0000001)
data = data / std[:, None]
data = np.transpose(data)
data = data.reshape(ori_shape)
return data


def create_dataset(opt):
"""
load data
"""
data_input_path = opt.input_path
data_label_path = opt.label_path

data_input = np.load(data_input_path)
data_label = np.load(data_label_path)

frequency = data_label[0, :, 0]
data_label = data_label[:, :, 1]

print(data_input.shape)
print(data_label.shape)
print("data load finish")

data_input = custom_normalize(data_input)

config_data_prepare = {}

config_data_prepare["scale_input"] = 0.5 * np.max(np.abs(data_input), axis=0)
config_data_prepare["scale_S11"] = 0.5 * np.max(np.abs(data_label))

data_input[:, :] = data_input[:, :] / config_data_prepare["scale_input"]
data_label[:, :] = data_label[:, :] / config_data_prepare["scale_S11"]

permutation = np.random.permutation(data_input.shape[0])
data_input = data_input[permutation]
data_label = data_label[permutation]

length = data_input.shape[0] // 10
train_input, train_label = data_input[length:], data_label[length:]
eval_input, eval_label = data_input[:length], data_label[:length]

print(np.shape(train_input))
print(np.shape(train_label))
print(np.shape(eval_input))
print(np.shape(eval_label))

if not os.path.exists('./data_prepare'):
os.mkdir('./data_prepare')
else:
shutil.rmtree('./data_prepare')
os.mkdir('./data_prepare')

train_input = train_input.astype(np.float32)
np.save('./data_prepare/train_input', train_input)
train_label = train_label.astype(np.float32)
np.save('./data_prepare/train_label', train_label)
eval_input = eval_input.astype(np.float32)
np.save('./data_prepare/eval_input', eval_input)
eval_label = eval_label.astype(np.float32)
np.save('./data_prepare/eval_label', eval_label)

electromagnetic_train = ExistedDataConfig(name="electromagnetic_train",
data_dir=['./data_prepare/train_input.npy',
'./data_prepare/train_label.npy'],
columns_list=["inputs", "label"],
data_format="npy")
electromagnetic_eval = ExistedDataConfig(name="electromagnetic_eval",
data_dir=['./data_prepare/eval_input.npy',
'./data_prepare/eval_label.npy'],
columns_list=["inputs", "label"],
data_format="npy")
train_batch_size = opt.batch_size
eval_batch_size = len(eval_input)

train_dataset = Dataset(existed_data_list=[electromagnetic_train])
train_loader = train_dataset.create_dataset(batch_size=train_batch_size, shuffle=True)

eval_dataset = Dataset(existed_data_list=[electromagnetic_eval])
eval_loader = eval_dataset.create_dataset(batch_size=eval_batch_size, shuffle=False)

data = {
"train_loader": train_loader,
"eval_loader": eval_loader,

"train_data": train_input,
"train_label": train_label,
"eval_data": eval_input,
"eval_label": eval_label,

"train_data_length": len(train_label),
"eval_data_length": len(eval_label),
"frequency": frequency,
}

return data, config_data_prepare

+ 98
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/src/loss.py View File

@@ -0,0 +1,98 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
loss
"""

import os
import shutil
import mindspore.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import cv2


class EvalMetric(nn.Metric):
"""
eval metric
"""

def __init__(self, scale_s11, length, frequency, show_pic_number, file_path):
super(EvalMetric, self).__init__()
self.clear()
self.scale_s11 = scale_s11
self.length = length
self.frequency = frequency
self.show_pic_number = show_pic_number
self.file_path = file_path
self.show_pic_id = np.random.choice(length, self.show_pic_number, replace=False)

def clear(self):
"""
clear error
"""
self.error_sum_l2_error = 0
self.error_sum_loss_error = 0
self.pic_res = None

def update(self, *inputs):
"""
update error
"""
if not os.path.exists(self.file_path):
os.mkdir(self.file_path)
else:
shutil.rmtree(self.file_path)
os.mkdir(self.file_path)

y_pred = self._convert_data(inputs[0])
y_label = self._convert_data(inputs[1])

test_predict, test_label = y_pred, y_label
test_predict[:, :] = test_predict[:, :] * self.scale_s11
test_label[:, :] = test_label[:, :] * self.scale_s11
self.pic_res = []

for i in range(len(test_label)):
predict_real_temp = test_predict[i]
label_real_temp = test_label[i]
l2_error_temp = np.sqrt(np.sum(np.square(label_real_temp - predict_real_temp))) / \
np.sqrt(np.sum(np.square(label_real_temp)))
self.error_sum_l2_error += l2_error_temp
self.error_sum_loss_error += np.mean((label_real_temp - predict_real_temp) ** 2)

s11_label, s11_predict = label_real_temp, predict_real_temp
plt.figure(dpi=250)
plt.plot(self.frequency, s11_predict, '-', label='AI Model', linewidth=2)
plt.plot(self.frequency, s11_label, '--', label='CST', linewidth=1)
plt.title('s11(dB)')
plt.xlabel('frequency(GHz) l2_s11:' + str(l2_error_temp)[:10])
plt.ylabel('dB')
plt.legend()
plt.savefig(self.file_path + '/' + str(i) + '_' + str(l2_error_temp)[:10] + '.jpg')
plt.close()
if i in self.show_pic_id:
self.pic_res.append(cv2.imread(
self.file_path + '/' + str(i) + '_' + str(l2_error_temp)[:10] + '.jpg'))

self.pic_res = np.array(self.pic_res).astype(np.float32)

def eval(self):
"""
compute final error
"""
return {'l2_error': self.error_sum_l2_error / self.length,
'loss_error': self.error_sum_loss_error / self.length,
'pic_res': self.pic_res}

+ 47
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/src/maxwell_model.py View File

@@ -0,0 +1,47 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
maxwell S11 model
"""

import mindspore.nn as nn


class S11Predictor(nn.Cell):
"""
maxwell S11 model define
"""
def __init__(self, input_dimension):
super(S11Predictor, self).__init__()
self.fc1 = nn.Dense(input_dimension, 128)
self.fc2 = nn.Dense(128, 128)
self.fc3 = nn.Dense(128, 128)
self.fc4 = nn.Dense(128, 128)
self.fc5 = nn.Dense(128, 128)
self.fc6 = nn.Dense(128, 128)
self.fc7 = nn.Dense(128, 1001)
self.relu = nn.ReLU()

def construct(self, x):
"""forward"""
x0 = x
x1 = self.relu(self.fc1(x0))
x2 = self.relu(self.fc2(x1))
x3 = self.relu(self.fc3(x1 + x2))
x4 = self.relu(self.fc4(x1 + x2 + x3))
x5 = self.relu(self.fc5(x1 + x2 + x3 + x4))
x6 = self.relu(self.fc6(x1 + x2 + x3 + x4 + x5))
x = self.fc7(x1 + x2 + x3 + x4 + x5 + x6)
return x

+ 166
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/test_parameterization.py View File

@@ -0,0 +1,166 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
test parameterization
"""

import time

import pytest
import numpy as np
import mindspore.nn as nn
from mindspore.common import set_seed
import mindspore.common.dtype as mstype
from mindspore import context
from mindspore.train.callback import Callback

from easydict import EasyDict as edict
from mindelec.solver import Solver
from mindelec.vision import MonitorTrain, MonitorEval

from src.dataset import create_dataset
from src.maxwell_model import S11Predictor
from src.loss import EvalMetric

set_seed(123456)
np.random.seed(123456)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")

opt = edict({
'epochs': 100,
'print_interval': 50,
'batch_size': 8,
'save_epoch': 10,
'lr': 0.0001,
'input_dim': 3,
'device_num': 1,
'device_target': "Ascend",
'checkpoint_dir': './ckpt/',
'save_graphs_path': './graph_result/',
'input_path': './dataset/Butterfly_antenna/data_input.npy',
'label_path': './dataset/Butterfly_antenna/data_label.npy',
})


class TimeMonitor(Callback):
"""
Monitor the time in training.
"""

def __init__(self, data_size=None):
super(TimeMonitor, self).__init__()
self.data_size = data_size
self.epoch_time = time.time()
self.per_step_time = 0
self.t = 0

def epoch_begin(self, run_context):
"""
Record time at the begin of epoch.
"""
self.t = run_context
self.epoch_time = time.time()

def epoch_end(self, run_context):
"""
Print process cost time at the end of epoch.
"""
epoch_seconds = (time.time() - self.epoch_time) * 1000
step_size = self.data_size
cb_params = run_context.original_args()
if hasattr(cb_params, "batch_num"):
batch_num = cb_params.batch_num
if isinstance(batch_num, int) and batch_num > 0:
step_size = cb_params.batch_num

self.per_step_time = epoch_seconds / step_size

def get_step_time(self,):
return self.per_step_time


def get_lr(data):
"""
get_lr
"""
num_milestones = 10
if data['train_data_length'] % opt.batch_size == 0:
iter_number = int(data['train_data_length'] / opt.batch_size)
else:
iter_number = int(data['train_data_length'] / opt.batch_size) + 1
iter_number = opt.epochs * iter_number
milestones = [int(iter_number * i / num_milestones) for i in range(1, num_milestones)]
milestones.append(iter_number)
learning_rates = [opt.lr * 0.5 ** i for i in range(0, num_milestones - 1)]
learning_rates.append(opt.lr * 0.5 ** (num_milestones - 1))
return milestones, learning_rates


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_parameterization():
"""
test parameterization
"""
data, config_data = create_dataset(opt)

model_net = S11Predictor(opt.input_dim)
model_net.to_float(mstype.float16)

milestones, learning_rates = get_lr(data)

optim = nn.Adam(model_net.trainable_params(),
learning_rate=nn.piecewise_constant_lr(milestones, learning_rates))

eval_error_mrc = EvalMetric(scale_s11=config_data["scale_S11"],
length=data["eval_data_length"],
frequency=data["frequency"],
show_pic_number=4,
file_path='./eval_res')

solver = Solver(network=model_net,
mode="Data",
optimizer=optim,
metrics={'eval_mrc': eval_error_mrc},
loss_fn=nn.MSELoss())

monitor_train = MonitorTrain(per_print_times=1,
summary_dir='./summary_dir_train')

monitor_eval = MonitorEval(summary_dir='./summary_dir_eval',
model=solver,
eval_ds=data["eval_loader"],
eval_interval=opt.print_interval,
draw_flag=True)

time_monitor = TimeMonitor()
callbacks_train = [monitor_train, time_monitor, monitor_eval]

solver.model.train(epoch=opt.epochs,
train_dataset=data["train_loader"],
callbacks=callbacks_train,
dataset_sink_mode=True)

loss_print, l2_s11_print = monitor_eval.loss_final, monitor_eval.l2_s11_final
per_step_time = time_monitor.get_step_time()

print('loss_mse:', loss_print)
print('l2_s11:', l2_s11_print)
print('per_step_time:', per_step_time)

assert l2_s11_print <= 1.0
assert per_step_time <= 10.0

+ 26
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/src/config.py View File

@@ -0,0 +1,26 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""

# config
config = {
'input_channels': 496,
'epochs': 20,
'batch_size': 8,
'lr': 0.0001,
'lr_decay_milestones': 10,
}

+ 79
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/src/dataset.py View File

@@ -0,0 +1,79 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""dataset utilities"""

import os
import numpy as np

from mindelec.data import Dataset, ExistedDataConfig

INPUT_PATH = ""
LABEL_PATH = ""
DATA_CONFIG_PATH = "./data_config.npz"
SAVE_DATA_PATH = "./"

def custom_normalize(dataset, mean=None, std=None):
""" custom normalization """
ori_shape = dataset.shape
dataset = dataset.reshape(ori_shape[0], -1)
dataset = np.transpose(dataset)
if mean is None:
mean = np.mean(dataset, axis=1)
dataset = dataset - mean[:, None]
if std is None:
std = np.std(dataset, axis=1)
std += (np.abs(std) < 0.0000001)
dataset = dataset / std[:, None]
dataset = np.transpose(dataset)
dataset = dataset.reshape(ori_shape)
return dataset, mean, std

def generate_data(input_path, label_path):
"""generate dataset for s11 parameter prediction"""

data_input = np.load(input_path)
if os.path.exists(DATA_CONFIG_PATH):
data_config = np.load(DATA_CONFIG_PATH)
mean = data_config["mean"]
std = data_config["std"]
data_input, mean, std = custom_normalize(data_input)
data_label = np.load(label_path)

print(data_input.shape)
print(data_label.shape)

data_input = data_input.transpose((0, 4, 1, 2, 3))
data_label[:, :] = np.log10(-data_label[:, :] + 1.0)
scale_s11 = 0.5 * np.max(np.abs(data_label[:, :]))
data_label[:, :] = data_label[:, :] / scale_s11

np.savez(DATA_CONFIG_PATH, scale_s11=scale_s11, mean=mean, std=std)
np.save(os.path.join(SAVE_DATA_PATH, 'data_input.npy'), data_input)
np.save(os.path.join(SAVE_DATA_PATH, 'data_label.npy'), data_label)
print("data saved in target path")


def create_dataset(input_path, label_path, batch_size=8, shuffle=True):
electromagnetic = ExistedDataConfig(name="electromagnetic",
data_dir=[input_path, label_path],
columns_list=["inputs", "label"],
data_format=input_path.split('.')[-1])
dataset = Dataset(existed_data_list=[electromagnetic])
data_loader = dataset.create_dataset(batch_size=batch_size, shuffle=shuffle)

return data_loader

if __name__ == "__main__":
generate_data(input_path=INPUT_PATH, label_path=LABEL_PATH)

+ 29
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/src/lr_generator.py View File

@@ -0,0 +1,29 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning rate generator"""

def step_lr_generator(step_size, epochs, lr, lr_decay_milestones):
"""generate step decayed learning rate"""

total_steps = epochs * step_size

milestones = [int(total_steps * i / lr_decay_milestones) for i in range(1, lr_decay_milestones)]
milestones.append(total_steps)
learning_rates = [lr*0.5**i for i in range(0, lr_decay_milestones - 1)]
learning_rates.append(lr*0.5**(lr_decay_milestones - 1))

print("total_steps: %s, milestones: %s, learning_rates: %s " %(total_steps, milestones, learning_rates))

return milestones, learning_rates

+ 108
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/src/metric.py View File

@@ -0,0 +1,108 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""metrics"""

import os
import shutil
import mindspore.nn as nn
from mindspore.ops import functional as F
import matplotlib.pyplot as plt
import numpy as np
import cv2

class MyMSELoss(nn.LossBase):
"""mse loss function"""
def construct(self, base, target):
x = F.square(base - target)
return self.get_loss(x)


class EvalMetric(nn.Metric):
"""
eval metric
"""

def __init__(self, scale_s11, length, frequency, show_pic_number, file_path):
super(EvalMetric, self).__init__()
self.clear()
self.scale_s11 = scale_s11
self.length = length
self.frequency = frequency
self.show_pic_number = show_pic_number
self.file_path = file_path
self.show_pic_id = np.random.choice(length, self.show_pic_number, replace=False)

if not os.path.exists(self.file_path):
os.mkdir(self.file_path)
else:
shutil.rmtree(self.file_path)
os.mkdir(self.file_path)

def clear(self):
"""
clear error
"""
self.error_sum_l2_error = 0
self.error_sum_loss_error = 0
self.pic_res = None
self.index = 0

def update(self, *inputs):
"""
update error
"""

y_pred = self._convert_data(inputs[0])
y_label = self._convert_data(inputs[1])

test_predict, test_label = y_pred, y_label
test_predict[:, :] = test_predict[:, :] * self.scale_s11
test_label[:, :] = test_label[:, :] * self.scale_s11
test_predict[:, :] = 1.0 - np.power(10, test_predict[:, :])
test_label[:, :] = 1.0 - np.power(10, test_label[:, :])
self.pic_res = []

for i in range(len(test_label)):
self.index += 1
predict_real_temp = test_predict[i]
label_real_temp = test_label[i]
l2_error_temp = np.sqrt(np.sum(np.square(label_real_temp - predict_real_temp))) / \
np.sqrt(np.sum(np.square(label_real_temp)))
self.error_sum_l2_error += l2_error_temp
self.error_sum_loss_error += np.mean((label_real_temp - predict_real_temp) ** 2)

s11_label, s11_predict = label_real_temp, predict_real_temp
plt.figure(dpi=250)
plt.plot(self.frequency, s11_predict, '-', label='AI Model', linewidth=2)
plt.plot(self.frequency, s11_label, '--', label='CST', linewidth=1)
plt.title('s11(dB)')
plt.xlabel('frequency(GHz) l2_s11:' + str(l2_error_temp)[:10])
plt.ylabel('dB')
plt.legend()
plt.savefig(self.file_path + '/' + str(self.index) + '_' + str(l2_error_temp)[:10] + '.jpg')
plt.close()
if i in self.show_pic_id:
self.pic_res.append(cv2.imread(
self.file_path + '/' + str(i) + '_' + str(l2_error_temp)[:10] + '.jpg'))

self.pic_res = np.array(self.pic_res).astype(np.float32)

def eval(self):
"""
compute final error
"""
return {'l2_error': self.error_sum_l2_error / self.length,
'loss_error': self.error_sum_loss_error / self.length,
'pic_res': self.pic_res}

+ 89
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/src/model.py View File

@@ -0,0 +1,89 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""S parameter prediction model"""

import mindspore.nn as nn
import mindspore.ops as ops


class S11Predictor(nn.Cell):
"""
S11Predictor architecture for MindElec.

Args:
input_dim (int): input channel.

Returns:
Tensor, output tensor.

Examples:
>>> S11Predictor(input_dim=128)
"""

def __init__(self, input_dim):
super(S11Predictor, self).__init__()
# input shape is [20, 40, 3]
self.conv1 = nn.Conv3d(input_dim, 512, kernel_size=(3, 3, 1))
self.conv2 = nn.Conv3d(512, 512, kernel_size=(3, 3, 1))
self.conv3 = nn.Conv3d(512, 512, kernel_size=(3, 3, 1))
self.conv4 = nn.Conv3d(512, 512, kernel_size=(2, 1, 3), pad_mode='pad', padding=0)

self.down1 = ops.MaxPool3D(kernel_size=(2, 3, 1), strides=(2, 3, 1))
self.down2 = ops.MaxPool3D(kernel_size=(2, 3, 1), strides=(2, 3, 1))
self.down3 = ops.MaxPool3D(kernel_size=(2, 3, 1), strides=(2, 3, 1))

self.down_1_1 = ops.MaxPool3D(kernel_size=(1, 13, 1), strides=(1, 13, 1))
self.down_1_2 = nn.MaxPool2d(kernel_size=(10, 3))

self.down_2 = nn.MaxPool2d((5, 4*3))

self.fc1 = nn.Dense(1536, 2048)
self.fc2 = nn.Dense(2048, 2048)
self.fc3 = nn.Dense(2048, 1001)

self.concat = ops.Concat(axis=1)
self.relu = nn.ReLU()


def construct(self, x):
"""forward"""
bs = x.shape[0]

x = self.conv1(x)
x = self.relu(x)
x = self.down1(x)
x_1 = self.down_1_1(x)
x_1 = self.down_1_2(x_1.view(bs, x_1.shape[1], x_1.shape[2], -1)).view((bs, -1))

x = self.conv2(x)
x = self.relu(x)
x = self.down2(x)
x_2 = self.down_2(x.view(bs, x.shape[1], x.shape[2], -1)).view((bs, -1))

x = self.conv3(x)
x = self.relu(x)
x = self.down3(x)

x = self.conv4(x)
x = self.relu(x).view((bs, -1))

x = self.concat([x, x_1, x_2])
x = self.relu(x).view(bs, -1)

x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)

return x

+ 144
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/test_s_parameter.py View File

@@ -0,0 +1,144 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""s parameter prediction model training test"""
import time
import pytest
import numpy as np

import mindspore.nn as nn
from mindspore import context
from mindspore.common import set_seed
import mindspore.common.initializer as weight_init
from mindspore.train.callback import LossMonitor, Callback

from mindelec.solver import Solver

from src.dataset import create_dataset
from src.model import S11Predictor
from src.lr_generator import step_lr_generator
from src.config import config
from src.metric import MyMSELoss, EvalMetric

set_seed(0)
np.random.seed(0)

context.set_context(mode=context.GRAPH_MODE,
save_graphs=False,
device_target="Ascend")


class TimeMonitor(Callback):
"""
Monitor the time in training.
"""

def __init__(self, data_size=None):
super(TimeMonitor, self).__init__()
self.data_size = data_size
self.epoch_time = time.time()
self.per_step_time = 0
self._tmp = None

def epoch_begin(self, run_context):
"""
Record time at the begin of epoch.
"""
self.epoch_time = time.time()
self._tmp = run_context

def epoch_end(self, run_context):
"""
Print process cost time at the end of epoch.
"""
epoch_seconds = (time.time() - self.epoch_time) * 1000
step_size = self.data_size
cb_params = run_context.original_args()
if hasattr(cb_params, "batch_num"):
batch_num = cb_params.batch_num
if isinstance(batch_num, int) and batch_num > 0:
step_size = cb_params.batch_num

self.per_step_time = epoch_seconds / step_size
print("epoch time: {:5.3f} ms, per step time: {:5.3f} ms".format(epoch_seconds, self.per_step_time), flush=True)

def get_step_time(self,):
return self.per_step_time

def init_weight(net):
"""init_weight"""
for _, cell in net.cells_and_names():
if isinstance(cell, (nn.Conv3d, nn.Dense)):
cell.weight.set_data(weight_init.initializer(weight_init.HeNormal(),
cell.weight.shape,
cell.weight.dtype))

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_s_predictor_train():
"""training"""
train_input_path = "input.npy"
train_label_path = "label.npy"
train_input = np.ones((100, 496, 20, 40, 3), np.float32)
train_label = np.ones((100, 1001), np.float32)
np.save(train_input_path, train_input)
np.save(train_label_path, train_label)

model_net = S11Predictor(input_dim=config["input_channels"])
init_weight(net=model_net)

train_dataset = create_dataset(input_path=train_input_path,
label_path=train_label_path,
batch_size=config["batch_size"],
shuffle=True)

step_size = train_dataset.get_dataset_size()
milestones, learning_rates = step_lr_generator(step_size,
config["epochs"],
config["lr"],
config["lr_decay_milestones"])
optimizer = nn.Adam(model_net.trainable_params(),
learning_rate=nn.piecewise_constant_lr(milestones, learning_rates))

loss_net = MyMSELoss()

eval_step_size = train_dataset.get_dataset_size() * config["batch_size"]
evl_error_mrc = EvalMetric(scale_s11=1,
length=eval_step_size,
frequency=np.linspace(0, 4*10**8, 1001),
show_pic_number=4,
file_path='./eval_result')

solver = Solver(model_net,
train_input_map={'train': ['train_input_data']},
test_input_map={'test': ['test_input_data']},
optimizer=optimizer,
metrics={'evl_mrc': evl_error_mrc,},
amp_level="O2",
loss_fn=loss_net)

time_cb = TimeMonitor()
solver.model.train(config["epochs"],
train_dataset,
callbacks=[LossMonitor(), time_cb],
dataset_sink_mode=False)
res = solver.model.eval(train_dataset, dataset_sink_mode=False)
per_step_time = time_cb.get_step_time()
l1_error = res['evl_mrc']['l2_error']
print('test_res:', f'l1_error: {l1_error:.10f} ')
print(f'per step time: {per_step_time:.10f} ')
assert l1_error <= 0.01
assert per_step_time <= 100

+ 37
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/config.json View File

@@ -0,0 +1,37 @@
{
"Description" : [ "PINNs for solve Maxwell's equations" ],

"Case" : "2D_Mur_Src_Gauss_Mscale_MTL",
"random_sampling" : true,
"coord_min" : [0, 0],
"coord_max" : [1, 1],
"src_pos" : [0.4975, 0.4975],
"src_frq": 1e+9,
"range_t" : 4e-9,
"input_scale": [1.0, 1.0, 2.5e+8],
"output_scale": [37.67303, 37.67303, 0.1],
"src_radius": 0.01,
"input_size" : 3,
"output_size" : 3,
"residual" : true,
"num_scales" : 4,
"layers" : 7,
"neurons" : 64,
"amp_factor" : 10.0,
"scale_factor" : 2.0,
"save_ckpt" : true,
"load_ckpt" : false,
"save_ckpt_path" : "./ckpt",
"load_ckpt_path" : "",
"train_data_path" : "./input/",
"test_data_path" : "./input/benchmark/",
"lr" : 0.002,
"milestones" : [300],
"lr_gamma" : 0.1,
"train_epoch" : 300,
"train_batch_size" : 8192,
"test_batch_size" : 32768,
"predict_interval" : 6000,
"vision_path" : "./vision",
"summary_path" : "./summary"
}

+ 32
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/src/__init__.py View File

@@ -0,0 +1,32 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
init
"""
from .dataset import create_train_dataset, get_test_data, create_random_dataset
from .maxwell import Maxwell2DMur
from .lr_scheduler import MultiStepLR
from .callback import PredictCallback
from .utils import visual_result

__all__ = [
"create_train_dataset",
"create_random_dataset",
"get_test_data",
"Maxwell2DMur",
"MultiStepLR",
"PredictCallback",
"visual_result"
]

+ 127
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/src/callback.py View File

@@ -0,0 +1,127 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
call back functions
"""
import time
import copy

import numpy as np
from mindspore.train.callback import Callback
from mindspore.train.summary import SummaryRecord
from mindspore import Tensor
import mindspore.common.dtype as mstype

class PredictCallback(Callback):
"""
Monitor the prediction accuracy in training.

Args:
model (Cell): Prediction network cell.
inputs (Array): Input data of prediction.
label (Array): Label data of prediction.
config (dict): config info of prediction.
visual_fn (dict): Visualization function. Default: None.
"""
def __init__(self, model, inputs, label, config, visual_fn=None):
super(PredictCallback, self).__init__()
self.model = model
self.inputs = inputs
self.label = label
self.label_shape = label.shape
self.visual_fn = visual_fn
self.vision_path = config.get("vision_path", "./vision")
self.summary_dir = config.get("summary_path", "./summary")

self.output_size = config.get("output_size", 3)
self.input_size = config.get("input_size", 3)
self.output_scale = np.array(config["output_scale"], dtype=np.float32)
self.predict_interval = config.get("predict_interval", 10)
self.batch_size = config.get("test_batch_size", 8192*4)

self.dx = inputs[0, 1, 0, 0] - inputs[0, 0, 0, 0]
self.dy = inputs[0, 0, 1, 1] - inputs[0, 0, 0, 1]
self.dt = inputs[1, 0, 0, 2] - inputs[0, 0, 0, 2]
print("check yee delta: {}, {}, {}".format(self.dx, self.dy, self.dt))

self.ex_inputs = copy.deepcopy(inputs)
self.ey_inputs = copy.deepcopy(inputs)
self.hz_inputs = copy.deepcopy(inputs)
self.ex_inputs = self.ex_inputs.reshape(-1, self.input_size)
self.ey_inputs = self.ey_inputs.reshape(-1, self.input_size)
self.hz_inputs = self.hz_inputs.reshape(-1, self.input_size)
self.ex_inputs[:, 1] += self.dy / 2.0
self.ex_inputs[:, 2] += self.dt / 2.0
self.ey_inputs[:, 0] += self.dx / 2.0
self.ey_inputs[:, 2] += self.dt / 2.0
self.inputs_each = [self.ex_inputs, self.ey_inputs, self.hz_inputs]
self._step_counter = 0
self.l2_error = (1.0, 1.0, 1.0)

def __enter__(self):
self.summary_record = SummaryRecord(self.summary_dir)
return self

def __exit__(self, *exc_args):
self.summary_record.close()

def epoch_end(self, run_context):
"""
Evaluate the model at the end of epoch.

Args:
run_context (RunContext): Context of the train running.
"""
cb_params = run_context.original_args()
if cb_params.cur_epoch_num % self.predict_interval == 0:
# predict each quantity
index = 0
prediction_each = np.zeros(self.label_shape)
prediction_each = prediction_each.reshape((-1, self.output_size))
time_beg = time.time()
while index < len(self.inputs_each[0]):
index_end = min(index + self.batch_size, len(self.inputs_each[0]))
for i in range(self.output_size):
test_batch = Tensor(self.inputs_each[i][index: index_end, :], mstype.float32)
predict = self.model(test_batch)
predict = predict.asnumpy()
prediction_each[index: index_end, i] = predict[:, i] * self.output_scale[i]
index = index_end
print("==================================================================================================")
print("predict total time: {} s".format(time.time() - time_beg))
prediction = prediction_each.reshape(self.label_shape)
if self.visual_fn is not None:
self.visual_fn(self.inputs, self.label, prediction, path=self.vision_path,
name="epoch" + str(cb_params.cur_epoch_num))

label = self.label.reshape((-1, self.output_size))
prediction = prediction.reshape((-1, self.output_size))
self.l2_error = self._calculate_error(label, prediction)

def _calculate_error(self, label, prediction):
"""calculate l2-error to evaluate accuracy"""
self._step_counter += 1
error = label - prediction
l2_error_ex = np.sqrt(np.sum(np.square(error[..., 0]))) / np.sqrt(np.sum(np.square(label[..., 0])))
l2_error_ey = np.sqrt(np.sum(np.square(error[..., 1]))) / np.sqrt(np.sum(np.square(label[..., 1])))
l2_error_hz = np.sqrt(np.sum(np.square(error[..., 2]))) / np.sqrt(np.sum(np.square(label[..., 2])))
print("l2_error, Ex: ", l2_error_ex, ", Ey: ", l2_error_ey, ", Hz: ", l2_error_hz)
self.summary_record.add_value('scalar', 'l2_ex', Tensor(l2_error_ex))
self.summary_record.add_value('scalar', 'l2_ey', Tensor(l2_error_ey))
self.summary_record.add_value('scalar', 'l2_hz', Tensor(l2_error_hz))
return l2_error_ex, l2_error_ey, l2_error_hz

def get_l2_error(self):
return self.l2_error

+ 95
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/src/dataset.py View File

@@ -0,0 +1,95 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
create dataset
"""
import numpy as np

from mindelec.data import Dataset, ExistedDataConfig
from mindelec.geometry import Disk, Rectangle, TimeDomain, GeometryWithTime
from mindelec.geometry import create_config_from_edict

from .sampling_config import no_src_sampling_config, src_sampling_config, bc_sampling_config

def get_test_data(test_data_path):
"""load labeled data for evaluation"""
# check data
paths = [test_data_path + '/input.npy', test_data_path + '/output.npy']
inputs = np.load(paths[0])
label = np.load(paths[1])
return inputs, label

def create_train_dataset(train_data_path):
"""create training dataset from existed data files"""
src_ic = ExistedDataConfig(name="src_ic",
data_dir=[train_data_path + "elec_src_ic.npy"],
columns_list=["points"],
data_format="npy",
constraint_type="IC",
random_merge=False)
boundary = ExistedDataConfig(name="boundary",
data_dir=[train_data_path + "elec_no_src_bc.npy"],
columns_list=["points"],
data_format="npy",
constraint_type="BC",
random_merge=True)
no_src_ic = ExistedDataConfig(name="no_src_ic",
data_dir=[train_data_path + "elec_no_src_ic.npy"],
columns_list=["points"],
data_format="npy",
constraint_type="IC",
random_merge=True)
src_domain = ExistedDataConfig(name="src_domain",
data_dir=[train_data_path + "elec_src_domain.npy"],
columns_list=["points"],
data_format="npy",
constraint_type="Equation",
random_merge=True)
no_src_domain = ExistedDataConfig(name="no_src_domain",
data_dir=[train_data_path + "elec_no_src_domain.npy"],
columns_list=["points"],
data_format="npy",
constraint_type="Equation",
random_merge=True)
dataset = Dataset(existed_data_list=[no_src_domain, no_src_ic, boundary, src_domain, src_ic])
return dataset

def create_random_dataset(config):
"""create training dataset by online sampling"""
disk_radius = config["src_radius"]
disk_origin = config["src_pos"]
coord_min = config["coord_min"]
coord_max = config["coord_max"]

disk = Disk("src", disk_origin, disk_radius)
rectangle = Rectangle("rect", coord_min, coord_max)
diff = rectangle - disk
time_interval = TimeDomain("time", 0.0, config["range_t"])
no_src_region = GeometryWithTime(diff, time_interval)
no_src_region.set_name("no_src")
no_src_region.set_sampling_config(create_config_from_edict(no_src_sampling_config))
src_region = GeometryWithTime(disk, time_interval)
src_region.set_name("src")
src_region.set_sampling_config(create_config_from_edict(src_sampling_config))
boundary = GeometryWithTime(rectangle, time_interval)
boundary.set_name("bc")
boundary.set_sampling_config(create_config_from_edict(bc_sampling_config))

geom_dict = {src_region: ["domain", "IC"],
no_src_region: ["domain", "IC"],
boundary: ["BC"]}

dataset = Dataset(geom_dict)
return dataset

+ 73
- 0
reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/src/lr_scheduler.py View File

@@ -0,0 +1,73 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""Learning rate scheduler."""
from collections import Counter
import numpy as np

class _LRScheduler():
"""
Basic class for learning rate scheduler
"""

def __init__(self, lr, max_epoch, steps_per_epoch):
self.base_lr = lr
self.steps_per_epoch = steps_per_epoch
self.total_steps = int(max_epoch * steps_per_epoch)

def get_lr(self):
# Compute learning rate using chainable form of the scheduler
raise NotImplementedError


class MultiStepLR(_LRScheduler):
"""
Multi-step learning rate scheduler

Decays the learning rate by gamma once the number of epoch reaches one of the milestones.

Args:
lr (float): Initial learning rate which is the lower boundary in the cycle.
milestones (list): List of epoch indices. Must be increasing.
gamma (float): Multiplicative factor of learning rate decay.
steps_per_epoch (int): The number of steps per epoch to train for.
max_epoch (int): The number of epochs to train for.

Outputs:
numpy.ndarray, shape=(1, steps_per_epoch*max_epoch)

Example:
>>> # Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.05 if epoch < 30
>>> # lr = 0.005 if 30 <= epoch < 80
>>> # lr = 0.0005 if epoch >= 80
>>> scheduler = MultiStepLR(lr=0.1, milestones=[30,80], gamma=0.1, steps_per_epoch=5000, max_epoch=90)
>>> lr = scheduler.get_lr()
"""

def __init__(self, lr, milestones, gamma, steps_per_epoch, max_epoch):
self.milestones = Counter(milestones)
self.gamma = gamma
super(MultiStepLR, self).__init__(lr, max_epoch, steps_per_epoch)

def get_lr(self):
lr_each_step = []
current_lr = self.base_lr
for i in range(self.total_steps):
cur_ep = i // self.steps_per_epoch
if i % self.steps_per_epoch == 0 and cur_ep in self.milestones:
current_lr = current_lr * self.gamma
lr = current_lr
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)

Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save