diff --git a/reproduce/AlphaFold2-Chinese/.gitignore b/reproduce/AlphaFold2-Chinese/.gitignore
new file mode 100644
index 0000000..1eaa6c5
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/.gitignore
@@ -0,0 +1,21 @@
+/tmp/
+*.coverage
+/.idea
+/.vscode
+.vscode
+__pycache__/
+*.pyc
+*.so
+*.so.*
+*.o
+*.out
+*.gch
+build
+*.egg-info
+dist
+version.py
+local_script/
+output
+.ipynb_checkpoints
+somas_meta
+analyze_fail.dat
\ No newline at end of file
diff --git a/reproduce/AlphaFold2-Chinese/AlphaFold2中文版蛋白质预测模型使用指南(基于Deepmind与Mindspore开源框架).ipynb b/reproduce/AlphaFold2-Chinese/AlphaFold2中文版蛋白质预测模型使用指南(基于Deepmind与Mindspore开源框架).ipynb
new file mode 100644
index 0000000..d8cea6b
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/AlphaFold2中文版蛋白质预测模型使用指南(基于Deepmind与Mindspore开源框架).ipynb
@@ -0,0 +1,479 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "name": "AlphaFold2中文版蛋白质预测模型使用指南(基于Deepmind与Mindspore开源框架).ipynb",
+ "provenance": [],
+ "collapsed_sections": [],
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.10"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "G4yBrceuFbf3"
+ },
+ "source": [
+ "
\n",
+ "\n",
+ "##AlphaFold2_CN: AlphaFold2 with MMseqs2\n",
+ "\n",
+ "简单的中文版蛋白质结构预测操作指南(中文版),基于[AlphaFold2](https://www.nature.com/articles/s41586-021-03819-2)和[Alphafold2-multimer](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1). 序列比对方式基于[MMseqs2](mmseqs.com)和[HHsearch](https://github.com/soedinglab/hh-suite)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "kOblAo-xetgx",
+ "cellView": "form"
+ },
+ "source": [
+ "#@title 输入蛋白质序列(默认为视频中的测试序列)\n",
+ "from google.colab import files\n",
+ "import os.path\n",
+ "import re\n",
+ "import hashlib\n",
+ "import random\n",
+ "\n",
+ "def add_hash(x,y):\n",
+ " return x+\"_\"+hashlib.sha1(y.encode()).hexdigest()[:5]\n",
+ "\n",
+ "query_sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n",
+ "#@markdown - Use `:` to specify inter-protein chainbreaks for **modeling complexes** (supports homo- and hetro-oligomers). For example **PI...SK:PI...SK** for a homodimer\n",
+ "\n",
+ "# remove whitespaces\n",
+ "query_sequence = \"\".join(query_sequence.split())\n",
+ "\n",
+ "jobname = 'test' #@param {type:\"string\"}\n",
+ "# remove whitespaces\n",
+ "basejobname = \"\".join(jobname.split())\n",
+ "basejobname = re.sub(r'\\W+', '', basejobname)\n",
+ "jobname = add_hash(basejobname, query_sequence)\n",
+ "while os.path.isfile(f\"{jobname}.csv\"):\n",
+ " jobname = add_hash(basejobname, ''.join(random.sample(query_sequence,len(query_sequence))))\n",
+ "\n",
+ "with open(f\"{jobname}.csv\", \"w\") as text_file:\n",
+ " text_file.write(f\"id,sequence\\n{jobname},{query_sequence}\")\n",
+ "\n",
+ "queries_path=f\"{jobname}.csv\"\n",
+ "\n",
+ "# number of models to use\n",
+ "use_amber = False #@param {type:\"boolean\"}\n",
+ "template_mode = \"none\" #@param [\"none\", \"pdb70\",\"custom\"]\n",
+ "#@markdown - \"none\" = no template information is used, \"pdb70\" = detect templates in pdb70, \"custom\" - upload and search own templates (PDB or mmCIF format, see [notes below](#custom_templates))\n",
+ "\n",
+ "if template_mode == \"pdb70\":\n",
+ " use_templates = True\n",
+ " custom_template_path = None\n",
+ "elif template_mode == \"custom\":\n",
+ " custom_template_path = f\"{jobname}_template\"\n",
+ " os.mkdir(custom_template_path)\n",
+ " uploaded = files.upload()\n",
+ " use_templates = True\n",
+ " for fn in uploaded.keys():\n",
+ " os.rename(fn, f\"{jobname}_template/{fn}\")\n",
+ "else:\n",
+ " custom_template_path = None\n",
+ " use_templates = False\n"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@markdown ### MSA选项(custom MSA upload, single sequence, pairing mode)\n",
+ "msa_mode = \"MMseqs2 (UniRef+Environmental)\" #@param [\"MMseqs2 (UniRef+Environmental)\", \"MMseqs2 (UniRef only)\",\"single_sequence\",\"custom\"]\n",
+ "pair_mode = \"unpaired+paired\" #@param [\"unpaired+paired\",\"paired\",\"unpaired\"] {type:\"string\"}\n",
+ "#@markdown - \"unpaired+paired\" = pair sequences from same species + unpaired MSA, \"unpaired\" = seperate MSA for each chain, \"paired\" - only use paired sequences.\n",
+ "\n",
+ "# decide which a3m to use\n",
+ "if msa_mode.startswith(\"MMseqs2\"):\n",
+ " a3m_file = f\"{jobname}.a3m\"\n",
+ "elif msa_mode == \"custom\":\n",
+ " a3m_file = f\"{jobname}.custom.a3m\"\n",
+ " if not os.path.isfile(a3m_file):\n",
+ " custom_msa_dict = files.upload()\n",
+ " custom_msa = list(custom_msa_dict.keys())[0]\n",
+ " header = 0\n",
+ " import fileinput\n",
+ " for line in fileinput.FileInput(custom_msa,inplace=1):\n",
+ " if line.startswith(\">\"):\n",
+ " header = header + 1\n",
+ " if not line.rstrip():\n",
+ " continue\n",
+ " if line.startswith(\">\") == False and header == 1:\n",
+ " query_sequence = line.rstrip()\n",
+ " print(line, end='')\n",
+ "\n",
+ " os.rename(custom_msa, a3m_file)\n",
+ " queries_path=a3m_file\n",
+ " print(f\"moving {custom_msa} to {a3m_file}\")\n",
+ "else:\n",
+ " a3m_file = f\"{jobname}.single_sequence.a3m\"\n",
+ " with open(a3m_file, \"w\") as text_file:\n",
+ " text_file.write(\">1\\n%s\" % query_sequence)"
+ ],
+ "metadata": {
+ "cellView": "form",
+ "id": "C2_sh2uAonJH"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@markdown ### 参数设置\n",
+ "model_type = \"auto\" #@param [\"auto\", \"AlphaFold2-ptm\", \"AlphaFold2-multimer-v1\", \"AlphaFold2-multimer-v2\"]\n",
+ "#@markdown - \"auto\" = protein structure prediction using \"AlphaFold2-ptm\" and complex prediction \"AlphaFold-multimer-v2\". For complexes \"AlphaFold-multimer-v[1,2]\" and \"AlphaFold-ptm\" can be used.\n",
+ "num_recycles = 3 #@param [1,3,6,12,24,48] {type:\"raw\"}\n",
+ "save_to_google_drive = False #@param {type:\"boolean\"}\n",
+ "\n",
+ "#@markdown - if the save_to_google_drive option was selected, the result zip will be uploaded to your Google Drive\n",
+ "dpi = 200 #@param {type:\"integer\"}\n",
+ "#@markdown - set dpi for image resolution\n",
+ "\n",
+ "#@markdown Don't forget to hit `Runtime` -> `Run all` after updating the form.\n",
+ "\n",
+ "\n",
+ "if save_to_google_drive:\n",
+ " from pydrive.drive import GoogleDrive\n",
+ " from pydrive.auth import GoogleAuth\n",
+ " from google.colab import auth\n",
+ " from oauth2client.client import GoogleCredentials\n",
+ " auth.authenticate_user()\n",
+ " gauth = GoogleAuth()\n",
+ " gauth.credentials = GoogleCredentials.get_application_default()\n",
+ " drive = GoogleDrive(gauth)\n",
+ " print(\"You are logged into Google Drive and are good to go!\")"
+ ],
+ "metadata": {
+ "cellView": "form",
+ "id": "ADDuaolKmjGW"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "iccGdbe_Pmt9",
+ "pycharm": {
+ "name": "#%%\n"
+ },
+ "cellView": "form"
+ },
+ "source": [
+ "#@title 环境安装\n",
+ "%%bash -s $use_amber $use_templates\n",
+ "\n",
+ "set -e\n",
+ "\n",
+ "USE_AMBER=$1\n",
+ "USE_TEMPLATES=$2\n",
+ "\n",
+ "if [ ! -f COLABFOLD_READY ]; then\n",
+ " # install dependencies\n",
+ " # We have to use \"--no-warn-conflicts\" because colab already has a lot preinstalled with requirements different to ours\n",
+ " pip install -q --no-warn-conflicts \"colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold\"\n",
+ " # high risk high gain\n",
+ " pip install -q \"jax[cuda11_cudnn805]>=0.3.8,<0.4\" -f https://storage.googleapis.com/jax-releases/jax_releases.html\n",
+ " touch COLABFOLD_READY\n",
+ "fi\n",
+ "\n",
+ "# setup conda\n",
+ "if [ ${USE_AMBER} == \"True\" ] || [ ${USE_TEMPLATES} == \"True\" ]; then\n",
+ " if [ ! -f CONDA_READY ]; then\n",
+ " wget -qnc https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh\n",
+ " bash Miniconda3-latest-Linux-x86_64.sh -bfp /usr/local 2>&1 1>/dev/null\n",
+ " rm Miniconda3-latest-Linux-x86_64.sh\n",
+ " touch CONDA_READY\n",
+ " fi\n",
+ "fi\n",
+ "# setup template search\n",
+ "if [ ${USE_TEMPLATES} == \"True\" ] && [ ! -f HH_READY ]; then\n",
+ " conda install -y -q -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 python=3.7 2>&1 1>/dev/null\n",
+ " touch HH_READY\n",
+ "fi\n",
+ "# setup openmm for amber refinement\n",
+ "if [ ${USE_AMBER} == \"True\" ] && [ ! -f AMBER_READY ]; then\n",
+ " conda install -y -q -c conda-forge openmm=7.5.1 python=3.7 pdbfixer 2>&1 1>/dev/null\n",
+ " touch AMBER_READY\n",
+ "fi"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "_sztQyz29DIC",
+ "cellView": "form"
+ },
+ "source": [
+ "#@title 模型预测\n",
+ "\n",
+ "import sys\n",
+ "\n",
+ "from colabfold.download import download_alphafold_params, default_data_dir\n",
+ "from colabfold.utils import setup_logging\n",
+ "from colabfold.batch import get_queries, run, set_model_type\n",
+ "K80_chk = !nvidia-smi | grep \"Tesla K80\" | wc -l\n",
+ "if \"1\" in K80_chk:\n",
+ " print(\"WARNING: found GPU Tesla K80: limited to total length < 1000\")\n",
+ " if \"TF_FORCE_UNIFIED_MEMORY\" in os.environ:\n",
+ " del os.environ[\"TF_FORCE_UNIFIED_MEMORY\"]\n",
+ " if \"XLA_PYTHON_CLIENT_MEM_FRACTION\" in os.environ:\n",
+ " del os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]\n",
+ "\n",
+ "from colabfold.colabfold import plot_protein\n",
+ "from pathlib import Path\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "\n",
+ "# For some reason we need that to get pdbfixer to import\n",
+ "if use_amber and '/usr/local/lib/python3.7/site-packages/' not in sys.path:\n",
+ " sys.path.insert(0, '/usr/local/lib/python3.7/site-packages/')\n",
+ "\n",
+ "def prediction_callback(unrelaxed_protein, length, prediction_result, input_features, type):\n",
+ " fig = plot_protein(unrelaxed_protein, Ls=length, dpi=150)\n",
+ " plt.show()\n",
+ " plt.close()\n",
+ "\n",
+ "result_dir=\".\"\n",
+ "setup_logging(Path(\".\").joinpath(\"log.txt\"))\n",
+ "queries, is_complex = get_queries(queries_path)\n",
+ "model_type = set_model_type(is_complex, model_type)\n",
+ "download_alphafold_params(model_type, Path(\".\"))\n",
+ "run(\n",
+ " queries=queries,\n",
+ " result_dir=result_dir,\n",
+ " use_templates=use_templates,\n",
+ " custom_template_path=custom_template_path,\n",
+ " use_amber=use_amber,\n",
+ " msa_mode=msa_mode, \n",
+ " model_type=model_type,\n",
+ " num_models=5,\n",
+ " num_recycles=num_recycles,\n",
+ " model_order=[1, 2, 3, 4, 5],\n",
+ " is_complex=is_complex,\n",
+ " data_dir=Path(\".\"),\n",
+ " keep_existing_results=False,\n",
+ " recompile_padding=1.0,\n",
+ " rank_by=\"auto\",\n",
+ " pair_mode=pair_mode,\n",
+ " stop_at_score=float(100),\n",
+ " prediction_callback=prediction_callback,\n",
+ " dpi=dpi\n",
+ ")"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "KK7X9T44pWb7",
+ "cellView": "form"
+ },
+ "source": [
+ "#@title 展示3维结构 {run: \"auto\"}\n",
+ "import py3Dmol\n",
+ "import glob\n",
+ "import matplotlib.pyplot as plt\n",
+ "from colabfold.colabfold import plot_plddt_legend\n",
+ "rank_num = 1 #@param [\"1\", \"2\", \"3\", \"4\", \"5\"] {type:\"raw\"}\n",
+ "color = \"lDDT\" #@param [\"chain\", \"lDDT\", \"rainbow\"]\n",
+ "show_sidechains = False #@param {type:\"boolean\"}\n",
+ "show_mainchains = False #@param {type:\"boolean\"}\n",
+ "\n",
+ "jobname_prefix = \".custom\" if msa_mode == \"custom\" else \"\"\n",
+ "if use_amber:\n",
+ " pdb_filename = f\"{jobname}{jobname_prefix}_relaxed_rank_{rank_num}_model_*.pdb\"\n",
+ "else:\n",
+ " pdb_filename = f\"{jobname}{jobname_prefix}_unrelaxed_rank_{rank_num}_model_*.pdb\"\n",
+ "\n",
+ "pdb_file = glob.glob(pdb_filename)\n",
+ "\n",
+ "def show_pdb(rank_num=1, show_sidechains=False, show_mainchains=False, color=\"lDDT\"):\n",
+ " model_name = f\"rank_{rank_num}\"\n",
+ " view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)\n",
+ " view.addModel(open(pdb_file[0],'r').read(),'pdb')\n",
+ "\n",
+ " if color == \"lDDT\":\n",
+ " view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}})\n",
+ " elif color == \"rainbow\":\n",
+ " view.setStyle({'cartoon': {'color':'spectrum'}})\n",
+ " elif color == \"chain\":\n",
+ " chains = len(queries[0][1]) + 1 if is_complex else 1\n",
+ " for n,chain,color in zip(range(chains),list(\"ABCDEFGH\"),\n",
+ " [\"lime\",\"cyan\",\"magenta\",\"yellow\",\"salmon\",\"white\",\"blue\",\"orange\"]):\n",
+ " view.setStyle({'chain':chain},{'cartoon': {'color':color}})\n",
+ " if show_sidechains:\n",
+ " BB = ['C','O','N']\n",
+ " view.addStyle({'and':[{'resn':[\"GLY\",\"PRO\"],'invert':True},{'atom':BB,'invert':True}]},\n",
+ " {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
+ " view.addStyle({'and':[{'resn':\"GLY\"},{'atom':'CA'}]},\n",
+ " {'sphere':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
+ " view.addStyle({'and':[{'resn':\"PRO\"},{'atom':['C','O'],'invert':True}]},\n",
+ " {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}}) \n",
+ " if show_mainchains:\n",
+ " BB = ['C','O','N','CA']\n",
+ " view.addStyle({'atom':BB},{'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
+ "\n",
+ " view.zoomTo()\n",
+ " return view\n",
+ "\n",
+ "\n",
+ "show_pdb(rank_num,show_sidechains, show_mainchains, color).show()\n",
+ "if color == \"lDDT\":\n",
+ " plot_plddt_legend().show() "
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "11l8k--10q0C",
+ "cellView": "form"
+ },
+ "source": [
+ "#@title 图表展示 {run: \"auto\"}\n",
+ "from IPython.display import display, HTML\n",
+ "import base64\n",
+ "from html import escape\n",
+ "\n",
+ "# see: https://stackoverflow.com/a/53688522\n",
+ "def image_to_data_url(filename):\n",
+ " ext = filename.split('.')[-1]\n",
+ " prefix = f'data:image/{ext};base64,'\n",
+ " with open(filename, 'rb') as f:\n",
+ " img = f.read()\n",
+ " return prefix + base64.b64encode(img).decode('utf-8')\n",
+ "\n",
+ "pae = image_to_data_url(f\"{jobname}{jobname_prefix}_PAE.png\")\n",
+ "cov = image_to_data_url(f\"{jobname}{jobname_prefix}_coverage.png\")\n",
+ "plddt = image_to_data_url(f\"{jobname}{jobname_prefix}_plddt.png\")\n",
+ "display(HTML(f\"\"\"\n",
+ "\n",
+ "
\n",
+ "
Plots for {escape(jobname)}
\n",
+ "

\n",
+ "

\n",
+ "

\n",
+ "
\n",
+ "\"\"\"))\n"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "33g5IIegij5R",
+ "cellView": "form"
+ },
+ "source": [
+ "#@title结果下载\n",
+ "#@markdown If you are having issues downloading the result archive, try disabling your adblocker and run this cell again. If that fails click on the little folder icon to the left, navigate to file: `jobname.result.zip`, right-click and select \\\"Download\\\" (see [screenshot](https://pbs.twimg.com/media/E6wRW2lWUAEOuoe?format=jpg&name=small)).\n",
+ "\n",
+ "if msa_mode == \"custom\":\n",
+ " print(\"Don't forget to cite your custom MSA generation method.\")\n",
+ "\n",
+ "!zip -FSr $jobname\".result.zip\" config.json $jobname*\".json\" $jobname*\".a3m\" $jobname*\"relaxed_rank_\"*\".pdb\" \"cite.bibtex\" $jobname*\".png\"\n",
+ "files.download(f\"{jobname}.result.zip\")\n",
+ "\n",
+ "if save_to_google_drive == True and drive:\n",
+ " uploaded = drive.CreateFile({'title': f\"{jobname}.result.zip\"})\n",
+ " uploaded.SetContentFile(f\"{jobname}.result.zip\")\n",
+ " uploaded.Upload()\n",
+ " print(f\"Uploaded {jobname}.result.zip to Google Drive with ID {uploaded.get('id')}\")"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "UGUBLzB3C6WN",
+ "pycharm": {
+ "name": "#%% md\n"
+ }
+ },
+ "source": [
+ "# 操作指南 \n",
+ "**Quick start**\n",
+ "1. 把你要预测的氨基酸序列复制进输入框.\n",
+ "2. 点击\"Runtime\" -> \"Run all\".\n",
+ "3. 目前的模型预测包括5个模块,最终会生成一个3维结构图.\n",
+ "\n",
+ "**生成文件**\n",
+ "\n",
+ "1. PDB格式的模型结构文件.\n",
+ "2. 模型质量图\n",
+ "3. 模型MSA覆盖率.\n",
+ "4. 其他.\n",
+ "\n",
+ "**Acknowledgments**\n",
+ "- We thank the AlphaFold team for developing an excellent model and open sourcing the software. \n",
+ "\n",
+ "- [Söding Lab](https://www.mpibpc.mpg.de/soeding) for providing the computational resources for the MMseqs2 server\n",
+ "\n",
+ "- Richard Evans for helping to benchmark the ColabFold's Alphafold-multimer support\n",
+ "\n",
+ "- [David Koes](https://github.com/dkoes) for his awesome [py3Dmol](https://3dmol.csb.pitt.edu/) plugin, without whom these notebooks would be quite boring!\n",
+ "\n",
+ "- Do-Yoon Kim for creating the ColabFold logo.\n",
+ "\n",
+ "- A colab by Sergey Ovchinnikov ([@sokrypton](https://twitter.com/sokrypton)), Milot Mirdita ([@milot_mirdita](https://twitter.com/milot_mirdita)) and Martin Steinegger ([@thesteinegger](https://twitter.com/thesteinegger)).\n"
+ ]
+ }
+ ]
+}
\ No newline at end of file
diff --git a/reproduce/AlphaFold2-Chinese/Fold_CN.ipynb b/reproduce/AlphaFold2-Chinese/Fold_CN.ipynb
new file mode 100644
index 0000000..d8cea6b
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/Fold_CN.ipynb
@@ -0,0 +1,479 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "name": "AlphaFold2中文版蛋白质预测模型使用指南(基于Deepmind与Mindspore开源框架).ipynb",
+ "provenance": [],
+ "collapsed_sections": [],
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.10"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "G4yBrceuFbf3"
+ },
+ "source": [
+ "
\n",
+ "\n",
+ "##AlphaFold2_CN: AlphaFold2 with MMseqs2\n",
+ "\n",
+ "简单的中文版蛋白质结构预测操作指南(中文版),基于[AlphaFold2](https://www.nature.com/articles/s41586-021-03819-2)和[Alphafold2-multimer](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1). 序列比对方式基于[MMseqs2](mmseqs.com)和[HHsearch](https://github.com/soedinglab/hh-suite)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "kOblAo-xetgx",
+ "cellView": "form"
+ },
+ "source": [
+ "#@title 输入蛋白质序列(默认为视频中的测试序列)\n",
+ "from google.colab import files\n",
+ "import os.path\n",
+ "import re\n",
+ "import hashlib\n",
+ "import random\n",
+ "\n",
+ "def add_hash(x,y):\n",
+ " return x+\"_\"+hashlib.sha1(y.encode()).hexdigest()[:5]\n",
+ "\n",
+ "query_sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n",
+ "#@markdown - Use `:` to specify inter-protein chainbreaks for **modeling complexes** (supports homo- and hetro-oligomers). For example **PI...SK:PI...SK** for a homodimer\n",
+ "\n",
+ "# remove whitespaces\n",
+ "query_sequence = \"\".join(query_sequence.split())\n",
+ "\n",
+ "jobname = 'test' #@param {type:\"string\"}\n",
+ "# remove whitespaces\n",
+ "basejobname = \"\".join(jobname.split())\n",
+ "basejobname = re.sub(r'\\W+', '', basejobname)\n",
+ "jobname = add_hash(basejobname, query_sequence)\n",
+ "while os.path.isfile(f\"{jobname}.csv\"):\n",
+ " jobname = add_hash(basejobname, ''.join(random.sample(query_sequence,len(query_sequence))))\n",
+ "\n",
+ "with open(f\"{jobname}.csv\", \"w\") as text_file:\n",
+ " text_file.write(f\"id,sequence\\n{jobname},{query_sequence}\")\n",
+ "\n",
+ "queries_path=f\"{jobname}.csv\"\n",
+ "\n",
+ "# number of models to use\n",
+ "use_amber = False #@param {type:\"boolean\"}\n",
+ "template_mode = \"none\" #@param [\"none\", \"pdb70\",\"custom\"]\n",
+ "#@markdown - \"none\" = no template information is used, \"pdb70\" = detect templates in pdb70, \"custom\" - upload and search own templates (PDB or mmCIF format, see [notes below](#custom_templates))\n",
+ "\n",
+ "if template_mode == \"pdb70\":\n",
+ " use_templates = True\n",
+ " custom_template_path = None\n",
+ "elif template_mode == \"custom\":\n",
+ " custom_template_path = f\"{jobname}_template\"\n",
+ " os.mkdir(custom_template_path)\n",
+ " uploaded = files.upload()\n",
+ " use_templates = True\n",
+ " for fn in uploaded.keys():\n",
+ " os.rename(fn, f\"{jobname}_template/{fn}\")\n",
+ "else:\n",
+ " custom_template_path = None\n",
+ " use_templates = False\n"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@markdown ### MSA选项(custom MSA upload, single sequence, pairing mode)\n",
+ "msa_mode = \"MMseqs2 (UniRef+Environmental)\" #@param [\"MMseqs2 (UniRef+Environmental)\", \"MMseqs2 (UniRef only)\",\"single_sequence\",\"custom\"]\n",
+ "pair_mode = \"unpaired+paired\" #@param [\"unpaired+paired\",\"paired\",\"unpaired\"] {type:\"string\"}\n",
+ "#@markdown - \"unpaired+paired\" = pair sequences from same species + unpaired MSA, \"unpaired\" = seperate MSA for each chain, \"paired\" - only use paired sequences.\n",
+ "\n",
+ "# decide which a3m to use\n",
+ "if msa_mode.startswith(\"MMseqs2\"):\n",
+ " a3m_file = f\"{jobname}.a3m\"\n",
+ "elif msa_mode == \"custom\":\n",
+ " a3m_file = f\"{jobname}.custom.a3m\"\n",
+ " if not os.path.isfile(a3m_file):\n",
+ " custom_msa_dict = files.upload()\n",
+ " custom_msa = list(custom_msa_dict.keys())[0]\n",
+ " header = 0\n",
+ " import fileinput\n",
+ " for line in fileinput.FileInput(custom_msa,inplace=1):\n",
+ " if line.startswith(\">\"):\n",
+ " header = header + 1\n",
+ " if not line.rstrip():\n",
+ " continue\n",
+ " if line.startswith(\">\") == False and header == 1:\n",
+ " query_sequence = line.rstrip()\n",
+ " print(line, end='')\n",
+ "\n",
+ " os.rename(custom_msa, a3m_file)\n",
+ " queries_path=a3m_file\n",
+ " print(f\"moving {custom_msa} to {a3m_file}\")\n",
+ "else:\n",
+ " a3m_file = f\"{jobname}.single_sequence.a3m\"\n",
+ " with open(a3m_file, \"w\") as text_file:\n",
+ " text_file.write(\">1\\n%s\" % query_sequence)"
+ ],
+ "metadata": {
+ "cellView": "form",
+ "id": "C2_sh2uAonJH"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@markdown ### 参数设置\n",
+ "model_type = \"auto\" #@param [\"auto\", \"AlphaFold2-ptm\", \"AlphaFold2-multimer-v1\", \"AlphaFold2-multimer-v2\"]\n",
+ "#@markdown - \"auto\" = protein structure prediction using \"AlphaFold2-ptm\" and complex prediction \"AlphaFold-multimer-v2\". For complexes \"AlphaFold-multimer-v[1,2]\" and \"AlphaFold-ptm\" can be used.\n",
+ "num_recycles = 3 #@param [1,3,6,12,24,48] {type:\"raw\"}\n",
+ "save_to_google_drive = False #@param {type:\"boolean\"}\n",
+ "\n",
+ "#@markdown - if the save_to_google_drive option was selected, the result zip will be uploaded to your Google Drive\n",
+ "dpi = 200 #@param {type:\"integer\"}\n",
+ "#@markdown - set dpi for image resolution\n",
+ "\n",
+ "#@markdown Don't forget to hit `Runtime` -> `Run all` after updating the form.\n",
+ "\n",
+ "\n",
+ "if save_to_google_drive:\n",
+ " from pydrive.drive import GoogleDrive\n",
+ " from pydrive.auth import GoogleAuth\n",
+ " from google.colab import auth\n",
+ " from oauth2client.client import GoogleCredentials\n",
+ " auth.authenticate_user()\n",
+ " gauth = GoogleAuth()\n",
+ " gauth.credentials = GoogleCredentials.get_application_default()\n",
+ " drive = GoogleDrive(gauth)\n",
+ " print(\"You are logged into Google Drive and are good to go!\")"
+ ],
+ "metadata": {
+ "cellView": "form",
+ "id": "ADDuaolKmjGW"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "iccGdbe_Pmt9",
+ "pycharm": {
+ "name": "#%%\n"
+ },
+ "cellView": "form"
+ },
+ "source": [
+ "#@title 环境安装\n",
+ "%%bash -s $use_amber $use_templates\n",
+ "\n",
+ "set -e\n",
+ "\n",
+ "USE_AMBER=$1\n",
+ "USE_TEMPLATES=$2\n",
+ "\n",
+ "if [ ! -f COLABFOLD_READY ]; then\n",
+ " # install dependencies\n",
+ " # We have to use \"--no-warn-conflicts\" because colab already has a lot preinstalled with requirements different to ours\n",
+ " pip install -q --no-warn-conflicts \"colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold\"\n",
+ " # high risk high gain\n",
+ " pip install -q \"jax[cuda11_cudnn805]>=0.3.8,<0.4\" -f https://storage.googleapis.com/jax-releases/jax_releases.html\n",
+ " touch COLABFOLD_READY\n",
+ "fi\n",
+ "\n",
+ "# setup conda\n",
+ "if [ ${USE_AMBER} == \"True\" ] || [ ${USE_TEMPLATES} == \"True\" ]; then\n",
+ " if [ ! -f CONDA_READY ]; then\n",
+ " wget -qnc https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh\n",
+ " bash Miniconda3-latest-Linux-x86_64.sh -bfp /usr/local 2>&1 1>/dev/null\n",
+ " rm Miniconda3-latest-Linux-x86_64.sh\n",
+ " touch CONDA_READY\n",
+ " fi\n",
+ "fi\n",
+ "# setup template search\n",
+ "if [ ${USE_TEMPLATES} == \"True\" ] && [ ! -f HH_READY ]; then\n",
+ " conda install -y -q -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 python=3.7 2>&1 1>/dev/null\n",
+ " touch HH_READY\n",
+ "fi\n",
+ "# setup openmm for amber refinement\n",
+ "if [ ${USE_AMBER} == \"True\" ] && [ ! -f AMBER_READY ]; then\n",
+ " conda install -y -q -c conda-forge openmm=7.5.1 python=3.7 pdbfixer 2>&1 1>/dev/null\n",
+ " touch AMBER_READY\n",
+ "fi"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "_sztQyz29DIC",
+ "cellView": "form"
+ },
+ "source": [
+ "#@title 模型预测\n",
+ "\n",
+ "import sys\n",
+ "\n",
+ "from colabfold.download import download_alphafold_params, default_data_dir\n",
+ "from colabfold.utils import setup_logging\n",
+ "from colabfold.batch import get_queries, run, set_model_type\n",
+ "K80_chk = !nvidia-smi | grep \"Tesla K80\" | wc -l\n",
+ "if \"1\" in K80_chk:\n",
+ " print(\"WARNING: found GPU Tesla K80: limited to total length < 1000\")\n",
+ " if \"TF_FORCE_UNIFIED_MEMORY\" in os.environ:\n",
+ " del os.environ[\"TF_FORCE_UNIFIED_MEMORY\"]\n",
+ " if \"XLA_PYTHON_CLIENT_MEM_FRACTION\" in os.environ:\n",
+ " del os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]\n",
+ "\n",
+ "from colabfold.colabfold import plot_protein\n",
+ "from pathlib import Path\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "\n",
+ "# For some reason we need that to get pdbfixer to import\n",
+ "if use_amber and '/usr/local/lib/python3.7/site-packages/' not in sys.path:\n",
+ " sys.path.insert(0, '/usr/local/lib/python3.7/site-packages/')\n",
+ "\n",
+ "def prediction_callback(unrelaxed_protein, length, prediction_result, input_features, type):\n",
+ " fig = plot_protein(unrelaxed_protein, Ls=length, dpi=150)\n",
+ " plt.show()\n",
+ " plt.close()\n",
+ "\n",
+ "result_dir=\".\"\n",
+ "setup_logging(Path(\".\").joinpath(\"log.txt\"))\n",
+ "queries, is_complex = get_queries(queries_path)\n",
+ "model_type = set_model_type(is_complex, model_type)\n",
+ "download_alphafold_params(model_type, Path(\".\"))\n",
+ "run(\n",
+ " queries=queries,\n",
+ " result_dir=result_dir,\n",
+ " use_templates=use_templates,\n",
+ " custom_template_path=custom_template_path,\n",
+ " use_amber=use_amber,\n",
+ " msa_mode=msa_mode, \n",
+ " model_type=model_type,\n",
+ " num_models=5,\n",
+ " num_recycles=num_recycles,\n",
+ " model_order=[1, 2, 3, 4, 5],\n",
+ " is_complex=is_complex,\n",
+ " data_dir=Path(\".\"),\n",
+ " keep_existing_results=False,\n",
+ " recompile_padding=1.0,\n",
+ " rank_by=\"auto\",\n",
+ " pair_mode=pair_mode,\n",
+ " stop_at_score=float(100),\n",
+ " prediction_callback=prediction_callback,\n",
+ " dpi=dpi\n",
+ ")"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "KK7X9T44pWb7",
+ "cellView": "form"
+ },
+ "source": [
+ "#@title 展示3维结构 {run: \"auto\"}\n",
+ "import py3Dmol\n",
+ "import glob\n",
+ "import matplotlib.pyplot as plt\n",
+ "from colabfold.colabfold import plot_plddt_legend\n",
+ "rank_num = 1 #@param [\"1\", \"2\", \"3\", \"4\", \"5\"] {type:\"raw\"}\n",
+ "color = \"lDDT\" #@param [\"chain\", \"lDDT\", \"rainbow\"]\n",
+ "show_sidechains = False #@param {type:\"boolean\"}\n",
+ "show_mainchains = False #@param {type:\"boolean\"}\n",
+ "\n",
+ "jobname_prefix = \".custom\" if msa_mode == \"custom\" else \"\"\n",
+ "if use_amber:\n",
+ " pdb_filename = f\"{jobname}{jobname_prefix}_relaxed_rank_{rank_num}_model_*.pdb\"\n",
+ "else:\n",
+ " pdb_filename = f\"{jobname}{jobname_prefix}_unrelaxed_rank_{rank_num}_model_*.pdb\"\n",
+ "\n",
+ "pdb_file = glob.glob(pdb_filename)\n",
+ "\n",
+ "def show_pdb(rank_num=1, show_sidechains=False, show_mainchains=False, color=\"lDDT\"):\n",
+ " model_name = f\"rank_{rank_num}\"\n",
+ " view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)\n",
+ " view.addModel(open(pdb_file[0],'r').read(),'pdb')\n",
+ "\n",
+ " if color == \"lDDT\":\n",
+ " view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}})\n",
+ " elif color == \"rainbow\":\n",
+ " view.setStyle({'cartoon': {'color':'spectrum'}})\n",
+ " elif color == \"chain\":\n",
+ " chains = len(queries[0][1]) + 1 if is_complex else 1\n",
+ " for n,chain,color in zip(range(chains),list(\"ABCDEFGH\"),\n",
+ " [\"lime\",\"cyan\",\"magenta\",\"yellow\",\"salmon\",\"white\",\"blue\",\"orange\"]):\n",
+ " view.setStyle({'chain':chain},{'cartoon': {'color':color}})\n",
+ " if show_sidechains:\n",
+ " BB = ['C','O','N']\n",
+ " view.addStyle({'and':[{'resn':[\"GLY\",\"PRO\"],'invert':True},{'atom':BB,'invert':True}]},\n",
+ " {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
+ " view.addStyle({'and':[{'resn':\"GLY\"},{'atom':'CA'}]},\n",
+ " {'sphere':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
+ " view.addStyle({'and':[{'resn':\"PRO\"},{'atom':['C','O'],'invert':True}]},\n",
+ " {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}}) \n",
+ " if show_mainchains:\n",
+ " BB = ['C','O','N','CA']\n",
+ " view.addStyle({'atom':BB},{'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
+ "\n",
+ " view.zoomTo()\n",
+ " return view\n",
+ "\n",
+ "\n",
+ "show_pdb(rank_num,show_sidechains, show_mainchains, color).show()\n",
+ "if color == \"lDDT\":\n",
+ " plot_plddt_legend().show() "
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "11l8k--10q0C",
+ "cellView": "form"
+ },
+ "source": [
+ "#@title 图表展示 {run: \"auto\"}\n",
+ "from IPython.display import display, HTML\n",
+ "import base64\n",
+ "from html import escape\n",
+ "\n",
+ "# see: https://stackoverflow.com/a/53688522\n",
+ "def image_to_data_url(filename):\n",
+ " ext = filename.split('.')[-1]\n",
+ " prefix = f'data:image/{ext};base64,'\n",
+ " with open(filename, 'rb') as f:\n",
+ " img = f.read()\n",
+ " return prefix + base64.b64encode(img).decode('utf-8')\n",
+ "\n",
+ "pae = image_to_data_url(f\"{jobname}{jobname_prefix}_PAE.png\")\n",
+ "cov = image_to_data_url(f\"{jobname}{jobname_prefix}_coverage.png\")\n",
+ "plddt = image_to_data_url(f\"{jobname}{jobname_prefix}_plddt.png\")\n",
+ "display(HTML(f\"\"\"\n",
+ "\n",
+ "\n",
+ "
Plots for {escape(jobname)}
\n",
+ "

\n",
+ "

\n",
+ "

\n",
+ "
\n",
+ "\"\"\"))\n"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "33g5IIegij5R",
+ "cellView": "form"
+ },
+ "source": [
+ "#@title结果下载\n",
+ "#@markdown If you are having issues downloading the result archive, try disabling your adblocker and run this cell again. If that fails click on the little folder icon to the left, navigate to file: `jobname.result.zip`, right-click and select \\\"Download\\\" (see [screenshot](https://pbs.twimg.com/media/E6wRW2lWUAEOuoe?format=jpg&name=small)).\n",
+ "\n",
+ "if msa_mode == \"custom\":\n",
+ " print(\"Don't forget to cite your custom MSA generation method.\")\n",
+ "\n",
+ "!zip -FSr $jobname\".result.zip\" config.json $jobname*\".json\" $jobname*\".a3m\" $jobname*\"relaxed_rank_\"*\".pdb\" \"cite.bibtex\" $jobname*\".png\"\n",
+ "files.download(f\"{jobname}.result.zip\")\n",
+ "\n",
+ "if save_to_google_drive == True and drive:\n",
+ " uploaded = drive.CreateFile({'title': f\"{jobname}.result.zip\"})\n",
+ " uploaded.SetContentFile(f\"{jobname}.result.zip\")\n",
+ " uploaded.Upload()\n",
+ " print(f\"Uploaded {jobname}.result.zip to Google Drive with ID {uploaded.get('id')}\")"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "UGUBLzB3C6WN",
+ "pycharm": {
+ "name": "#%% md\n"
+ }
+ },
+ "source": [
+ "# 操作指南 \n",
+ "**Quick start**\n",
+ "1. 把你要预测的氨基酸序列复制进输入框.\n",
+ "2. 点击\"Runtime\" -> \"Run all\".\n",
+ "3. 目前的模型预测包括5个模块,最终会生成一个3维结构图.\n",
+ "\n",
+ "**生成文件**\n",
+ "\n",
+ "1. PDB格式的模型结构文件.\n",
+ "2. 模型质量图\n",
+ "3. 模型MSA覆盖率.\n",
+ "4. 其他.\n",
+ "\n",
+ "**Acknowledgments**\n",
+ "- We thank the AlphaFold team for developing an excellent model and open sourcing the software. \n",
+ "\n",
+ "- [Söding Lab](https://www.mpibpc.mpg.de/soeding) for providing the computational resources for the MMseqs2 server\n",
+ "\n",
+ "- Richard Evans for helping to benchmark the ColabFold's Alphafold-multimer support\n",
+ "\n",
+ "- [David Koes](https://github.com/dkoes) for his awesome [py3Dmol](https://3dmol.csb.pitt.edu/) plugin, without whom these notebooks would be quite boring!\n",
+ "\n",
+ "- Do-Yoon Kim for creating the ColabFold logo.\n",
+ "\n",
+ "- A colab by Sergey Ovchinnikov ([@sokrypton](https://twitter.com/sokrypton)), Milot Mirdita ([@milot_mirdita](https://twitter.com/milot_mirdita)) and Martin Steinegger ([@thesteinegger](https://twitter.com/thesteinegger)).\n"
+ ]
+ }
+ ]
+}
\ No newline at end of file
diff --git a/reproduce/AlphaFold2-Chinese/Fold_CN.ipynb.txt b/reproduce/AlphaFold2-Chinese/Fold_CN.ipynb.txt
new file mode 100644
index 0000000..56496e5
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/Fold_CN.ipynb.txt
@@ -0,0 +1,478 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "name": "AlphaFold2_CN.ipynb",
+ "provenance": [],
+ "collapsed_sections": [],
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.10"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "G4yBrceuFbf3"
+ },
+ "source": [
+ "
\n",
+ "\n",
+ "##ColabFold: AlphaFold2 using MMseqs2\n",
+ "\n",
+ "简单的中文版蛋白质结构预测操作指南(中文版),基于[AlphaFold2](https://www.nature.com/articles/s41586-021-03819-2)和[Alphafold2-multimer](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1). 序列比对方式基于[MMseqs2](mmseqs.com)和[HHsearch](https://github.com/soedinglab/hh-suite)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "kOblAo-xetgx",
+ "cellView": "form"
+ },
+ "source": [
+ "#输入蛋白质序列(默认为视频中的测试序列)然后点击上边框的 `Runtime` -> `Run all`\n",
+ "from google.colab import files\n",
+ "import os.path\n",
+ "import re\n",
+ "import hashlib\n",
+ "import random\n",
+ "\n",
+ "def add_hash(x,y):\n",
+ " return x+\"_\"+hashlib.sha1(y.encode()).hexdigest()[:5]\n",
+ "\n",
+ "query_sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n",
+ "#@markdown - Use `:` to specify inter-protein chainbreaks for **modeling complexes** (supports homo- and hetro-oligomers). For example **PI...SK:PI...SK** for a homodimer\n",
+ "\n",
+ "# remove whitespaces\n",
+ "query_sequence = \"\".join(query_sequence.split())\n",
+ "\n",
+ "jobname = 'test' #@param {type:\"string\"}\n",
+ "# remove whitespaces\n",
+ "basejobname = \"\".join(jobname.split())\n",
+ "basejobname = re.sub(r'\\W+', '', basejobname)\n",
+ "jobname = add_hash(basejobname, query_sequence)\n",
+ "while os.path.isfile(f\"{jobname}.csv\"):\n",
+ " jobname = add_hash(basejobname, ''.join(random.sample(query_sequence,len(query_sequence))))\n",
+ "\n",
+ "with open(f\"{jobname}.csv\", \"w\") as text_file:\n",
+ " text_file.write(f\"id,sequence\\n{jobname},{query_sequence}\")\n",
+ "\n",
+ "queries_path=f\"{jobname}.csv\"\n",
+ "\n",
+ "# number of models to use\n",
+ "use_amber = False #@param {type:\"boolean\"}\n",
+ "template_mode = \"none\" #@param [\"none\", \"pdb70\",\"custom\"]\n",
+ "#@markdown - \"none\" = no template information is used, \"pdb70\" = detect templates in pdb70, \"custom\" - upload and search own templates (PDB or mmCIF format, see [notes below](#custom_templates))\n",
+ "\n",
+ "if template_mode == \"pdb70\":\n",
+ " use_templates = True\n",
+ " custom_template_path = None\n",
+ "elif template_mode == \"custom\":\n",
+ " custom_template_path = f\"{jobname}_template\"\n",
+ " os.mkdir(custom_template_path)\n",
+ " uploaded = files.upload()\n",
+ " use_templates = True\n",
+ " for fn in uploaded.keys():\n",
+ " os.rename(fn, f\"{jobname}_template/{fn}\")\n",
+ "else:\n",
+ " custom_template_path = None\n",
+ " use_templates = False\n"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@markdown ### MSA options (custom MSA upload, single sequence, pairing mode)\n",
+ "msa_mode = \"MMseqs2 (UniRef+Environmental)\" #@param [\"MMseqs2 (UniRef+Environmental)\", \"MMseqs2 (UniRef only)\",\"single_sequence\",\"custom\"]\n",
+ "pair_mode = \"unpaired+paired\" #@param [\"unpaired+paired\",\"paired\",\"unpaired\"] {type:\"string\"}\n",
+ "#@markdown - \"unpaired+paired\" = pair sequences from same species + unpaired MSA, \"unpaired\" = seperate MSA for each chain, \"paired\" - only use paired sequences.\n",
+ "\n",
+ "# decide which a3m to use\n",
+ "if msa_mode.startswith(\"MMseqs2\"):\n",
+ " a3m_file = f\"{jobname}.a3m\"\n",
+ "elif msa_mode == \"custom\":\n",
+ " a3m_file = f\"{jobname}.custom.a3m\"\n",
+ " if not os.path.isfile(a3m_file):\n",
+ " custom_msa_dict = files.upload()\n",
+ " custom_msa = list(custom_msa_dict.keys())[0]\n",
+ " header = 0\n",
+ " import fileinput\n",
+ " for line in fileinput.FileInput(custom_msa,inplace=1):\n",
+ " if line.startswith(\">\"):\n",
+ " header = header + 1\n",
+ " if not line.rstrip():\n",
+ " continue\n",
+ " if line.startswith(\">\") == False and header == 1:\n",
+ " query_sequence = line.rstrip()\n",
+ " print(line, end='')\n",
+ "\n",
+ " os.rename(custom_msa, a3m_file)\n",
+ " queries_path=a3m_file\n",
+ " print(f\"moving {custom_msa} to {a3m_file}\")\n",
+ "else:\n",
+ " a3m_file = f\"{jobname}.single_sequence.a3m\"\n",
+ " with open(a3m_file, \"w\") as text_file:\n",
+ " text_file.write(\">1\\n%s\" % query_sequence)"
+ ],
+ "metadata": {
+ "cellView": "form",
+ "id": "C2_sh2uAonJH"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@markdown ### Advanced settings\n",
+ "model_type = \"auto\" #@param [\"auto\", \"AlphaFold2-ptm\", \"AlphaFold2-multimer-v1\", \"AlphaFold2-multimer-v2\"]\n",
+ "#@markdown - \"auto\" = protein structure prediction using \"AlphaFold2-ptm\" and complex prediction \"AlphaFold-multimer-v2\". For complexes \"AlphaFold-multimer-v[1,2]\" and \"AlphaFold-ptm\" can be used.\n",
+ "num_recycles = 3 #@param [1,3,6,12,24,48] {type:\"raw\"}\n",
+ "save_to_google_drive = False #@param {type:\"boolean\"}\n",
+ "\n",
+ "#@markdown - if the save_to_google_drive option was selected, the result zip will be uploaded to your Google Drive\n",
+ "dpi = 200 #@param {type:\"integer\"}\n",
+ "#@markdown - set dpi for image resolution\n",
+ "\n",
+ "#@markdown Don't forget to hit `Runtime` -> `Run all` after updating the form.\n",
+ "\n",
+ "\n",
+ "if save_to_google_drive:\n",
+ " from pydrive.drive import GoogleDrive\n",
+ " from pydrive.auth import GoogleAuth\n",
+ " from google.colab import auth\n",
+ " from oauth2client.client import GoogleCredentials\n",
+ " auth.authenticate_user()\n",
+ " gauth = GoogleAuth()\n",
+ " gauth.credentials = GoogleCredentials.get_application_default()\n",
+ " drive = GoogleDrive(gauth)\n",
+ " print(\"You are logged into Google Drive and are good to go!\")"
+ ],
+ "metadata": {
+ "cellView": "form",
+ "id": "ADDuaolKmjGW"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "iccGdbe_Pmt9",
+ "pycharm": {
+ "name": "#%%\n"
+ },
+ "cellView": "form"
+ },
+ "source": [
+ "#@title Install dependencies\n",
+ "%%bash -s $use_amber $use_templates\n",
+ "\n",
+ "set -e\n",
+ "\n",
+ "USE_AMBER=$1\n",
+ "USE_TEMPLATES=$2\n",
+ "\n",
+ "if [ ! -f COLABFOLD_READY ]; then\n",
+ " # install dependencies\n",
+ " # We have to use \"--no-warn-conflicts\" because colab already has a lot preinstalled with requirements different to ours\n",
+ " pip install -q --no-warn-conflicts \"colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold\"\n",
+ " # high risk high gain\n",
+ " pip install -q \"jax[cuda11_cudnn805]>=0.3.8,<0.4\" -f https://storage.googleapis.com/jax-releases/jax_releases.html\n",
+ " touch COLABFOLD_READY\n",
+ "fi\n",
+ "\n",
+ "# setup conda\n",
+ "if [ ${USE_AMBER} == \"True\" ] || [ ${USE_TEMPLATES} == \"True\" ]; then\n",
+ " if [ ! -f CONDA_READY ]; then\n",
+ " wget -qnc https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh\n",
+ " bash Miniconda3-latest-Linux-x86_64.sh -bfp /usr/local 2>&1 1>/dev/null\n",
+ " rm Miniconda3-latest-Linux-x86_64.sh\n",
+ " touch CONDA_READY\n",
+ " fi\n",
+ "fi\n",
+ "# setup template search\n",
+ "if [ ${USE_TEMPLATES} == \"True\" ] && [ ! -f HH_READY ]; then\n",
+ " conda install -y -q -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 python=3.7 2>&1 1>/dev/null\n",
+ " touch HH_READY\n",
+ "fi\n",
+ "# setup openmm for amber refinement\n",
+ "if [ ${USE_AMBER} == \"True\" ] && [ ! -f AMBER_READY ]; then\n",
+ " conda install -y -q -c conda-forge openmm=7.5.1 python=3.7 pdbfixer 2>&1 1>/dev/null\n",
+ " touch AMBER_READY\n",
+ "fi"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "_sztQyz29DIC",
+ "cellView": "form"
+ },
+ "source": [
+ "#@title Run Prediction\n",
+ "\n",
+ "import sys\n",
+ "\n",
+ "from colabfold.download import download_alphafold_params, default_data_dir\n",
+ "from colabfold.utils import setup_logging\n",
+ "from colabfold.batch import get_queries, run, set_model_type\n",
+ "K80_chk = !nvidia-smi | grep \"Tesla K80\" | wc -l\n",
+ "if \"1\" in K80_chk:\n",
+ " print(\"WARNING: found GPU Tesla K80: limited to total length < 1000\")\n",
+ " if \"TF_FORCE_UNIFIED_MEMORY\" in os.environ:\n",
+ " del os.environ[\"TF_FORCE_UNIFIED_MEMORY\"]\n",
+ " if \"XLA_PYTHON_CLIENT_MEM_FRACTION\" in os.environ:\n",
+ " del os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]\n",
+ "\n",
+ "from colabfold.colabfold import plot_protein\n",
+ "from pathlib import Path\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "\n",
+ "# For some reason we need that to get pdbfixer to import\n",
+ "if use_amber and '/usr/local/lib/python3.7/site-packages/' not in sys.path:\n",
+ " sys.path.insert(0, '/usr/local/lib/python3.7/site-packages/')\n",
+ "\n",
+ "def prediction_callback(unrelaxed_protein, length, prediction_result, input_features, type):\n",
+ " fig = plot_protein(unrelaxed_protein, Ls=length, dpi=150)\n",
+ " plt.show()\n",
+ " plt.close()\n",
+ "\n",
+ "result_dir=\".\"\n",
+ "setup_logging(Path(\".\").joinpath(\"log.txt\"))\n",
+ "queries, is_complex = get_queries(queries_path)\n",
+ "model_type = set_model_type(is_complex, model_type)\n",
+ "download_alphafold_params(model_type, Path(\".\"))\n",
+ "run(\n",
+ " queries=queries,\n",
+ " result_dir=result_dir,\n",
+ " use_templates=use_templates,\n",
+ " custom_template_path=custom_template_path,\n",
+ " use_amber=use_amber,\n",
+ " msa_mode=msa_mode, \n",
+ " model_type=model_type,\n",
+ " num_models=5,\n",
+ " num_recycles=num_recycles,\n",
+ " model_order=[1, 2, 3, 4, 5],\n",
+ " is_complex=is_complex,\n",
+ " data_dir=Path(\".\"),\n",
+ " keep_existing_results=False,\n",
+ " recompile_padding=1.0,\n",
+ " rank_by=\"auto\",\n",
+ " pair_mode=pair_mode,\n",
+ " stop_at_score=float(100),\n",
+ " prediction_callback=prediction_callback,\n",
+ " dpi=dpi\n",
+ ")"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "KK7X9T44pWb7",
+ "cellView": "form"
+ },
+ "source": [
+ "#@title Display 3D structure {run: \"auto\"}\n",
+ "import py3Dmol\n",
+ "import glob\n",
+ "import matplotlib.pyplot as plt\n",
+ "from colabfold.colabfold import plot_plddt_legend\n",
+ "rank_num = 1 #@param [\"1\", \"2\", \"3\", \"4\", \"5\"] {type:\"raw\"}\n",
+ "color = \"lDDT\" #@param [\"chain\", \"lDDT\", \"rainbow\"]\n",
+ "show_sidechains = False #@param {type:\"boolean\"}\n",
+ "show_mainchains = False #@param {type:\"boolean\"}\n",
+ "\n",
+ "jobname_prefix = \".custom\" if msa_mode == \"custom\" else \"\"\n",
+ "if use_amber:\n",
+ " pdb_filename = f\"{jobname}{jobname_prefix}_relaxed_rank_{rank_num}_model_*.pdb\"\n",
+ "else:\n",
+ " pdb_filename = f\"{jobname}{jobname_prefix}_unrelaxed_rank_{rank_num}_model_*.pdb\"\n",
+ "\n",
+ "pdb_file = glob.glob(pdb_filename)\n",
+ "\n",
+ "def show_pdb(rank_num=1, show_sidechains=False, show_mainchains=False, color=\"lDDT\"):\n",
+ " model_name = f\"rank_{rank_num}\"\n",
+ " view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)\n",
+ " view.addModel(open(pdb_file[0],'r').read(),'pdb')\n",
+ "\n",
+ " if color == \"lDDT\":\n",
+ " view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}})\n",
+ " elif color == \"rainbow\":\n",
+ " view.setStyle({'cartoon': {'color':'spectrum'}})\n",
+ " elif color == \"chain\":\n",
+ " chains = len(queries[0][1]) + 1 if is_complex else 1\n",
+ " for n,chain,color in zip(range(chains),list(\"ABCDEFGH\"),\n",
+ " [\"lime\",\"cyan\",\"magenta\",\"yellow\",\"salmon\",\"white\",\"blue\",\"orange\"]):\n",
+ " view.setStyle({'chain':chain},{'cartoon': {'color':color}})\n",
+ " if show_sidechains:\n",
+ " BB = ['C','O','N']\n",
+ " view.addStyle({'and':[{'resn':[\"GLY\",\"PRO\"],'invert':True},{'atom':BB,'invert':True}]},\n",
+ " {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
+ " view.addStyle({'and':[{'resn':\"GLY\"},{'atom':'CA'}]},\n",
+ " {'sphere':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
+ " view.addStyle({'and':[{'resn':\"PRO\"},{'atom':['C','O'],'invert':True}]},\n",
+ " {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}}) \n",
+ " if show_mainchains:\n",
+ " BB = ['C','O','N','CA']\n",
+ " view.addStyle({'atom':BB},{'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
+ "\n",
+ " view.zoomTo()\n",
+ " return view\n",
+ "\n",
+ "\n",
+ "show_pdb(rank_num,show_sidechains, show_mainchains, color).show()\n",
+ "if color == \"lDDT\":\n",
+ " plot_plddt_legend().show() "
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "11l8k--10q0C",
+ "cellView": "form"
+ },
+ "source": [
+ "#@title Plots {run: \"auto\"}\n",
+ "from IPython.display import display, HTML\n",
+ "import base64\n",
+ "from html import escape\n",
+ "\n",
+ "# see: https://stackoverflow.com/a/53688522\n",
+ "def image_to_data_url(filename):\n",
+ " ext = filename.split('.')[-1]\n",
+ " prefix = f'data:image/{ext};base64,'\n",
+ " with open(filename, 'rb') as f:\n",
+ " img = f.read()\n",
+ " return prefix + base64.b64encode(img).decode('utf-8')\n",
+ "\n",
+ "pae = image_to_data_url(f\"{jobname}{jobname_prefix}_PAE.png\")\n",
+ "cov = image_to_data_url(f\"{jobname}{jobname_prefix}_coverage.png\")\n",
+ "plddt = image_to_data_url(f\"{jobname}{jobname_prefix}_plddt.png\")\n",
+ "display(HTML(f\"\"\"\n",
+ "\n",
+ "\n",
+ "
Plots for {escape(jobname)}
\n",
+ "

\n",
+ "

\n",
+ "

\n",
+ "
\n",
+ "\"\"\"))\n"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "33g5IIegij5R",
+ "cellView": "form"
+ },
+ "source": [
+ "#@title Package and download results\n",
+ "#@markdown If you are having issues downloading the result archive, try disabling your adblocker and run this cell again. If that fails click on the little folder icon to the left, navigate to file: `jobname.result.zip`, right-click and select \\\"Download\\\" (see [screenshot](https://pbs.twimg.com/media/E6wRW2lWUAEOuoe?format=jpg&name=small)).\n",
+ "\n",
+ "if msa_mode == \"custom\":\n",
+ " print(\"Don't forget to cite your custom MSA generation method.\")\n",
+ "\n",
+ "!zip -FSr $jobname\".result.zip\" config.json $jobname*\".json\" $jobname*\".a3m\" $jobname*\"relaxed_rank_\"*\".pdb\" \"cite.bibtex\" $jobname*\".png\"\n",
+ "files.download(f\"{jobname}.result.zip\")\n",
+ "\n",
+ "if save_to_google_drive == True and drive:\n",
+ " uploaded = drive.CreateFile({'title': f\"{jobname}.result.zip\"})\n",
+ " uploaded.SetContentFile(f\"{jobname}.result.zip\")\n",
+ " uploaded.Upload()\n",
+ " print(f\"Uploaded {jobname}.result.zip to Google Drive with ID {uploaded.get('id')}\")"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "UGUBLzB3C6WN",
+ "pycharm": {
+ "name": "#%% md\n"
+ }
+ },
+ "source": [
+ "# 操作指南 \n",
+ "**Quick start**\n",
+ "1. 把你要预测的氨基酸序列复制进输入框.\n",
+ "2. 点击\"Runtime\" -> \"Run all\".\n",
+ "3. 目前的模型预测包括5个模块,最终会生成一个3维结构图.\n",
+ "\n",
+ "**生成文件**\n",
+ "\n",
+ "1. PDB格式的模型结构文件.\n",
+ "2. 模型质量图\n",
+ "3. 模型MSA覆盖率.\n",
+ "4. 其他.\n",
+ "**Acknowledgments**\n",
+ "- We thank the AlphaFold team for developing an excellent model and open sourcing the software. \n",
+ "\n",
+ "- [Söding Lab](https://www.mpibpc.mpg.de/soeding) for providing the computational resources for the MMseqs2 server\n",
+ "\n",
+ "- Richard Evans for helping to benchmark the ColabFold's Alphafold-multimer support\n",
+ "\n",
+ "- [David Koes](https://github.com/dkoes) for his awesome [py3Dmol](https://3dmol.csb.pitt.edu/) plugin, without whom these notebooks would be quite boring!\n",
+ "\n",
+ "- Do-Yoon Kim for creating the ColabFold logo.\n",
+ "\n",
+ "- A colab by Sergey Ovchinnikov ([@sokrypton](https://twitter.com/sokrypton)), Milot Mirdita ([@milot_mirdita](https://twitter.com/milot_mirdita)) and Martin Steinegger ([@thesteinegger](https://twitter.com/thesteinegger)).\n"
+ ]
+ }
+ ]
+}
\ No newline at end of file
diff --git a/reproduce/AlphaFold2-Chinese/LICENSE b/reproduce/AlphaFold2-Chinese/LICENSE
new file mode 100644
index 0000000..261eeb9
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/reproduce/AlphaFold2-Chinese/README.md b/reproduce/AlphaFold2-Chinese/README.md
new file mode 100644
index 0000000..38ec153
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/README.md
@@ -0,0 +1,221 @@
+# AlphaFold2-Chinese
+中文版AlphaFold2开源模型复现-基于DeepMind&ColabFold&Mindspore:
+* https://github.com/sokrypton/ColabFold
+* https://github.com/deepmind/alphafold
+* https://gitee.com/mindspore/mindscience/tree/master/MindSPONGE/mindsponge/fold
+
+
+

+
+
+# 目录
+
+
+
+- [目录](#目录)
+ - [模型描述](#模型描述)
+ - [环境要求](#环境要求)
+ - [硬件环境与框架](#硬件环境与框架)
+ - [MMseqs2安装](#mmseqs2安装)
+ - [MindSpore Serving安装](#mindspore_serving安装)
+ - [数据准备](#数据准备)
+ - [MSA所需数据库](#msa所需数据库)
+ - [Template所需工具和数据](#template所需工具和数据)
+ - [数据](#数据)
+ - [工具](#工具)
+ - [脚本说明](#脚本说明)
+ - [脚本及样例代码](#脚本及样例代码)
+ - [推理示例](#推理示例)
+ - [推理过程](#推理过程)
+ - [推理结果](#推理结果)
+ - [推理性能](#推理性能)
+ - [TMscore对比图](#tmscore对比图)
+ - [预测结果对比图](#预测结果对比图)
+ - [引用](#引用)
+
+
+
+## 模型描述
+
+蛋白质结构预测工具是利用计算机高效计算获取蛋白质空间结构的软件。该计算方法一直存在精度不足的缺陷,直至2020年谷歌DeepMind团队的[AlphaFold2](https://www.nature.com/articles/s41586-021-03819-2)【1】【2】取得CASP14比赛中蛋白质3D结构预测的榜首,才让这一缺陷得以弥补。本次开源的蛋白质结构预测推理工具模型部分与其相同,在多序列比对阶段,采用了[MMseqs2](https://www.biorxiv.org/content/10.1101/2021.08.15.456425v1.full.pdf)【3】进行序列检索,相比于原版算法端到端运算速度有2-3倍提升。
+
+## 环境要求
+
+### 硬件环境与框架
+
+本代码运行基于Ascend处理器硬件环境与[MindSpore](https://www.mindspore.cn/) AI框架,当前版本需基于最新库上master代码(2021-11-08之后的代码)[编译](https://www.mindspore.cn/install/detail?path=install/r1.5/mindspore_ascend_install_source.md&highlight=%E6%BA%90%E7%A0%81%E7%BC%96%E8%AF%91),
+MindSpore环境参见[MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html),环境安装后需要运行以下命令配置环境变量:
+
+``` shell
+export MS_DEV_ENABLE_FALLBACK=0
+```
+
+其余python依赖请参见[requirements.txt](https://gitee.com/mindspore/mindscience/tree/master/MindSPONGE/mindsponge/fold/requirements.txt)。
+
+### MMseqs2安装
+
+MMseqs2用于生成多序列比对(multiple sequence alignments,MSA),MMseqs2安装和使用可以参考[MMseqs2 User Guide](https://mmseqs.com/latest/userguide.pdf),安装完成后需要运行以下命令配置环境变量:
+
+``` shell
+export PATH=$(pwd)/mmseqs/bin/:$PATH
+```
+
+### MindSpore Serving安装
+
+我们提供以服务模式运行推理,该模式使用MindSpore Serving提供高效推理服务,多条序列推理时避免重复编译,大幅提高推理效率,MindSpore Serving安装和配置可以参考[MindSpore Serving安装页面](https://www.mindspore.cn/serving/docs/zh-CN/r1.5/serving_install.html)。
+
+## 数据准备
+
+### MSA所需数据库
+
+- [uniref30_2103](http://wwwuser.gwdg.de/~compbiol/colabfold/uniref30_2103.tar.gz):375G(下载68G)
+- [colabfold_envdb_202108](http://wwwuser.gwdg.de/~compbiol/colabfold/colabfold_envdb_202108.tar.gz):949G(下载110G)
+
+数据处理参考[colabfold](http://colabfold.mmseqs.com)。
+
+### Template所需工具和数据
+
+#### 数据
+
+- [pdb70](http://wwwuser.gwdg.de/~compbiol/data/hhsuite/databases/hhsuite_dbs/old-releases/pdb70_from_mmcif_200401.tar.gz):56G(下载19G)
+- [mmcif database](https://ftp.rcsb.org/pub/pdb/data/structures/divided/mmCIF/): 206G(下载48G)
+- [obsolete_pdbs](http://ftp.wwpdb.org/pub/pdb/data/status/obsolete.dat):140K
+
+#### 工具
+
+- [HHsearch](https://github.com/soedinglab/hh-suite)
+- [kalign](https://msa.sbc.su.se/downloads/kalign/current.tar.gz)
+
+## 脚本说明
+
+### 脚本及样例代码
+
+```bash
+├── mindscience
+ ├── MindSPONGE
+ ├── mindsponge
+ ├── fold
+ ├── README_CN.md // fold 相关中文说明
+ ├── run.py // 推理脚本
+ ├── model.py // 主模型
+ ├── requirements.txt // 依赖包
+ ├── serving_server.py // 服务模式服务端脚本
+ ├── serving_cline.py // 服务模式客户端脚本
+ ├── fold_service
+ ├── servable_config.py // 服务模式配置脚本
+ ├── module
+ ├── basic_module.py // 基础模块
+ ├── evoformer_module.py // evoformer模块
+ ├── structure_module.py // 结构模块
+ ├── data
+ ├── feature
+ ├── data_transforms.py //msa和template数据处理
+ ├── feature_extraction.py //msa和template特征提取
+ ├── tools
+ ├── data_process.py // 搜索msa和template
+ ├── data_tools.py // 数据处理脚本
+ ├── mmcif_parsing.py // mmcif解析脚本
+ ├── msa_search.sh // mmseqs2搜索msa的shell脚本
+ ├── parsers.py // 解析文件脚本
+ ├── templates.py // 模板搜索脚本
+ ├── config
+ ├── config.py //参数配置脚本
+ ├── global_config.py //全局参数配置脚本
+ ├── common
+ ├── generate_pdb.py // 生成pdb
+ ├── r3.py // 3D坐标转换
+ ├── residue_constants.py // 氨基酸残基常量
+ ├── utils.py // 功能函数
+```
+
+### 推理示例
+
+```bash
+用法:run.py [--seq_length PADDING_SEQENCE_LENGTH]
+ [--input_fasta_path INPUT_PATH][--msa_result_path MSA_RESULT_PATH]
+ [--database_dir DATABASE_PATH][--database_envdb_dir DATABASE_ENVDB_PATH]
+ [--hhsearch_binary_path HHSEARCH_PATH][--pdb70_database_path PDB70_PATH]
+ [--template_mmcif_dir TEMPLATE_PATH][--max_template_date TRMPLATE_DATE]
+ [--kalign_binary_path KALIGN_PATH][--obsolete_pdbs_path OBSOLETE_PATH]
+
+
+选项:
+ --seq_length 补零后序列长度,目前支持256/512/1024/2048
+ --input_fasta_path FASTA文件,用于预测蛋白质结构的蛋白质序列
+ --msa_result_path 保存mmseqs2检索得到的msa结果路径
+ --database_dir 搜索msa时的数据库
+ --database_envdb_dir 搜索msa时的扩展数据库
+ --hhsearch_binary_path hhsearch可执行文件路径
+ --pdb70_database_path 供hhsearch使用的pdb70数据库路径
+ --template_mmcif_dir 具有mmcif结构模板的路径
+ --max_template_date 模板最新发布的时间
+ --kalign_binary_path kalign可执行文件路径
+ --obsolete_pdbs_path PDB IDs的映射文件路径
+```
+
+### 推理过程
+
+ 加载alphafold checkpoint,下载地址[点击这里](https://download.mindspore.cn/model_zoo/research/hpc/molecular_dynamics/protein_fold_1.ckpt),根据自身需求选择合适蛋白质序列配置,当前提供256/512/1024/2048四个标准配置,推理过程如下:
+
+1. 输入参数需要通过`fold_service/config.py`配置,参数含义参见[推理示例](#推理示例)
+
+2. 参数配置好后,先使用`serving_server.py`启动服务端进程,进程成功启动时log显示如下:
+
+ ``` log
+ Serving: Serving gRPC server start success, listening on 127.0.0.1:5500
+ Serving: Serving RESTful server start success, listening on 127.0.0.1:1500
+ ```
+
+3. 服务端进程成功启动后运行`serving_client.py`即可进行推理,第一次推理需要编译
+
+#### 推理结果
+
+推理结果保存在 `./result` 中,共有两个文件,其中的pdb文件即为蛋白质结构预测结果,timings文件保存了运行过程中的时间信息和confidence信息。
+
+```bash
+{"pre_process_time": 418.57, "model_time": 122.86, "pos_process_time": 0.14, "all_time ": 541.56, "confidence ": 94.61789646019058}
+```
+
+## 推理性能
+
+| 参数 | Fold(Ascend) |
+| ------------------- | --------------------------- |
+| 模型版本 | AlphaFold |
+| 资源 | Ascend 910 |
+| 上传日期 | 2021-11-05 |
+| MindSpore版本 | master |
+| 数据集 | CASP14 T1079 |
+| seq_length | 505 |
+| confidence | 94.62 |
+| TM-score | 98.01% |
+|运行时间|541.56s|
+
+### TMscore对比图
+
+- 34条CASP14结果与alphafold2对比:
+
+
+

+
+
+### 预测结果对比图
+
+- T1079(长度505):
+
+
+

+
+
+- T1044(长度2180):
+
+
+

+
+
+## 引用
+
+[1] Jumper J, Evans R, Pritzel A, et al. Applying and improving AlphaFold at CASP14[J]. Proteins: Structure, Function, and Bioinformatics, 2021.
+
+[2] Jumper J, Evans R, Pritzel A, et al. Highly accurate protein structure prediction with AlphaFold[J]. Nature, 2021, 596(7873): 583-589.
+
+[3] Mirdita M, Ovchinnikov S, Steinegger M. ColabFold-Making protein folding accessible to all[J]. BioRxiv, 2021.
\ No newline at end of file
diff --git a/reproduce/AlphaFold2-Chinese/commons/generate_pdb.py b/reproduce/AlphaFold2-Chinese/commons/generate_pdb.py
new file mode 100644
index 0000000..4f595cf
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/commons/generate_pdb.py
@@ -0,0 +1,118 @@
+"""generate pdb file"""
+
+import dataclasses
+import numpy as np
+
+from commons import residue_constants
+
+
+@dataclasses.dataclass(frozen=True)
+class Protein:
+ """Protein structure representation."""
+
+ # Cartesian coordinates of atoms in angstroms. The atom types correspond to
+ # residue_constants.atom_types, i.e. the first three are N, CA, CB.
+ atom_positions: np.ndarray # [num_res, num_atom_type, 3]
+
+ # Amino-acid type for each residue represented as an integer between 0 and
+ # 20, where 20 is 'X'.
+ aatype: np.ndarray # [num_res]
+
+ # Binary float mask to indicate presence of a particular atom. 1.0 if an atom
+ # is present and 0.0 if not. This should be used for loss masking.
+ atom_mask: np.ndarray # [num_res, num_atom_type]
+
+ # Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
+ residue_index: np.ndarray # [num_res]
+
+ # B-factors, or temperature factors, of each residue (in sq. angstroms units),
+ # representing the displacement of the residue from its ground truth mean
+ # value.
+ b_factors: np.ndarray # [num_res, num_atom_type]
+
+
+def from_prediction(final_atom_mask, aatype, final_atom_positions, residue_index):
+ """Assembles a protein from a prediction.
+
+ Args:
+ final_atom_mask: atom mask info from structure module.
+ aatype: amino acid type info.
+ final_atom_positions: final atom positions from structure module
+ residue_index: from processed_features
+
+ Returns:
+ A protein instance.
+ """
+ dist_per_residue = np.zeros_like(final_atom_mask)
+
+ return Protein(
+ aatype=aatype,
+ atom_positions=final_atom_positions,
+ atom_mask=final_atom_mask,
+ residue_index=residue_index + 1,
+ b_factors=dist_per_residue)
+
+
+def to_pdb(prot: Protein):
+ """Converts a `Protein` instance to a PDB string.
+
+ Args:
+ prot: The protein to convert to PDB.
+
+ Returns:
+ PDB string.
+ """
+ restypes = residue_constants.restypes + ['X']
+ res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK')
+ atom_types = residue_constants.atom_types
+
+ pdb_lines = []
+
+ atom_mask = prot.atom_mask
+ aatype = prot.aatype
+ atom_positions = prot.atom_positions
+ residue_index = prot.residue_index.astype(np.int32)
+ b_factors = prot.b_factors
+
+ if (aatype > residue_constants.restype_num).any():
+ raise ValueError('Invalid aatypes.')
+
+ pdb_lines.append('MODEL 1')
+ atom_index = 1
+ chain_id = 'A'
+ # Add all atom sites.
+ for i in range(aatype.shape[0]):
+ res_name_3 = res_1to3(aatype[i])
+ for atom_name, pos, mask, b_factor in zip(
+ atom_types, atom_positions[i], atom_mask[i], b_factors[i]):
+ if mask < 0.5:
+ continue
+
+ record_type = 'ATOM'
+ name = atom_name if len(atom_name) == 4 else f' {atom_name}'
+ alt_loc = ''
+ insertion_code = ''
+ occupancy = 1.00
+ element = atom_name[0] # Protein supports only C, N, O, S, this works.
+ charge = ''
+ # PDB is a columnar format, every space matters here!
+ atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}'
+ f'{res_name_3:>3} {chain_id:>1}'
+ f'{residue_index[i]:>4}{insertion_code:>1} '
+ f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}'
+ f'{occupancy:>6.2f}{b_factor:>6.2f} '
+ f'{element:>2}{charge:>2}')
+ pdb_lines.append(atom_line)
+ atom_index += 1
+
+ # Close the chain.
+ chain_end = 'TER'
+ chain_termination_line = (
+ f'{chain_end:<6}{atom_index:>5} {res_1to3(aatype[-1]):>3} '
+ f'{chain_id:>1}{residue_index[-1]:>4}')
+ pdb_lines.append(chain_termination_line)
+ pdb_lines.append('ENDMDL')
+
+ pdb_lines.append('END')
+ pdb_lines.append('')
+ return '\n'.join(pdb_lines)
diff --git a/reproduce/AlphaFold2-Chinese/commons/r3.py b/reproduce/AlphaFold2-Chinese/commons/r3.py
new file mode 100644
index 0000000..2a8f3fb
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/commons/r3.py
@@ -0,0 +1,104 @@
+"""Transformations for 3D coordinates."""
+
+import mindspore.numpy as mnp
+
+
+def vecs_sub(v1, v2):
+ """Computes v1 - v2."""
+ return v1 - v2
+
+
+def vecs_robust_norm(v, epsilon=1e-8):
+ """Computes norm of vectors 'v'."""
+
+ return mnp.sqrt(mnp.sum(mnp.square(v), axis=-1) + epsilon)
+
+
+def vecs_robust_normalize(v, epsilon=1e-8):
+ """Normalizes vectors 'v'."""
+
+ norms = vecs_robust_norm(v, epsilon)
+ return v / norms[..., None]
+
+
+def vecs_dot_vecs(v1, v2):
+ """Dot product of vectors 'v1' and 'v2'."""
+ return mnp.sum(v1 * v2, axis=-1)
+
+
+def vecs_cross_vecs(v1, v2):
+ """Cross product of vectors 'v1' and 'v2'."""
+
+ return mnp.concatenate(((v1[..., 1] * v2[..., 2] - v1[..., 2] * v2[..., 1])[..., None],
+ (v1[..., 2] * v2[..., 0] - v1[..., 0] * v2[..., 2])[..., None],
+ (v1[..., 0] * v2[..., 1] - v1[..., 1] * v2[..., 0])[..., None]), axis=-1)
+
+
+def rots_from_two_vecs(e0_unnormalized, e1_unnormalized):
+ """Create rotation matrices from unnormalized vectors for the x and y-axes."""
+
+ # Normalize the unit vector for the x-axis, e0.
+ e0 = vecs_robust_normalize(e0_unnormalized)
+
+ # make e1 perpendicular to e0.
+ c = vecs_dot_vecs(e1_unnormalized, e0)
+ e1 = e1_unnormalized - c[..., None] * e0
+ e1 = vecs_robust_normalize(e1)
+
+ # Compute e2 as cross product of e0 and e1.
+ e2 = vecs_cross_vecs(e0, e1)
+
+ rots = mnp.concatenate(
+ (mnp.concatenate([e0[..., 0][None, ...], e1[..., 0][None, ...], e2[..., 0][None, ...]], axis=0)[None, ...],
+ mnp.concatenate([e0[..., 1][None, ...], e1[..., 1][None, ...], e2[..., 1][None, ...]], axis=0)[None, ...],
+ mnp.concatenate([e0[..., 2][None, ...], e1[..., 2][None, ...], e2[..., 2][None, ...]], axis=0)[None, ...]),
+ axis=0)
+ return rots
+
+
+def rigids_from_3_points(
+ point_on_neg_x_axis, # shape (...)
+ origin, # shape (...)
+ point_on_xy_plane, # shape (...)
+): # shape (...)
+ """Create Rigids from 3 points. """
+
+ m = rots_from_two_vecs(
+ e0_unnormalized=vecs_sub(origin, point_on_neg_x_axis),
+ e1_unnormalized=vecs_sub(point_on_xy_plane, origin))
+ return m, origin
+
+
+def invert_rots(m):
+ """Computes inverse of rotations 'm'."""
+
+ return mnp.transpose(m, (1, 0, 2, 3, 4))
+
+
+def rots_mul_vecs(m, v):
+ """Apply rotations 'm' to vectors 'v'."""
+
+ return mnp.concatenate(((m[0][0] * v[..., 0] + m[0][1] * v[..., 1] + m[0][2] * v[..., 2])[..., None],
+ (m[1][0] * v[..., 0] + m[1][1] * v[..., 1] + m[1][2] * v[..., 2])[..., None],
+ (m[2][0] * v[..., 0] + m[2][1] * v[..., 1] + m[2][2] * v[..., 2])[..., None]), axis=-1)
+
+
+def invert_rigids(rot, trans):
+ """Computes group inverse of rigid transformations 'r'."""
+
+ inv_rots = invert_rots(rot)
+ t = rots_mul_vecs(inv_rots, trans)
+ inv_trans = -t
+ return inv_rots, inv_trans
+
+
+def vecs_add(v1, v2):
+ """Add two vectors 'v1' and 'v2'."""
+
+ return v1 + v2
+
+
+def rigids_mul_vecs(rot, trans, v):
+ """Apply rigid transforms 'r' to points 'v'."""
+
+ return vecs_add(rots_mul_vecs(rot, v), trans)
diff --git a/reproduce/AlphaFold2-Chinese/commons/residue_constants.py b/reproduce/AlphaFold2-Chinese/commons/residue_constants.py
new file mode 100644
index 0000000..d8446bb
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/commons/residue_constants.py
@@ -0,0 +1,842 @@
+"""residue_constants."""
+
+import collections
+import functools
+from typing import Mapping, List, Tuple
+import numpy as np
+
+from mindspore.common.tensor import Tensor
+
+QUAT_MULTIPLY = np.zeros((4, 4, 4), dtype=np.float32)
+QUAT_MULTIPLY[:, :, 0] = [[1, 0, 0, 0],
+ [0, -1, 0, 0],
+ [0, 0, -1, 0],
+ [0, 0, 0, -1]]
+
+QUAT_MULTIPLY[:, :, 1] = [[0, 1, 0, 0],
+ [1, 0, 0, 0],
+ [0, 0, 0, 1],
+ [0, 0, -1, 0]]
+
+QUAT_MULTIPLY[:, :, 2] = [[0, 0, 1, 0],
+ [0, 0, 0, -1],
+ [1, 0, 0, 0],
+ [0, 1, 0, 0]]
+
+QUAT_MULTIPLY[:, :, 3] = [[0, 0, 0, 1],
+ [0, 0, 1, 0],
+ [0, -1, 0, 0],
+ [1, 0, 0, 0]]
+
+QUAT_MULTIPLY_BY_VEC = Tensor(QUAT_MULTIPLY[:, 1:, :])
+
+
+# Distance from one CA to next CA [trans configuration: omega = 180].
+ca_ca = 3.80209737096
+
+# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in
+# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have
+# chi angles so their chi angle lists are empty.
+chi_angles_atoms = {
+ 'ALA': [],
+ # Chi5 in arginine is always 0 +- 5 degrees, so ignore it.
+ 'ARG': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
+ ['CB', 'CG', 'CD', 'NE'], ['CG', 'CD', 'NE', 'CZ']],
+ 'ASN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']],
+ 'ASP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']],
+ 'CYS': [['N', 'CA', 'CB', 'SG']],
+ 'GLN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
+ ['CB', 'CG', 'CD', 'OE1']],
+ 'GLU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
+ ['CB', 'CG', 'CD', 'OE1']],
+ 'GLY': [],
+ 'HIS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'ND1']],
+ 'ILE': [['N', 'CA', 'CB', 'CG1'], ['CA', 'CB', 'CG1', 'CD1']],
+ 'LEU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
+ 'LYS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
+ ['CB', 'CG', 'CD', 'CE'], ['CG', 'CD', 'CE', 'NZ']],
+ 'MET': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'SD'],
+ ['CB', 'CG', 'SD', 'CE']],
+ 'PHE': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
+ 'PRO': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD']],
+ 'SER': [['N', 'CA', 'CB', 'OG']],
+ 'THR': [['N', 'CA', 'CB', 'OG1']],
+ 'TRP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
+ 'TYR': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
+ 'VAL': [['N', 'CA', 'CB', 'CG1']],
+}
+
+# If chi angles given in fixed-length array, this matrix determines how to mask
+# them for each AA type. The order is as per restype_order (see below).
+chi_angles_mask = [
+ [0.0, 0.0, 0.0, 0.0], # ALA
+ [1.0, 1.0, 1.0, 1.0], # ARG
+ [1.0, 1.0, 0.0, 0.0], # ASN
+ [1.0, 1.0, 0.0, 0.0], # ASP
+ [1.0, 0.0, 0.0, 0.0], # CYS
+ [1.0, 1.0, 1.0, 0.0], # GLN
+ [1.0, 1.0, 1.0, 0.0], # GLU
+ [0.0, 0.0, 0.0, 0.0], # GLY
+ [1.0, 1.0, 0.0, 0.0], # HIS
+ [1.0, 1.0, 0.0, 0.0], # ILE
+ [1.0, 1.0, 0.0, 0.0], # LEU
+ [1.0, 1.0, 1.0, 1.0], # LYS
+ [1.0, 1.0, 1.0, 0.0], # MET
+ [1.0, 1.0, 0.0, 0.0], # PHE
+ [1.0, 1.0, 0.0, 0.0], # PRO
+ [1.0, 0.0, 0.0, 0.0], # SER
+ [1.0, 0.0, 0.0, 0.0], # THR
+ [1.0, 1.0, 0.0, 0.0], # TRP
+ [1.0, 1.0, 0.0, 0.0], # TYR
+ [1.0, 0.0, 0.0, 0.0], # VAL
+]
+
+# The following chi angles are pi periodic: they can be rotated by a multiple
+# of pi without affecting the structure.
+chi_pi_periodic = [
+ [0.0, 0.0, 0.0, 0.0], # ALA
+ [0.0, 0.0, 0.0, 0.0], # ARG
+ [0.0, 0.0, 0.0, 0.0], # ASN
+ [0.0, 1.0, 0.0, 0.0], # ASP
+ [0.0, 0.0, 0.0, 0.0], # CYS
+ [0.0, 0.0, 0.0, 0.0], # GLN
+ [0.0, 0.0, 1.0, 0.0], # GLU
+ [0.0, 0.0, 0.0, 0.0], # GLY
+ [0.0, 0.0, 0.0, 0.0], # HIS
+ [0.0, 0.0, 0.0, 0.0], # ILE
+ [0.0, 0.0, 0.0, 0.0], # LEU
+ [0.0, 0.0, 0.0, 0.0], # LYS
+ [0.0, 0.0, 0.0, 0.0], # MET
+ [0.0, 1.0, 0.0, 0.0], # PHE
+ [0.0, 0.0, 0.0, 0.0], # PRO
+ [0.0, 0.0, 0.0, 0.0], # SER
+ [0.0, 0.0, 0.0, 0.0], # THR
+ [0.0, 0.0, 0.0, 0.0], # TRP
+ [0.0, 1.0, 0.0, 0.0], # TYR
+ [0.0, 0.0, 0.0, 0.0], # VAL
+ [0.0, 0.0, 0.0, 0.0], # UNK
+]
+
+# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi,
+# psi and chi angles:
+# 0: 'backbone group',
+# 1: 'pre-omega-group', (empty)
+# 2: 'phi-group', (currently empty, because it defines only hydrogens)
+# 3: 'psi-group',
+# 4,5,6,7: 'chi1,2,3,4-group'
+# The atom positions are relative to the axis-end-atom of the corresponding
+# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis
+# is defined such that the dihedral-angle-definiting atom (the last entry in
+# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate).
+# format: [atomname, group_idx, rel_position]
+rigid_group_atom_positions = {
+ 'ALA': [
+ ['N', 0, (-0.525, 1.363, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.526, -0.000, -0.000)],
+ ['CB', 0, (-0.529, -0.774, -1.205)],
+ ['O', 3, (0.627, 1.062, 0.000)],
+ ],
+ 'ARG': [
+ ['N', 0, (-0.524, 1.362, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.525, -0.000, -0.000)],
+ ['CB', 0, (-0.524, -0.778, -1.209)],
+ ['O', 3, (0.626, 1.062, 0.000)],
+ ['CG', 4, (0.616, 1.390, -0.000)],
+ ['CD', 5, (0.564, 1.414, 0.000)],
+ ['NE', 6, (0.539, 1.357, -0.000)],
+ ['NH1', 7, (0.206, 2.301, 0.000)],
+ ['NH2', 7, (2.078, 0.978, -0.000)],
+ ['CZ', 7, (0.758, 1.093, -0.000)],
+ ],
+ 'ASN': [
+ ['N', 0, (-0.536, 1.357, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.526, -0.000, -0.000)],
+ ['CB', 0, (-0.531, -0.787, -1.200)],
+ ['O', 3, (0.625, 1.062, 0.000)],
+ ['CG', 4, (0.584, 1.399, 0.000)],
+ ['ND2', 5, (0.593, -1.188, 0.001)],
+ ['OD1', 5, (0.633, 1.059, 0.000)],
+ ],
+ 'ASP': [
+ ['N', 0, (-0.525, 1.362, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.527, 0.000, -0.000)],
+ ['CB', 0, (-0.526, -0.778, -1.208)],
+ ['O', 3, (0.626, 1.062, -0.000)],
+ ['CG', 4, (0.593, 1.398, -0.000)],
+ ['OD1', 5, (0.610, 1.091, 0.000)],
+ ['OD2', 5, (0.592, -1.101, -0.003)],
+ ],
+ 'CYS': [
+ ['N', 0, (-0.522, 1.362, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.524, 0.000, 0.000)],
+ ['CB', 0, (-0.519, -0.773, -1.212)],
+ ['O', 3, (0.625, 1.062, -0.000)],
+ ['SG', 4, (0.728, 1.653, 0.000)],
+ ],
+ 'GLN': [
+ ['N', 0, (-0.526, 1.361, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.526, 0.000, 0.000)],
+ ['CB', 0, (-0.525, -0.779, -1.207)],
+ ['O', 3, (0.626, 1.062, -0.000)],
+ ['CG', 4, (0.615, 1.393, 0.000)],
+ ['CD', 5, (0.587, 1.399, -0.000)],
+ ['NE2', 6, (0.593, -1.189, -0.001)],
+ ['OE1', 6, (0.634, 1.060, 0.000)],
+ ],
+ 'GLU': [
+ ['N', 0, (-0.528, 1.361, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.526, -0.000, -0.000)],
+ ['CB', 0, (-0.526, -0.781, -1.207)],
+ ['O', 3, (0.626, 1.062, 0.000)],
+ ['CG', 4, (0.615, 1.392, 0.000)],
+ ['CD', 5, (0.600, 1.397, 0.000)],
+ ['OE1', 6, (0.607, 1.095, -0.000)],
+ ['OE2', 6, (0.589, -1.104, -0.001)],
+ ],
+ 'GLY': [
+ ['N', 0, (-0.572, 1.337, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.517, -0.000, -0.000)],
+ ['O', 3, (0.626, 1.062, -0.000)],
+ ],
+ 'HIS': [
+ ['N', 0, (-0.527, 1.360, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.525, 0.000, 0.000)],
+ ['CB', 0, (-0.525, -0.778, -1.208)],
+ ['O', 3, (0.625, 1.063, 0.000)],
+ ['CG', 4, (0.600, 1.370, -0.000)],
+ ['CD2', 5, (0.889, -1.021, 0.003)],
+ ['ND1', 5, (0.744, 1.160, -0.000)],
+ ['CE1', 5, (2.030, 0.851, 0.002)],
+ ['NE2', 5, (2.145, -0.466, 0.004)],
+ ],
+ 'ILE': [
+ ['N', 0, (-0.493, 1.373, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.527, -0.000, -0.000)],
+ ['CB', 0, (-0.536, -0.793, -1.213)],
+ ['O', 3, (0.627, 1.062, -0.000)],
+ ['CG1', 4, (0.534, 1.437, -0.000)],
+ ['CG2', 4, (0.540, -0.785, -1.199)],
+ ['CD1', 5, (0.619, 1.391, 0.000)],
+ ],
+ 'LEU': [
+ ['N', 0, (-0.520, 1.363, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.525, -0.000, -0.000)],
+ ['CB', 0, (-0.522, -0.773, -1.214)],
+ ['O', 3, (0.625, 1.063, -0.000)],
+ ['CG', 4, (0.678, 1.371, 0.000)],
+ ['CD1', 5, (0.530, 1.430, -0.000)],
+ ['CD2', 5, (0.535, -0.774, 1.200)],
+ ],
+ 'LYS': [
+ ['N', 0, (-0.526, 1.362, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.526, 0.000, 0.000)],
+ ['CB', 0, (-0.524, -0.778, -1.208)],
+ ['O', 3, (0.626, 1.062, -0.000)],
+ ['CG', 4, (0.619, 1.390, 0.000)],
+ ['CD', 5, (0.559, 1.417, 0.000)],
+ ['CE', 6, (0.560, 1.416, 0.000)],
+ ['NZ', 7, (0.554, 1.387, 0.000)],
+ ],
+ 'MET': [
+ ['N', 0, (-0.521, 1.364, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.525, 0.000, 0.000)],
+ ['CB', 0, (-0.523, -0.776, -1.210)],
+ ['O', 3, (0.625, 1.062, -0.000)],
+ ['CG', 4, (0.613, 1.391, -0.000)],
+ ['SD', 5, (0.703, 1.695, 0.000)],
+ ['CE', 6, (0.320, 1.786, -0.000)],
+ ],
+ 'PHE': [
+ ['N', 0, (-0.518, 1.363, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.524, 0.000, -0.000)],
+ ['CB', 0, (-0.525, -0.776, -1.212)],
+ ['O', 3, (0.626, 1.062, -0.000)],
+ ['CG', 4, (0.607, 1.377, 0.000)],
+ ['CD1', 5, (0.709, 1.195, -0.000)],
+ ['CD2', 5, (0.706, -1.196, 0.000)],
+ ['CE1', 5, (2.102, 1.198, -0.000)],
+ ['CE2', 5, (2.098, -1.201, -0.000)],
+ ['CZ', 5, (2.794, -0.003, -0.001)],
+ ],
+ 'PRO': [
+ ['N', 0, (-0.566, 1.351, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.527, -0.000, 0.000)],
+ ['CB', 0, (-0.546, -0.611, -1.293)],
+ ['O', 3, (0.621, 1.066, 0.000)],
+ ['CG', 4, (0.382, 1.445, 0.0)],
+ # ['CD', 5, (0.427, 1.440, 0.0)],
+ ['CD', 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger
+ ],
+ 'SER': [
+ ['N', 0, (-0.529, 1.360, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.525, -0.000, -0.000)],
+ ['CB', 0, (-0.518, -0.777, -1.211)],
+ ['O', 3, (0.626, 1.062, -0.000)],
+ ['OG', 4, (0.503, 1.325, 0.000)],
+ ],
+ 'THR': [
+ ['N', 0, (-0.517, 1.364, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.526, 0.000, -0.000)],
+ ['CB', 0, (-0.516, -0.793, -1.215)],
+ ['O', 3, (0.626, 1.062, 0.000)],
+ ['CG2', 4, (0.550, -0.718, -1.228)],
+ ['OG1', 4, (0.472, 1.353, 0.000)],
+ ],
+ 'TRP': [
+ ['N', 0, (-0.521, 1.363, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.525, -0.000, 0.000)],
+ ['CB', 0, (-0.523, -0.776, -1.212)],
+ ['O', 3, (0.627, 1.062, 0.000)],
+ ['CG', 4, (0.609, 1.370, -0.000)],
+ ['CD1', 5, (0.824, 1.091, 0.000)],
+ ['CD2', 5, (0.854, -1.148, -0.005)],
+ ['CE2', 5, (2.186, -0.678, -0.007)],
+ ['CE3', 5, (0.622, -2.530, -0.007)],
+ ['NE1', 5, (2.140, 0.690, -0.004)],
+ ['CH2', 5, (3.028, -2.890, -0.013)],
+ ['CZ2', 5, (3.283, -1.543, -0.011)],
+ ['CZ3', 5, (1.715, -3.389, -0.011)],
+ ],
+ 'TYR': [
+ ['N', 0, (-0.522, 1.362, 0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.524, -0.000, -0.000)],
+ ['CB', 0, (-0.522, -0.776, -1.213)],
+ ['O', 3, (0.627, 1.062, -0.000)],
+ ['CG', 4, (0.607, 1.382, -0.000)],
+ ['CD1', 5, (0.716, 1.195, -0.000)],
+ ['CD2', 5, (0.713, -1.194, -0.001)],
+ ['CE1', 5, (2.107, 1.200, -0.002)],
+ ['CE2', 5, (2.104, -1.201, -0.003)],
+ ['OH', 5, (4.168, -0.002, -0.005)],
+ ['CZ', 5, (2.791, -0.001, -0.003)],
+ ],
+ 'VAL': [
+ ['N', 0, (-0.494, 1.373, -0.000)],
+ ['CA', 0, (0.000, 0.000, 0.000)],
+ ['C', 0, (1.527, -0.000, -0.000)],
+ ['CB', 0, (-0.533, -0.795, -1.213)],
+ ['O', 3, (0.627, 1.062, -0.000)],
+ ['CG1', 4, (0.540, 1.429, -0.000)],
+ ['CG2', 4, (0.533, -0.776, 1.203)],
+ ],
+}
+
+# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention.
+residue_atoms = {
+ 'ALA': ['C', 'CA', 'CB', 'N', 'O'],
+ 'ARG': ['C', 'CA', 'CB', 'CG', 'CD', 'CZ', 'N', 'NE', 'O', 'NH1', 'NH2'],
+ 'ASP': ['C', 'CA', 'CB', 'CG', 'N', 'O', 'OD1', 'OD2'],
+ 'ASN': ['C', 'CA', 'CB', 'CG', 'N', 'ND2', 'O', 'OD1'],
+ 'CYS': ['C', 'CA', 'CB', 'N', 'O', 'SG'],
+ 'GLU': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O', 'OE1', 'OE2'],
+ 'GLN': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'NE2', 'O', 'OE1'],
+ 'GLY': ['C', 'CA', 'N', 'O'],
+ 'HIS': ['C', 'CA', 'CB', 'CG', 'CD2', 'CE1', 'N', 'ND1', 'NE2', 'O'],
+ 'ILE': ['C', 'CA', 'CB', 'CG1', 'CG2', 'CD1', 'N', 'O'],
+ 'LEU': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'N', 'O'],
+ 'LYS': ['C', 'CA', 'CB', 'CG', 'CD', 'CE', 'N', 'NZ', 'O'],
+ 'MET': ['C', 'CA', 'CB', 'CG', 'CE', 'N', 'O', 'SD'],
+ 'PHE': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O'],
+ 'PRO': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O'],
+ 'SER': ['C', 'CA', 'CB', 'N', 'O', 'OG'],
+ 'THR': ['C', 'CA', 'CB', 'CG2', 'N', 'O', 'OG1'],
+ 'TRP': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE2', 'CE3', 'CZ2', 'CZ3',
+ 'CH2', 'N', 'NE1', 'O'],
+ 'TYR': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O',
+ 'OH'],
+ 'VAL': ['C', 'CA', 'CB', 'CG1', 'CG2', 'N', 'O']
+}
+
+# Van der Waals radii [Angstroem] of the atoms (from Wikipedia)
+van_der_waals_radius = {
+ 'C': 1.7,
+ 'N': 1.55,
+ 'O': 1.52,
+ 'S': 1.8,
+}
+
+Bond = collections.namedtuple(
+ 'Bond', ['atom1_name', 'atom2_name', 'length', 'stddev'])
+BondAngle = collections.namedtuple(
+ 'BondAngle',
+ ['atom1_name', 'atom2_name', 'atom3name', 'angle_rad', 'stddev'])
+
+
+@functools.lru_cache(maxsize=None)
+def load_stereo_chemical_props() -> Tuple[Mapping[str, List[Bond]],
+ Mapping[str, List[Bond]],
+ Mapping[str, List[BondAngle]]]:
+ """Load stereo_chemical_props.txt into a nice structure.
+
+ Load literature values for bond lengths and bond angles and translate
+ bond angles into the length of the opposite edge of the triangle
+ ("residue_virtual_bonds").
+
+ Returns:
+ residue_bonds: dict that maps resname --> list of Bond tuples
+ residue_virtual_bonds: dict that maps resname --> list of Bond tuples
+ residue_bond_angles: dict that maps resname --> list of BondAngle tuples
+ """
+ stereo_chemical_props_path = (
+ 'alphafold/common/stereo_chemical_props.txt')
+ with open(stereo_chemical_props_path, 'rt') as f:
+ stereo_chemical_props = f.read()
+ lines_iter = iter(stereo_chemical_props.splitlines())
+ # Load bond lengths.
+ residue_bonds = {}
+ next(lines_iter) # Skip header line.
+ for line in lines_iter:
+ if line.strip() == '-':
+ break
+ bond, resname, length, stddev = line.split()
+ atom1, atom2 = bond.split('-')
+ if resname not in residue_bonds:
+ residue_bonds[resname] = []
+ residue_bonds[resname].append(
+ Bond(atom1, atom2, float(length), float(stddev)))
+ residue_bonds['UNK'] = []
+
+ # Load bond angles.
+ residue_bond_angles = {}
+ next(lines_iter) # Skip empty line.
+ next(lines_iter) # Skip header line.
+ for line in lines_iter:
+ if line.strip() == '-':
+ break
+ bond, resname, angle_degree, stddev_degree = line.split()
+ atom1, atom2, atom3 = bond.split('-')
+ if resname not in residue_bond_angles:
+ residue_bond_angles[resname] = []
+ residue_bond_angles[resname].append(
+ BondAngle(atom1, atom2, atom3,
+ float(angle_degree) / 180. * np.pi,
+ float(stddev_degree) / 180. * np.pi))
+ residue_bond_angles['UNK'] = []
+
+ def make_bond_key(atom1_name, atom2_name):
+ """Unique key to lookup bonds."""
+ return '-'.join(sorted([atom1_name, atom2_name]))
+
+ # Translate bond angles into distances ("virtual bonds").
+ residue_virtual_bonds = {}
+ for resname, bond_angles in residue_bond_angles.items():
+ # Create a fast lookup dict for bond lengths.
+ bond_cache = {}
+ for b in residue_bonds[resname]:
+ bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b
+ residue_virtual_bonds[resname] = []
+ for ba in bond_angles:
+ bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)]
+ bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)]
+
+ # Compute distance between atom1 and atom3 using the law of cosines
+ # c^2 = a^2 + b^2 - 2ab*cos(gamma).
+ gamma = ba.angle_rad
+ length = np.sqrt(bond1.length**2 + bond2.length**2
+ - 2 * bond1.length * bond2.length * np.cos(gamma))
+
+ # Propagation of uncertainty assuming uncorrelated errors.
+ dl_outer = 0.5 / length
+ dl_dgamma = (2 * bond1.length * bond2.length *
+ np.sin(gamma)) * dl_outer
+ dl_db1 = (2 * bond1.length - 2 * bond2.length *
+ np.cos(gamma)) * dl_outer
+ dl_db2 = (2 * bond2.length - 2 * bond1.length *
+ np.cos(gamma)) * dl_outer
+ stddev = np.sqrt((dl_dgamma * ba.stddev)**2 +
+ (dl_db1 * bond1.stddev)**2 +
+ (dl_db2 * bond2.stddev)**2)
+ residue_virtual_bonds[resname].append(
+ Bond(ba.atom1_name, ba.atom3name, length, stddev))
+
+ return (residue_bonds,
+ residue_virtual_bonds,
+ residue_bond_angles)
+
+
+# This mapping is used when we need to store atom data in a format that requires
+# fixed atom data size for every residue (e.g. a numpy array).
+atom_types = [
+ 'N', 'CA', 'C', 'CB', 'O', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD',
+ 'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3',
+ 'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2',
+ 'CZ3', 'NZ', 'OXT'
+]
+atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
+atom_type_num = len(atom_types) # := 37.
+
+# A compact atom encoding with 14 columns
+# pylint: disable=line-too-long
+# pylint: disable=bad-whitespace
+restype_name_to_atom14_names = {
+ 'ALA': ['N', 'CA', 'C', 'O', 'CB', '', '', '', '', '', '', '', '', ''],
+ 'ARG': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2', '', '', ''],
+ 'ASN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2', '', '', '', '', '', ''],
+ 'ASP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2', '', '', '', '', '', ''],
+ 'CYS': ['N', 'CA', 'C', 'O', 'CB', 'SG', '', '', '', '', '', '', '', ''],
+ 'GLN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2', '', '', '', '', ''],
+ 'GLU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2', '', '', '', '', ''],
+ 'GLY': ['N', 'CA', 'C', 'O', '', '', '', '', '', '', '', '', '', ''],
+ 'HIS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2', '', '', '', ''],
+ 'ILE': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '', '', '', '', '', ''],
+ 'LEU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', '', '', '', '', '', ''],
+ 'LYS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ', '', '', '', '', ''],
+ 'MET': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE', '', '', '', '', '', ''],
+ 'PHE': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', '', '', ''],
+ 'PRO': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', '', '', '', '', '', '', ''],
+ 'SER': ['N', 'CA', 'C', 'O', 'CB', 'OG', '', '', '', '', '', '', '', ''],
+ 'THR': ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '', '', '', '', '', '', ''],
+ 'TRP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2'],
+ 'TYR': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH', '', ''],
+ 'VAL': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '', '', '', '', '', '', ''],
+ 'UNK': ['', '', '', '', '', '', '', '', '', '', '', '', '', ''],
+
+}
+
+# This is the standard residue order when coding AA type as a number.
+# Reproduce it by taking 3-letter AA codes and sorting them alphabetically.
+restypes = [
+ 'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P',
+ 'S', 'T', 'W', 'Y', 'V'
+]
+restype_order = {restype: i for i, restype in enumerate(restypes)}
+restype_num = len(restypes) # := 20.
+
+restypes_with_x = restypes + ['X']
+restype_order_with_x = {
+ restype: i for i,
+ restype in enumerate(restypes_with_x)}
+
+
+def sequence_to_onehot(
+ sequence: str,
+ mapping: Mapping[str, int],
+ map_unknown_to_x: bool = False) -> np.ndarray:
+ """Maps the given sequence into a one-hot encoded matrix.
+
+ Args:
+ sequence: An amino acid sequence.
+ mapping: A dictionary mapping amino acids to integers.
+ map_unknown_to_x: If True, any amino acid that is not in the mapping will be
+ mapped to the unknown amino acid 'X'. If the mapping doesn't contain
+ amino acid 'X', an error will be thrown. If False, any amino acid not in
+ the mapping will throw an error.
+
+ Returns:
+ A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of
+ the sequence.
+
+ Raises:
+ ValueError: If the mapping doesn't contain values from 0 to
+ num_unique_aas - 1 without any gaps.
+ """
+ num_entries = max(mapping.values()) + 1
+
+ if sorted(set(mapping.values())) != list(range(num_entries)):
+ raise ValueError(
+ 'The mapping must have values from 0 to num_unique_aas-1 '
+ 'without any gaps. Got: %s' %
+ sorted(
+ mapping.values()))
+
+ one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32)
+
+ for aa_index, aa_type in enumerate(sequence):
+ if map_unknown_to_x:
+ if aa_type.isalpha() and aa_type.isupper():
+ aa_id = mapping.get(aa_type, mapping['X'])
+ else:
+ raise ValueError(
+ f'Invalid character in the sequence: {aa_type}')
+ else:
+ aa_id = mapping[aa_type]
+ one_hot_arr[aa_index, aa_id] = 1
+
+ return one_hot_arr
+
+
+restype_1to3 = {
+ 'A': 'ALA',
+ 'R': 'ARG',
+ 'N': 'ASN',
+ 'D': 'ASP',
+ 'C': 'CYS',
+ 'Q': 'GLN',
+ 'E': 'GLU',
+ 'G': 'GLY',
+ 'H': 'HIS',
+ 'I': 'ILE',
+ 'L': 'LEU',
+ 'K': 'LYS',
+ 'M': 'MET',
+ 'F': 'PHE',
+ 'P': 'PRO',
+ 'S': 'SER',
+ 'T': 'THR',
+ 'W': 'TRP',
+ 'Y': 'TYR',
+ 'V': 'VAL',
+}
+
+
+# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple
+# 1-to-1 mapping of 3 letter names to one letter names. The latter contains
+# many more, and less common, three letter names as keys and maps many of these
+# to the same one letter name (including 'X' and 'U' which we don't use here).
+restype_3to1 = {v: k for k, v in restype_1to3.items()}
+
+# Define a restype name for all unknown residues.
+unk_restype = 'UNK'
+
+resnames = [restype_1to3[r] for r in restypes] + [unk_restype]
+resname_to_idx = {resname: i for i, resname in enumerate(resnames)}
+
+
+# The mapping here uses hhblits convention, so that B is mapped to D, J and O
+# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the
+# remaining 20 amino acids are kept in alphabetical order.
+# There are 2 non-amino acid codes, X (representing any amino acid) and
+# "-" representing a missing amino acid in an alignment. The id for these
+# codes is put at the end (20 and 21) so that they can easily be ignored if
+# desired.
+HHBLITS_AA_TO_ID = {
+ 'A': 0,
+ 'B': 2,
+ 'C': 1,
+ 'D': 2,
+ 'E': 3,
+ 'F': 4,
+ 'G': 5,
+ 'H': 6,
+ 'I': 7,
+ 'J': 20,
+ 'K': 8,
+ 'L': 9,
+ 'M': 10,
+ 'N': 11,
+ 'O': 20,
+ 'P': 12,
+ 'Q': 13,
+ 'R': 14,
+ 'S': 15,
+ 'T': 16,
+ 'U': 1,
+ 'V': 17,
+ 'W': 18,
+ 'X': 20,
+ 'Y': 19,
+ 'Z': 3,
+ '-': 21,
+}
+
+# Partial inversion of HHBLITS_AA_TO_ID.
+ID_TO_HHBLITS_AA = {
+ 0: 'A',
+ 1: 'C', # Also U.
+ 2: 'D', # Also B.
+ 3: 'E', # Also Z.
+ 4: 'F',
+ 5: 'G',
+ 6: 'H',
+ 7: 'I',
+ 8: 'K',
+ 9: 'L',
+ 10: 'M',
+ 11: 'N',
+ 12: 'P',
+ 13: 'Q',
+ 14: 'R',
+ 15: 'S',
+ 16: 'T',
+ 17: 'V',
+ 18: 'W',
+ 19: 'Y',
+ 20: 'X', # Includes J and O.
+ 21: '-',
+}
+
+restypes_with_x_and_gap = restypes + ['X', '-']
+MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple(restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i])
+ for i in range(len(restypes_with_x_and_gap)))
+
+
+def _make_standard_atom_mask() -> np.ndarray:
+ """Returns [num_res_types, num_atom_types] mask array."""
+ # +1 to account for unknown (all 0s).
+ mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32)
+ for restype, restype_letter in enumerate(restypes):
+ restype_name = restype_1to3[restype_letter]
+ atom_names = residue_atoms[restype_name]
+ for atom_name in atom_names:
+ atom_type = atom_order[atom_name]
+ mask[restype, atom_type] = 1
+ return mask
+
+
+STANDARD_ATOM_MASK = _make_standard_atom_mask()
+
+
+# A one hot representation for the first and second atoms defining the axis
+# of rotation for each chi-angle in each residue.
+def chi_angle_atom(atom_index: int) -> np.ndarray:
+ """Define chi-angle rigid groups via one-hot representations."""
+ chi_angles_index = {}
+ one_hots = []
+
+ for k, v in chi_angles_atoms.items():
+ indices = [atom_types.index(s[atom_index]) for s in v]
+ indices.extend([-1] * (4 - len(indices)))
+ chi_angles_index[k] = indices
+
+ for r in restypes:
+ res3 = restype_1to3[r]
+ one_hot = np.eye(atom_type_num)[chi_angles_index[res3]]
+ one_hots.append(one_hot)
+
+ one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`.
+ one_hot = np.stack(one_hots, axis=0)
+ one_hot = np.transpose(one_hot, [0, 2, 1])
+
+ return one_hot
+
+
+chi_atom_1_one_hot = chi_angle_atom(1)
+chi_atom_2_one_hot = chi_angle_atom(2)
+
+# An array like chi_angles_atoms but using indices rather than names.
+chi_angles_atom_indices = [[], [[0, 1, 3, 5], [1, 3, 5, 11], [3, 5, 11, 23], [5, 11, 23, 32]], [[0, 1, 3, 5], [1, 3, 5, 16]], [[0, 1, 3, 5], [1, 3, 5, 16]], [[0, 1, 3, 10]], [[0, 1, 3, 5], [1, 3, 5, 11], [3, 5, 11, 26]], [[0, 1, 3, 5], [1, 3, 5, 11], [3, 5, 11, 26]], [], [[0, 1, 3, 5], [1, 3, 5, 14]], [[0, 1, 3, 6], [1, 3, 6, 12]], [[0, 1, 3, 5], [1, 3, 5, 12]], [[0, 1, 3, 5], [1, 3, 5, 11], [3, 5, 11, 19], [5, 11, 19, 35]], [[0, 1, 3, 5], [1, 3, 5, 18], [3, 5, 18, 19]], [[0, 1, 3, 5], [1, 3, 5, 12]], [[0, 1, 3, 5], [1, 3, 5, 11]], [[0, 1, 3, 8]], [[0, 1, 3, 9]], [[0, 1, 3, 5], [1, 3, 5, 12]], [[0, 1, 3, 5], [1, 3, 5, 12]], [[0, 1, 3, 6]]]
+chi_angles_atom_indices = np.array([
+ chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms)))
+ for chi_atoms in chi_angles_atom_indices])
+
+# Mapping from (res_name, atom_name) pairs to the atom's chi group index
+# and atom index within that group.
+chi_groups_for_atom = collections.defaultdict(list)
+for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items():
+ for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res):
+ for atom_i, atom in enumerate(chi_group):
+ chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i))
+chi_groups_for_atom = dict(chi_groups_for_atom)
+
+
+def _make_rigid_transformation_4x4(ex, ey, translation):
+ """Create a rigid 4x4 transformation matrix from two axes and transl."""
+ # Normalize ex.
+ ex_normalized = ex / np.linalg.norm(ex)
+
+ # make ey perpendicular to ex
+ ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized
+ ey_normalized /= np.linalg.norm(ey_normalized)
+
+ # compute ez as cross product
+ eznorm = np.cross(ex_normalized, ey_normalized)
+ m = np.stack([ex_normalized, ey_normalized,
+ eznorm, translation]).transpose()
+ m = np.concatenate([m, [[0., 0., 0., 1.]]], axis=0)
+ return m
+
+
+# create an array with (restype, atomtype) --> rigid_group_idx
+# and an array with (restype, atomtype, coord) for the atom positions
+# and compute affine transformation matrices (4,4) from one rigid group to the
+# previous group
+restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int)
+restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
+restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
+restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int)
+restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
+restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
+restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
+
+
+def _make_rigid_group_constants():
+ """Fill the arrays above."""
+ for restype, restype_letter in enumerate(restypes):
+ resname = restype_1to3[restype_letter]
+ for atomname, group_idx, atom_position in rigid_group_atom_positions[
+ resname]:
+ atomtype = atom_order[atomname]
+ restype_atom37_to_rigid_group[restype, atomtype] = group_idx
+ restype_atom37_mask[restype, atomtype] = 1
+ restype_atom37_rigid_group_positions[restype,
+ atomtype, :] = atom_position
+
+ atom14idx = restype_name_to_atom14_names[resname].index(atomname)
+ restype_atom14_to_rigid_group[restype, atom14idx] = group_idx
+ restype_atom14_mask[restype, atom14idx] = 1
+ restype_atom14_rigid_group_positions[restype,
+ atom14idx, :] = atom_position
+
+ for restype, restype_letter in enumerate(restypes):
+ resname = restype_1to3[restype_letter]
+ atom_positions = {name: np.array(pos) for name, _, pos
+ in rigid_group_atom_positions[resname]}
+
+ # backbone to backbone is the identity transform
+ restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4)
+
+ # pre-omega-frame to backbone (currently dummy identity matrix)
+ restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4)
+
+ # phi-frame to backbone
+ mat = _make_rigid_transformation_4x4(
+ ex=atom_positions['N'] - atom_positions['CA'],
+ ey=np.array([1., 0., 0.]),
+ translation=atom_positions['N'])
+ restype_rigid_group_default_frame[restype, 2, :, :] = mat
+
+ # psi-frame to backbone
+ mat = _make_rigid_transformation_4x4(
+ ex=atom_positions['C'] - atom_positions['CA'],
+ ey=atom_positions['CA'] - atom_positions['N'],
+ translation=atom_positions['C'])
+ restype_rigid_group_default_frame[restype, 3, :, :] = mat
+
+ # chi1-frame to backbone
+ if chi_angles_mask[restype][0]:
+ base_atom_names = chi_angles_atoms[resname][0]
+ base_atom_positions = [atom_positions[name]
+ for name in base_atom_names]
+ mat = _make_rigid_transformation_4x4(
+ ex=base_atom_positions[2] - base_atom_positions[1],
+ ey=base_atom_positions[0] - base_atom_positions[1],
+ translation=base_atom_positions[2])
+ restype_rigid_group_default_frame[restype, 4, :, :] = mat
+
+ # chi2-frame to chi1-frame
+ # chi3-frame to chi2-frame
+ # chi4-frame to chi3-frame
+ # luckily all rotation axes for the next frame start at (0,0,0) of the
+ # previous frame
+ for chi_idx in range(1, 4):
+ if chi_angles_mask[restype][chi_idx]:
+ axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2]
+ axis_end_atom_position = atom_positions[axis_end_atom_name]
+ mat = _make_rigid_transformation_4x4(
+ ex=axis_end_atom_position,
+ ey=np.array([-1., 0., 0.]),
+ translation=axis_end_atom_position)
+ restype_rigid_group_default_frame[restype,
+ 4 + chi_idx, :, :] = mat
+
+
+_make_rigid_group_constants()
diff --git a/reproduce/AlphaFold2-Chinese/commons/utils.py b/reproduce/AlphaFold2-Chinese/commons/utils.py
new file mode 100644
index 0000000..f8cffe3
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/commons/utils.py
@@ -0,0 +1,1038 @@
+"""utils module"""
+
+import numpy as np
+from scipy.special import softmax
+
+from mindspore.ops import operations as P
+import mindspore.numpy as mnp
+import mindspore.nn as nn
+from mindspore.common.tensor import Tensor
+
+from commons import residue_constants
+import commons.r3 as r3
+
+
+QUAT_TO_ROT = np.zeros((4, 4, 3, 3), dtype=np.float32)
+
+QUAT_TO_ROT[0, 0] = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] # rr
+QUAT_TO_ROT[1, 1] = [[1, 0, 0], [0, -1, 0], [0, 0, -1]] # ii
+QUAT_TO_ROT[2, 2] = [[-1, 0, 0], [0, 1, 0], [0, 0, -1]] # jj
+QUAT_TO_ROT[3, 3] = [[-1, 0, 0], [0, -1, 0], [0, 0, 1]] # kk
+
+QUAT_TO_ROT[1, 2] = [[0, 2, 0], [2, 0, 0], [0, 0, 0]] # ij
+QUAT_TO_ROT[1, 3] = [[0, 0, 2], [0, 0, 0], [2, 0, 0]] # ik
+QUAT_TO_ROT[2, 3] = [[0, 0, 0], [0, 0, 2], [0, 2, 0]] # jk
+
+QUAT_TO_ROT[0, 1] = [[0, 0, 0], [0, 0, -2], [0, 2, 0]] # ir
+QUAT_TO_ROT[0, 2] = [[0, 0, 2], [0, 0, 0], [-2, 0, 0]] # jr
+QUAT_TO_ROT[0, 3] = [[0, -2, 0], [2, 0, 0], [0, 0, 0]] # kr
+
+QUAT_TO_ROT = Tensor(QUAT_TO_ROT)
+
+
+def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
+ """Create pseudo beta features."""
+
+ is_gly = mnp.equal(aatype, residue_constants.restype_order['G'])
+ ca_idx = residue_constants.atom_order['CA']
+ cb_idx = residue_constants.atom_order['CB']
+ pseudo_beta = mnp.where(
+ mnp.tile(is_gly[..., None].astype("int32"), [1,] * len(is_gly.shape) + [3,]).astype("bool"),
+ all_atom_positions[..., ca_idx, :],
+ all_atom_positions[..., cb_idx, :])
+ if all_atom_masks is not None:
+ pseudo_beta_mask = mnp.where(is_gly, all_atom_masks[..., ca_idx], all_atom_masks[..., cb_idx])
+ pseudo_beta_mask = pseudo_beta_mask.astype(mnp.float32)
+ return pseudo_beta, pseudo_beta_mask
+ return pseudo_beta
+
+
+def dgram_from_positions(positions, num_bins, min_bin, max_bin):
+ """Compute distogram from amino acid positions.
+
+ Arguments:
+ positions: [N_res, 3] Position coordinates.
+ num_bins: The number of bins in the distogram.
+ min_bin: The left edge of the first bin.
+ max_bin: The left edge of the final bin. The final bin catches
+ everything larger than `max_bin`.
+
+ Returns:
+ Distogram with the specified number of bins.
+ """
+
+ def squared_difference(x, y):
+ return mnp.square(x - y)
+
+ lower_breaks = mnp.linspace(min_bin, max_bin, num_bins)
+ lower_breaks = mnp.square(lower_breaks)
+ upper_breaks = mnp.concatenate([lower_breaks[1:], mnp.array([1e8], dtype=mnp.float32)], axis=-1)
+ dist2 = mnp.sum(squared_difference(mnp.expand_dims(positions, axis=-2),
+ mnp.expand_dims(positions, axis=-3)), axis=-1, keepdims=True)
+ dgram = ((dist2 > lower_breaks).astype(mnp.float32) * (dist2 < upper_breaks).astype(mnp.float32))
+ return dgram
+
+
+def _multiply(a, b):
+ return mnp.stack([mnp.concatenate([(a[0][0] * b[0][0] + a[0][1] * b[1][0] + a[0][2] * b[2][0])[None, ...],
+ (a[0][0] * b[0][1] + a[0][1] * b[1][1] + a[0][2] * b[2][1])[None, ...],
+ (a[0][0] * b[0][2] + a[0][1] * b[1][2] + a[0][2] * b[2][2])[None, ...]], axis=0),
+ mnp.concatenate([(a[1][0] * b[0][0] + a[1][1] * b[1][0] + a[1][2] * b[2][0])[None, ...],
+ (a[1][0] * b[0][1] + a[1][1] * b[1][1] + a[1][2] * b[2][1])[None, ...],
+ (a[1][0] * b[0][2] + a[1][1] * b[1][2] + a[1][2] * b[2][2])[None, ...]], axis=0),
+ mnp.concatenate([(a[2][0] * b[0][0] + a[2][1] * b[1][0] + a[2][2] * b[2][0])[None, ...],
+ (a[2][0] * b[0][1] + a[2][1] * b[1][1] + a[2][2] * b[2][1])[None, ...],
+ (a[2][0] * b[0][2] + a[2][1] * b[1][2] + a[2][2] * b[2][2])[None, ...]],
+ axis=0)])
+
+
+def apply_rot_to_vec(rot, vec, unstack=False):
+ """Multiply rotation matrix by a vector."""
+ if unstack:
+ x, y, z = vec[:, 0], vec[:, 1], vec[:, 2]
+ else:
+ x, y, z = vec
+ return [rot[0][0] * x + rot[0][1] * y + rot[0][2] * z,
+ rot[1][0] * x + rot[1][1] * y + rot[1][2] * z,
+ rot[2][0] * x + rot[2][1] * y + rot[2][2] * z]
+
+
+def make_canonical_transform(n_xyz, ca_xyz, c_xyz):
+ """Returns translation and rotation matrices to canonicalize residue atoms.
+
+ Note that this method does not take care of symmetries. If you provide the
+ atom positions in the non-standard way, the N atom will end up not at
+ [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You
+ need to take care of such cases in your code.
+
+ Args:
+ n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates.
+ ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates.
+ c_xyz: An array of shape [batch, 3] of carbon xyz coordinates.
+
+ Returns:
+ A tuple (translation, rotation) where:
+ translation is an array of shape [batch, 3] defining the translation.
+ rotation is an array of shape [batch, 3, 3] defining the rotation.
+ After applying the translation and rotation to all atoms in a residue:
+ * All atoms will be shifted so that CA is at the origin,
+ * All atoms will be rotated so that C is at the x-axis,
+ * All atoms will be shifted so that N is in the xy plane.
+ """
+
+ # Place CA at the origin.
+ translation = -ca_xyz
+ n_xyz = n_xyz + translation
+ c_xyz = c_xyz + translation
+
+ # Place C on the x-axis.
+ c_x, c_y, c_z = c_xyz[:, 0], c_xyz[:, 1], c_xyz[:, 2]
+ # Rotate by angle c1 in the x-y plane (around the z-axis).
+ sin_c1 = -c_y / mnp.sqrt(1e-20 + c_x ** 2 + c_y ** 2)
+ cos_c1 = c_x / mnp.sqrt(1e-20 + c_x ** 2 + c_y ** 2)
+ zeros = mnp.zeros_like(sin_c1).astype("float32")
+ ones = mnp.ones_like(sin_c1).astype("float32")
+ # # pylint: disable=bad-whitespace
+ c1_rot_matrix = mnp.stack([mnp.concatenate((cos_c1[None, ...], (-sin_c1)[None, ...], zeros[None, ...]), axis=0),
+ mnp.concatenate((sin_c1[None, ...], cos_c1[None, ...], zeros[None, ...]), axis=0),
+ mnp.concatenate((zeros[None, ...], zeros[None, ...], ones[None, ...]), axis=0)])
+ # # Rotate by angle c2 in the x-z plane (around the y-axis).
+ sin_c2 = c_z / mnp.sqrt(1e-20 + c_x ** 2 + c_y ** 2 + c_z ** 2)
+ cos_c2 = mnp.sqrt(c_x ** 2 + c_y ** 2) / mnp.sqrt(1e-20 + c_x ** 2 + c_y ** 2 + c_z ** 2)
+ c2_rot_matrix = mnp.stack([mnp.concatenate((cos_c2[None, ...], zeros[None, ...], sin_c2[None, ...]), axis=0),
+ mnp.concatenate((zeros[None, ...], ones[None, ...], zeros[None, ...]), axis=0),
+ mnp.concatenate(((-sin_c2)[None, ...], zeros[None, ...], cos_c2[None, ...]), axis=0)])
+ c_rot_matrix = _multiply(c2_rot_matrix, c1_rot_matrix)
+ n_xyz = mnp.stack(apply_rot_to_vec(c_rot_matrix, n_xyz, unstack=True)).T
+ # Place N in the x-y plane.
+ _, n_y, n_z = n_xyz[:, 0], n_xyz[:, 1], n_xyz[:, 2]
+ # Rotate by angle alpha in the y-z plane (around the x-axis).
+ sin_n = -n_z / mnp.sqrt(1e-20 + n_y ** 2 + n_z ** 2)
+ cos_n = n_y / mnp.sqrt(1e-20 + n_y ** 2 + n_z ** 2)
+ n_rot_matrix = mnp.stack([mnp.concatenate([ones[None, ...], zeros[None, ...], zeros[None, ...]], axis=0),
+ mnp.concatenate([zeros[None, ...], cos_n[None, ...], (-sin_n)[None, ...]], axis=0),
+ mnp.concatenate([zeros[None, ...], sin_n[None, ...], cos_n[None, ...]], axis=0)])
+ return translation, mnp.transpose(_multiply(n_rot_matrix, c_rot_matrix), [2, 0, 1])
+
+
+def make_transform_from_reference(n_xyz, ca_xyz, c_xyz):
+ """Returns rotation and translation matrices to convert from reference.
+
+ Note that this method does not take care of symmetries. If you provide the
+ atom positions in the non-standard way, the N atom will end up not at
+ [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You
+ need to take care of such cases in your code.
+
+ Args:
+ n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates.
+ ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates.
+ c_xyz: An array of shape [batch, 3] of carbon xyz coordinates.
+
+ Returns:
+ A tuple (rotation, translation) where:
+ rotation is an array of shape [batch, 3, 3] defining the rotation.
+ translation is an array of shape [batch, 3] defining the translation.
+ After applying the translation and rotation to the reference backbone,
+ the coordinates will approximately equal to the input coordinates.
+
+ The order of translation and rotation differs from make_canonical_transform
+ because the rotation from this function should be applied before the
+ translation, unlike make_canonical_transform.
+ """
+ translation, rotation = make_canonical_transform(n_xyz, ca_xyz, c_xyz)
+ return mnp.transpose(rotation, (0, 2, 1)), -translation
+
+
+def rot_to_quat(rot, unstack_inputs=False):
+ """Convert rotation matrix to quaternion.
+
+ Note that this function calls self_adjoint_eig which is extremely expensive on
+ the GPU. If at all possible, this function should run on the CPU.
+
+ Args:
+ rot: rotation matrix (see below for format).
+ unstack_inputs: If true, rotation matrix should be shape (..., 3, 3)
+ otherwise the rotation matrix should be a list of lists of tensors.
+
+ Returns:
+ Quaternion as (..., 4) tensor.
+ """
+
+ if unstack_inputs:
+ rot = mnp.transpose(rot, [2, 1, 0])
+ xx, xy, xz = rot[0][0], rot[0][1], rot[0][2]
+ yx, yy, yz = rot[1][0], rot[1][1], rot[1][2]
+ zx, zy, zz = rot[2][0], rot[2][1], rot[2][2]
+ k = mnp.stack((mnp.stack((xx + yy + zz, zy - yz, xz - zx, yx - xy), axis=-1),
+ mnp.stack((zy - yz, xx - yy - zz, xy + yx, xz + zx), axis=-1),
+ mnp.stack((xz - zx, xy + yx, yy - xx - zz, yz + zy), axis=-1),
+ mnp.stack((yx - xy, xz + zx, yz + zy, zz - xx - yy), axis=-1)), axis=-2)
+ k = (1. / 3.) * k
+
+ k = k[:, :, 0]
+ return k
+
+
+def quat_to_rot(normalized_quat):
+ """Convert a normalized quaternion to a rotation matrix."""
+ rot_tensor = mnp.sum(mnp.reshape(QUAT_TO_ROT, (4, 4, 9)) * normalized_quat[..., :, None, None] *
+ normalized_quat[..., None, :, None], axis=(-3, -2))
+ rot = mnp.moveaxis(rot_tensor, -1, 0) # Unstack.
+ return [[rot[0], rot[1], rot[2]],
+ [rot[3], rot[4], rot[5]],
+ [rot[6], rot[7], rot[8]]]
+
+
+def quat_affine(quaternion, translation, rotation=None, normalize=True, unstack_inputs=False):
+ """create quat affine representations"""
+
+ if unstack_inputs and rotation is not None:
+ rotation = mnp.transpose(rotation, [2, 1, 0])
+ translation = mnp.moveaxis(translation, -1, 0) # Unstack.
+ if normalize and quaternion is not None:
+ quaternion = quaternion / mnp.norm(quaternion, axis=-1, keepdims=True)
+
+ if rotation is None:
+ rotation = quat_to_rot(quaternion)
+
+ return quaternion, rotation, translation
+
+
+def apply_inverse_rot_to_vec(rot, vec):
+ """Multiply the inverse of a rotation matrix by a vector."""
+ # Inverse rotation is just transpose
+ return mnp.concatenate(((rot[0][0] * vec[0] + rot[1][0] * vec[1] + rot[2][0] * vec[2])[None, ...],
+ (rot[0][1] * vec[0] + rot[1][1] * vec[1] + rot[2][1] * vec[2])[None, ...],
+ (rot[0][2] * vec[0] + rot[1][2] * vec[1] + rot[2][2] * vec[2])[None, ...]), axis=0)
+
+
+def invert_point(transformed_point, rotation, translation, extra_dims=0):
+ """Apply inverse of transformation to a point.
+
+ Args:
+ transformed_point: List of 3 tensors to apply affine
+ extra_dims: Number of dimensions at the end of the transformed_point
+ shape that are not present in the rotation and translation. The most
+ common use is rotation N points at once with extra_dims=1 for use in a
+ network.
+
+ Returns:
+ Transformed point after applying affine.
+ """
+ for _ in range(extra_dims):
+ rotation = mnp.expand_dims(rotation, axis=-1)
+ translation = mnp.expand_dims(translation, axis=-1)
+ rot_point = transformed_point - translation
+ return apply_inverse_rot_to_vec(rotation, rot_point)
+
+
+def _invert_point(transformed_point, rotation, translation):
+ """Apply inverse of transformation to a point.
+
+ Args:
+ transformed_point: List of 3 tensors to apply affine
+ extra_dims: Number of dimensions at the end of the transformed_point
+ shape that are not present in the rotation and translation. The most
+ common use is rotation N points at once with extra_dims=1 for use in a
+ network.
+
+ Returns:
+ Transformed point after applying affine.
+ """
+ r00 = mnp.expand_dims(rotation[0][0], axis=-1)
+ r01 = mnp.expand_dims(rotation[0][1], axis=-1)
+ r02 = mnp.expand_dims(rotation[0][2], axis=-1)
+ r10 = mnp.expand_dims(rotation[1][0], axis=-1)
+ r11 = mnp.expand_dims(rotation[1][1], axis=-1)
+ r12 = mnp.expand_dims(rotation[1][2], axis=-1)
+ r20 = mnp.expand_dims(rotation[2][0], axis=-1)
+ r21 = mnp.expand_dims(rotation[2][1], axis=-1)
+ r22 = mnp.expand_dims(rotation[2][2], axis=-1)
+
+ t0 = mnp.expand_dims(translation[0], axis=-1)
+ t1 = mnp.expand_dims(translation[1], axis=-1)
+ t2 = mnp.expand_dims(translation[2], axis=-1)
+
+ rot_point = [transformed_point[0] - t0, transformed_point[1] - t1, transformed_point[2] - t2]
+
+ result = [r00 * rot_point[0] + r10 * rot_point[1] + r20 * rot_point[2],
+ r01 * rot_point[0] + r11 * rot_point[1] + r21 * rot_point[2],
+ r02 * rot_point[0] + r12 * rot_point[1] + r22 * rot_point[2]]
+ return result
+
+
+def mask_mean(mask, value, axis=None, drop_mask_channel=False, eps=1e-10):
+ """Masked mean."""
+ if drop_mask_channel:
+ mask = mask[..., 0]
+ mask_shape = mask.shape
+ value_shape = value.shape
+ broadcast_factor = 1.
+ value_size = value_shape[axis]
+ mask_size = mask_shape[axis]
+ if mask_size == 1:
+ broadcast_factor *= value_size
+ return mnp.sum(mask * value, axis=axis) / (mnp.sum(mask, axis=axis) * broadcast_factor + eps)
+
+
+def atom37_to_torsion_angles(
+ aatype, # (B, N)
+ all_atom_pos, # (B, N, 37, 3)
+ all_atom_mask, # (B, N, 37)
+ chi_atom_indices,
+ chi_angles_mask,
+ mirror_psi_mask,
+ chi_pi_periodic,
+ indices0,
+ indices1
+):
+ """Computes the 7 torsion angles (in sin, cos encoding) for each residue.
+
+ The 7 torsion angles are in the order
+ '[pre_omega, phi, psi, chi_1, chi_2, chi_3, chi_4]',
+ here pre_omega denotes the omega torsion angle between the given amino acid
+ and the previous amino acid.
+
+ Args:
+ aatype: Amino acid type, given as array with integers.
+ all_atom_pos: atom37 representation of all atom coordinates.
+ all_atom_mask: atom37 representation of mask on all atom coordinates.
+ placeholder_for_undefined: flag denoting whether to set masked torsion
+ angles to zero.
+ Returns:
+ Dict containing:
+ * 'torsion_angles_sin_cos': Array with shape (B, N, 7, 2) where the final
+ 2 dimensions denote sin and cos respectively
+ * 'alt_torsion_angles_sin_cos': same as 'torsion_angles_sin_cos', but
+ with the angle shifted by pi for all chi angles affected by the naming
+ ambiguities.
+ * 'torsion_angles_mask': Mask for which chi angles are present.
+ """
+
+ # Map aatype > 20 to 'Unknown' (20).
+ aatype = mnp.minimum(aatype, 20)
+
+ # Compute the backbone angles.
+ num_batch, num_res = aatype.shape
+
+ pad = mnp.zeros([num_batch, 1, 37, 3], mnp.float32)
+ prev_all_atom_pos = mnp.concatenate([pad, all_atom_pos[:, :-1, :, :]], axis=1)
+
+ pad = mnp.zeros([num_batch, 1, 37], mnp.float32)
+ prev_all_atom_mask = mnp.concatenate([pad, all_atom_mask[:, :-1, :]], axis=1)
+
+ # For each torsion angle collect the 4 atom positions that define this angle.
+ # shape (B, N, atoms=4, xyz=3)
+ pre_omega_atom_pos = mnp.concatenate([prev_all_atom_pos[:, :, 1:3, :], all_atom_pos[:, :, 0:2, :]], axis=-2)
+ phi_atom_pos = mnp.concatenate([prev_all_atom_pos[:, :, 2:3, :], all_atom_pos[:, :, 0:3, :]], axis=-2)
+ psi_atom_pos = mnp.concatenate([all_atom_pos[:, :, 0:3, :], all_atom_pos[:, :, 4:5, :]], axis=-2)
+ # # Collect the masks from these atoms.
+ # # Shape [batch, num_res]
+ # ERROR NO PROD
+ pre_omega_mask = (P.ReduceProd()(prev_all_atom_mask[:, :, 1:3], -1) # prev CA, C
+ * P.ReduceProd()(all_atom_mask[:, :, 0:2], -1)) # this N, CA
+ phi_mask = (prev_all_atom_mask[:, :, 2] # prev C
+ * P.ReduceProd()(all_atom_mask[:, :, 0:3], -1)) # this N, CA, C
+ psi_mask = (P.ReduceProd()(all_atom_mask[:, :, 0:3], -1) * # this N, CA, C
+ all_atom_mask[:, :, 4]) # this O
+ # Collect the atoms for the chi-angles.
+ # Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4].
+ # Select atoms to compute chis. Shape: [batch, num_res, chis=4, atoms=4].
+ atom_indices = mnp.take(chi_atom_indices, aatype, axis=0)
+
+ # # Gather atom positions Batch Gather. Shape: [batch, num_res, chis=4, atoms=4, xyz=3].
+
+ # 4 seq_length 4 4 batch, sequence length, chis, atoms
+ seq_length = all_atom_pos.shape[1]
+ atom_indices = atom_indices.reshape((4, seq_length, 4, 4, 1)).astype("int64")
+ new_indices = P.Concat(4)((indices0, indices1, atom_indices)) # 4, seq_length, 4, 4, 3
+ chis_atom_pos = P.GatherNd()(all_atom_pos, new_indices)
+ chis_mask = mnp.take(chi_angles_mask, aatype, axis=0)
+ chi_angle_atoms_mask = P.GatherNd()(all_atom_mask, new_indices)
+
+ # chis_atom_pos = P.GatherBatch(axis=0, batch=2)(all_atom_pos, atom_indices)
+ # chis_mask = mnp.take(chi_angles_mask, aatype, axis=0)
+ # chi_angle_atoms_mask = P.GatherBatch(axis=0, batch=2)(all_atom_mask, atom_indices)
+
+ # Check if all 4 chi angle atoms were set. Shape: [batch, num_res, chis=4].
+ chi_angle_atoms_mask = P.ReduceProd()(chi_angle_atoms_mask, -1)
+ chis_mask = chis_mask * (chi_angle_atoms_mask).astype(mnp.float32)
+
+ # Stack all torsion angle atom positions.
+ # Shape (B, N, torsions=7, atoms=4, xyz=3)ls
+ torsions_atom_pos = mnp.concatenate([pre_omega_atom_pos[:, :, None, :, :],
+ phi_atom_pos[:, :, None, :, :],
+ psi_atom_pos[:, :, None, :, :],
+ chis_atom_pos], axis=2)
+ # Stack up masks for all torsion angles.
+ # shape (B, N, torsions=7)
+ torsion_angles_mask = mnp.concatenate([pre_omega_mask[:, :, None],
+ phi_mask[:, :, None],
+ psi_mask[:, :, None],
+ chis_mask], axis=2)
+
+ torsion_frames_rots, torsion_frames_trans = r3.rigids_from_3_points(
+ torsions_atom_pos[:, :, :, 1, :],
+ torsions_atom_pos[:, :, :, 2, :],
+ torsions_atom_pos[:, :, :, 0, :])
+ inv_torsion_rots, inv_torsion_trans = r3.invert_rigids(torsion_frames_rots, torsion_frames_trans)
+ forth_atom_rel_pos = r3.rigids_mul_vecs(inv_torsion_rots, inv_torsion_trans, torsions_atom_pos[:, :, :, 3, :])
+
+ # Compute the position of the forth atom in this frame (y and z coordinate
+ torsion_angles_sin_cos = mnp.stack([forth_atom_rel_pos[..., 2], forth_atom_rel_pos[..., 1]], axis=-1)
+ torsion_angles_sin_cos /= mnp.sqrt(mnp.sum(mnp.square(torsion_angles_sin_cos), axis=-1, keepdims=True) + 1e-8)
+ # Mirror psi, because we computed it from the Oxygen-atom.
+ torsion_angles_sin_cos *= mirror_psi_mask
+ chi_is_ambiguous = mnp.take(chi_pi_periodic, aatype, axis=0)
+ mirror_torsion_angles = mnp.concatenate([mnp.ones([num_batch, num_res, 3]), 1.0 - 2.0 * chi_is_ambiguous], axis=-1)
+ alt_torsion_angles_sin_cos = (torsion_angles_sin_cos * mirror_torsion_angles[:, :, :, None])
+ return torsion_angles_sin_cos, alt_torsion_angles_sin_cos, torsion_angles_mask
+
+
+def get_chi_atom_indices():
+ """Returns atom indices needed to compute chi angles for all residue types.
+
+ Returns:
+ A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
+ in the order specified in residue_constants.restypes + unknown residue type
+ at the end. For chi angles which are not defined on the residue, the
+ positions indices are by default set to 0.
+ """
+
+ chi_atom_indices = []
+ for residue_name in residue_constants.restypes:
+ residue_name = residue_constants.restype_1to3[residue_name]
+ residue_chi_angles = residue_constants.chi_angles_atoms[residue_name]
+ atom_indices = []
+ for chi_angle in residue_chi_angles:
+ atom_indices.append([residue_constants.atom_order[atom] for atom in chi_angle])
+ for _ in range(4 - len(atom_indices)):
+ atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA.
+ chi_atom_indices.append(atom_indices)
+ chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
+ return np.asarray(chi_atom_indices)
+
+
+def to_tensor(quaternion, translation):
+ return mnp.concatenate([quaternion, translation], axis=-1)
+
+
+def from_tensor(tensor, normalize=False):
+ quaternion, tx, ty, tz = mnp.split(tensor, [4, 5, 6], axis=-1)
+ return quat_affine(quaternion, mnp.stack([tx[..., 0], ty[..., 0], tz[..., 0]], axis=-1), normalize=normalize)
+ # return quat_affine(quaternion, [tx[..., 0], ty[..., 0], tz[..., 0]], normalize=normalize)
+
+
+def generate_new_affine(sequence_mask):
+ num_residues, _ = sequence_mask.shape
+ quaternion = mnp.tile(mnp.reshape(mnp.asarray([1., 0., 0., 0.]), [1, 4]), [num_residues, 1])
+ translation = mnp.zeros([num_residues, 3])
+ return quat_affine(quaternion, translation, unstack_inputs=True)
+
+
+def pre_compose(quaternion, rotation, translation, update):
+ """Return a new QuatAffine which applies the transformation update first.
+
+ Args:
+ update: Length-6 vector. 3-vector of x, y, and z such that the quaternion
+ update is (1, x, y, z) and zero for the 3-vector is the identity
+ quaternion. 3-vector for translation concatenated.
+
+ Returns:
+ New QuatAffine object.
+ """
+
+ vector_quaternion_update, x, y, z = mnp.split(update, [3, 4, 5], axis=-1)
+ trans_update = [mnp.squeeze(x, axis=-1), mnp.squeeze(y, axis=-1), mnp.squeeze(z, axis=-1)]
+ new_quaternion = (quaternion + quat_multiply_by_vec(quaternion, vector_quaternion_update))
+ trans_update = apply_rot_to_vec(rotation, trans_update)
+ new_translation = [translation[0] + trans_update[0],
+ translation[1] + trans_update[1],
+ translation[2] + trans_update[2]]
+ return quat_affine(new_quaternion, mnp.stack(new_translation, axis=-1))
+
+
+def scale_translation(quaternion, translation, rotation, position_scale):
+ """Return a new quat affine with a different scale for translation."""
+
+ return quat_affine(quaternion,
+ mnp.stack([translation[0] * position_scale, translation[1] * position_scale,
+ translation[2] * position_scale], axis=-1),
+ rotation=rotation,
+ normalize=False)
+
+
+def rigids_from_tensor4x4(m):
+ """Construct Rigids object from an 4x4 array.
+
+ Here the 4x4 is representing the transformation in homogeneous coordinates.
+
+ Args:
+ m: Array representing transformations in homogeneous coordinates.
+ Returns:
+ Rigids object corresponding to transformations m
+ """
+ return m[..., 0, 0], m[..., 0, 1], m[..., 0, 2], m[..., 1, 0], m[..., 1, 1], m[..., 1, 2], m[..., 2, 0], \
+ m[..., 2, 1], m[..., 2, 2], m[..., 0, 3], m[..., 1, 3], m[..., 2, 3]
+
+
+def apply_to_point(rotation, translation, point):
+ """apply to point func"""
+
+ r00 = mnp.expand_dims(rotation[0][0], axis=-1)
+ r01 = mnp.expand_dims(rotation[0][1], axis=-1)
+ r02 = mnp.expand_dims(rotation[0][2], axis=-1)
+ r10 = mnp.expand_dims(rotation[1][0], axis=-1)
+ r11 = mnp.expand_dims(rotation[1][1], axis=-1)
+ r12 = mnp.expand_dims(rotation[1][2], axis=-1)
+ r20 = mnp.expand_dims(rotation[2][0], axis=-1)
+ r21 = mnp.expand_dims(rotation[2][1], axis=-1)
+ r22 = mnp.expand_dims(rotation[2][2], axis=-1)
+
+ t0 = mnp.expand_dims(translation[0], axis=-1)
+ t1 = mnp.expand_dims(translation[1], axis=-1)
+ t2 = mnp.expand_dims(translation[2], axis=-1)
+
+ p0 = point[0]
+ p1 = point[1]
+ p2 = point[2]
+ rot_point = [r00 * p0 + r01 * p1 + r02 * p2,
+ r10 * p0 + r11 * p1 + r12 * p2,
+ r20 * p0 + r21 * p1 + r22 * p2]
+ result = [rot_point[0] + t0,
+ rot_point[1] + t1,
+ rot_point[2] + t2]
+ return result
+
+
+def frames_and_literature_positions_to_atom14_pos(aatype, all_frames_to_global, restype_atom14_to_rigid_group,
+ restype_atom14_rigid_group_positions, restype_atom14_mask): # (N, 14)
+ """Put atom literature positions (atom14 encoding) in each rigid group.
+
+ Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" line 11
+
+ Args:
+ aatype: aatype for each residue.
+ all_frames_to_global: All per residue coordinate frames.
+ Returns:
+ Positions of all atom coordinates in global frame.
+ """
+
+ # Pick the appropriate transform for every atom.
+ residx_to_group_idx = P.Gather()(restype_atom14_to_rigid_group, aatype, 0)
+ group_mask = nn.OneHot(depth=8, axis=-1)(residx_to_group_idx)
+
+ # # r3.Rigids with shape (N, 14)
+ map_atoms_to_global = map_atoms_to_global_func(all_frames_to_global, group_mask)
+
+ # Gather the literature atom positions for each residue.
+ # r3.Vecs with shape (N, 14)
+ lit_positions = vecs_from_tensor(P.Gather()(restype_atom14_rigid_group_positions, aatype, 0))
+
+ # Transform each atom from its local frame to the global frame.
+ # r3.Vecs with shape (N, 14)
+ pred_positions = rigids_mul_vecs(map_atoms_to_global, lit_positions)
+
+ # Mask out non-existing atoms.
+ mask = P.Gather()(restype_atom14_mask, aatype, 0)
+
+ pred_positions = pred_map_mul(pred_positions, mask)
+
+ return pred_positions
+
+
+def pred_map_mul(pred_positions, mask):
+ return [pred_positions[0] * mask,
+ pred_positions[1] * mask,
+ pred_positions[2] * mask]
+
+
+def rots_mul_vecs(m, v):
+ """Apply rotations 'm' to vectors 'v'."""
+
+ return [m[0] * v[0] + m[1] * v[1] + m[2] * v[2],
+ m[3] * v[0] + m[4] * v[1] + m[5] * v[2],
+ m[6] * v[0] + m[7] * v[1] + m[8] * v[2]]
+
+
+def rigids_mul_vecs(r, v):
+ """Apply rigid transforms 'r' to points 'v'."""
+
+ rots = rots_mul_vecs(r, v)
+ vecs_add_r = [rots[0] + r[9],
+ rots[1] + r[10],
+ rots[2] + r[11]]
+ return vecs_add_r
+
+
+def vecs_from_tensor(x): # shape (...)
+ """Converts from tensor of shape (3,) to Vecs."""
+ # num_components = x.shape[-1]
+ # assert num_components == 3
+ return x[..., 0], x[..., 1], x[..., 2]
+
+
+def get_exp_atom_pos(atom_pos):
+ return [mnp.expand_dims(atom_pos[0], axis=0),
+ mnp.expand_dims(atom_pos[1], axis=0),
+ mnp.expand_dims(atom_pos[2], axis=0)
+ ]
+
+
+def to_tensor_new(quaternion, translation):
+ tr_new = [mnp.expand_dims(translation[0], axis=-1),
+ mnp.expand_dims(translation[1], axis=-1),
+ mnp.expand_dims(translation[2], axis=-1)]
+ return mnp.concatenate([quaternion, tr_new[0], tr_new[1], tr_new[2]], axis=-1)
+
+
+def quat_multiply_by_vec(quat, vec):
+ """Multiply a quaternion by a pure-vector quaternion."""
+
+ return mnp.sum(residue_constants.QUAT_MULTIPLY_BY_VEC * quat[..., :, None, None] * vec[..., None, :, None],
+ axis=(-3, -2))
+
+
+def rigids_mul_rots(xx, xy, xz, yx, yy, yz, zx, zy, zz, ones, zeros, cos_angles, sin_angles):
+ """Compose rigid transformations 'r' with rotations 'm'."""
+
+ c00 = xx * ones + xy * zeros + xz * zeros
+ c01 = yx * ones + yy * zeros + yz * zeros
+ c02 = zx * ones + zy * zeros + zz * zeros
+ c10 = xx * zeros + xy * cos_angles + xz * sin_angles
+ c11 = yx * zeros + yy * cos_angles + yz * sin_angles
+ c12 = zx * zeros + zy * cos_angles + zz * sin_angles
+ c20 = xx * zeros + xy * (-sin_angles) + xz * cos_angles
+ c21 = yx * zeros + yy * (-sin_angles) + yz * cos_angles
+ c22 = zx * zeros + zy * (-sin_angles) + zz * cos_angles
+ return c00, c10, c20, c01, c11, c21, c02, c12, c22
+
+
+def rigids_mul_rigids(a, b):
+ """Group composition of Rigids 'a' and 'b'."""
+
+ c00 = a[0] * b[0] + a[1] * b[3] + a[2] * b[6]
+ c01 = a[3] * b[0] + a[4] * b[3] + a[5] * b[6]
+ c02 = a[6] * b[0] + a[7] * b[3] + a[8] * b[6]
+
+ c10 = a[0] * b[1] + a[1] * b[4] + a[2] * b[7]
+ c11 = a[3] * b[1] + a[4] * b[4] + a[5] * b[7]
+ c12 = a[6] * b[1] + a[7] * b[4] + a[8] * b[7]
+
+ c20 = a[0] * b[2] + a[1] * b[5] + a[2] * b[8]
+ c21 = a[3] * b[2] + a[4] * b[5] + a[5] * b[8]
+ c22 = a[6] * b[2] + a[7] * b[5] + a[8] * b[8]
+
+ tr0 = a[0] * b[9] + a[1] * b[10] + a[2] * b[11]
+ tr1 = a[3] * b[9] + a[4] * b[10] + a[5] * b[11]
+ tr2 = a[6] * b[9] + a[7] * b[10] + a[8] * b[11]
+
+ new_tr0 = a[9] + tr0
+ new_tr1 = a[10] + tr1
+ new_tr2 = a[11] + tr2
+
+ return [c00, c10, c20, c01, c11, c21, c02, c12, c22, new_tr0, new_tr1, new_tr2]
+
+
+def rigits_concate_all(xall, x5, x6, x7):
+ return [mnp.concatenate([xall[0][:, 0:5], x5[0][:, None], x6[0][:, None], x7[0][:, None]], axis=-1),
+ mnp.concatenate([xall[1][:, 0:5], x5[1][:, None], x6[1][:, None], x7[1][:, None]], axis=-1),
+ mnp.concatenate([xall[2][:, 0:5], x5[2][:, None], x6[2][:, None], x7[2][:, None]], axis=-1),
+ mnp.concatenate([xall[3][:, 0:5], x5[3][:, None], x6[3][:, None], x7[3][:, None]], axis=-1),
+ mnp.concatenate([xall[4][:, 0:5], x5[4][:, None], x6[4][:, None], x7[4][:, None]], axis=-1),
+ mnp.concatenate([xall[5][:, 0:5], x5[5][:, None], x6[5][:, None], x7[5][:, None]], axis=-1),
+ mnp.concatenate([xall[6][:, 0:5], x5[6][:, None], x6[6][:, None], x7[6][:, None]], axis=-1),
+ mnp.concatenate([xall[7][:, 0:5], x5[7][:, None], x6[7][:, None], x7[7][:, None]], axis=-1),
+ mnp.concatenate([xall[8][:, 0:5], x5[8][:, None], x6[8][:, None], x7[8][:, None]], axis=-1),
+ mnp.concatenate([xall[9][:, 0:5], x5[9][:, None], x6[9][:, None], x7[9][:, None]], axis=-1),
+ mnp.concatenate([xall[10][:, 0:5], x5[10][:, None], x6[10][:, None], x7[10][:, None]], axis=-1),
+ mnp.concatenate([xall[11][:, 0:5], x5[11][:, None], x6[11][:, None], x7[11][:, None]], axis=-1)
+ ]
+
+
+def reshape_back(backb):
+ return [backb[0][:, None],
+ backb[1][:, None],
+ backb[2][:, None],
+ backb[3][:, None],
+ backb[4][:, None],
+ backb[5][:, None],
+ backb[6][:, None],
+ backb[7][:, None],
+ backb[8][:, None],
+ backb[9][:, None],
+ backb[10][:, None],
+ backb[11][:, None]
+ ]
+
+
+def l2_normalize(x, axis=-1):
+ return x / mnp.sqrt(mnp.sum(x ** 2, axis=axis, keepdims=True))
+
+
+def torsion_angles_to_frames(aatype, backb_to_global, torsion_angles_sin_cos, restype_rigid_group_default_frame):
+ """Compute rigid group frames from torsion angles."""
+
+ # Gather the default frames for all rigid groups.
+ m = P.Gather()(restype_rigid_group_default_frame, aatype, 0)
+
+ xx1, xy1, xz1, yx1, yy1, yz1, zx1, zy1, zz1, x1, y1, z1 = rigids_from_tensor4x4(m)
+
+ # Create the rotation matrices according to the given angles (each frame is
+ # defined such that its rotation is around the x-axis).
+ sin_angles = torsion_angles_sin_cos[..., 0]
+ cos_angles = torsion_angles_sin_cos[..., 1]
+
+ # insert zero rotation for backbone group.
+ num_residues, = aatype.shape
+ sin_angles = mnp.concatenate([mnp.zeros([num_residues, 1]), sin_angles], axis=-1)
+ cos_angles = mnp.concatenate([mnp.ones([num_residues, 1]), cos_angles], axis=-1)
+ zeros = mnp.zeros_like(sin_angles)
+ ones = mnp.ones_like(sin_angles)
+ # Apply rotations to the frames.
+ xx2, xy2, xz2, yx2, yy2, yz2, zx2, zy2, zz2 = rigids_mul_rots(xx1, xy1, xz1, yx1, yy1, yz1, zx1, zy1, zz1,
+ ones, zeros, cos_angles, sin_angles)
+ all_frames = [xx2, xy2, xz2, yx2, yy2, yz2, zx2, zy2, zz2, x1, y1, z1]
+ # chi2, chi3, and chi4 frames do not transform to the backbone frame but to
+ # the previous frame. So chain them up accordingly.
+ chi2_frame_to_frame = [xx2[:, 5], xy2[:, 5], xz2[:, 5], yx2[:, 5], yy2[:, 5], yz2[:, 5], zx2[:, 5], zy2[:, 5],
+ zz2[:, 5], x1[:, 5], y1[:, 5], z1[:, 5]]
+ chi3_frame_to_frame = [xx2[:, 6], xy2[:, 6], xz2[:, 6], yx2[:, 6], yy2[:, 6], yz2[:, 6], zx2[:, 6], zy2[:, 6],
+ zz2[:, 6], x1[:, 6], y1[:, 6], z1[:, 6]]
+ chi4_frame_to_frame = [xx2[:, 7], xy2[:, 7], xz2[:, 7], yx2[:, 7], yy2[:, 7], yz2[:, 7], zx2[:, 7], zy2[:, 7],
+ zz2[:, 7], x1[:, 7], y1[:, 7], z1[:, 7]]
+ #
+ chi1_frame_to_backb = [xx2[:, 4], xy2[:, 4], xz2[:, 4], yx2[:, 4], yy2[:, 4], yz2[:, 4], zx2[:, 4], zy2[:, 4],
+ zz2[:, 4], x1[:, 4], y1[:, 4], z1[:, 4]]
+
+ chi2_frame_to_backb = rigids_mul_rigids(chi1_frame_to_backb, chi2_frame_to_frame)
+ chi3_frame_to_backb = rigids_mul_rigids(chi2_frame_to_backb, chi3_frame_to_frame)
+ chi4_frame_to_backb = rigids_mul_rigids(chi3_frame_to_backb, chi4_frame_to_frame)
+
+ # Recombine them to a r3.Rigids with shape (N, 8).
+ all_frames_to_backb = rigits_concate_all(all_frames, chi2_frame_to_backb,
+ chi3_frame_to_backb, chi4_frame_to_backb)
+ backb_to_global_new = reshape_back(backb_to_global)
+ # Create the global frames.
+ # shape (N, 8)
+ all_frames_to_global = rigids_mul_rigids(backb_to_global_new, all_frames_to_backb)
+ # all_frames_to_global = rigids_mul_rigids(all_frames_to_backb, backb_to_global)
+ return all_frames_to_global
+
+
+def map_atoms_to_global_func(all_frames, group_mask):
+ return [mnp.sum(all_frames[0][:, None, :] * group_mask, axis=-1),
+ mnp.sum(all_frames[1][:, None, :] * group_mask, axis=-1),
+ mnp.sum(all_frames[2][:, None, :] * group_mask, axis=-1),
+ mnp.sum(all_frames[3][:, None, :] * group_mask, axis=-1),
+ mnp.sum(all_frames[4][:, None, :] * group_mask, axis=-1),
+ mnp.sum(all_frames[5][:, None, :] * group_mask, axis=-1),
+ mnp.sum(all_frames[6][:, None, :] * group_mask, axis=-1),
+ mnp.sum(all_frames[7][:, None, :] * group_mask, axis=-1),
+ mnp.sum(all_frames[8][:, None, :] * group_mask, axis=-1),
+ mnp.sum(all_frames[9][:, None, :] * group_mask, axis=-1),
+ mnp.sum(all_frames[10][:, None, :] * group_mask, axis=-1),
+ mnp.sum(all_frames[11][:, None, :] * group_mask, axis=-1)
+ ]
+
+
+def get_exp_frames(frames):
+ return [mnp.expand_dims(frames[0], axis=0),
+ mnp.expand_dims(frames[1], axis=0),
+ mnp.expand_dims(frames[2], axis=0),
+ mnp.expand_dims(frames[3], axis=0),
+ mnp.expand_dims(frames[4], axis=0),
+ mnp.expand_dims(frames[5], axis=0),
+ mnp.expand_dims(frames[6], axis=0),
+ mnp.expand_dims(frames[7], axis=0),
+ mnp.expand_dims(frames[8], axis=0),
+ mnp.expand_dims(frames[9], axis=0),
+ mnp.expand_dims(frames[10], axis=0),
+ mnp.expand_dims(frames[11], axis=0)
+ ]
+
+
+def vecs_to_tensor(v):
+ """Converts 'v' to tensor with shape 3, inverse of 'vecs_from_tensor'."""
+
+ return mnp.stack([v[0], v[1], v[2]], axis=-1)
+
+
+def atom14_to_atom37(atom14_data, residx_atom37_to_atom14, atom37_atom_exists, indices0):
+ """Convert atom14 to atom37 representation."""
+
+ seq_length = atom14_data.shape[0]
+ residx_atom37_to_atom14 = residx_atom37_to_atom14.reshape((seq_length, 37, 1))
+ new_indices = P.Concat(2)((indices0, residx_atom37_to_atom14))
+
+ atom37_data = P.GatherNd()(atom14_data, new_indices)
+ # atom37_data = P.GatherBatch()(atom14_data, residx_atom37_to_atom14)
+
+ if len(atom14_data.shape) == 2:
+ atom37_data *= atom37_atom_exists
+ elif len(atom14_data.shape) == 3:
+ atom37_data *= atom37_atom_exists[:, :, None].astype(atom37_data.dtype)
+
+ return atom37_data
+
+
+def batch_apply_rot_to_vec(rot, vec, unstack=False):
+ """Multiply rotation matrix by a vector."""
+ if unstack:
+ x, y, z = vec[:, :, 0], vec[:, :, 1], vec[:, :, 2]
+ else:
+ x, y, z = vec
+ return [(rot[:, 0, 0, :] * x + rot[:, 0, 1, :] * y + rot[:, 0, 2, :] * z)[:, None, :],
+ (rot[:, 1, 0, :] * x + rot[:, 1, 1, :] * y + rot[:, 1, 2, :] * z)[:, None, :],
+ (rot[:, 2, 0, :] * x + rot[:, 2, 1, :] * y + rot[:, 2, 2, :] * z)[:, None, :]]
+
+
+def _batch_multiply(a, b):
+ """ batch multiply operation"""
+
+ x1 = mnp.concatenate(
+ [(a[:, 0, 0, :] * b[:, 0, 0, :] + a[:, 0, 1, :] * b[:, 1, 0, :] + a[:, 0, 2, :] * b[:, 2, 0, :])[:, None, :],
+ (a[:, 0, 0, :] * b[:, 0, 1, :] + a[:, 0, 1, :] * b[:, 1, 1, :] + a[:, 0, 2, :] * b[:, 2, 1, :])[:, None, :],
+ (a[:, 0, 0, :] * b[:, 0, 2, :] + a[:, 0, 1, :] * b[:, 1, 2, :] + a[:, 0, 2, :] * b[:, 2, 2, :])[:, None, :]],
+ axis=1)[:, None, :, :]
+ x2 = mnp.concatenate(
+ [(a[:, 1, 0, :] * b[:, 0, 0, :] + a[:, 1, 1, :] * b[:, 1, 0, :] + a[:, 1, 2, :] * b[:, 2, 0, :])[:, None, :],
+ (a[:, 1, 0, :] * b[:, 0, 1, :] + a[:, 1, 1, :] * b[:, 1, 1, :] + a[:, 1, 2, :] * b[:, 2, 1, :])[:, None, :],
+ (a[:, 1, 0, :] * b[:, 0, 2, :] + a[:, 1, 1, :] * b[:, 1, 2, :] + a[:, 1, 2, :] * b[:, 2, 2, :])[:, None, :]],
+ axis=1)[:, None, :, :]
+ x3 = mnp.concatenate(
+ [(a[:, 2, 0, :] * b[:, 0, 0, :] + a[:, 2, 1, :] * b[:, 1, 0, :] + a[:, 2, 2, :] * b[:, 2, 0, :])[:, None, :],
+ (a[:, 2, 0, :] * b[:, 0, 1, :] + a[:, 2, 1, :] * b[:, 1, 1, :] + a[:, 2, 2, :] * b[:, 2, 1, :])[:, None, :],
+ (a[:, 2, 0, :] * b[:, 0, 2, :] + a[:, 2, 1, :] * b[:, 1, 2, :] + a[:, 2, 2, :] * b[:, 2, 2, :])[:, None, :]],
+ axis=1)[:, None, :, :]
+ return mnp.concatenate([x1, x2, x3], axis=1)
+
+
+def batch_make_canonical_transform(n_xyz, ca_xyz, c_xyz):
+ """Returns translation and rotation matrices to canonicalize residue atoms.
+
+ Note that this method does not take care of symmetries. If you provide the
+ atom positions in the non-standard way, the N atom will end up not at
+ [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You
+ need to take care of such cases in your code.
+
+ Args:
+ n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates.
+ ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates.
+ c_xyz: An array of shape [batch, 3] of carbon xyz coordinates.
+
+ Returns:
+ A tuple (translation, rotation) where:
+ translation is an array of shape [batch, 3] defining the translation.
+ rotation is an array of shape [batch, 3, 3] defining the rotation.
+ After applying the translation and rotation to all atoms in a residue:
+ * All atoms will be shifted so that CA is at the origin,
+ * All atoms will be rotated so that C is at the x-axis,
+ * All atoms will be shifted so that N is in the xy plane.
+ """
+ # Place CA at the origin.
+ translation = -ca_xyz
+ n_xyz = n_xyz + translation
+ c_xyz = c_xyz + translation
+
+ # Place C on the x-axis.
+ c_x, c_y, c_z = c_xyz[:, :, 0], c_xyz[:, :, 1], c_xyz[:, :, 2]
+ # Rotate by angle c1 in the x-y plane (around the z-axis).
+ sin_c1 = -c_y / mnp.sqrt(1e-20 + c_x ** 2 + c_y ** 2)
+ cos_c1 = c_x / mnp.sqrt(1e-20 + c_x ** 2 + c_y ** 2)
+ zeros = mnp.zeros_like(sin_c1).astype("float32")
+ ones = mnp.ones_like(sin_c1).astype("float32")
+ # # pylint: disable=bad-whitespace
+ c1_rot_matrix = mnp.concatenate(
+ [mnp.concatenate((cos_c1[:, None, ...], (-sin_c1)[:, None, ...], zeros[:, None, ...]), axis=1)[:, None, :, :],
+ mnp.concatenate((sin_c1[:, None, ...], cos_c1[:, None, ...], zeros[:, None, ...]), axis=1)[:, None, :, :],
+ mnp.concatenate((zeros[:, None, ...], zeros[:, None, ...], ones[:, None, ...]), axis=1)[:, None, :, :]],
+ axis=1)
+ # # Rotate by angle c2 in the x-z plane (around the y-axis).
+ sin_c2 = c_z / mnp.sqrt(1e-20 + c_x ** 2 + c_y ** 2 + c_z ** 2)
+ cos_c2 = mnp.sqrt(c_x ** 2 + c_y ** 2) / mnp.sqrt(1e-20 + c_x ** 2 + c_y ** 2 + c_z ** 2)
+
+ c2_rot_matrix = mnp.concatenate(
+ [mnp.concatenate((cos_c2[:, None, ...], zeros[:, None, ...], sin_c2[:, None, ...]), axis=1)[:, None, :, :],
+ mnp.concatenate((zeros[:, None, ...], ones[:, None, ...], zeros[:, None, ...]), axis=1)[:, None, :, :],
+ mnp.concatenate(((-sin_c2)[:, None, ...], zeros[:, None, ...], cos_c2[:, None, ...]), axis=1)[:, None, :, :]],
+ axis=1)
+ c_rot_matrix = _batch_multiply(c2_rot_matrix, c1_rot_matrix)
+ n_xyz = mnp.transpose(mnp.concatenate(batch_apply_rot_to_vec(c_rot_matrix, n_xyz, unstack=True), axis=1), (0, 2, 1))
+ # # Place N in the x-y plane.
+ _, n_y, n_z = n_xyz[:, :, 0], n_xyz[:, :, 1], n_xyz[:, :, 2]
+ # # Rotate by angle alpha in the y-z plane (around the x-axis).
+ sin_n = -n_z / mnp.sqrt(1e-20 + n_y ** 2 + n_z ** 2)
+ cos_n = n_y / mnp.sqrt(1e-20 + n_y ** 2 + n_z ** 2)
+ n_rot_matrix = mnp.concatenate(
+ [mnp.concatenate([ones[:, None, ...], zeros[:, None, ...], zeros[:, None, ...]], axis=1)[:, None, :, :],
+ mnp.concatenate([zeros[:, None, ...], cos_n[:, None, ...], (-sin_n)[:, None, ...]], axis=1)[:, None, :, :],
+ mnp.concatenate([zeros[:, None, ...], sin_n[:, None, ...], cos_n[:, None, ...]], axis=1)[:, None, :, :]],
+ axis=1)
+ return translation, mnp.transpose(_batch_multiply(n_rot_matrix, c_rot_matrix), [0, 3, 1, 2])
+
+
+def batch_make_transform_from_reference(n_xyz, ca_xyz, c_xyz):
+ """Returns rotation and translation matrices to convert from reference.
+
+ Note that this method does not take care of symmetries. If you provide the
+ atom positions in the non-standard way, the N atom will end up not at
+ [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You
+ need to take care of such cases in your code.
+
+ Args:
+ n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates.
+ ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates.
+ c_xyz: An array of shape [batch, 3] of carbon xyz coordinates.
+
+ Returns:
+ A tuple (rotation, translation) where:
+ rotation is an array of shape [batch, 3, 3] defining the rotation.
+ translation is an array of shape [batch, 3] defining the translation.
+ After applying the translation and rotation to the reference backbone,
+ the coordinates will approximately equal to the input coordinates.
+
+ The order of translation and rotation differs from make_canonical_transform
+ because the rotation from this function should be applied before the
+ translation, unlike make_canonical_transform.
+ """
+ translation, rotation = batch_make_canonical_transform(n_xyz, ca_xyz, c_xyz)
+ return mnp.transpose(rotation, (0, 1, 3, 2)), -translation
+
+
+def batch_rot_to_quat(rot, unstack_inputs=False):
+ """Convert rotation matrix to quaternion.
+
+ Note that this function calls self_adjoint_eig which is extremely expensive on
+ the GPU. If at all possible, this function should run on the CPU.
+
+ Args:
+ rot: rotation matrix (see below for format).
+ unstack_inputs: If true, rotation matrix should be shape (..., 3, 3)
+ otherwise the rotation matrix should be a list of lists of tensors.
+
+ Returns:
+ Quaternion as (..., 4) tensor.
+ """
+ if unstack_inputs:
+ rot = mnp.transpose(rot, [0, 3, 2, 1])
+
+ xx, xy, xz = rot[:, 0, 0, :], rot[:, 0, 1, :], rot[:, 0, 2, :]
+ yx, yy, yz = rot[:, 1, 0, :], rot[:, 1, 1, :], rot[:, 1, 2, :]
+ zx, zy, zz = rot[:, 2, 0, :], rot[:, 2, 1, :], rot[:, 2, 2, :]
+
+ k = mnp.stack((mnp.stack((xx + yy + zz, zy - yz, xz - zx, yx - xy), axis=-1),
+ mnp.stack((zy - yz, xx - yy - zz, xy + yx, xz + zx), axis=-1),
+ mnp.stack((xz - zx, xy + yx, yy - xx - zz, yz + zy), axis=-1),
+ mnp.stack((yx - xy, xz + zx, yz + zy, zz - xx - yy), axis=-1)), axis=-2)
+ k = (1. / 3.) * k
+
+ k = k[:, :, :, 0]
+ return k
+
+
+def batch_quat_affine(quaternion, translation, rotation=None, normalize=True, unstack_inputs=False):
+ if unstack_inputs:
+ if rotation is not None:
+ rotation = mnp.transpose(rotation, [0, 3, 2, 1])
+ translation = mnp.moveaxis(translation, -1, 1) # Unstack.
+ if normalize and quaternion is not None:
+ quaternion = quaternion / mnp.norm(quaternion, axis=-1, keepdims=True)
+
+ return quaternion, rotation, translation
+
+
+def batch_apply_inverse_rot_to_vec(rot, vec):
+ """Multiply the inverse of a rotation matrix by a vector."""
+ # Inverse rotation is just transpose
+ return mnp.concatenate(
+ ((rot[:, 0, 0, :] * vec[:, 0] + rot[:, 1, 0, :] * vec[:, 1] + rot[:, 2, 0, :] * vec[:, 2])[:, None, ...],
+ (rot[:, 0, 1, :] * vec[:, 0] + rot[:, 1, 1, :] * vec[:, 1] + rot[:, 2, 1, :] * vec[:, 2])[:, None, ...],
+ (rot[:, 0, 2, :] * vec[:, 0] + rot[:, 1, 2, :] * vec[:, 1] + rot[:, 2, 2, :] * vec[:, 2])[:, None, ...]),
+ axis=1)
+
+
+def batch_invert_point(transformed_point, rotation, translation, extra_dims=0):
+ """Apply inverse of transformation to a point.
+
+ Args:
+ transformed_point: List of 3 tensors to apply affine
+ extra_dims: Number of dimensions at the end of the transformed_point
+ shape that are not present in the rotation and translation. The most
+ common use is rotation N points at once with extra_dims=1 for use in a
+ network.
+
+ Returns:
+ Transformed point after applying affine.
+ """
+ for _ in range(extra_dims):
+ rotation = mnp.expand_dims(rotation, axis=-1)
+ translation = mnp.expand_dims(translation, axis=-1)
+ rot_point = transformed_point - translation
+ return batch_apply_inverse_rot_to_vec(rotation, rot_point)
+
+
+def compute_confidence(predicted_lddt_logits):
+ """compute confidence"""
+
+ num_bins = predicted_lddt_logits.shape[-1]
+ bin_width = 1 / num_bins
+ start_n = bin_width / 2
+ plddt = compute_plddt(predicted_lddt_logits, start_n, bin_width)
+ confidence = np.mean(plddt)
+ return confidence
+
+
+def compute_plddt(logits, start_n, bin_width):
+ """Computes per-residue pLDDT from logits.
+
+ Args:
+ logits: [num_res, num_bins] output from the PredictedLDDTHead.
+
+ Returns:
+ plddt: [num_res] per-residue pLDDT.
+ """
+ bin_centers = np.arange(start=start_n, stop=1.0, step=bin_width)
+ probs = softmax(logits, axis=-1)
+ predicted_lddt_ca = np.sum(probs * bin_centers[None, :], axis=-1)
+ return predicted_lddt_ca * 100
diff --git a/reproduce/AlphaFold2-Chinese/config/config.py b/reproduce/AlphaFold2-Chinese/config/config.py
new file mode 100644
index 0000000..de25fb9
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/config/config.py
@@ -0,0 +1,382 @@
+"""Model config."""
+
+import copy
+import ml_collections
+
+
+NUM_RES = 'num residues placeholder'
+NUM_MSA_SEQ = 'msa placeholder'
+NUM_EXTRA_SEQ = 'extra msa placeholder'
+NUM_TEMPLATES = 'num templates placeholder'
+
+
+def model_config(name: str) -> ml_collections.ConfigDict:
+ """Get the ConfigDict of a CASP14 model."""
+
+ if name not in CONFIG_DIFFS:
+ raise ValueError(f'Invalid model name {name}.')
+ cfg = copy.deepcopy(CONFIG)
+ cfg.update_from_flattened_dict(CONFIG_DIFFS[name])
+ return cfg
+
+
+CONFIG_DIFFS = {
+ 'model_1': {
+ # Jumper et al. (2021) Suppl. Table 5, Model 1.1.1
+ 'data.common.max_extra_msa': 5120,
+ 'data.common.reduce_msa_clusters_by_max_templates': True,
+ 'data.common.use_templates': True,
+ 'model.embeddings_and_evoformer.template.embed_torsion_angles': True,
+ 'model.embeddings_and_evoformer.template.enabled': True
+ },
+ 'model_2': {
+ # Jumper et al. (2021) Suppl. Table 5, Model 1.1.2
+ 'data.common.reduce_msa_clusters_by_max_templates': True,
+ 'data.common.use_templates': True,
+ 'model.embeddings_and_evoformer.template.embed_torsion_angles': True,
+ 'model.embeddings_and_evoformer.template.enabled': True
+ },
+ 'model_3': {
+ # Jumper et al. (2021) Suppl. Table 5, Model 1.2.1
+ 'data.common.max_extra_msa': 5120,
+ },
+ 'model_4': {
+ # Jumper et al. (2021) Suppl. Table 5, Model 1.2.2
+ 'data.common.max_extra_msa': 5120,
+ },
+ 'model_5': {
+ # Jumper et al. (2021) Suppl. Table 5, Model 1.2.3
+ },
+
+ # The following models are fine-tuned from the corresponding models above
+ # with an additional predicted_aligned_error head that can produce
+ # predicted TM-score (pTM) and predicted aligned errors.
+ 'model_1_ptm': {
+ 'data.common.max_extra_msa': 5120,
+ 'data.common.reduce_msa_clusters_by_max_templates': True,
+ 'data.common.use_templates': True,
+ 'model.embeddings_and_evoformer.template.embed_torsion_angles': True,
+ 'model.embeddings_and_evoformer.template.enabled': True,
+ 'model.heads.predicted_aligned_error.weight': 0.1
+ },
+ 'model_2_ptm': {
+ 'data.common.reduce_msa_clusters_by_max_templates': True,
+ 'data.common.use_templates': True,
+ 'model.embeddings_and_evoformer.template.embed_torsion_angles': True,
+ 'model.embeddings_and_evoformer.template.enabled': True,
+ 'model.heads.predicted_aligned_error.weight': 0.1
+ },
+ 'model_3_ptm': {
+ 'data.common.max_extra_msa': 5120,
+ 'model.heads.predicted_aligned_error.weight': 0.1
+ },
+ 'model_4_ptm': {
+ 'data.common.max_extra_msa': 5120,
+ 'model.heads.predicted_aligned_error.weight': 0.1
+ },
+ 'model_5_ptm': {
+ 'model.heads.predicted_aligned_error.weight': 0.1
+ }
+}
+
+CONFIG = ml_collections.ConfigDict({
+ 'data': {
+ 'common': {
+ 'masked_msa': {
+ 'profile_prob': 0.1,
+ 'same_prob': 0.1,
+ 'uniform_prob': 0.1
+ },
+ 'max_extra_msa': 1024,
+ 'msa_cluster_features': True,
+ 'num_recycle': 3,
+ 'reduce_msa_clusters_by_max_templates': False,
+ 'resample_msa_in_recycling': True,
+ 'template_features': [
+ 'template_all_atom_positions', 'template_sum_probs',
+ 'template_aatype', 'template_all_atom_masks',
+ 'template_domain_names'
+ ],
+ 'unsupervised_features': [
+ 'aatype', 'residue_index', 'sequence', 'msa', 'domain_name',
+ 'num_alignments', 'seq_length', 'between_segment_residues',
+ 'deletion_matrix'
+ ],
+ 'use_templates': False,
+ },
+ 'eval': {
+ 'feat': {
+ 'aatype': [NUM_RES],
+ 'all_atom_mask': [NUM_RES, None],
+ 'all_atom_positions': [NUM_RES, None, None],
+ 'alt_chi_angles': [NUM_RES, None],
+ 'atom14_alt_gt_exists': [NUM_RES, None],
+ 'atom14_alt_gt_positions': [NUM_RES, None, None],
+ 'atom14_atom_exists': [NUM_RES, None],
+ 'atom14_atom_is_ambiguous': [NUM_RES, None],
+ 'atom14_gt_exists': [NUM_RES, None],
+ 'atom14_gt_positions': [NUM_RES, None, None],
+ 'atom37_atom_exists': [NUM_RES, None],
+ 'backbone_affine_mask': [NUM_RES],
+ 'backbone_affine_tensor': [NUM_RES, None],
+ 'bert_mask': [NUM_MSA_SEQ, NUM_RES],
+ 'chi_angles': [NUM_RES, None],
+ 'chi_mask': [NUM_RES, None],
+ 'extra_deletion_value': [NUM_EXTRA_SEQ, NUM_RES],
+ 'extra_has_deletion': [NUM_EXTRA_SEQ, NUM_RES],
+ 'extra_msa': [NUM_EXTRA_SEQ, NUM_RES],
+ 'extra_msa_mask': [NUM_EXTRA_SEQ, NUM_RES],
+ 'extra_msa_row_mask': [NUM_EXTRA_SEQ],
+ 'is_distillation': [],
+ 'msa_feat': [NUM_MSA_SEQ, NUM_RES, None],
+ 'msa_mask': [NUM_MSA_SEQ, NUM_RES],
+ 'msa_row_mask': [NUM_MSA_SEQ],
+ 'pseudo_beta': [NUM_RES, None],
+ 'pseudo_beta_mask': [NUM_RES],
+ 'random_crop_to_size_seed': [None],
+ 'residue_index': [NUM_RES],
+ 'residx_atom14_to_atom37': [NUM_RES, None],
+ 'residx_atom37_to_atom14': [NUM_RES, None],
+ 'resolution': [],
+ 'rigidgroups_alt_gt_frames': [NUM_RES, None, None],
+ 'rigidgroups_group_exists': [NUM_RES, None],
+ 'rigidgroups_group_is_ambiguous': [NUM_RES, None],
+ 'rigidgroups_gt_exists': [NUM_RES, None],
+ 'rigidgroups_gt_frames': [NUM_RES, None, None],
+ 'seq_length': [],
+ 'seq_mask': [NUM_RES],
+ 'target_feat': [NUM_RES, None],
+ 'template_aatype': [NUM_TEMPLATES, NUM_RES],
+ 'template_all_atom_masks': [NUM_TEMPLATES, NUM_RES, None],
+ 'template_all_atom_positions': [
+ NUM_TEMPLATES, NUM_RES, None, None],
+ 'template_backbone_affine_mask': [NUM_TEMPLATES, NUM_RES],
+ 'template_backbone_affine_tensor': [
+ NUM_TEMPLATES, NUM_RES, None],
+ 'template_mask': [NUM_TEMPLATES],
+ 'template_pseudo_beta': [NUM_TEMPLATES, NUM_RES, None],
+ 'template_pseudo_beta_mask': [NUM_TEMPLATES, NUM_RES],
+ 'template_sum_probs': [NUM_TEMPLATES, None],
+ 'true_msa': [NUM_MSA_SEQ, NUM_RES]
+ },
+ 'fixed_size': True,
+ 'subsample_templates': False, # We want top templates.
+ 'masked_msa_replace_fraction': 0.15,
+ 'max_msa_clusters': 512,
+ 'max_templates': 4,
+ 'num_ensemble': 1,
+ },
+ },
+ 'model': {
+ 'embeddings_and_evoformer': {
+ 'evoformer_num_block': 48,
+ 'evoformer': {
+ 'msa_row_attention_with_pair_bias': {
+ 'dropout_rate': 0.15,
+ 'gating': True,
+ 'num_head': 8,
+ 'orientation': 'per_row',
+ 'shared_dropout': True
+ },
+ 'msa_column_attention': {
+ 'dropout_rate': 0.0,
+ 'gating': True,
+ 'num_head': 8,
+ 'orientation': 'per_column',
+ 'shared_dropout': True
+ },
+ 'msa_transition': {
+ 'dropout_rate': 0.0,
+ 'num_intermediate_factor': 4,
+ 'orientation': 'per_row',
+ 'shared_dropout': True
+ },
+ 'outer_product_mean': {
+ 'chunk_size': 128,
+ 'dropout_rate': 0.0,
+ 'num_outer_channel': 32,
+ 'orientation': 'per_row',
+ 'shared_dropout': True
+ },
+ 'triangle_attention_starting_node': {
+ 'dropout_rate': 0.25,
+ 'gating': True,
+ 'num_head': 4,
+ 'orientation': 'per_row',
+ 'shared_dropout': True
+ },
+ 'triangle_attention_ending_node': {
+ 'dropout_rate': 0.25,
+ 'gating': True,
+ 'num_head': 4,
+ 'orientation': 'per_column',
+ 'shared_dropout': True
+ },
+ 'triangle_multiplication_outgoing': {
+ 'dropout_rate': 0.25,
+ 'equation': 'ikc,jkc->ijc',
+ 'num_intermediate_channel': 128,
+ 'orientation': 'per_row',
+ 'shared_dropout': True
+ },
+ 'triangle_multiplication_incoming': {
+ 'dropout_rate': 0.25,
+ 'equation': 'kjc,kic->ijc',
+ 'num_intermediate_channel': 128,
+ 'orientation': 'per_row',
+ 'shared_dropout': True
+ },
+ 'pair_transition': {
+ 'dropout_rate': 0.0,
+ 'num_intermediate_factor': 4,
+ 'orientation': 'per_row',
+ 'shared_dropout': True
+ }
+ },
+ 'extra_msa_channel': 64,
+ 'extra_msa_stack_num_block': 4,
+ 'max_relative_feature': 32,
+ 'msa_channel': 256,
+ 'pair_channel': 128,
+ 'prev_pos': {
+ 'min_bin': 3.25,
+ 'max_bin': 20.75,
+ 'num_bins': 15
+ },
+ 'recycle_features': True,
+ 'recycle_pos': True,
+ 'seq_channel': 384,
+ 'template': {
+ 'attention': {
+ 'gating': False,
+ 'key_dim': 64,
+ 'num_head': 4,
+ 'value_dim': 64
+ },
+ 'dgram_features': {
+ 'min_bin': 3.25,
+ 'max_bin': 50.75,
+ 'num_bins': 39
+ },
+ 'embed_torsion_angles': False,
+ 'enabled': False,
+ 'template_pair_stack': {
+ 'num_block': 2,
+ 'triangle_attention_starting_node': {
+ 'dropout_rate': 0.25,
+ 'gating': True,
+ 'key_dim': 64,
+ 'num_head': 4,
+ 'orientation': 'per_row',
+ 'shared_dropout': True,
+ 'value_dim': 64
+ },
+ 'triangle_attention_ending_node': {
+ 'dropout_rate': 0.25,
+ 'gating': True,
+ 'key_dim': 64,
+ 'num_head': 4,
+ 'orientation': 'per_column',
+ 'shared_dropout': True,
+ 'value_dim': 64
+ },
+ 'triangle_multiplication_outgoing': {
+ 'dropout_rate': 0.25,
+ 'equation': 'ikc,jkc->ijc',
+ 'num_intermediate_channel': 64,
+ 'orientation': 'per_row',
+ 'shared_dropout': True
+ },
+ 'triangle_multiplication_incoming': {
+ 'dropout_rate': 0.25,
+ 'equation': 'kjc,kic->ijc',
+ 'num_intermediate_channel': 64,
+ 'orientation': 'per_row',
+ 'shared_dropout': True
+ },
+ 'pair_transition': {
+ 'dropout_rate': 0.0,
+ 'num_intermediate_factor': 2,
+ 'orientation': 'per_row',
+ 'shared_dropout': True
+ }
+ },
+ 'max_templates': 4,
+ 'subbatch_size': 128,
+ 'use_template_unit_vector': False,
+ }
+ },
+ 'heads': {
+ 'distogram': {
+ 'first_break': 2.3125,
+ 'last_break': 21.6875,
+ 'num_bins': 64,
+ 'weight': 0.3
+ },
+ 'predicted_aligned_error': {
+ # `num_bins - 1` bins uniformly space the
+ # [0, max_error_bin A] range.
+ # The final bin covers [max_error_bin A, +infty]
+ # 31A gives bins with 0.5A width.
+ 'max_error_bin': 31.,
+ 'num_bins': 64,
+ 'num_channels': 128,
+ 'filter_by_resolution': True,
+ 'min_resolution': 0.1,
+ 'max_resolution': 3.0,
+ 'weight': 0.0,
+ },
+ 'experimentally_resolved': {
+ 'filter_by_resolution': True,
+ 'max_resolution': 3.0,
+ 'min_resolution': 0.1,
+ 'weight': 0.01
+ },
+ 'structure_module': {
+ 'num_layer': 8,
+ 'fape': {
+ 'clamp_distance': 10.0,
+ 'clamp_type': 'relu',
+ 'loss_unit_distance': 10.0
+ },
+ 'angle_norm_weight': 0.01,
+ 'chi_weight': 0.5,
+ 'clash_overlap_tolerance': 1.5,
+ 'compute_in_graph_metrics': True,
+ 'dropout': 0.1,
+ 'num_channel': 384,
+ 'num_head': 12,
+ 'num_layer_in_transition': 3,
+ 'num_point_qk': 4,
+ 'num_point_v': 8,
+ 'num_scalar_qk': 16,
+ 'num_scalar_v': 16,
+ 'position_scale': 10.0,
+ 'sidechain': {
+ 'atom_clamp_distance': 10.0,
+ 'num_channel': 128,
+ 'num_residual_block': 2,
+ 'weight_frac': 0.5,
+ 'length_scale': 10.,
+ },
+ 'structural_violation_loss_weight': 1.0,
+ 'violation_tolerance_factor': 12.0,
+ 'weight': 1.0
+ },
+ 'predicted_lddt': {
+ 'filter_by_resolution': True,
+ 'max_resolution': 3.0,
+ 'min_resolution': 0.1,
+ 'num_bins': 50,
+ 'num_channels': 128,
+ 'weight': 0.01
+ },
+ 'masked_msa': {
+ 'num_output': 23,
+ 'weight': 2.0
+ },
+ },
+ 'num_recycle': 3,
+ 'resample_msa_in_recycling': True
+ },
+})
diff --git a/reproduce/AlphaFold2-Chinese/config/global_config.py b/reproduce/AlphaFold2-Chinese/config/global_config.py
new file mode 100644
index 0000000..77fb538
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/config/global_config.py
@@ -0,0 +1,341 @@
+"""Model config."""
+
+import copy
+import ml_collections
+
+def global_config(length: int) -> ml_collections.ConfigDict:
+ """Get the global config."""
+ if str(length) not in GLOBAL_CONFIG:
+ raise ValueError(f'Invalid padding sequence length {length}.')
+ cfg = copy.deepcopy(GLOBAL_CONFIG[str(length)])
+ return cfg
+
+GLOBAL_CONFIG = ml_collections.ConfigDict({
+ "256": {
+ 'zero_init': True,
+ 'seq_length': 256,
+ 'extra_msa_length': 5120,
+ 'template_embedding': {
+ 'slice_num': 0,
+ },
+ 'template_pair_stack': {
+ 'triangle_attention_starting_node': {
+ 'slice_num': 0,
+ },
+ 'triangle_attention_ending_node': {
+ 'slice_num': 0,
+ },
+ 'pair_transition': {
+ 'slice_num': 0,
+ },
+ },
+ 'extra_msa_stack': {
+ 'msa_transition': {
+ 'slice_num': 0,
+ },
+ 'msa_row_attention_with_pair_bias': {
+ 'slice_num': 0,
+ },
+ 'msa_column_global_attention': {
+ 'slice_num': 0,
+ },
+ 'outer_product_mean': {
+ 'slice_num': 0,
+ },
+ 'triangle_attention_starting_node': {
+ 'slice_num': 0,
+ },
+ 'triangle_attention_ending_node': {
+ 'slice_num': 0,
+ },
+ 'pair_transition': {
+ 'slice_num': 0,
+ },
+ },
+ 'evoformer_iteration': {
+ 'msa_transition': {
+ 'slice_num': 0,
+ },
+ 'msa_row_attention_with_pair_bias': {
+ 'slice_num': 0,
+ },
+ 'msa_column_attention': {
+ 'slice_num': 0,
+ },
+ 'outer_product_mean': {
+ 'slice_num': 0,
+ },
+ 'triangle_attention_starting_node': {
+ 'slice_num': 0,
+ },
+ 'triangle_attention_ending_node': {
+ 'slice_num': 0,
+ },
+ 'pair_transition': {
+ 'slice_num': 0,
+ },
+ },
+ },
+ "512": {
+ 'zero_init': True,
+ 'seq_length': 512,
+ 'extra_msa_length': 5120,
+ 'template_embedding': {
+ 'slice_num': 0,
+ },
+ 'template_pair_stack': {
+ 'triangle_attention_starting_node': {
+ 'slice_num': 0,
+ },
+ 'triangle_attention_ending_node': {
+ 'slice_num': 0,
+ },
+ 'pair_transition': {
+ 'slice_num': 0,
+ },
+ },
+ 'extra_msa_stack': {
+ 'msa_transition': {
+ 'slice_num': 0,
+ },
+ 'msa_row_attention_with_pair_bias': {
+ 'slice_num': 4,
+ },
+ 'msa_column_global_attention': {
+ 'slice_num': 0,
+ },
+ 'outer_product_mean': {
+ 'slice_num': 0,
+ },
+ 'triangle_attention_starting_node': {
+ 'slice_num': 0,
+ },
+ 'triangle_attention_ending_node': {
+ 'slice_num': 0,
+ },
+ 'pair_transition': {
+ 'slice_num': 0,
+ },
+ },
+ 'evoformer_iteration': {
+ 'msa_transition': {
+ 'slice_num': 0,
+ },
+ 'msa_row_attention_with_pair_bias': {
+ 'slice_num': 0,
+ },
+ 'msa_column_attention': {
+ 'slice_num': 0,
+ },
+ 'outer_product_mean': {
+ 'slice_num': 0,
+ },
+ 'triangle_attention_starting_node': {
+ 'slice_num': 0,
+ },
+ 'triangle_attention_ending_node': {
+ 'slice_num': 0,
+ },
+ 'pair_transition': {
+ 'slice_num': 0,
+ },
+ },
+ },
+ "1024": {
+ 'zero_init': True,
+ 'seq_length': 1024,
+ 'extra_msa_length': 5120,
+ 'template_embedding': {
+ 'slice_num': 4,
+ },
+ 'template_pair_stack': {
+ 'triangle_attention_starting_node': {
+ 'slice_num': 4,
+ },
+ 'triangle_attention_ending_node': {
+ 'slice_num': 4,
+ },
+ 'pair_transition': {
+ 'slice_num': 0,
+ },
+ },
+ 'extra_msa_stack': {
+ 'msa_transition': {
+ 'slice_num': 0,
+ },
+ 'msa_row_attention_with_pair_bias': {
+ 'slice_num': 16,
+ },
+ 'msa_column_global_attention': {
+ 'slice_num': 4,
+ },
+ 'outer_product_mean': {
+ 'slice_num': 0,
+ },
+ 'triangle_attention_starting_node': {
+ 'slice_num': 4,
+ },
+ 'triangle_attention_ending_node': {
+ 'slice_num': 4,
+ },
+ 'pair_transition': {
+ 'slice_num': 0,
+ },
+ },
+ 'evoformer_iteration': {
+ 'msa_transition': {
+ 'slice_num': 0,
+ },
+ 'msa_row_attention_with_pair_bias': {
+ 'slice_num': 4,
+ },
+ 'msa_column_attention': {
+ 'slice_num': 4,
+ },
+ 'outer_product_mean': {
+ 'slice_num': 0,
+ },
+ 'triangle_attention_starting_node': {
+ 'slice_num': 4,
+ },
+ 'triangle_attention_ending_node': {
+ 'slice_num': 4,
+ },
+ 'pair_transition': {
+ 'slice_num': 0,
+ },
+ },
+ },
+ "2048": {
+ 'zero_init': True,
+ 'seq_length': 2048,
+ 'extra_msa_length': 5120,
+ 'template_embedding': {
+ 'slice_num': 32,
+ },
+ 'template_pair_stack': {
+ 'triangle_attention_starting_node': {
+ 'slice_num': 32,
+ },
+ 'triangle_attention_ending_node': {
+ 'slice_num': 32,
+ },
+ 'pair_transition': {
+ 'slice_num': 16,
+ },
+ },
+
+ 'extra_msa_stack': {
+ 'msa_transition': {
+ 'slice_num': 16,
+ },
+ 'msa_row_attention_with_pair_bias': {
+ 'slice_num': 128,
+ },
+ 'msa_column_global_attention': {
+ 'slice_num': 32,
+ },
+ 'outer_product_mean': {
+ 'slice_num': 16,
+ },
+ 'triangle_attention_starting_node': {
+ 'slice_num': 32,
+ },
+ 'triangle_attention_ending_node': {
+ 'slice_num': 32,
+ },
+ 'pair_transition': {
+ 'slice_num': 16,
+ },
+ },
+ 'evoformer_iteration': {
+ 'msa_transition': {
+ 'slice_num': 16,
+ },
+ 'msa_row_attention_with_pair_bias': {
+ 'slice_num': 32,
+ },
+ 'msa_column_attention': {
+ 'slice_num': 32,
+ },
+ 'outer_product_mean': {
+ 'slice_num': 16,
+ },
+ 'triangle_attention_starting_node': {
+ 'slice_num': 32,
+ },
+ 'triangle_attention_ending_node': {
+ 'slice_num': 32,
+ },
+ 'pair_transition': {
+ 'slice_num': 16,
+ },
+ },
+ },
+ "2304": {
+ 'zero_init': True,
+ 'seq_length': 2304,
+ 'extra_msa_length': 5120,
+ 'template_embedding': {
+ 'slice_num': 64,
+ },
+ 'template_pair_stack': {
+ 'triangle_attention_starting_node': {
+ 'slice_num': 64,
+ },
+ 'triangle_attention_ending_node': {
+ 'slice_num': 64,
+ },
+ 'pair_transition': {
+ 'slice_num': 2,
+ },
+ },
+
+ 'extra_msa_stack': {
+ 'msa_transition': {
+ 'slice_num': 2,
+ },
+ 'msa_row_attention_with_pair_bias': {
+ 'slice_num': 64,
+ },
+ 'msa_column_global_attention': {
+ 'slice_num': 64,
+ },
+ 'outer_product_mean': {
+ 'slice_num': 8,
+ },
+ 'triangle_attention_starting_node': {
+ 'slice_num': 64,
+ },
+ 'triangle_attention_ending_node': {
+ 'slice_num': 64,
+ },
+ 'pair_transition': {
+ 'slice_num': 4,
+ },
+ },
+ 'evoformer_iteration': {
+ 'msa_transition': {
+ 'slice_num': 2,
+ },
+ 'msa_row_attention_with_pair_bias': {
+ 'slice_num': 64,
+ },
+ 'msa_column_attention': {
+ 'slice_num': 64,
+ },
+ 'outer_product_mean': {
+ 'slice_num': 8,
+ },
+ 'triangle_attention_starting_node': {
+ 'slice_num': 64,
+ },
+ 'triangle_attention_ending_node': {
+ 'slice_num': 64,
+ },
+ 'pair_transition': {
+ 'slice_num': 4,
+ },
+ },
+ },
+})
diff --git a/reproduce/AlphaFold2-Chinese/data/feature/data_transforms.py b/reproduce/AlphaFold2-Chinese/data/feature/data_transforms.py
new file mode 100644
index 0000000..acd8a7c
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/data/feature/data_transforms.py
@@ -0,0 +1,517 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""data transforms"""
+import numpy as np
+
+from commons import residue_constants
+
+NUM_RES = 'num residues placeholder'
+NUM_MSA_SEQ = 'msa placeholder'
+NUM_EXTRA_SEQ = 'extra msa placeholder'
+NUM_TEMPLATES = 'num templates placeholder'
+MS_MIN32 = -2147483648
+MS_MAX32 = 2147483647
+_MSA_FEATURE_NAMES = ['msa', 'deletion_matrix', 'msa_mask', 'msa_row_mask', 'bert_mask', 'true_msa']
+
+
+class SeedMaker:
+ """Return unique seeds."""
+
+ def __init__(self, initial_seed=0):
+ self.next_seed = initial_seed
+
+ def __call__(self):
+ i = self.next_seed
+ self.next_seed += 1
+ return i
+
+
+seed_maker = SeedMaker()
+
+
+def one_hot(depth, indices):
+ res = np.eye(depth)[indices.reshape(-1)]
+ return res.reshape(list(indices.shape) + [depth])
+
+
+def make_random_seed(size, seed_maker_t, low=MS_MIN32, high=MS_MAX32):
+ np.random.seed(seed_maker_t)
+ return np.random.uniform(size=size, low=low, high=high)
+
+
+def curry1(f):
+ """Supply all arguments except the first."""
+
+ def fc(*args, **kwargs):
+ return lambda x: f(x, *args, **kwargs)
+
+ return fc
+
+
+@curry1
+def compose(x, fs):
+ for f in fs:
+ x = f(x)
+ return x
+
+
+@curry1
+def randomly_replace_msa_with_unknown(protein, replace_proportion):
+ """Replace a proportion of the MSA with 'X'."""
+ msa_mask = np.random.uniform(size=shape_list(protein['msa']), low=0, high=1) < replace_proportion
+ x_idx = 20
+ gap_idx = 21
+ msa_mask = np.logical_and(msa_mask, protein['msa'] != gap_idx)
+ protein['msa'] = np.where(msa_mask, np.ones_like(protein['msa']) * x_idx, protein['msa'])
+
+ aatype_mask = np.random.uniform(size=shape_list(protein['aatype']), low=0, high=1) < replace_proportion
+ protein['aatype'] = np.where(aatype_mask, np.ones_like(protein['aatype']) * x_idx, protein['aatype'])
+
+ return protein
+
+
+@curry1
+def sample_msa(protein, max_seq, keep_extra):
+ """Sample MSA randomly, remaining sequences are stored as `extra_*`."""
+ num_seq = protein['msa'].shape[0]
+
+ shuffled = list(range(1, num_seq))
+ np.random.shuffle(shuffled)
+ shuffled.insert(0, 0)
+ index_order = np.array(shuffled, np.int32)
+ num_sel = min(max_seq, num_seq)
+
+ sel_seq = index_order[:num_sel]
+ not_sel_seq = index_order[num_sel:]
+ is_sel = num_seq - num_sel
+
+ for k in _MSA_FEATURE_NAMES:
+ if k in protein:
+ if keep_extra and not is_sel:
+ new_shape = list(protein[k].shape)
+ new_shape[0] = 1
+ protein['extra_' + k] = np.zeros(new_shape)
+ elif keep_extra and is_sel:
+ protein['extra_' + k] = protein[k][not_sel_seq]
+ if k == 'msa':
+ protein['extra_msa'] = protein['extra_msa'].astype(np.int32)
+ protein[k] = protein[k][sel_seq]
+
+ return protein
+
+
+def shaped_categorical(probs):
+ ds = shape_list(probs)
+ num_classes = ds[-1]
+ probs = np.reshape(probs, (-1, num_classes))
+ nums = list(range(num_classes))
+ counts = []
+ for prob in probs:
+ counts.append(np.random.choice(nums, p=prob))
+ return np.reshape(np.array(counts, np.int32), ds[:-1])
+
+
+@curry1
+def make_masked_msa(protein, config, replace_fraction):
+ """Create data for BERT on raw MSA."""
+ random_aa = np.array([0.05] * 20 + [0., 0.], dtype=np.float32)
+
+ categorical_probs = config.uniform_prob * random_aa + config.profile_prob * protein['hhblits_profile'] + \
+ config.same_prob * one_hot(22, protein['msa'])
+
+ pad_shapes = [[0, 0] for _ in range(len(categorical_probs.shape))]
+ pad_shapes[-1][1] = 1
+ mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob
+ assert mask_prob >= 0.
+ categorical_probs = np.pad(categorical_probs, pad_shapes, constant_values=(mask_prob,))
+
+ mask_position = np.random.uniform(size=shape_list(protein['msa']), low=0, high=1) < replace_fraction
+
+ bert_msa = shaped_categorical(categorical_probs)
+ bert_msa = np.where(mask_position, bert_msa, protein['msa'])
+
+ protein['bert_mask'] = mask_position.astype(np.int32)
+ protein['true_msa'] = protein['msa']
+ protein['msa'] = bert_msa
+
+ return protein
+
+
+@curry1
+def nearest_neighbor_clusters(protein, gap_agreement_weight=0.):
+ """Assign each extra MSA sequence to its nearest neighbor in sampled MSA."""
+ weights = np.concatenate([np.ones(21), gap_agreement_weight * np.ones(1), np.zeros(1)], 0)
+
+ sample_one_hot = protein['msa_mask'][:, :, None] * one_hot(23, protein['msa'])
+ num_seq, num_res, _ = shape_list(sample_one_hot)
+
+ array_extra_msa_mask = protein['extra_msa_mask']
+ if array_extra_msa_mask.any():
+ extra_one_hot = protein['extra_msa_mask'][:, :, None] * one_hot(23, protein['extra_msa'])
+ extra_num_seq, _, _ = shape_list(extra_one_hot)
+
+ agreement = np.matmul(
+ np.reshape(extra_one_hot, [extra_num_seq, num_res * 23]),
+ np.reshape(sample_one_hot * weights, [num_seq, num_res * 23]).T)
+ protein['extra_cluster_assignment'] = np.argmax(agreement, axis=1)
+ else:
+ protein['extra_cluster_assignment'] = np.array([])
+
+ return protein
+
+
+@curry1
+def summarize_clusters(protein):
+ """Produce profile and deletion_matrix_mean within each cluster."""
+ num_seq = shape_list(protein['msa'])[0]
+
+ def csum(x):
+ result = []
+ for i in range(num_seq):
+ result.append(np.sum(x[np.where(protein['extra_cluster_assignment'] == i)], axis=0))
+ return np.array(result)
+
+ mask = protein['extra_msa_mask']
+ mask_counts = 1e-6 + protein['msa_mask'] + csum(mask) # Include center
+
+ msa_sum = csum(mask[:, :, None] * np.zeros(mask.shape + (23,), np.float32))
+ msa_sum += one_hot(23, protein['msa']) # Original sequence
+ protein['cluster_profile'] = msa_sum / mask_counts[:, :, None]
+
+ del msa_sum
+
+ del_sum = csum(mask * protein['extra_deletion_matrix'])
+ del_sum += protein['deletion_matrix'] # Original sequence
+ protein['cluster_deletion_mean'] = del_sum / mask_counts
+ del del_sum
+
+ return protein
+
+
+@curry1
+def crop_extra_msa(protein, max_extra_msa):
+ """MSA features are cropped so only `max_extra_msa` sequences are kept."""
+ if protein['extra_msa'].any():
+ num_seq = protein['extra_msa'].shape[0]
+ num_sel = np.minimum(max_extra_msa, num_seq)
+ shuffled = list(range(num_seq))
+ np.random.shuffle(shuffled)
+ select_indices = shuffled[:num_sel]
+ for k in _MSA_FEATURE_NAMES:
+ if 'extra_' + k in protein:
+ protein['extra_' + k] = protein['extra_' + k][select_indices]
+
+ return protein
+
+
+def delete_extra_msa(protein):
+ for k in _MSA_FEATURE_NAMES:
+ if 'extra_' + k in protein:
+ del protein['extra_' + k]
+ return protein
+
+
+@curry1
+def make_msa_feat(protein):
+ """Create and concatenate MSA features."""
+ has_break = np.clip(protein['between_segment_residues'].astype(np.float32), np.array(0), np.array(1))
+ aatype_1hot = one_hot(21, protein['aatype'])
+
+ target_feat = [np.expand_dims(has_break, axis=-1), aatype_1hot]
+
+ msa_1hot = one_hot(23, protein['msa'])
+ has_deletion = np.clip(protein['deletion_matrix'], np.array(0), np.array(1))
+ deletion_value = np.arctan(protein['deletion_matrix'] / 3.) * (2. / np.pi)
+
+ msa_feat = [msa_1hot, np.expand_dims(has_deletion, axis=-1), np.expand_dims(deletion_value, axis=-1)]
+
+ if 'cluster_profile' in protein:
+ deletion_mean_value = (np.arctan(protein['cluster_deletion_mean'] / 3.) * (2. / np.pi))
+ msa_feat.extend([protein['cluster_profile'], np.expand_dims(deletion_mean_value, axis=-1)])
+
+ if 'extra_deletion_matrix' in protein:
+ protein['extra_has_deletion'] = np.clip(protein['extra_deletion_matrix'], np.array(0), np.array(1))
+ protein['extra_deletion_value'] = np.arctan(protein['extra_deletion_matrix'] / 3.) * (2. / np.pi)
+
+ protein['msa_feat'] = np.concatenate(msa_feat, axis=-1)
+ protein['target_feat'] = np.concatenate(target_feat, axis=-1)
+ return protein
+
+
+@curry1
+def select_feat(protein, feature_list):
+ return {k: v for k, v in protein.items() if k in feature_list}
+
+
+@curry1
+def random_crop_to_size(protein, crop_size, max_templates, shape_schema,
+ subsample_templates=False):
+ """Crop randomly to `crop_size`, or keep as is if shorter than that."""
+ seq_length = protein['seq_length']
+ seq_length_int = int(seq_length)
+ if 'template_mask' in protein:
+ num_templates = np.array(shape_list(protein['template_mask'])[0], np.int32)
+ else:
+ num_templates = np.array(0, np.int32)
+ num_res_crop_size = np.minimum(seq_length, crop_size)
+ num_res_crop_size_int = int(num_res_crop_size)
+
+ if subsample_templates:
+ templates_crop_start = make_random_seed(size=(), seed_maker_t=seed_maker(), low=0, high=num_templates + 1)
+ else:
+ templates_crop_start = 0
+
+ num_templates_crop_size = np.minimum(num_templates - templates_crop_start, max_templates)
+ num_templates_crop_size_int = int(num_templates_crop_size)
+
+ num_res_crop_start = int(make_random_seed(size=(), seed_maker_t=seed_maker(), low=0,
+ high=seq_length_int - num_res_crop_size_int + 1))
+
+ templates_select_indices = np.argsort(make_random_seed(size=[num_templates], seed_maker_t=seed_maker()))
+
+ for k, v in protein.items():
+ if k not in shape_schema or ('template' not in k and NUM_RES not in shape_schema[k]):
+ continue
+
+ if k.startswith('template') and subsample_templates:
+ v = v[templates_select_indices]
+
+ crop_sizes = []
+ crop_starts = []
+ for i, (dim_size, dim) in enumerate(zip(shape_schema[k], shape_list(v))):
+ is_num_res = (dim_size == NUM_RES)
+ if i == 0 and k.startswith('template'):
+ crop_size = num_templates_crop_size_int
+ crop_start = templates_crop_start
+ else:
+ crop_start = num_res_crop_start if is_num_res else 0
+ crop_size = (num_res_crop_size_int if is_num_res else (-1 if dim is None else dim))
+ crop_sizes.append(crop_size)
+ crop_starts.append(crop_start)
+ if len(v.shape) == 1:
+ protein[k] = v[crop_starts[0]:crop_starts[0] + crop_sizes[0]]
+ elif len(v.shape) == 2:
+ protein[k] = v[crop_starts[0]:crop_starts[0] + crop_sizes[0], crop_starts[1]:crop_starts[1] + crop_sizes[1]]
+ elif len(v.shape) == 3:
+ protein[k] = v[crop_starts[0]:crop_starts[0] + crop_sizes[0], crop_starts[1]:crop_starts[1] + crop_sizes[1],
+ crop_starts[2]:crop_starts[2] + crop_sizes[2]]
+ else:
+ protein[k] = v[crop_starts[0]:crop_starts[0] + crop_sizes[0], crop_starts[1]:crop_starts[1] + crop_sizes[1],
+ crop_starts[2]:crop_starts[2] + crop_sizes[2], crop_starts[3]:crop_starts[3] + crop_sizes[3]]
+
+ protein['seq_length'] = num_res_crop_size
+ return protein
+
+
+@curry1
+def make_fixed_size(protein, shape_schema, msa_cluster_size, extra_msa_size,
+ num_res, num_templates=0):
+ """Guess at the MSA and sequence dimensions to make fixed size."""
+
+ pad_size_map = {
+ NUM_RES: num_res,
+ NUM_MSA_SEQ: msa_cluster_size,
+ NUM_EXTRA_SEQ: extra_msa_size,
+ NUM_TEMPLATES: num_templates,
+ }
+
+ for k, v in protein.items():
+ if k == 'extra_cluster_assignment':
+ continue
+ shape = list(v.shape)
+ schema = shape_schema[k]
+ assert len(shape) == len(schema), f'Rank mismatch between shape and shape schema for {k}: {shape} vs {schema}'
+ pad_size = [pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)]
+ padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)]
+ if padding:
+ protein[k] = np.pad(v, padding)
+ protein[k].reshape(pad_size)
+
+ return protein
+
+
+@curry1
+def crop_templates(protein, max_templates):
+ for k, v in protein.items():
+ if k.startswith('template_'):
+ protein[k] = v[:max_templates]
+ return protein
+
+
+def correct_msa_restypes(protein):
+ """Correct MSA restype to have the same order as residue_constants."""
+ new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
+ new_order = np.array(new_order_list, dtype=protein['msa'].dtype)
+ protein['msa'] = new_order[protein['msa']]
+
+ perm_matrix = np.zeros((22, 22), dtype=np.float32)
+ perm_matrix[range(len(new_order_list)), new_order_list] = 1.
+ return protein
+
+
+@curry1
+def add_distillation_flag(protein, distillation):
+ protein['is_distillation'] = np.array(float(distillation), dtype=np.float32)
+ return protein
+
+
+def squeeze_features(protein):
+ """Remove singleton and repeated dimensions in protein features."""
+ protein['aatype'] = np.argmax(protein['aatype'], axis=-1)
+ for k in ['msa', 'num_alignments', 'seq_length', 'sequence', 'superfamily', 'deletion_matrix',
+ 'resolution', 'between_segment_residues', 'residue_index', 'template_all_atom_masks']:
+ if k in protein:
+ final_dim = shape_list(protein[k])[-1]
+ if isinstance(final_dim, int) and final_dim == 1:
+ protein[k] = np.squeeze(protein[k], axis=-1)
+
+ for k in ['seq_length', 'num_alignments']:
+ if k in protein:
+ protein[k] = protein[k][0] # Remove fake sequence dimension
+ return protein
+
+
+def cast_64bit_ints(protein):
+ for k, v in protein.items():
+ if v.dtype == np.int64:
+ protein[k] = v.astype(np.int32)
+ return protein
+
+
+def make_seq_mask(protein):
+ protein['seq_mask'] = np.ones(shape_list(protein['aatype']), dtype=np.float32)
+ return protein
+
+
+def make_msa_mask(protein):
+ """Mask features are all ones, but will later be zero-padded."""
+ protein['msa_mask'] = np.ones(shape_list(protein['msa']), dtype=np.float32)
+ protein['msa_row_mask'] = np.ones(shape_list(protein['msa'])[0], dtype=np.float32)
+ return protein
+
+
+def make_hhblits_profile(protein):
+ """Compute the HHblits MSA profile if not already present."""
+ if 'hhblits_profile' in protein:
+ return protein
+
+ protein['hhblits_profile'] = np.mean(one_hot(22, protein['msa']), axis=0)
+ return protein
+
+
+def make_random_crop_to_size_seed(protein):
+ """Random seed for cropping residues and templates."""
+ protein['random_crop_to_size_seed'] = np.array(make_random_seed([2], seed_maker_t=seed_maker()), np.int32)
+ return protein
+
+
+def fix_templates_aatype(protein):
+ """Fixes aatype encoding of templates."""
+ protein['template_aatype'] = np.argmax(protein['template_aatype'], axis=-1).astype(np.int32)
+ new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
+ new_order = np.array(new_order_list, np.int32)
+ protein['template_aatype'] = new_order[protein['template_aatype']]
+ return protein
+
+
+def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
+ """Create pseudo beta features."""
+ is_gly = np.equal(aatype, residue_constants.restype_order['G'])
+ ca_idx = residue_constants.atom_order['CA']
+ cb_idx = residue_constants.atom_order['CB']
+ pseudo_beta = np.where(
+ np.tile(is_gly[..., None].astype("int32"), [1,] * len(is_gly.shape) + [3,]).astype("bool"),
+ all_atom_positions[..., ca_idx, :],
+ all_atom_positions[..., cb_idx, :])
+ if all_atom_masks is not None:
+ pseudo_beta_mask = np.where(is_gly, all_atom_masks[..., ca_idx], all_atom_masks[..., cb_idx])
+ pseudo_beta_mask = pseudo_beta_mask.astype(np.float32)
+ return pseudo_beta, pseudo_beta_mask
+ return pseudo_beta
+
+
+@curry1
+def make_pseudo_beta(protein, prefix=''):
+ """Create pseudo-beta (alpha for glycine) position and mask."""
+ assert prefix in ['', 'template_']
+ protein[prefix + 'pseudo_beta'], protein[prefix + 'pseudo_beta_mask'] = (
+ pseudo_beta_fn(
+ protein['template_aatype' if prefix else 'all_atom_aatype'],
+ protein[prefix + 'all_atom_positions'],
+ protein['template_all_atom_masks' if prefix else 'all_atom_mask']))
+ return protein
+
+
+def make_atom14_masks(protein):
+ """Construct denser atom positions (14 dimensions instead of 37)."""
+ restype_atom14_to_atom37 = []
+ restype_atom37_to_atom14 = []
+ restype_atom14_mask = []
+
+ for rt in residue_constants.restypes:
+ atom_names = residue_constants.restype_name_to_atom14_names[residue_constants.restype_1to3[rt]]
+
+ restype_atom14_to_atom37.append([(residue_constants.atom_order[name] if name else 0) for name in atom_names])
+
+ atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
+ restype_atom37_to_atom14.append([(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
+ for name in residue_constants.atom_types])
+
+ restype_atom14_mask.append([(1. if name else 0.) for name in atom_names])
+
+ restype_atom14_to_atom37.append([0] * 14)
+ restype_atom37_to_atom14.append([0] * 37)
+ restype_atom14_mask.append([0.] * 14)
+
+ restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, np.int32)
+ restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, np.int32)
+ restype_atom14_mask = np.array(restype_atom14_mask, np.float32)
+
+ residx_atom14_to_atom37 = restype_atom14_to_atom37[protein['aatype']]
+ residx_atom14_mask = restype_atom14_mask[protein['aatype']]
+
+ protein['atom14_atom_exists'] = residx_atom14_mask
+ protein['residx_atom14_to_atom37'] = residx_atom14_to_atom37
+
+ residx_atom37_to_atom14 = restype_atom37_to_atom14[protein['aatype']]
+ protein['residx_atom37_to_atom14'] = residx_atom37_to_atom14
+
+ restype_atom37_mask = np.zeros([21, 37], np.float32)
+ for restype, restype_letter in enumerate(residue_constants.restypes):
+ restype_name = residue_constants.restype_1to3[restype_letter]
+ atom_names = residue_constants.residue_atoms[restype_name]
+ for atom_name in atom_names:
+ atom_type = residue_constants.atom_order[atom_name]
+ restype_atom37_mask[restype, atom_type] = 1
+
+ residx_atom37_mask = restype_atom37_mask[protein['aatype']]
+ protein['atom37_atom_exists'] = residx_atom37_mask
+
+ return protein
+
+
+def shape_list(x):
+ """Return list of dimensions of an array."""
+ x = np.array(x)
+
+ if x.ndim is None:
+ return x.shape
+
+ static = x.shape
+
+ ret = []
+ for _, dim in enumerate(static):
+ ret.append(dim)
+ return ret
diff --git a/reproduce/AlphaFold2-Chinese/data/feature/feature_extraction.py b/reproduce/AlphaFold2-Chinese/data/feature/feature_extraction.py
new file mode 100644
index 0000000..ad43d6b
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/data/feature/feature_extraction.py
@@ -0,0 +1,294 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""feature extraction"""
+import copy
+
+import numpy as np
+
+from commons import residue_constants
+from data.feature import data_transforms
+
+NUM_RES = "num residues placeholder"
+NUM_SEQ = "length msa placeholder"
+NUM_TEMPLATES = "num templates placeholder"
+
+FEATURES = {
+ "aatype": (np.float32, [NUM_RES, 21]),
+ "between_segment_residues": (np.int64, [NUM_RES, 1]),
+ "deletion_matrix": (np.float32, [NUM_SEQ, NUM_RES, 1]),
+ "msa": (np.int64, [NUM_SEQ, NUM_RES, 1]),
+ "num_alignments": (np.int64, [NUM_RES, 1]),
+ "residue_index": (np.int64, [NUM_RES, 1]),
+ "seq_length": (np.int64, [NUM_RES, 1]),
+ "all_atom_positions": (np.float32, [NUM_RES, residue_constants.atom_type_num, 3]),
+ "all_atom_mask": (np.int64, [NUM_RES, residue_constants.atom_type_num]),
+ "resolution": (np.float32, [1]),
+ "template_domain_names": (str, [NUM_TEMPLATES]),
+ "template_sum_probs": (np.float32, [NUM_TEMPLATES, 1]),
+ "template_aatype": (np.float32, [NUM_TEMPLATES, NUM_RES, 22]),
+ "template_all_atom_positions": (np.float32, [NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 3]),
+ "template_all_atom_masks": (np.float32, [NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 1]),
+}
+
+
+def nonensembled_map_fns(data_config):
+ """Input pipeline functions which are not ensembled."""
+ common_cfg = data_config.common
+
+ map_fns = [
+ data_transforms.correct_msa_restypes,
+ data_transforms.add_distillation_flag(False),
+ data_transforms.cast_64bit_ints,
+ data_transforms.squeeze_features,
+ data_transforms.randomly_replace_msa_with_unknown(0.0),
+ data_transforms.make_seq_mask,
+ data_transforms.make_msa_mask,
+ data_transforms.make_hhblits_profile,
+ data_transforms.make_random_crop_to_size_seed,
+ ]
+ if common_cfg.use_templates:
+ map_fns.extend([data_transforms.fix_templates_aatype, data_transforms.make_pseudo_beta('template_')])
+ map_fns.extend([data_transforms.make_atom14_masks,])
+
+ return map_fns
+
+
+def ensembled_map_fns(data_config):
+ """Input pipeline functions that can be ensembled and averaged."""
+ common_cfg = data_config.common
+ eval_cfg = data_config.eval
+
+ map_fns = []
+
+ if common_cfg.reduce_msa_clusters_by_max_templates:
+ pad_msa_clusters = eval_cfg.max_msa_clusters - eval_cfg.max_templates
+ else:
+ pad_msa_clusters = eval_cfg.max_msa_clusters
+
+ max_msa_clusters = pad_msa_clusters
+ max_extra_msa = common_cfg.max_extra_msa
+
+ map_fns.append(data_transforms.sample_msa(max_msa_clusters, keep_extra=True))
+
+ if 'masked_msa' in common_cfg:
+ map_fns.append(data_transforms.make_masked_msa(common_cfg.masked_msa, eval_cfg.masked_msa_replace_fraction))
+
+ if common_cfg.msa_cluster_features:
+ map_fns.append(data_transforms.nearest_neighbor_clusters())
+ map_fns.append(data_transforms.summarize_clusters())
+
+ if max_extra_msa:
+ map_fns.append(data_transforms.crop_extra_msa(max_extra_msa))
+ else:
+ map_fns.append(data_transforms.delete_extra_msa)
+
+ map_fns.append(data_transforms.make_msa_feat())
+
+ crop_feats = dict(eval_cfg.feat)
+
+ if eval_cfg.fixed_size:
+ map_fns.append(data_transforms.select_feat(list(crop_feats)))
+ map_fns.append(data_transforms.random_crop_to_size(
+ eval_cfg.crop_size,
+ eval_cfg.max_templates,
+ crop_feats,
+ eval_cfg.subsample_templates))
+ map_fns.append(data_transforms.make_fixed_size(
+ crop_feats,
+ pad_msa_clusters,
+ common_cfg.max_extra_msa,
+ eval_cfg.crop_size,
+ eval_cfg.max_templates))
+ else:
+ map_fns.append(data_transforms.crop_templates(eval_cfg.max_templates))
+
+ return map_fns
+
+
+def process_arrays_from_config(arrays, data_config):
+ """Apply filters and maps to an existing dataset, based on the config."""
+
+ def wrap_ensemble_fn(data, i):
+ """Function to be mapped over the ensemble dimension."""
+ d = data.copy()
+ fns = ensembled_map_fns(data_config)
+ fn = data_transforms.compose(fns)
+ d['ensemble_index'] = i
+ return fn(d)
+
+ eval_cfg = data_config.eval
+ arrays = data_transforms.compose(nonensembled_map_fns(data_config))(arrays)
+ arrays_0 = wrap_ensemble_fn(arrays, np.array(0, np.int32))
+ num_ensemble = eval_cfg.num_ensemble
+ if data_config.common.resample_msa_in_recycling:
+ num_ensemble *= data_config.common.num_recycle + 1
+
+ result_array = {x: () for x in arrays_0.keys()}
+ if num_ensemble > 1:
+ for i in range(num_ensemble):
+ arrays_t = wrap_ensemble_fn(arrays, np.array(i, np.int32))
+ for key in arrays_0.keys():
+ result_array[key] += (arrays_t[key][None],)
+ for key in arrays_0.keys():
+ result_array[key] = np.concatenate(result_array[key], axis=0)
+ else:
+ result_array = {key: arrays_0[key][None] for key in arrays_0.keys()}
+ return result_array
+
+
+def feature_shape(feature_name,
+ num_residues,
+ msa_length,
+ num_templates,
+ features=None):
+ """Get the shape for the given feature name."""
+ features = features or FEATURES
+ if feature_name.endswith("_unnormalized"):
+ feature_name = feature_name[:-13]
+
+ unused_dtype, raw_sizes = features[feature_name]
+ replacements = {NUM_RES: num_residues,
+ NUM_SEQ: msa_length}
+
+ if num_templates is not None:
+ replacements[NUM_TEMPLATES] = num_templates
+
+ sizes = [replacements.get(dimension, dimension) for dimension in raw_sizes]
+ for dimension in sizes:
+ if isinstance(dimension, str):
+ raise ValueError("Could not parse %s (shape: %s) with values: %s" % (
+ feature_name, raw_sizes, replacements))
+ size_r = [int(x) for x in sizes]
+ return size_r
+
+
+def parse_reshape_logic(parsed_features, features, num_template, key=None):
+ """Transforms parsed serial features to the correct shape."""
+ num_residues = np.reshape(parsed_features['seq_length'].astype(np.int32), (-1,))[0]
+
+ if "num_alignments" in parsed_features:
+ num_msa = np.reshape(parsed_features["num_alignments"].astype(np.int32), (-1,))[0]
+ else:
+ num_msa = 0
+
+ if key is not None and "key" in features:
+ parsed_features["key"] = [key] # Expand dims from () to (1,).
+
+ for k, v in parsed_features.items():
+ new_shape = feature_shape(
+ feature_name=k,
+ num_residues=num_residues,
+ msa_length=num_msa,
+ num_templates=num_template,
+ features=features)
+ new_shape_size = 1
+ for dim in new_shape:
+ new_shape_size *= dim
+
+ if np.size(v) != new_shape_size:
+ raise ValueError("the size of feature {} ({}) could not be reshaped into {}"
+ "".format(k, np.size(v), new_shape))
+
+ if "template" not in k:
+ if np.size(v) <= 0:
+ raise ValueError("The feature {} is not empty.".format(k))
+ parsed_features[k] = np.reshape(v, new_shape)
+
+ return parsed_features
+
+
+def _make_features_metadata(feature_names):
+ """Makes a feature name to type and shape mapping from a list of names."""
+ required_features = ["sequence", "domain_name", "template_domain_names"]
+ feature_names = list(set(feature_names) - set(required_features))
+
+ features_metadata = {name: FEATURES[name] for name in feature_names}
+ return features_metadata
+
+
+def np_to_array_dict(np_example, features):
+ """Creates dict of arrays."""
+ features_metadata = _make_features_metadata(features)
+ array_dict = {k: v for k, v in np_example.items() if k in features_metadata}
+ if "template_domain_names" in np_example:
+ num_template = len(np_example["template_domain_names"])
+ else:
+ num_template = 0
+
+ array_dict = parse_reshape_logic(array_dict, features_metadata, num_template)
+ array_dict['template_mask'] = np.ones([num_template], np.float32)
+ return array_dict
+
+
+def make_data_config(config, num_res):
+ """Makes a data config for the input pipeline."""
+ cfg = copy.deepcopy(config.data)
+
+ feature_names = cfg.common.unsupervised_features
+ if cfg.common.use_templates:
+ feature_names += cfg.common.template_features
+
+ with cfg.unlocked():
+ cfg.eval.crop_size = num_res
+
+ return cfg, feature_names
+
+
+def custom_padding(config, arrays, dims):
+ """Pad array to fixed size."""
+ step_size = config.seq_length
+
+ res_length = arrays[0].shape[dims[0]]
+ padding_size = step_size - res_length
+ for i, arr in enumerate(arrays):
+ if dims[i] == -1:
+ continue
+ extra_array_shape = list(arr.shape)
+ extra_array_shape[dims[i]] = padding_size
+ extra_array = np.zeros(extra_array_shape, dtype=arr.dtype)
+ arrays[i] = np.concatenate((arr, extra_array), axis=dims[i])
+ return arrays
+
+
+def process_features(raw_features, config, global_config):
+ """Preprocesses NumPy feature dict using pipeline."""
+ num_res = int(raw_features['seq_length'][0])
+ cfg, feature_names = make_data_config(config, num_res=num_res)
+
+ if 'deletion_matrix_int' in raw_features:
+ raw_features['deletion_matrix'] = (raw_features.pop('deletion_matrix_int').astype(np.float32))
+
+ array_dict = np_to_array_dict(np_example=raw_features, features=feature_names)
+
+ features = process_arrays_from_config(array_dict, cfg)
+ features = {k: v for k, v in features.items() if v.dtype != 'O'}
+
+ extra_msa_length = global_config.extra_msa_length
+ ori_res_length = features["target_feat"].shape[1]
+ aatype = features["aatype"]
+ residue_index = features["residue_index"]
+ for key in ["extra_msa", "extra_has_deletion", "extra_deletion_value", "extra_msa_mask"]:
+ features[key] = features[key][:, :extra_msa_length]
+ input_keys = ['target_feat', 'msa_feat', 'msa_mask', 'seq_mask', 'aatype', 'template_aatype',
+ 'template_all_atom_masks', 'template_all_atom_positions', 'template_mask',
+ 'template_pseudo_beta_mask', 'template_pseudo_beta', 'template_sum_probs',
+ 'extra_msa', 'extra_has_deletion', 'extra_deletion_value', 'extra_msa_mask',
+ 'residx_atom37_to_atom14', 'atom37_atom_exists', 'residue_index']
+ arrays = [features[key] for key in input_keys]
+ dims = [1, 2, 2, 1, 1, 2, 2, 2, -1, 2, 2, -1, 2, 2, 2, 2, 1, 1, 1]
+ arrays = custom_padding(global_config, arrays, dims)
+ arrays = [array.astype(np.float16) if array.dtype == "float64" else array for array in arrays]
+ arrays = [array.astype(np.float16) if array.dtype == "float32" else array for array in arrays]
+ return arrays, aatype, residue_index, ori_res_length
diff --git a/reproduce/AlphaFold2-Chinese/data/tools/data_process.py b/reproduce/AlphaFold2-Chinese/data/tools/data_process.py
new file mode 100644
index 0000000..18c0bfc
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/data/tools/data_process.py
@@ -0,0 +1,205 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+'''data process'''
+import os
+import hashlib
+import re
+import numpy as np
+from commons import residue_constants
+from data.tools.parsers import parse_fasta, parse_hhr, parse_a3m
+from data.tools.templates import TemplateHitFeaturizer
+from data.tools.data_tools import HHSearch
+
+def get_hash(x):
+ return hashlib.sha1(x.encode()).hexdigest()
+
+def run_mmseqs2(x, path, use_env=False):
+ '''run mmseqs2'''
+ a3m_files = [f"{path}/uniref.a3m"]
+ if use_env: a3m_files.append(f"{path}/bfd.mgnify30.metaeuk30.smag30.a3m")
+
+ # gather a3m lines
+ a3m_lines = {}
+ for a3m_file in a3m_files:
+ update_m, m = True, None
+ for line in open(a3m_file, "r"):
+ if line:
+ if "\x00" in line:
+ line = line.replace("\x00", "")
+ update_m = True
+ if line.startswith(">") and update_m:
+
+ m = int(line[2:6].rstrip())
+ update_m = False
+ if m not in a3m_lines: a3m_lines[m] = []
+ a3m_lines[m].append(line)
+
+ # return results
+ a3m_lines = ["".join(a3m_lines[key]) for key in a3m_lines]
+
+ if isinstance(x, str):
+ return a3m_lines[0]
+ return a3m_lines
+
+
+def make_sequence_features(
+ sequence: str, description: str, num_res: int):
+ """Constructs a feature dict of sequence features."""
+ features = {'aatype': residue_constants.sequence_to_onehot(sequence=sequence,
+ mapping=residue_constants.restype_order_with_x,
+ map_unknown_to_x=True),
+ 'between_segment_residues': np.zeros((num_res,), dtype=np.int32),
+ 'domain_name': np.array([description.encode('utf-8')], dtype=np.object_),
+ 'residue_index': np.array(range(num_res), dtype=np.int32),
+ 'seq_length': np.array([num_res] * num_res, dtype=np.int32),
+ 'sequence': np.array([sequence.encode('utf-8')], dtype=np.object_)}
+ return features
+
+
+def make_msa_features(
+ msas,
+ deletion_matrices):
+ """Constructs a feature dict of MSA features."""
+ if not msas:
+ raise ValueError('At least one MSA must be provided.')
+
+ int_msa = []
+ deletion_matrix = []
+ seen_sequences = set()
+ for msa_index, msa in enumerate(msas):
+ if not msa:
+ raise ValueError(f'MSA {msa_index} must contain at least one sequence.')
+ for sequence_index, sequence in enumerate(msa):
+ if sequence in seen_sequences:
+ continue
+ seen_sequences.add(sequence)
+ int_msa.append(
+ [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence])
+ deletion_matrix.append(deletion_matrices[msa_index][sequence_index])
+
+ num_res = len(msas[0][0])
+ num_alignments = len(int_msa)
+ features = {'deletion_matrix_int': np.array(deletion_matrix, dtype=np.int32),
+ 'msa': np.array(int_msa, dtype=np.int32),
+ 'num_alignments': np.array([num_alignments] * num_res, dtype=np.int32)}
+ return features
+
+
+class DataPipeline:
+ """Runs the alignment tools and assembles the input features."""
+
+ def __init__(self,
+ hhsearch_binary_path: str,
+ pdb70_database_path: str,
+ template_featurizer: TemplateHitFeaturizer,
+ result_path,
+ use_env=False):
+ """Constructs a feature dict for a given FASTA file."""
+
+ self.hhsearch_pdb70_runner = HHSearch(
+ binary_path=hhsearch_binary_path,
+ databases=[pdb70_database_path])
+ self.template_featurizer = template_featurizer
+ self.result_path = result_path
+ self.use_env = use_env
+
+ def process(self, input_fasta_path):
+ """Runs alignment tools on the input sequence and creates features."""
+ with open(input_fasta_path) as f:
+ input_fasta_str = f.read()
+ input_seqs, input_descs = parse_fasta(input_fasta_str)
+ if len(input_seqs) != 1:
+ raise ValueError(f'More than one input sequence found in {input_fasta_path}.')
+ input_sequence = input_seqs[0]
+ input_description = input_descs[0]
+
+ num_res = len(input_sequence)
+
+ # mmseq2
+ sequence = input_sequence
+ sequence = re.sub("[^A-Z:/]", "", sequence.upper())
+ sequence = re.sub(":+", ":", sequence)
+ sequence = re.sub("/+", "/", sequence)
+ sequence = re.sub("^[:/]+", "", sequence)
+ sequence = re.sub("[:/]+$", "", sequence)
+ ori_sequence = sequence
+ seqs = ori_sequence.replace("/", "").split(":")
+
+ a3m_lines = run_mmseqs2(seqs, path=self.result_path, use_env=self.use_env)
+
+ hhsearch_result = self.hhsearch_pdb70_runner.query(a3m_lines[0])
+ hhsearch_hits = parse_hhr(hhsearch_result)
+
+ msas, deletion_matrices = parse_a3m(a3m_lines[0])
+ templates_result = self.template_featurizer.get_templates(
+ query_sequence=input_sequence,
+ query_pdb_code=None,
+ query_release_date=None,
+ hhr_hits=hhsearch_hits)
+ sequence_features = make_sequence_features(
+ sequence=input_sequence,
+ description=input_description,
+ num_res=num_res)
+ msa_features = make_msa_features(
+ msas=(msas,),
+ deletion_matrices=(deletion_matrices,
+ ))
+ return {**sequence_features, **msa_features, **templates_result.features}
+
+
+def data_process(seq_name, args):
+ """data_process"""
+
+ fasta_path = os.path.join(args.input_fasta_path, seq_name + '.fasta')
+ result_path = os.path.join(args.msa_result_path, "/result_" + str(seq_name))
+ if args.database_envdb_dir:
+ use_env = True
+ command = "sh ./data/tools/msa_search.sh mmseqs " + fasta_path + " " + result_path + " " + \
+ args.database_dir + " " + "\"\"" + " " + args.database_envdb_dir + " \"1\" \"0\" \"1\""
+ else:
+ use_env = False
+ command = "sh ./data/tools/msa_search.sh mmseqs " + fasta_path + " " + result_path + " " + \
+ args.database_dir + " " + "\"\"" + " \"\"" + " \"0\" \"0\" \"1\""
+ print('start mmseqs2 MSA')
+ print('command: ', command)
+ os.system(command)
+ print('mmseqs2 MSA successful')
+ print('use_env: ', use_env)
+ hhsearch_binary_path = args.hhsearch_binary_path
+
+ pdb70_database_path = args.pdb70_database_path
+ template_mmcif_dir = args.template_mmcif_dir
+ max_template_date = args.max_template_date
+ kalign_binary_path = args.kalign_binary_path
+ obsolete_pdbs_path = args.obsolete_pdbs_path
+
+ template_featurizer = TemplateHitFeaturizer(
+ mmcif_dir=template_mmcif_dir,
+ max_template_date=max_template_date,
+ max_hits=20,
+ kalign_binary_path=kalign_binary_path,
+ release_dates_path=None,
+ obsolete_pdbs_path=obsolete_pdbs_path)
+
+ data_pipeline = DataPipeline(
+
+ hhsearch_binary_path=hhsearch_binary_path,
+ pdb70_database_path=pdb70_database_path,
+ template_featurizer=template_featurizer,
+ result_path=result_path,
+ use_env=use_env)
+
+ feature_dict = data_pipeline.process(fasta_path)
+ return feature_dict
diff --git a/reproduce/AlphaFold2-Chinese/data/tools/data_tools.py b/reproduce/AlphaFold2-Chinese/data/tools/data_tools.py
new file mode 100644
index 0000000..3b8593d
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/data/tools/data_tools.py
@@ -0,0 +1,428 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+'''data tools'''
+import glob
+import os
+import subprocess
+import contextlib
+import shutil
+import tempfile
+import time
+
+from typing import Any, Mapping, Optional, Sequence
+
+from absl import logging
+
+_HHBLITS_DEFAULT_P = 20
+_HHBLITS_DEFAULT_Z = 500
+
+
+def _to_a3m(sequences: Sequence[str]) -> str:
+ """Converts sequences to an a3m file."""
+ names = ['sequence %d' % i for i in range(1, len(sequences) + 1)]
+ a3m = []
+ for sequence, name in zip(sequences, names):
+ a3m.append(u'>' + name + u'\n')
+ a3m.append(sequence + u'\n')
+ return ''.join(a3m)
+
+
+class Kalign:
+ """Python wrapper of the Kalign binary."""
+
+ def __init__(self, *, binary_path: str):
+ """Initializes the Python Kalign wrapper.
+
+ Args:
+ binary_path: The path to the Kalign binary.
+ """
+ self.binary_path = binary_path
+
+ def align(self, sequences: Sequence[str]) -> str:
+ """Aligns the sequences and returns the alignment in A3M string.
+
+ Args:
+ sequences: A list of query sequence strings. The sequences have to be at
+ least 6 residues long (Kalign requires this). Note that the order in
+ which you give the sequences might alter the output slightly as
+ different alignment tree might get constructed.
+
+ Returns:
+ A string with the alignment in a3m format.
+
+ Raises:
+ RuntimeError: If Kalign fails.
+ ValueError: If any of the sequences is less than 6 residues long.
+ """
+ logging.info('Aligning %d sequences', len(sequences))
+
+ for s in sequences:
+ if len(s) < 6:
+ raise ValueError('Kalign requires all sequences to be at least 6 '
+ 'residues long. Got %s (%d residues).' % (s, len(s)))
+
+ with tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
+ input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta')
+ output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m')
+
+ with open(input_fasta_path, 'w') as f:
+ f.write(_to_a3m(sequences))
+
+ cmd = [self.binary_path, '-i', input_fasta_path, '-o', output_a3m_path, '-format', 'fasta',]
+
+ logging.info('Launching subprocess "%s"', ' '.join(cmd))
+ process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+
+ with timing('Kalign query'):
+ stdout, stderr = process.communicate()
+ retcode = process.wait()
+ logging.info('Kalign stdout:\n%s\n\nstderr:\n%s\n', stdout.decode('utf-8'), stderr.decode('utf-8'))
+
+ if retcode:
+ raise RuntimeError(
+ 'Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n' % (stdout.decode('utf-8'), stderr.decode('utf-8')))
+
+ with open(output_a3m_path) as f:
+ a3m = f.read()
+
+ return a3m
+
+
+@contextlib.contextmanager
+def tmpdir_manager(base_dir: Optional[str] = None):
+ """Context manager that deletes a temporary directory on exit."""
+ tmpdir = tempfile.mkdtemp(dir=base_dir)
+ try:
+ yield tmpdir
+ finally:
+ shutil.rmtree(tmpdir, ignore_errors=True)
+
+
+@contextlib.contextmanager
+def timing(msg: str):
+ logging.info('Started %s', msg)
+ tic = time.time()
+ yield
+ toc = time.time()
+ logging.info('Finished %s in %.3f seconds', msg, toc - tic)
+
+
+class HHBlits:
+ """Python wrapper of the HHblits binary."""
+
+ def __init__(self,
+ *,
+ binary_path: str,
+ databases: Sequence[str],
+ n_cpu: int = 4,
+ n_iter: int = 3,
+ e_value: float = 0.001,
+ maxseq: int = 1_000_000,
+ realign_max: int = 100_000,
+ maxfilt: int = 100_000,
+ min_prefilter_hits: int = 1000,
+ all_seqs: bool = False,
+ alt: Optional[int] = None,
+ p: int = _HHBLITS_DEFAULT_P,
+ z: int = _HHBLITS_DEFAULT_Z):
+ """Initializes the Python HHblits wrapper.
+
+ Args:
+ binary_path: The path to the HHblits executable.
+ databases: A sequence of HHblits database paths. This should be the
+ common prefix for the database files (i.e. up to but not including
+ _hhm.ffindex etc.)
+ n_cpu: The number of CPUs to give HHblits.
+ n_iter: The number of HHblits iterations.
+ e_value: The E-value, see HHblits docs for more details.
+ maxseq: The maximum number of rows in an input alignment. Note that this
+ parameter is only supported in HHBlits version 3.1 and higher.
+ realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500.
+ maxfilt: Max number of hits allowed to pass the 2nd prefilter.
+ HHblits default: 20000.
+ min_prefilter_hits: Min number of hits to pass prefilter.
+ HHblits default: 100.
+ all_seqs: Return all sequences in the MSA / Do not filter the result MSA.
+ HHblits default: False.
+ alt: Show up to this many alternative alignments.
+ p: Minimum Prob for a hit to be included in the output hhr file.
+ HHblits default: 20.
+ z: Hard cap on number of hits reported in the hhr file.
+ HHblits default: 500. NB: The relevant HHblits flag is -Z not -z.
+
+ Raises:
+ RuntimeError: If HHblits binary not found within the path.
+ """
+ self.binary_path = binary_path
+ self.databases = databases
+
+ for database_path in self.databases:
+ if not glob.glob(database_path + '_*'):
+ logging.error('Could not find HHBlits database %s', database_path)
+ raise ValueError(f'Could not find HHBlits database {database_path}')
+
+ self.n_cpu = n_cpu
+ self.n_iter = n_iter
+ self.e_value = e_value
+ self.maxseq = maxseq
+ self.realign_max = realign_max
+ self.maxfilt = maxfilt
+ self.min_prefilter_hits = min_prefilter_hits
+ self.all_seqs = all_seqs
+ self.alt = alt
+ self.p = p
+ self.z = z
+
+ def query(self, input_fasta_path: str) -> Mapping[str, Any]:
+ """Queries the database using HHblits."""
+ with tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
+ a3m_path = os.path.join(query_tmp_dir, 'output.a3m')
+
+ db_cmd = []
+ for db_path in self.databases:
+ db_cmd.append('-d')
+ db_cmd.append(db_path)
+ cmd = [
+ self.binary_path,
+ '-i', input_fasta_path,
+ '-cpu', str(self.n_cpu),
+ '-oa3m', a3m_path,
+ '-o', '/dev/null',
+ '-n', str(self.n_iter),
+ '-e', str(self.e_value),
+ '-maxseq', str(self.maxseq),
+ '-realign_max', str(self.realign_max),
+ '-maxfilt', str(self.maxfilt),
+ '-min_prefilter_hits', str(self.min_prefilter_hits)]
+ if self.all_seqs:
+ cmd += ['-all']
+ if self.alt:
+ cmd += ['-alt', str(self.alt)]
+ if self.p != _HHBLITS_DEFAULT_P:
+ cmd += ['-p', str(self.p)]
+ if self.z != _HHBLITS_DEFAULT_Z:
+ cmd += ['-Z', str(self.z)]
+ cmd += db_cmd
+
+ logging.info('Launching subprocess "%s"', ' '.join(cmd))
+ process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+
+ with timing('HHblits query'):
+ stdout, stderr = process.communicate()
+ retcode = process.wait()
+
+ if retcode:
+ # Logs have a 15k character limit, so log HHblits error line by
+ # line.
+ logging.error('HHblits failed. HHblits stderr begin:')
+ for error_line in stderr.decode('utf-8').splitlines():
+ if error_line.strip():
+ logging.error(error_line.strip())
+ logging.error('HHblits stderr end')
+ raise RuntimeError('HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n' % (
+ stdout.decode('utf-8'), stderr[:500_000].decode('utf-8')))
+
+ with open(a3m_path) as f:
+ a3m = f.read()
+
+ raw_output = dict(
+ a3m=a3m,
+ output=stdout,
+ stderr=stderr,
+ n_iter=self.n_iter,
+ e_value=self.e_value)
+ return raw_output
+
+
+class HHSearch:
+ """Python wrapper of the HHsearch binary."""
+
+ def __init__(self,
+ *,
+ binary_path: str,
+ databases: Sequence[str],
+ maxseq: int = 1_000_000):
+ """Initializes the Python HHsearch wrapper.
+
+ Args:
+ binary_path: The path to the HHsearch executable.
+ databases: A sequence of HHsearch database paths. This should be the
+ common prefix for the database files (i.e. up to but not including
+ _hhm.ffindex etc.)
+ maxseq: The maximum number of rows in an input alignment. Note that this
+ parameter is only supported in HHBlits version 3.1 and higher.
+
+ Raises:
+ RuntimeError: If HHsearch binary not found within the path.
+ """
+ self.binary_path = binary_path
+ self.databases = databases
+ self.maxseq = maxseq
+
+ for database_path in self.databases:
+ if not glob.glob(database_path + '_*'):
+ logging.error(
+ 'Could not find HHsearch database %s',
+ database_path)
+ raise ValueError(
+ f'Could not find HHsearch database {database_path}')
+
+ def query(self, a3m: str) -> str:
+ """Queries the database using HHsearch using a given a3m."""
+ with tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
+ input_path = os.path.join(query_tmp_dir, 'query.a3m')
+ hhr_path = os.path.join(query_tmp_dir, 'output.hhr')
+ with open(input_path, 'w') as f:
+ f.write(a3m)
+
+ db_cmd = []
+ for db_path in self.databases:
+ db_cmd.append('-d')
+ db_cmd.append(db_path)
+ cmd = [self.binary_path,
+ '-i', input_path,
+ '-o', hhr_path,
+ '-maxseq', str(self.maxseq),
+ '-cpu', '8',] + db_cmd
+
+ logging.info('Launching subprocess "%s"', ' '.join(cmd))
+ process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ with timing('HHsearch query'):
+ stdout, stderr = process.communicate()
+ retcode = process.wait()
+ if retcode:
+ # Stderr is truncated to prevent proto size errors in Beam.
+ raise RuntimeError('HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % (
+ stdout.decode('utf-8'), stderr[:100_000].decode('utf-8')))
+ with open(hhr_path) as f:
+ hhr = f.read()
+ return hhr
+
+
+class Jackhmmer:
+ """Python wrapper of the Jackhmmer binary."""
+
+ def __init__(self,
+ *,
+ binary_path: str,
+ database_path: str,
+ n_cpu: int = 8,
+ n_iter: int = 1,
+ e_value: float = 0.0001,
+ z_value: Optional[int] = None,
+ get_tblout: bool = False,
+ filter_f1: float = 0.0005,
+ filter_f2: float = 0.00005,
+ filter_f3: float = 0.0000005,
+ incdom_e: Optional[float] = None,
+ dom_e: Optional[float] = None):
+ """Initializes the Python Jackhmmer wrapper.
+
+ Args:
+ binary_path: The path to the jackhmmer executable.
+ database_path: The path to the jackhmmer database (FASTA format).
+ n_cpu: The number of CPUs to give Jackhmmer.
+ n_iter: The number of Jackhmmer iterations.
+ e_value: The E-value, see Jackhmmer docs for more details.
+ z_value: The Z-value, see Jackhmmer docs for more details.
+ get_tblout: Whether to save tblout string.
+ filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off.
+ filter_f2: Viterbi pre-filter, set to >1.0 to turn off.
+ filter_f3: Forward pre-filter, set to >1.0 to turn off.
+ incdom_e: Domain e-value criteria for inclusion of domains in MSA/next
+ round.
+ dom_e: Domain e-value criteria for inclusion in tblout.
+ """
+ self.binary_path = binary_path
+ self.database_path = database_path
+
+ if not os.path.exists(self.database_path):
+ logging.error(
+ 'Could not find Jackhmmer database %s',
+ database_path)
+ raise ValueError(
+ f'Could not find Jackhmmer database {database_path}')
+
+ self.n_cpu = n_cpu
+ self.n_iter = n_iter
+ self.e_value = e_value
+ self.z_value = z_value
+ self.filter_f1 = filter_f1
+ self.filter_f2 = filter_f2
+ self.filter_f3 = filter_f3
+ self.incdom_e = incdom_e
+ self.dom_e = dom_e
+ self.get_tblout = get_tblout
+
+ def query(self, input_fasta_path: str) -> Mapping[str, Any]:
+ """Queries the database using Jackhmmer."""
+ with tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
+ sto_path = os.path.join(query_tmp_dir, 'output.sto')
+
+ # The F1/F2/F3 are the expected proportion to pass each of the filtering
+ # stages (which get progressively more expensive), reducing these
+ # speeds up the pipeline at the expensive of sensitivity. They are
+ # currently set very low to make querying Mgnify run in a reasonable
+ # amount of time.
+ cmd_flags = [
+ # Don't pollute stdout with Jackhmmer output.
+ '-o', '/dev/null',
+ '-A', sto_path,
+ '--noali',
+ '--F1', str(self.filter_f1),
+ '--F2', str(self.filter_f2),
+ '--F3', str(self.filter_f3),
+ '--incE', str(self.e_value),
+ # Report only sequences with E-values <= x in per-sequence
+ # output.
+ '-E', str(self.e_value),
+ '--cpu', str(self.n_cpu),
+ '-N', str(self.n_iter)
+ ]
+ if self.get_tblout:
+ tblout_path = os.path.join(query_tmp_dir, 'tblout.txt')
+ cmd_flags.extend(['--tblout', tblout_path])
+
+ if self.z_value:
+ cmd_flags.extend(['-Z', str(self.z_value)])
+
+ if self.dom_e is not None:
+ cmd_flags.extend(['--domE', str(self.dom_e)])
+
+ if self.incdom_e is not None:
+ cmd_flags.extend(['--incdomE', str(self.incdom_e)])
+
+ cmd = [self.binary_path] + cmd_flags + [input_fasta_path, self.database_path]
+
+ logging.info('Launching subprocess "%s"', ' '.join(cmd))
+ process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ with timing(f'Jackhmmer ({os.path.basename(self.database_path)}) query'):
+ _, stderr = process.communicate()
+ retcode = process.wait()
+
+ if retcode:
+ raise RuntimeError('Jackhmmer failed\nstderr:\n%s\n' % stderr.decode('utf-8'))
+
+ # Get e-values for each target name
+ tbl = ''
+ if self.get_tblout:
+ with open(tblout_path) as f:
+ tbl = f.read()
+
+ with open(sto_path) as f:
+ sto = f.read()
+
+ raw_output = dict(sto=sto, tbl=tbl, stderr=stderr, n_iter=self.n_iter, e_value=self.e_value)
+ return raw_output
diff --git a/reproduce/AlphaFold2-Chinese/data/tools/mmcif_parsing.py b/reproduce/AlphaFold2-Chinese/data/tools/mmcif_parsing.py
new file mode 100644
index 0000000..f843704
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/data/tools/mmcif_parsing.py
@@ -0,0 +1,393 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Parses the mmCIF file format."""
+import collections
+import io
+import dataclasses
+from typing import Any, Mapping, Optional, Sequence, Tuple
+
+from absl import logging
+from Bio import PDB
+from Bio.Data import SCOPData
+
+
+# Type aliases:
+ChainId = str
+PdbHeader = Mapping[str, Any]
+PDBSTRUCTURE = PDB.Structure.Structure
+SeqRes = str
+MmCIFDict = Mapping[str, Sequence[str]]
+
+
+@dataclasses.dataclass(frozen=True)
+class Monomer:
+ id: str
+ num: int
+
+
+# Note - mmCIF format provides no guarantees on the type of author-assigned
+# sequence numbers. They need not be integers.
+@dataclasses.dataclass(frozen=True)
+class AtomSite:
+ residue_name: str
+ author_chain_id: str
+ mmcif_chain_id: str
+ author_seq_num: str
+ mmcif_seq_num: int
+ insertion_code: str
+ hetatm_atom: str
+ model_num: int
+
+
+# Used to map SEQRES index to a residue in the structure.
+@dataclasses.dataclass(frozen=True)
+class ResiduePosition:
+ chain_id: str
+ residue_number: int
+ insertion_code: str
+
+
+@dataclasses.dataclass(frozen=True)
+class ResidueAtPosition:
+ position: Optional[ResiduePosition]
+ name: str
+ is_missing: bool
+ hetflag: str
+
+
+@dataclasses.dataclass(frozen=True)
+class MmcifObject:
+ """Representation of a parsed mmCIF file.
+
+ Contains:
+ file_id: A meaningful name, e.g. a pdb_id. Should be unique amongst all
+ files being processed.
+ header: Biopython header.
+ structure: Biopython structure.
+ chain_to_seqres: Dict mapping chain_id to 1 letter amino acid sequence. E.g.
+ {'A': 'ABCDEFG'}
+ seqres_to_structure: Dict; for each chain_id contains a mapping between
+ SEQRES index and a ResidueAtPosition. e.g. {'A': {0: ResidueAtPosition,
+ 1: ResidueAtPosition,
+ ...}}
+ raw_string: The raw string used to construct the MmcifObject.
+ """
+ file_id: str
+ header: PdbHeader
+ structure: PDBSTRUCTURE
+ chain_to_seqres: Mapping[ChainId, SeqRes]
+ seqres_to_structure: Mapping[ChainId, Mapping[int, ResidueAtPosition]]
+ raw_string: Any
+
+
+@dataclasses.dataclass(frozen=True)
+class ParsingResult:
+ """Returned by the parse function.
+
+ Contains:
+ mmcif_object: A MmcifObject, may be None if no chain could be successfully
+ parsed.
+ errors: A dict mapping (file_id, chain_id) to any exception generated.
+ """
+ mmcif_object: Optional[MmcifObject]
+ errors: Mapping[Tuple[str, str], Any]
+
+
+class ParseError(Exception):
+ """An error indicating that an mmCIF file could not be parsed."""
+
+
+def mmcif_loop_to_list(prefix: str,
+ parsed_info: MmCIFDict) -> Sequence[Mapping[str, str]]:
+ """Extracts loop associated with a prefix from mmCIF data as a list.
+
+ Reference for loop_ in mmCIF:
+ http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html
+
+ Args:
+ prefix: Prefix shared by each of the data items in the loop.
+ e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
+ _entity_poly_seq.mon_id. Should include the trailing period.
+ parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
+ parser.
+
+ Returns:
+ Returns a list of dicts; each dict represents 1 entry from an mmCIF loop.
+ """
+ cols = []
+ data = []
+ for key, value in parsed_info.items():
+ if key.startswith(prefix):
+ cols.append(key)
+ data.append(value)
+
+ assert all([len(xs) == len(data[0]) for xs in data]), (
+ 'mmCIF error: Not all loops are the same length: %s' % cols)
+
+ return [dict(zip(cols, xs)) for xs in zip(*data)]
+
+
+def mmcif_loop_to_dict(prefix: str,
+ index: str,
+ parsed_info: MmCIFDict,
+ ) -> Mapping[str, Mapping[str, str]]:
+ """Extracts loop associated with a prefix from mmCIF data as a dictionary.
+
+ Args:
+ prefix: Prefix shared by each of the data items in the loop.
+ e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
+ _entity_poly_seq.mon_id. Should include the trailing period.
+ index: Which item of loop data should serve as the key.
+ parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
+ parser.
+
+ Returns:
+ Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop,
+ indexed by the index column.
+ """
+ entries = mmcif_loop_to_list(prefix, parsed_info)
+ return {entry[index]: entry for entry in entries}
+
+
+def parse(*,
+ file_id: str,
+ mmcif_string: str,
+ catch_all_errors: bool = True) -> ParsingResult:
+ """Entry point, parses an mmcif_string.
+
+ Args:
+ file_id: A string identifier for this file. Should be unique within the
+ collection of files being processed.
+ mmcif_string: Contents of an mmCIF file.
+ catch_all_errors: If True, all exceptions are caught and error messages are
+ returned as part of the ParsingResult. If False exceptions will be allowed
+ to propagate.
+
+ Returns:
+ A ParsingResult.
+ """
+ errors = {}
+ try:
+ parser = PDB.MMCIFParser(QUIET=True)
+ handle = io.StringIO(mmcif_string)
+ full_structure = parser.get_structure('', handle)
+ first_model_structure = _get_first_model(full_structure)
+ # Extract the _mmcif_dict from the parser, which contains useful fields not
+ # reflected in the Biopython structure.
+ parsed_info = parser._mmcif_dict # pylint:disable=protected-access
+
+ # Ensure all values are lists, even if singletons.
+ for key, value in parsed_info.items():
+ if not isinstance(value, list):
+ parsed_info[key] = [value]
+
+ header = _get_header(parsed_info)
+
+ # Determine the protein chains, and their start numbers according to the
+ # internal mmCIF numbering scheme (likely but not guaranteed to be 1).
+ valid_chains = _get_protein_chains(parsed_info=parsed_info)
+ if not valid_chains:
+ return ParsingResult(
+ None, {(file_id, ''): 'No protein chains found in this file.'})
+ seq_start_num = {chain_id: min([monomer.num for monomer in seq])
+ for chain_id, seq in valid_chains.items()}
+
+ # Loop over the atoms for which we have coordinates. Populate two mappings:
+ # -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used
+ # the authors / Biopython).
+ # -seq_to_structure_mappings (maps idx into sequence to ResidueAtPosition).
+ mmcif_to_author_chain_id = {}
+ seq_to_structure_mappings = {}
+ for atom in _get_atom_site_list(parsed_info):
+ if atom.model_num != '1':
+ # We only process the first model at the moment.
+ continue
+
+ mmcif_to_author_chain_id[atom.mmcif_chain_id] = atom.author_chain_id
+
+ if atom.mmcif_chain_id in valid_chains:
+ hetflag = ' '
+ if atom.hetatm_atom == 'HETATM':
+ # Water atoms are assigned a special hetflag of W in Biopython. We
+ # need to do the same, so that this hetflag can be used to fetch
+ # a residue from the Biopython structure by id.
+ if atom.residue_name in ('HOH', 'WAT'):
+ hetflag = 'W'
+ else:
+ hetflag = 'H_' + atom.residue_name
+ insertion_code = atom.insertion_code
+ if not _is_set(atom.insertion_code):
+ insertion_code = ' '
+ position = ResiduePosition(
+ chain_id=atom.author_chain_id, residue_number=int(
+ atom.author_seq_num), insertion_code=insertion_code)
+ seq_idx = int(atom.mmcif_seq_num) - \
+ seq_start_num[atom.mmcif_chain_id]
+ current = seq_to_structure_mappings.get(
+ atom.author_chain_id, {})
+ current[seq_idx] = ResidueAtPosition(position=position,
+ name=atom.residue_name,
+ is_missing=False,
+ hetflag=hetflag)
+ seq_to_structure_mappings[atom.author_chain_id] = current
+
+ # Add missing residue information to seq_to_structure_mappings.
+ for chain_id, seq_info in valid_chains.items():
+ author_chain = mmcif_to_author_chain_id[chain_id]
+ current_mapping = seq_to_structure_mappings[author_chain]
+ for idx, monomer in enumerate(seq_info):
+ if idx not in current_mapping:
+ current_mapping[idx] = ResidueAtPosition(position=None,
+ name=monomer.id,
+ is_missing=True,
+ hetflag=' ')
+
+ author_chain_to_sequence = {}
+ for chain_id, seq_info in valid_chains.items():
+ author_chain = mmcif_to_author_chain_id[chain_id]
+ seq = []
+ for monomer in seq_info:
+ code = SCOPData.protein_letters_3to1.get(monomer.id, 'X')
+ seq.append(code if len(code) == 1 else 'X')
+ seq = ''.join(seq)
+ author_chain_to_sequence[author_chain] = seq
+
+ mmcif_object = MmcifObject(
+ file_id=file_id,
+ header=header,
+ structure=first_model_structure,
+ chain_to_seqres=author_chain_to_sequence,
+ seqres_to_structure=seq_to_structure_mappings,
+ raw_string=parsed_info)
+
+ return ParsingResult(mmcif_object=mmcif_object, errors=errors)
+ except Exception as e: # pylint:disable=broad-except
+ errors[(file_id, '')] = e
+ if not catch_all_errors:
+ raise
+ return ParsingResult(mmcif_object=None, errors=errors)
+
+
+def _get_first_model(structure: PDBSTRUCTURE) -> PDBSTRUCTURE:
+ """Returns the first model in a Biopython structure."""
+ return next(structure.get_models())
+
+
+_MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE = 21
+
+
+def get_release_date(parsed_info: MmCIFDict) -> str:
+ """Returns the oldest revision date."""
+ revision_dates = parsed_info['_pdbx_audit_revision_history.revision_date']
+ return min(revision_dates)
+
+
+def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
+ """Returns a basic header containing method, release date and resolution."""
+ header = {}
+
+ experiments = mmcif_loop_to_list('_exptl.', parsed_info)
+ header['structure_method'] = ','.join([
+ experiment['_exptl.method'].lower() for experiment in experiments])
+
+ # Note: The release_date here corresponds to the oldest revision. We prefer to
+ # use this for dataset filtering over the deposition_date.
+ if '_pdbx_audit_revision_history.revision_date' in parsed_info:
+ header['release_date'] = get_release_date(parsed_info)
+ else:
+ logging.warning('Could not determine release_date: %s',
+ parsed_info['_entry.id'])
+
+ header['resolution'] = 0.00
+ for res_key in ('_refine.ls_d_res_high', '_em_3d_reconstruction.resolution', '_reflns.d_resolution_high'):
+ if res_key in parsed_info:
+ try:
+ raw_resolution = parsed_info[res_key][0]
+ header['resolution'] = float(raw_resolution)
+ except ValueError:
+ logging.warning(
+ 'Invalid resolution format: %s',
+ parsed_info[res_key])
+
+ return header
+
+
+def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]:
+ """Returns list of atom sites; contains data not present in the structure."""
+ return [AtomSite(*site) for site in zip( # pylint:disable=g-complex-comprehension
+ parsed_info['_atom_site.label_comp_id'],
+ parsed_info['_atom_site.auth_asym_id'],
+ parsed_info['_atom_site.label_asym_id'],
+ parsed_info['_atom_site.auth_seq_id'],
+ parsed_info['_atom_site.label_seq_id'],
+ parsed_info['_atom_site.pdbx_PDB_ins_code'],
+ parsed_info['_atom_site.group_PDB'],
+ parsed_info['_atom_site.pdbx_PDB_model_num'],
+ )]
+
+
+def _get_protein_chains(*,
+ parsed_info: Mapping[str,
+ Any]) -> Mapping[ChainId,
+ Sequence[Monomer]]:
+ """Extracts polymer information for protein chains only.
+
+ Args:
+ parsed_info: _mmcif_dict produced by the Biopython parser.
+
+ Returns:
+ A dict mapping mmcif chain id to a list of Monomers.
+ """
+ # Get polymer information for each entity in the structure.
+ entity_poly_seqs = mmcif_loop_to_list('_entity_poly_seq.', parsed_info)
+
+ polymers = collections.defaultdict(list)
+ for entity_poly_seq in entity_poly_seqs:
+ polymers[entity_poly_seq['_entity_poly_seq.entity_id']].append(
+ Monomer(id=entity_poly_seq['_entity_poly_seq.mon_id'],
+ num=int(entity_poly_seq['_entity_poly_seq.num'])))
+
+ # Get chemical compositions. Will allow us to identify which of these polymers
+ # are proteins.
+ chem_comps = mmcif_loop_to_dict(
+ '_chem_comp.', '_chem_comp.id', parsed_info)
+
+ # Get chains information for each entity. Necessary so that we can return a
+ # dict keyed on chain id rather than entity.
+ struct_asyms = mmcif_loop_to_list('_struct_asym.', parsed_info)
+
+ entity_to_mmcif_chains = collections.defaultdict(list)
+ for struct_asym in struct_asyms:
+ chain_id = struct_asym['_struct_asym.id']
+ entity_id = struct_asym['_struct_asym.entity_id']
+ entity_to_mmcif_chains[entity_id].append(chain_id)
+
+ # Identify and return the valid protein chains.
+ valid_chains = {}
+ for entity_id, seq_info in polymers.items():
+ chain_ids = entity_to_mmcif_chains[entity_id]
+
+ # Reject polymers without any peptide-like components, such as DNA/RNA.
+ if any(['peptide' in chem_comps[monomer.id]['_chem_comp.type']
+ for monomer in seq_info]):
+ for chain_id in chain_ids:
+ valid_chains[chain_id] = seq_info
+ return valid_chains
+
+
+def _is_set(data: str) -> bool:
+ """Returns False if data is a special mmCIF character indicating 'unset'."""
+ return data not in ('.', '?')
diff --git a/reproduce/AlphaFold2-Chinese/data/tools/msa_search.sh b/reproduce/AlphaFold2-Chinese/data/tools/msa_search.sh
new file mode 100644
index 0000000..2e2f63c
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/data/tools/msa_search.sh
@@ -0,0 +1,61 @@
+#!/bin/bash -e
+MMSEQS="$1" #mmseqs
+QUERY="$2" #"/path/QUERY.fasta"
+BASE="$3" #"./result/"
+DB1="$4" #"uniref30_2103_db"
+DB2="$5" #""
+DB3="$6" #"colabfold_envdb_202108_db"
+USE_ENV="$7" #1
+USE_TEMPLATES="$8" #0
+FILTER="${9}" #1
+EXPAND_EVAL=inf
+ALIGN_EVAL=10
+DIFF=3000
+QSC=-20.0
+MAX_ACCEPT=1000000
+time=$(date )
+echo "${time}"
+if [ "${FILTER}" = "1" ]; then
+# 0.1 was not used in benchmarks due to POSIX shell bug in line above
+# EXPAND_EVAL=0.1
+ ALIGN_EVAL=10
+ QSC=0.8
+ MAX_ACCEPT=100000
+fi
+export MMSEQS_CALL_DEPTH=1
+SEARCH_PARAM="--num-iterations 3 --db-load-mode 2 -a -s 8 -e 0.1 --max-seqs 10000"
+FILTER_PARAM="--filter-msa ${FILTER} --filter-min-enable 1000 --diff ${DIFF} --qid 0.0,0.2,0.4,0.6,0.8,1.0 --qsc 0 --max-seq-id 0.95"
+EXPAND_PARAM="--expansion-mode 0 -e ${EXPAND_EVAL} --expand-filter-clusters ${FILTER} --max-seq-id 0.95"
+mkdir -p "${BASE}"
+"${MMSEQS}" createdb "${QUERY}" "${BASE}/qdb"
+"${MMSEQS}" search "${BASE}/qdb" "${DB1}" "${BASE}/res" "${BASE}/tmp" $SEARCH_PARAM
+"${MMSEQS}" expandaln "${BASE}/qdb" "${DB1}.idx" "${BASE}/res" "${DB1}.idx" "${BASE}/res_exp" --db-load-mode 2 ${EXPAND_PARAM}
+
+"${MMSEQS}" mvdb "${BASE}/tmp/latest/profile_1" "${BASE}/prof_res"
+"${MMSEQS}" lndb "${BASE}/qdb_h" "${BASE}/prof_res_h"
+"${MMSEQS}" align "${BASE}/prof_res" "${DB1}.idx" "${BASE}/res_exp" "${BASE}/res_exp_realign" --db-load-mode 2 -e ${ALIGN_EVAL} --max-accept ${MAX_ACCEPT} --alt-ali 10 -a
+"${MMSEQS}" filterresult "${BASE}/qdb" "${DB1}.idx" "${BASE}/res_exp_realign" "${BASE}/res_exp_realign_filter" --db-load-mode 2 --qid 0 --qsc $QSC --diff 0 --max-seq-id 1.0 --filter-min-enable 100
+"${MMSEQS}" result2msa "${BASE}/qdb" "${DB1}.idx" "${BASE}/res_exp_realign_filter" "${BASE}/uniref.a3m" --msa-format-mode 6 --db-load-mode 2 ${FILTER_PARAM}
+"${MMSEQS}" rmdb "${BASE}/res_exp_realign"
+"${MMSEQS}" rmdb "${BASE}/res_exp"
+"${MMSEQS}" rmdb "${BASE}/res"
+"${MMSEQS}" rmdb "${BASE}/res_exp_realign_filter"
+if [ "${USE_TEMPLATES}" = "1" ]; then
+ "${MMSEQS}" search "${BASE}/prof_res" "${DB2}" "${BASE}/res_pdb" "${BASE}/tmp" --db-load-mode 2 -s 7.5 -a -e 0.1
+ echo "-----------------------in here"
+ echo "${BASE}/${DB2}.m8"
+ "${MMSEQS}" convertalis "${BASE}/prof_res" "${DB2}.idx" "${BASE}/res_pdb" "${BASE}/${DB2}.m8" --format-output query,target,fident,alnlen,mismatch,gapopen,qstart,qend,tstart,tend,evalue,bits,cigar --db-load-mode 2
+ "${MMSEQS}" rmdb "${BASE}/res_pdb"
+fi
+if [ "${USE_ENV}" = "1" ]; then
+ "${MMSEQS}" search "${BASE}/prof_res" "${DB3}" "${BASE}/res_env" "${BASE}/tmp" $SEARCH_PARAM
+ "${MMSEQS}" expandaln "${BASE}/prof_res" "${DB3}.idx" "${BASE}/res_env" "${DB3}.idx" "${BASE}/res_env_exp" -e ${EXPAND_EVAL} --expansion-mode 0 --db-load-mode 2
+ "${MMSEQS}" align "${BASE}/tmp/latest/profile_1" "${DB3}.idx" "${BASE}/res_env_exp" "${BASE}/res_env_exp_realign" --db-load-mode 2 -e ${ALIGN_EVAL} --max-accept ${MAX_ACCEPT} --alt-ali 10 -a
+ "${MMSEQS}" filterresult "${BASE}/qdb" "${DB3}.idx" "${BASE}/res_env_exp_realign" "${BASE}/res_env_exp_realign_filter" --db-load-mode 2 --qid 0 --qsc $QSC --diff 0 --max-seq-id 1.0 --filter-min-enable 100
+ "${MMSEQS}" result2msa "${BASE}/qdb" "${DB3}.idx" "${BASE}/res_env_exp_realign_filter" "${BASE}/bfd.mgnify30.metaeuk30.smag30.a3m" --msa-format-mode 6 --db-load-mode 2 ${FILTER_PARAM}
+ "${MMSEQS}" rmdb "${BASE}/res_env_exp_realign_filter"
+ "${MMSEQS}" rmdb "${BASE}/res_env_exp_realign"
+fi
+
+time=$(date )
+echo "${time}"
diff --git a/reproduce/AlphaFold2-Chinese/data/tools/parsers.py b/reproduce/AlphaFold2-Chinese/data/tools/parsers.py
new file mode 100644
index 0000000..6d08fcb
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/data/tools/parsers.py
@@ -0,0 +1,389 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+'''parsers'''
+import collections
+import re
+import string
+from typing import Iterable, List, Optional, Sequence, Tuple
+
+import dataclasses
+
+DeletionMatrix = Sequence[Sequence[int]]
+
+
+@dataclasses.dataclass(frozen=True)
+class HhrHit:
+ """Class representing a hit in an hhr file."""
+ index: int
+ name: str
+ prob_true: float
+ e_value: float
+ score: float
+ aligned_cols: int
+ identity: float
+ similarity: float
+ sum_probs: float
+ neff: float
+ query: str
+ hit_sequence: str
+ hit_dssp: str
+ column_score_code: str
+ confidence_scores: str
+ indices_query: List[int]
+ indices_hit: List[int]
+
+
+def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
+ """Parses FASTA string and returns list of strings with amino-acid sequences.
+
+ Arguments:
+ fasta_string: The string contents of a FASTA file.
+
+ Returns:
+ A tuple of two lists:
+ * A list of sequences.
+ * A list of sequence descriptions taken from the comment lines. In the
+ same order as the sequences.
+ """
+ sequences = []
+ descriptions = []
+ index = -1
+ for line in fasta_string.splitlines():
+ line = line.strip()
+ if line.startswith('>'):
+ index += 1
+ descriptions.append(line[1:]) # Remove the '>' at the beginning.
+ sequences.append('')
+ continue
+ elif not line:
+ continue # Skip blank lines.
+ sequences[index] += line
+
+ return sequences, descriptions
+
+
+def parse_stockholm(
+ stockholm_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
+ """Parses sequences and deletion matrix from stockholm format alignment.
+
+ Args:
+ stockholm_string: The string contents of a stockholm file. The first
+ sequence in the file should be the query sequence.
+
+ Returns:
+ A tuple of:
+ * A list of sequences that have been aligned to the query. These
+ might contain duplicates.
+ * The deletion matrix for the alignment as a list of lists. The element
+ at `deletion_matrix[i][j]` is the number of residues deleted from
+ the aligned sequence i at residue position j.
+ """
+ name_to_sequence = collections.OrderedDict()
+ for line in stockholm_string.splitlines():
+ line = line.strip()
+ if not line or line.startswith(('#', '//')):
+ continue
+ name, sequence = line.split()
+ if name not in name_to_sequence:
+ name_to_sequence[name] = ''
+ name_to_sequence[name] += sequence
+
+ msa = []
+ deletion_matrix = []
+
+ query = ''
+ keep_columns = []
+ for seq_index, sequence in enumerate(name_to_sequence.values()):
+ if seq_index == 0:
+ # Gather the columns with gaps from the query
+ query = sequence
+ keep_columns = [i for i, res in enumerate(query) if res != '-']
+
+ # Remove the columns with gaps in the query from all sequences.
+ aligned_sequence = ''.join([sequence[c] for c in keep_columns])
+
+ msa.append(aligned_sequence)
+
+ # Count the number of deletions w.r.t. query.
+ deletion_vec = []
+ deletion_count = 0
+ for seq_res, query_res in zip(sequence, query):
+ if seq_res != '-' or query_res != '-':
+ if query_res == '-':
+ deletion_count += 1
+ else:
+ deletion_vec.append(deletion_count)
+ deletion_count = 0
+ deletion_matrix.append(deletion_vec)
+
+ return msa, deletion_matrix
+
+
+def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
+ """Parses sequences and deletion matrix from a3m format alignment.
+
+ Args:
+ a3m_string: The string contents of a a3m file. The first sequence in the
+ file should be the query sequence.
+
+ Returns:
+ A tuple of:
+ * A list of sequences that have been aligned to the query. These
+ might contain duplicates.
+ * The deletion matrix for the alignment as a list of lists. The element
+ at `deletion_matrix[i][j]` is the number of residues deleted from
+ the aligned sequence i at residue position j.
+ """
+ sequences, _ = parse_fasta(a3m_string)
+ deletion_matrix = []
+ for msa_sequence in sequences:
+ deletion_vec = []
+ deletion_count = 0
+ for j in msa_sequence:
+ if j.islower():
+ deletion_count += 1
+ else:
+ deletion_vec.append(deletion_count)
+ deletion_count = 0
+ deletion_matrix.append(deletion_vec)
+
+ # Make the MSA matrix out of aligned (deletion-free) sequences.
+ deletion_table = str.maketrans('', '', string.ascii_lowercase)
+ aligned_sequences = [s.translate(deletion_table) for s in sequences]
+ return aligned_sequences, deletion_matrix
+
+
+def _convert_sto_seq_to_a3m(
+ query_non_gaps: Sequence[bool], sto_seq: str) -> Iterable[str]:
+ for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq):
+ if is_query_res_non_gap:
+ yield sequence_res
+ elif sequence_res != '-':
+ yield sequence_res.lower()
+
+
+def convert_stockholm_to_a3m(stockholm_format: str,
+ max_sequences: Optional[int] = None) -> str:
+ """Converts MSA in Stockholm format to the A3M format."""
+ descriptions = {}
+ sequences = {}
+ reached_max_sequences = False
+
+ for line in stockholm_format.splitlines():
+ reached_max_sequences = max_sequences and len(sequences) >= max_sequences
+
+ if line.strip() and not line.startswith(('#', '//')):
+ # Ignore blank lines, markup and end symbols - remainder are alignment
+ # sequence parts.
+ seqname, aligned_seq = line.split(maxsplit=1)
+ if seqname not in sequences:
+ if reached_max_sequences:
+ continue
+ sequences[seqname] = ''
+ sequences[seqname] += aligned_seq
+
+ for line in stockholm_format.splitlines():
+ if line[:4] == '#=GS':
+ # Description row - example format is:
+ # #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ...
+ columns = line.split(maxsplit=3)
+ seqname, feature = columns[1:3]
+ value = columns[3] if len(columns) == 4 else ''
+ if feature != 'DE':
+ continue
+ if reached_max_sequences and seqname not in sequences:
+ continue
+ descriptions[seqname] = value
+ if len(descriptions) == len(sequences):
+ break
+
+ # Convert sto format to a3m line by line
+ a3m_sequences = {}
+ # query_sequence is assumed to be the first sequence
+ query_sequence = next(iter(sequences.values()))
+ query_non_gaps = [res != '-' for res in query_sequence]
+ for seqname, sto_sequence in sequences.items():
+ a3m_sequences[seqname] = ''.join(
+ _convert_sto_seq_to_a3m(query_non_gaps, sto_sequence))
+
+ fasta_chunks = (f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}"
+ for k in a3m_sequences)
+ return '\n'.join(fasta_chunks) + '\n' # Include terminating newline.
+
+
+def _get_hhr_line_regex_groups(
+ regex_pattern: str, line: str) -> Sequence[Optional[str]]:
+ match = re.match(regex_pattern, line)
+ if match is None:
+ raise RuntimeError(f'Could not parse query line {line}')
+ return match.groups()
+
+
+def _update_hhr_residue_indices_list(
+ sequence: str, start_index: int, indices_list: List[int]):
+ """Computes the relative indices for each residue with respect to the original sequence."""
+ counter = start_index
+ for symbol in sequence:
+ if symbol == '-':
+ indices_list.append(-1)
+ else:
+ indices_list.append(counter)
+ counter += 1
+
+
+def _parse_hhr_hit(detailed_lines: Sequence[str]) -> HhrHit:
+ """Parses the detailed HMM HMM comparison section for a single Hit.
+
+ This works on .hhr files generated from both HHBlits and HHSearch.
+
+ Args:
+ detailed_lines: A list of lines from a single comparison section between 2
+ sequences (which each have their own HMM's)
+
+ Returns:
+ A dictionary with the information from that detailed comparison section
+
+ Raises:
+ RuntimeError: If a certain line cannot be processed
+ """
+ # Parse first 2 lines.
+ number_of_hit = int(detailed_lines[0].split()[-1])
+ name_hit = detailed_lines[1][1:]
+
+ # Parse the summary line.
+ pattern = (
+ 'Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t'
+ ' ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t '
+ ']*Template_Neff=(.*)')
+ match = re.match(pattern, detailed_lines[2])
+ if match is None:
+ raise RuntimeError(
+ 'Could not parse section: %s. Expected this: \n%s to contain summary.' %
+ (detailed_lines, detailed_lines[2]))
+ (prob_true, e_value, score, aligned_cols, identity, similarity, sum_probs,
+ neff) = [float(x) for x in match.groups()]
+
+ # The next section reads the detailed comparisons. These are in a 'human
+ # readable' format which has a fixed length. The strategy employed is to
+ # assume that each block starts with the query sequence line, and to parse
+ # that with a regexp in order to deduce the fixed length used for that
+ # block.
+ query = ''
+ hit_sequence = ''
+ hit_dssp = ''
+ column_score_code = ''
+ confidence_scores = ''
+ indices_query = []
+ indices_hit = []
+ length_block = None
+
+ for line in detailed_lines[3:]:
+ # Parse the query sequence line
+ if (line.startswith('Q ') and not line.startswith('Q ss_dssp') and not line.startswith('Q ss_pred') \
+ and not line.startswith('Q Consensus')):
+ # Thus the first 17 characters must be 'Q ', and we can parse
+ # everything after that.
+ # start sequence end total_sequence_length
+ patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)'
+ groups = _get_hhr_line_regex_groups(patt, line[17:])
+
+ # Get the length of the parsed block using the start and finish indices,
+ # and ensure it is the same as the actual block length.
+ start = int(groups[0]) - 1 # Make index zero based.
+ delta_query = groups[1]
+ end = int(groups[2])
+ num_insertions = len([x for x in delta_query if x == '-'])
+ length_block = end - start + num_insertions
+ assert length_block == len(delta_query)
+
+ # Update the query sequence and indices list.
+ query += delta_query
+ _update_hhr_residue_indices_list(delta_query, start, indices_query)
+
+ elif line.startswith('T '):
+ # Parse the hit dssp line.
+ if line.startswith('T ss_dssp'):
+ # T ss_dssp hit_dssp
+ patt = r'T ss_dssp[\t ]*([A-Z-]*)'
+ groups = _get_hhr_line_regex_groups(patt, line)
+ assert len(groups[0]) == length_block
+ hit_dssp += groups[0]
+
+ # Parse the hit sequence.
+ elif (not line.startswith('T ss_pred') and
+ not line.startswith('T Consensus')):
+ # Thus the first 17 characters must be 'T ', and we can
+ # parse everything after that.
+ # start sequence end total_sequence_length
+ patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)'
+ groups = _get_hhr_line_regex_groups(patt, line[17:])
+ start = int(groups[0]) - 1 # Make index zero based.
+ delta_hit_sequence = groups[1]
+ assert length_block == len(delta_hit_sequence)
+
+ # Update the hit sequence and indices list.
+ hit_sequence += delta_hit_sequence
+ _update_hhr_residue_indices_list(
+ delta_hit_sequence, start, indices_hit)
+
+ # Parse the column score line.
+ elif line.startswith(' ' * 22):
+ assert length_block
+ column_score_code += line[22:length_block + 22]
+
+ # Update confidence score.
+ elif line.startswith('Confidence'):
+ assert length_block
+ confidence_scores += line[22:length_block + 22]
+
+ return HhrHit(
+ index=number_of_hit,
+ name=name_hit,
+ prob_true=prob_true,
+ e_value=e_value,
+ score=score,
+ aligned_cols=int(aligned_cols),
+ identity=identity,
+ similarity=similarity,
+ sum_probs=sum_probs,
+ neff=neff,
+ query=query,
+ hit_sequence=hit_sequence,
+ hit_dssp=hit_dssp,
+ column_score_code=column_score_code,
+ confidence_scores=confidence_scores,
+ indices_query=indices_query,
+ indices_hit=indices_hit,
+ )
+
+
+def parse_hhr(hhr_string: str) -> Sequence[HhrHit]:
+ """Parses the content of an entire HHR file."""
+ lines = hhr_string.splitlines()
+
+ # Each .hhr file starts with a results table, then has a sequence of hit
+ # "paragraphs", each paragraph starting with a line 'No '. We
+ # iterate through each paragraph to parse each hit.
+
+ block_starts = [i for i, line in enumerate(
+ lines) if line.startswith('No ')]
+
+ hits = []
+ if block_starts:
+ block_starts.append(len(lines)) # Add the end of the final block.
+ for i in range(len(block_starts) - 1):
+ hits.append(_parse_hhr_hit(
+ lines[block_starts[i]:block_starts[i + 1]]))
+ return hits
diff --git a/reproduce/AlphaFold2-Chinese/data/tools/templates.py b/reproduce/AlphaFold2-Chinese/data/tools/templates.py
new file mode 100644
index 0000000..b548787
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/data/tools/templates.py
@@ -0,0 +1,999 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+'''template'''
+import datetime
+import glob
+import os
+import re
+import dataclasses
+from typing import Any, Dict, Mapping, Optional, Sequence, Tuple
+from absl import logging
+import numpy as np
+
+import commons.residue_constants as residue_constants
+import data.tools.mmcif_parsing as mmcif_parsing
+import data.tools.parsers as parsers
+from data.tools.data_tools import Kalign
+
+
+class Error(Exception):
+ """Base class for exceptions."""
+
+
+class NoChainsError(Error):
+ """An error indicating that template mmCIF didn't have any chains."""
+
+
+class SequenceNotInTemplateError(Error):
+ """An error indicating that template mmCIF didn't contain the sequence."""
+
+
+class NoAtomDataInTemplateError(Error):
+ """An error indicating that template mmCIF didn't contain atom positions."""
+
+
+class TemplateAtomMaskAllZerosError(Error):
+ """An error indicating that template mmCIF had all atom positions masked."""
+
+
+class QueryToTemplateAlignError(Error):
+ """An error indicating that the query can't be aligned to the template."""
+
+
+class CaDistanceError(Error):
+ """An error indicating that a CA atom distance exceeds a threshold."""
+
+
+class MultipleChainsError(Error):
+ """An error indicating that multiple chains were found for a given ID."""
+
+
+# Prefilter exceptions.
+class PrefilterError(Exception):
+ """A base class for template prefilter exceptions."""
+
+
+class DateError(PrefilterError):
+ """An error indicating that the hit date was after the max allowed date."""
+
+
+class PdbIdError(PrefilterError):
+ """An error indicating that the hit PDB ID was identical to the query."""
+
+
+class AlignRatioError(PrefilterError):
+ """An error indicating that the hit align ratio to the query was too small."""
+
+
+class DuplicateError(PrefilterError):
+ """An error indicating that the hit was an exact subsequence of the query."""
+
+
+class LengthError(PrefilterError):
+ """An error indicating that the hit was too short."""
+
+
+TEMPLATE_FEATURES = {
+ 'template_aatype': np.float32,
+ 'template_all_atom_masks': np.float32,
+ 'template_all_atom_positions': np.float32,
+ 'template_domain_names': np.object,
+ 'template_e_value': np.float32,
+ 'template_neff': np.float32,
+ 'template_prob_true': np.float32,
+ 'template_release_date': np.object,
+ 'template_score': np.float32,
+ 'template_similarity': np.float32,
+ 'template_sequence': np.object,
+ 'template_sum_probs': np.float32,
+ 'template_confidence_scores': np.int64
+}
+
+
+def _get_pdb_id_and_chain(hit: parsers.HhrHit) -> Tuple[str, str]:
+ """Returns PDB id and chain id for an HHSearch Hit."""
+ # PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown.
+ id_match = re.match(r'[a-zA-Z\d]{4}_[a-zA-Z0-9.]+', hit.name)
+ if not id_match:
+ raise ValueError(
+ f'hit.name did not start with PDBID_chain: {hit.name}')
+ pdb_id, chain_id = id_match.group(0).split('_')
+ return pdb_id.lower(), chain_id
+
+
+def _is_after_cutoff(
+ pdb_id: str,
+ release_dates: Mapping[str, datetime.datetime],
+ release_date_cutoff: Optional[datetime.datetime]) -> bool:
+ """Checks if the template date is after the release date cutoff.
+
+ Args:
+ pdb_id: 4 letter pdb code.
+ release_dates: Dictionary mapping PDB ids to their structure release dates.
+ release_date_cutoff: Max release date that is valid for this query.
+
+ Returns:
+ True if the template release date is after the cutoff, False otherwise.
+ """
+ if release_date_cutoff is None:
+ raise ValueError('The release_date_cutoff must not be None.')
+ if pdb_id in release_dates:
+ return release_dates[pdb_id] > release_date_cutoff
+ logging.warning('Template structure not in release dates dict: %s', pdb_id)
+ return False
+
+
+def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
+ """Parses the data file from PDB that lists which PDB ids are obsolete."""
+ with open(obsolete_file_path) as f:
+ result = {}
+ for line in f:
+ line = line.strip()
+ # We skip obsolete entries that don't contain a mapping to a new
+ # entry.
+ if line.startswith('OBSLTE') and len(line) > 30:
+ # Format: Date From To
+ # 'OBSLTE 31-JUL-94 116L 216L'
+ from_id = line[20:24].lower()
+ to_id = line[29:33].lower()
+ result[from_id] = to_id
+ return result
+
+
+def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]:
+ """Parses release dates file, returns a mapping from PDBs to release dates."""
+ if path.endswith('txt'):
+ release_dates = {}
+ with open(path, 'r') as f:
+ for line in f:
+ pdb_id, date = line.split(':')
+ date = date.strip()
+ # Python 3.6 doesn't have datetime.date.fromisoformat() which is about
+ # 90x faster than strptime. However, splitting the string manually is
+ # about 10x faster than strptime.
+ release_dates[pdb_id.strip()] = datetime.datetime(
+ year=int(date[:4]), month=int(date[5:7]), day=int(date[8:10]))
+ return release_dates
+ raise ValueError('Invalid format of the release date file %s.' % path)
+
+
+def _assess_hhsearch_hit(
+ hit: parsers.HhrHit,
+ hit_pdb_code: str,
+ query_sequence: str,
+ query_pdb_code: Optional[str],
+ release_dates: Mapping[str, datetime.datetime],
+ release_date_cutoff: datetime.datetime,
+ max_subsequence_ratio: float = 0.95,
+ min_align_ratio: float = 0.1) -> bool:
+ """Determines if template is valid (without parsing the template mmcif file).
+
+ Args:
+ hit: HhrHit for the template.
+ hit_pdb_code: The 4 letter pdb code of the template hit. This might be
+ different from the value in the actual hit since the original pdb might
+ have become obsolete.
+ query_sequence: Amino acid sequence of the query.
+ query_pdb_code: 4 letter pdb code of the query.
+ release_dates: Dictionary mapping pdb codes to their structure release
+ dates.
+ release_date_cutoff: Max release date that is valid for this query.
+ max_subsequence_ratio: Exclude any exact matches with this much overlap.
+ min_align_ratio: Minimum overlap between the template and query.
+
+ Returns:
+ True if the hit passed the prefilter. Raises an exception otherwise.
+
+ Raises:
+ DateError: If the hit date was after the max allowed date.
+ PdbIdError: If the hit PDB ID was identical to the query.
+ AlignRatioError: If the hit align ratio to the query was too small.
+ DuplicateError: If the hit was an exact subsequence of the query.
+ LengthError: If the hit was too short.
+ """
+ aligned_cols = hit.aligned_cols
+ align_ratio = aligned_cols / len(query_sequence)
+
+ template_sequence = hit.hit_sequence.replace('-', '')
+ length_ratio = float(len(template_sequence)) / len(query_sequence)
+
+ # Check whether the template is a large subsequence or duplicate of original
+ # query. This can happen due to duplicate entries in the PDB database.
+ duplicate = (template_sequence in query_sequence and
+ length_ratio > max_subsequence_ratio)
+
+ if _is_after_cutoff(hit_pdb_code, release_dates, release_date_cutoff):
+ raise DateError(
+ f'Date ({release_dates[hit_pdb_code]}) > max template date '
+ f'({release_date_cutoff}).')
+
+ if query_pdb_code is not None:
+ if query_pdb_code.lower() == hit_pdb_code.lower():
+ raise PdbIdError('PDB code identical to Query PDB code.')
+
+ if align_ratio <= min_align_ratio:
+ raise AlignRatioError(
+ 'Proportion of residues aligned to query too small. '
+ f'Align ratio: {align_ratio}.')
+
+ if duplicate:
+ raise DuplicateError(
+ 'Template is an exact subsequence of query with large '
+ f'coverage. Length ratio: {length_ratio}.')
+
+ if len(template_sequence) < 10:
+ raise LengthError(
+ f'Template too short. Length: {len(template_sequence)}.')
+
+ return True
+
+
+def _find_template_in_pdb(
+ template_chain_id: str,
+ template_sequence: str,
+ mmcif_object: mmcif_parsing.MmcifObject) -> Tuple[str, str, int]:
+ """Tries to find the template chain in the given pdb file.
+
+ This method tries the three following things in order:
+ 1. Tries if there is an exact match in both the chain ID and the sequence.
+ If yes, the chain sequence is returned. Otherwise:
+ 2. Tries if there is an exact match only in the sequence.
+ If yes, the chain sequence is returned. Otherwise:
+ 3. Tries if there is a fuzzy match (X = wildcard) in the sequence.
+ If yes, the chain sequence is returned.
+ If none of these succeed, a SequenceNotInTemplateError is thrown.
+
+ Args:
+ template_chain_id: The template chain ID.
+ template_sequence: The template chain sequence.
+ mmcif_object: The PDB object to search for the template in.
+
+ Returns:
+ A tuple with:
+ * The chain sequence that was found to match the template in the PDB object.
+ * The ID of the chain that is being returned.
+ * The offset where the template sequence starts in the chain sequence.
+
+ Raises:
+ SequenceNotInTemplateError: If no match is found after the steps described
+ above.
+ """
+ # Try if there is an exact match in both the chain ID and the
+ # (sub)sequence.
+ pdb_id = mmcif_object.file_id
+ chain_sequence = mmcif_object.chain_to_seqres.get(template_chain_id)
+ if chain_sequence and (template_sequence in chain_sequence):
+ logging.info(
+ 'Found an exact template match %s_%s.', pdb_id, template_chain_id)
+ mapping_offset = chain_sequence.find(template_sequence)
+ return chain_sequence, template_chain_id, mapping_offset
+
+ # Try if there is an exact match in the (sub)sequence only.
+ for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items():
+ if chain_sequence and (template_sequence in chain_sequence):
+ logging.info(
+ 'Found a sequence-only match %s_%s.',
+ pdb_id,
+ chain_id)
+ mapping_offset = chain_sequence.find(template_sequence)
+ return chain_sequence, chain_id, mapping_offset
+
+ # Return a chain sequence that fuzzy matches (X = wildcard) the template.
+ # Make parentheses unnamed groups (?:_) to avoid the 100 named groups
+ # limit.
+ regex = ['.' if aa == 'X' else '(?:%s|X)' % aa for aa in template_sequence]
+ regex = re.compile(''.join(regex))
+ for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items():
+ match = re.search(regex, chain_sequence)
+ if match:
+ logging.info(
+ 'Found a fuzzy sequence-only match %s_%s.',
+ pdb_id,
+ chain_id)
+ mapping_offset = match.start()
+ return chain_sequence, chain_id, mapping_offset
+
+ # No hits, raise an error.
+ raise SequenceNotInTemplateError(
+ 'Could not find the template sequence in %s_%s. Template sequence: %s, '
+ 'chain_to_seqres: %s' % (pdb_id, template_chain_id, template_sequence,
+ mmcif_object.chain_to_seqres))
+
+
+def _realign_pdb_template_to_query(
+ old_template_sequence: str,
+ template_chain_id: str,
+ mmcif_object: mmcif_parsing.MmcifObject,
+ old_mapping: Mapping[int, int],
+ kalign_binary_path: str) -> Tuple[str, Mapping[int, int]]:
+ """Aligns template from the mmcif_object to the query.
+
+ In case PDB70 contains a different version of the template sequence, we need
+ to perform a realignment to the actual sequence that is in the mmCIF file.
+ This method performs such realignment, but returns the new sequence and
+ mapping only if the sequence in the mmCIF file is 90% identical to the old
+ sequence.
+
+ Note that the old_template_sequence comes from the hit, and contains only that
+ part of the chain that matches with the query while the new_template_sequence
+ is the full chain.
+
+ Args:
+ old_template_sequence: The template sequence that was returned by the PDB
+ template search (typically done using HHSearch).
+ template_chain_id: The template chain id was returned by the PDB template
+ search (typically done using HHSearch). This is used to find the right
+ chain in the mmcif_object chain_to_seqres mapping.
+ mmcif_object: A mmcif_object which holds the actual template data.
+ old_mapping: A mapping from the query sequence to the template sequence.
+ This mapping will be used to compute the new mapping from the query
+ sequence to the actual mmcif_object template sequence by aligning the
+ old_template_sequence and the actual template sequence.
+ kalign_binary_path: The path to a kalign executable.
+
+ Returns:
+ A tuple (new_template_sequence, new_query_to_template_mapping) where:
+ * new_template_sequence is the actual template sequence that was found in
+ the mmcif_object.
+ * new_query_to_template_mapping is the new mapping from the query to the
+ actual template found in the mmcif_object.
+
+ Raises:
+ QueryToTemplateAlignError:
+ * If there was an error thrown by the alignment tool.
+ * Or if the actual template sequence differs by more than 10% from the
+ old_template_sequence.
+ """
+ aligner = Kalign(binary_path=kalign_binary_path)
+ new_template_sequence = mmcif_object.chain_to_seqres.get(
+ template_chain_id, '')
+
+ # Sometimes the template chain id is unknown. But if there is only a single
+ # sequence within the mmcif_object, it is safe to assume it is that one.
+ if not new_template_sequence:
+ if len(mmcif_object.chain_to_seqres) == 1:
+ logging.info(
+ 'Could not find %s in %s, but there is only 1 sequence, so '
+ 'using that one.',
+ template_chain_id,
+ mmcif_object.file_id)
+ new_template_sequence = list(
+ mmcif_object.chain_to_seqres.values())[0]
+ else:
+ raise QueryToTemplateAlignError(
+ f'Could not find chain {template_chain_id} in {mmcif_object.file_id}. '
+ 'If there are no mmCIF parsing errors, it is possible it was not a '
+ 'protein chain.')
+
+ try:
+ (old_aligned_template, new_aligned_template), _ = parsers.parse_a3m(
+ aligner.align([old_template_sequence, new_template_sequence]))
+ except Exception as e:
+ raise QueryToTemplateAlignError(
+ 'Could not align old template %s to template %s (%s_%s). Error: %s' %
+ (old_template_sequence,
+ new_template_sequence,
+ mmcif_object.file_id,
+ template_chain_id,
+ str(e)))
+
+ logging.info('Old aligned template: %s\nNew aligned template: %s',
+ old_aligned_template, new_aligned_template)
+
+ old_to_new_template_mapping = {}
+ old_template_index = -1
+ new_template_index = -1
+ num_same = 0
+ for old_template_aa, new_template_aa in zip(
+ old_aligned_template, new_aligned_template):
+ if old_template_aa != '-':
+ old_template_index += 1
+ if new_template_aa != '-':
+ new_template_index += 1
+ if old_template_aa != '-' and new_template_aa != '-':
+ old_to_new_template_mapping[old_template_index] = new_template_index
+ if old_template_aa == new_template_aa:
+ num_same += 1
+
+ # Require at least 90 % sequence identity wrt to the shorter of the
+ # sequences.
+ if float(num_same) / min(
+ len(old_template_sequence), len(new_template_sequence)) < 0.9:
+ raise QueryToTemplateAlignError(
+ 'Insufficient similarity of the sequence in the database: %s to the '
+ 'actual sequence in the mmCIF file %s_%s: %s. We require at least '
+ '90 %% similarity wrt to the shorter of the sequences. This is not a '
+ 'problem unless you think this is a template that should be included.' %
+ (old_template_sequence, mmcif_object.file_id, template_chain_id,
+ new_template_sequence))
+
+ new_query_to_template_mapping = {}
+ for query_index, old_template_index in old_mapping.items():
+ new_query_to_template_mapping[query_index] = (
+ old_to_new_template_mapping.get(old_template_index, -1))
+
+ new_template_sequence = new_template_sequence.replace('-', '')
+
+ return new_template_sequence, new_query_to_template_mapping
+
+
+def _check_residue_distances(all_positions: np.ndarray,
+ all_positions_mask: np.ndarray,
+ max_ca_ca_distance: float):
+ """Checks if the distance between unmasked neighbor residues is ok."""
+ ca_position = residue_constants.atom_order['CA']
+ prev_is_unmasked = False
+ prev_calpha = None
+ for i, (coords, mask) in enumerate(zip(all_positions, all_positions_mask)):
+ this_is_unmasked = bool(mask[ca_position])
+ if this_is_unmasked:
+ this_calpha = coords[ca_position]
+ if prev_is_unmasked:
+ distance = np.linalg.norm(this_calpha - prev_calpha)
+ if distance > max_ca_ca_distance:
+ raise CaDistanceError(
+ 'The distance between residues %d and %d is %f > limit %f.' %
+ (i, i + 1, distance, max_ca_ca_distance))
+ prev_calpha = this_calpha
+ prev_is_unmasked = this_is_unmasked
+
+
+def _get_atom_positions(
+ mmcif_object: mmcif_parsing.MmcifObject,
+ auth_chain_id: str,
+ max_ca_ca_distance: float) -> Tuple[np.ndarray, np.ndarray]:
+ """Gets atom positions and mask from a list of Biopython Residues."""
+ num_res = len(mmcif_object.chain_to_seqres[auth_chain_id])
+
+ relevant_chains = [c for c in mmcif_object.structure.get_chains()
+ if c.id == auth_chain_id]
+ if len(relevant_chains) != 1:
+ raise MultipleChainsError(
+ f'Expected exactly one chain in structure with id {auth_chain_id}.')
+ chain = relevant_chains[0]
+
+ all_positions = np.zeros([num_res, residue_constants.atom_type_num, 3])
+ all_positions_mask = np.zeros([num_res, residue_constants.atom_type_num],
+ dtype=np.int64)
+ for res_index in range(num_res):
+ pos = np.zeros([residue_constants.atom_type_num, 3], dtype=np.float32)
+ mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32)
+ res_at_position = mmcif_object.seqres_to_structure[auth_chain_id][res_index]
+ if not res_at_position.is_missing:
+ res = chain[(res_at_position.hetflag,
+ res_at_position.position.residue_number,
+ res_at_position.position.insertion_code)]
+ for atom in res.get_atoms():
+ atom_name = atom.get_name()
+ x, y, z = atom.get_coord()
+ if atom_name in residue_constants.atom_order.keys():
+ pos[residue_constants.atom_order[atom_name]] = [x, y, z]
+ mask[residue_constants.atom_order[atom_name]] = 1.0
+ elif atom_name.upper() == 'SE' and res.get_resname() == 'MSE':
+ # Put the coordinates of the selenium atom in the sulphur
+ # column.
+ pos[residue_constants.atom_order['SD']] = [x, y, z]
+ mask[residue_constants.atom_order['SD']] = 1.0
+
+ all_positions[res_index] = pos
+ all_positions_mask[res_index] = mask
+ _check_residue_distances(
+ all_positions, all_positions_mask, max_ca_ca_distance)
+ return all_positions, all_positions_mask
+
+
+def _extract_template_features(
+ mmcif_object: mmcif_parsing.MmcifObject,
+ pdb_id: str,
+ mapping: Mapping[int, int],
+ template_sequence: str,
+ query_sequence: str,
+ template_chain_id: str,
+ confidence_scores: str,
+ kalign_binary_path: str) -> Tuple[Dict[str, Any], Optional[str]]:
+ """Parses atom positions in the target structure and aligns with the query.
+
+ Atoms for each residue in the template structure are indexed to coincide
+ with their corresponding residue in the query sequence, according to the
+ alignment mapping provided.
+
+ Note that we only extract at most 500 templates because of HHSearch settings.
+
+ We set missing/invalid confidence scores to the default value of -1.
+ Note: We now have 4 types of confidence scores:
+ 1. Valid scores
+ 2. Invalid scores of residues not in both the query sequence and template
+ sequence
+ 3. Missing scores because we don't have the secondary structure, and HHAlign
+ doesn't produce the posterior probabilities in this case.
+ 4. Missing scores because of a different template sequence in PDB70,
+ invalidating the previously computed confidence scores. (Though in theory
+ HHAlign can be run on these to recompute the correct confidence scores).
+ We handle invalid and missing scores by setting them to -1, but consider
+ adding masks for the different types.
+
+ Args:
+ mmcif_object: mmcif_parsing.MmcifObject representing the template.
+ pdb_id: PDB code for the template.
+ mapping: Dictionary mapping indices in the query sequence to indices in
+ the template sequence.
+ template_sequence: String describing the amino acid sequence for the
+ template protein.
+ query_sequence: String describing the amino acid sequence for the query
+ protein.
+ template_chain_id: String ID describing which chain in the structure proto
+ should be used.
+ confidence_scores: String containing per-residue confidence scores, where
+ each character represents the *TRUNCATED* posterior probability that the
+ corresponding template residue is correctly aligned with the query
+ residue, given the database match is correct (0 corresponds approximately
+ to 0-10%, 9 to 90-100%).
+ kalign_binary_path: The path to a kalign executable used for template
+ realignment.
+
+ Returns:
+ A tuple with:
+ * A dictionary containing the extra features derived from the template
+ protein structure.
+ * A warning message if the hit was realigned to the actual mmCIF sequence.
+ Otherwise None.
+
+ Raises:
+ NoChainsError: If the mmcif object doesn't contain any chains.
+ SequenceNotInTemplateError: If the given chain id / sequence can't
+ be found in the mmcif object.
+ QueryToTemplateAlignError: If the actual template in the mmCIF file
+ can't be aligned to the query.
+ NoAtomDataInTemplateError: If the mmcif object doesn't contain
+ atom positions.
+ TemplateAtomMaskAllZerosError: If the mmcif object doesn't have any
+ unmasked residues.
+ """
+ if mmcif_object is None or not mmcif_object.chain_to_seqres:
+ raise NoChainsError(
+ 'No chains in PDB: %s_%s' %
+ (pdb_id, template_chain_id))
+
+ warning = None
+ try:
+ seqres, chain_id, mapping_offset = _find_template_in_pdb(
+ template_chain_id=template_chain_id,
+ template_sequence=template_sequence,
+ mmcif_object=mmcif_object)
+ except SequenceNotInTemplateError:
+ # If PDB70 contains a different version of the template, we use the sequence
+ # from the mmcif_object.
+ chain_id = template_chain_id
+ warning = (
+ f'The exact sequence {template_sequence} was not found in '
+ f'{pdb_id}_{chain_id}. Realigning the template to the actual sequence.')
+ logging.warning(warning)
+ # This throws an exception if it fails to realign the hit.
+ seqres, mapping = _realign_pdb_template_to_query(
+ old_template_sequence=template_sequence,
+ template_chain_id=template_chain_id,
+ mmcif_object=mmcif_object,
+ old_mapping=mapping,
+ kalign_binary_path=kalign_binary_path)
+ logging.info('Sequence in %s_%s: %s successfully realigned to %s',
+ pdb_id, chain_id, template_sequence, seqres)
+ # The template sequence changed.
+ template_sequence = seqres
+ # No mapping offset, the query is aligned to the actual sequence.
+ mapping_offset = 0
+ # Confidence scores were based on the previous sequence, so they are
+ # invalid
+ confidence_scores = None
+
+ try:
+ # Essentially set to infinity - we don't want to reject templates unless
+ # they're really really bad.
+ all_atom_positions, all_atom_mask = _get_atom_positions(
+ mmcif_object, chain_id, max_ca_ca_distance=150.0)
+ except (CaDistanceError, KeyError) as ex:
+ raise NoAtomDataInTemplateError(
+ 'Could not get atom data (%s_%s): %s' % (pdb_id, chain_id, str(ex))
+ ) from ex
+
+ all_atom_positions = np.split(
+ all_atom_positions,
+ all_atom_positions.shape[0])
+ all_atom_masks = np.split(all_atom_mask, all_atom_mask.shape[0])
+
+ output_templates_sequence = []
+ output_confidence_scores = []
+ templates_all_atom_positions = []
+ templates_all_atom_masks = []
+
+ for _ in query_sequence:
+ # Residues in the query_sequence that are not in the template_sequence:
+ templates_all_atom_positions.append(
+ np.zeros((residue_constants.atom_type_num, 3)))
+ templates_all_atom_masks.append(
+ np.zeros(residue_constants.atom_type_num))
+ output_templates_sequence.append('-')
+ output_confidence_scores.append(-1)
+
+ for k, v in mapping.items():
+ template_index = v + mapping_offset
+ templates_all_atom_positions[k] = all_atom_positions[template_index][0]
+ templates_all_atom_masks[k] = all_atom_masks[template_index][0]
+ output_templates_sequence[k] = template_sequence[v]
+ if confidence_scores and confidence_scores[v] != ' ':
+ output_confidence_scores[k] = int(confidence_scores[v])
+
+ # Alanine (AA with the lowest number of atoms) has 5 atoms (C, CA, CB, N,
+ # O).
+ if np.sum(templates_all_atom_masks) < 5:
+ raise TemplateAtomMaskAllZerosError(
+ 'Template all atom mask was all zeros: %s_%s. Residue range: %d-%d' %
+ (pdb_id, chain_id, min(mapping.values()) + mapping_offset,
+ max(mapping.values()) + mapping_offset))
+
+ output_templates_sequence = ''.join(output_templates_sequence)
+
+ templates_aatype = residue_constants.sequence_to_onehot(
+ output_templates_sequence, residue_constants.HHBLITS_AA_TO_ID)
+
+ return (
+ {'template_all_atom_positions': np.array(templates_all_atom_positions),
+ 'template_all_atom_masks': np.array(templates_all_atom_masks),
+ 'template_sequence': output_templates_sequence.encode(),
+ 'template_aatype': np.array(templates_aatype),
+ 'template_confidence_scores': np.array(output_confidence_scores),
+ 'template_domain_names': f'{pdb_id.lower()}_{chain_id}'.encode(),
+ 'template_release_date': mmcif_object.header['release_date'].encode()},
+ warning)
+
+
+def _build_query_to_hit_index_mapping(
+ hit_query_sequence: str,
+ hit_sequence: str,
+ indices_hit: Sequence[int],
+ indices_query: Sequence[int],
+ original_query_sequence: str) -> Mapping[int, int]:
+ """Gets mapping from indices in original query sequence to indices in the hit.
+
+ hit_query_sequence and hit_sequence are two aligned sequences containing gap
+ characters. hit_query_sequence contains only the part of the original query
+ sequence that matched the hit. When interpreting the indices from the .hhr, we
+ need to correct for this to recover a mapping from original query sequence to
+ the hit sequence.
+
+ Args:
+ hit_query_sequence: The portion of the query sequence that is in the .hhr
+ hit
+ hit_sequence: The portion of the hit sequence that is in the .hhr
+ indices_hit: The indices for each aminoacid relative to the hit sequence
+ indices_query: The indices for each aminoacid relative to the original query
+ sequence
+ original_query_sequence: String describing the original query sequence.
+
+ Returns:
+ Dictionary with indices in the original query sequence as keys and indices
+ in the hit sequence as values.
+ """
+ # If the hit is empty (no aligned residues), return empty mapping
+ if not hit_query_sequence:
+ return {}
+
+ # Remove gaps and find the offset of hit.query relative to original query.
+ hhsearch_query_sequence = hit_query_sequence.replace('-', '')
+ hit_sequence = hit_sequence.replace('-', '')
+ hhsearch_query_offset = original_query_sequence.find(
+ hhsearch_query_sequence)
+
+ # Index of -1 used for gap characters. Subtract the min index ignoring
+ # gaps.
+ min_idx = min(x for x in indices_hit if x > -1)
+ fixed_indices_hit = [
+ x - min_idx if x > -1 else -1 for x in indices_hit
+ ]
+
+ min_idx = min(x for x in indices_query if x > -1)
+ fixed_indices_query = [
+ x -
+ min_idx if x > -
+ 1 else -
+ 1 for x in indices_query]
+
+ # Zip the corrected indices, ignore case where both seqs have gap
+ # characters.
+ mapping = {}
+ for q_i, q_t in zip(fixed_indices_query, fixed_indices_hit):
+ if q_t != -1 and q_i != -1:
+ if (q_t >= len(hit_sequence) or q_i +
+ hhsearch_query_offset >= len(original_query_sequence)):
+ continue
+ mapping[q_i + hhsearch_query_offset] = q_t
+
+ return mapping
+
+
+@dataclasses.dataclass(frozen=True)
+class SingleHitResult:
+ features: Optional[Mapping[str, Any]]
+ error: Optional[str]
+ warning: Optional[str]
+
+
+def _process_single_hit(
+ query_sequence: str,
+ query_pdb_code: Optional[str],
+ hit: parsers.HhrHit,
+ mmcif_dir: str,
+ max_template_date: datetime.datetime,
+ release_dates: Mapping[str, datetime.datetime],
+ obsolete_pdbs: Mapping[str, str],
+ kalign_binary_path: str,
+ strict_error_check: bool = False) -> SingleHitResult:
+ """Tries to extract template features from a single HHSearch hit."""
+ # Fail hard if we can't get the PDB ID and chain name from the hit.
+ hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit)
+
+ if hit_pdb_code not in release_dates:
+ if hit_pdb_code in obsolete_pdbs:
+ hit_pdb_code = obsolete_pdbs[hit_pdb_code]
+
+ # Pass hit_pdb_code since it might have changed due to the pdb being
+ # obsolete.
+ try:
+ _assess_hhsearch_hit(
+ hit=hit,
+ hit_pdb_code=hit_pdb_code,
+ query_sequence=query_sequence,
+ query_pdb_code=query_pdb_code,
+ release_dates=release_dates,
+ release_date_cutoff=max_template_date)
+ except PrefilterError as e:
+ msg = f'hit {hit_pdb_code}_{hit_chain_id} did not pass prefilter: {str(e)}'
+ logging.info('%s: %s', query_pdb_code, msg)
+ if strict_error_check and isinstance(
+ e, (DateError, PdbIdError, DuplicateError)):
+ # In strict mode we treat some prefilter cases as errors.
+ return SingleHitResult(features=None, error=msg, warning=None)
+
+ return SingleHitResult(features=None, error=None, warning=None)
+
+ mapping = _build_query_to_hit_index_mapping(
+ hit.query, hit.hit_sequence, hit.indices_hit, hit.indices_query,
+ query_sequence)
+
+ # The mapping is from the query to the actual hit sequence, so we need to
+ # remove gaps (which regardless have a missing confidence score).
+ template_sequence = hit.hit_sequence.replace('-', '')
+ confidence_scores = ''.join(
+ [cs for t, cs in zip(hit.hit_sequence, hit.confidence_scores)
+ if t != '-'])
+
+ cif_path = os.path.join(mmcif_dir, hit_pdb_code + '.cif')
+ logging.info('Reading PDB entry from %s. Query: %s, template: %s',
+ cif_path, query_sequence, template_sequence)
+ # Fail if we can't find the mmCIF file.
+ with open(cif_path, 'r') as cif_file:
+ cif_string = cif_file.read()
+
+ parsing_result = mmcif_parsing.parse(
+ file_id=hit_pdb_code, mmcif_string=cif_string)
+
+ if parsing_result.mmcif_object is not None:
+ hit_release_date = datetime.datetime.strptime(
+ parsing_result.mmcif_object.header['release_date'], '%Y-%m-%d')
+ if hit_release_date > max_template_date:
+ error = ('Template %s date (%s) > max template date (%s).' %
+ (hit_pdb_code, hit_release_date, max_template_date))
+ if strict_error_check:
+ return SingleHitResult(
+ features=None, error=error, warning=None)
+ logging.warning(error)
+ return SingleHitResult(features=None, error=None, warning=None)
+
+ try:
+ features, realign_warning = _extract_template_features(
+ mmcif_object=parsing_result.mmcif_object,
+ pdb_id=hit_pdb_code,
+ mapping=mapping,
+ template_sequence=template_sequence,
+ query_sequence=query_sequence,
+ template_chain_id=hit_chain_id,
+ confidence_scores=confidence_scores,
+ kalign_binary_path=kalign_binary_path)
+ features['template_e_value'] = [hit.e_value]
+ features['template_sum_probs'] = [hit.sum_probs]
+ features['template_prob_true'] = [hit.prob_true]
+ features['template_score'] = [hit.score]
+ features['template_neff'] = [hit.neff]
+ features['template_similarity'] = [hit.similarity]
+
+ # It is possible there were some errors when parsing the other chains in the
+ # mmCIF file, but the template features for the chain we want were still
+ # computed. In such case the mmCIF parsing errors are not relevant.
+ return SingleHitResult(
+ features=features, error=None, warning=realign_warning)
+ except (NoChainsError, NoAtomDataInTemplateError,
+ TemplateAtomMaskAllZerosError) as e:
+ # These 3 errors indicate missing mmCIF experimental data rather than a
+ # problem with the template search, so turn them into warnings.
+ warning = (
+ '%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: '
+ '%s, mmCIF parsing errors: %s' %
+ (hit_pdb_code,
+ hit_chain_id,
+ hit.sum_probs,
+ hit.index,
+ str(e),
+ parsing_result.errors))
+ if strict_error_check:
+ return SingleHitResult(features=None, error=warning, warning=None)
+ return SingleHitResult(features=None, error=None, warning=warning)
+ except Error as e:
+ error = (
+ '%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: '
+ '%s, mmCIF parsing errors: %s' %
+ (hit_pdb_code,
+ hit_chain_id,
+ hit.sum_probs,
+ hit.index,
+ str(e),
+ parsing_result.errors))
+ return SingleHitResult(features=None, error=error, warning=None)
+
+
+@dataclasses.dataclass(frozen=True)
+class TemplateSearchResult:
+ features: Mapping[str, Any]
+ errors: Sequence[str]
+ warnings: Sequence[str]
+
+
+class TemplateHitFeaturizer:
+ """A class for turning hhr hits to template features."""
+
+ def __init__(
+ self,
+ mmcif_dir: str,
+ max_template_date: str,
+ max_hits: int,
+ kalign_binary_path: str,
+ release_dates_path: Optional[str],
+ obsolete_pdbs_path: Optional[str],
+ strict_error_check: bool = False):
+ """Initializes the Template Search.
+
+ Args:
+ mmcif_dir: Path to a directory with mmCIF structures. Once a template ID
+ is found by HHSearch, this directory is used to retrieve the template
+ data.
+ max_template_date: The maximum date permitted for template structures. No
+ template with date higher than this date will be returned. In ISO8601
+ date format, YYYY-MM-DD.
+ max_hits: The maximum number of templates that will be returned.
+ kalign_binary_path: The path to a kalign executable used for template
+ realignment.
+ release_dates_path: An optional path to a file with a mapping from PDB IDs
+ to their release dates. Thanks to this we don't have to redundantly
+ parse mmCIF files to get that information.
+ obsolete_pdbs_path: An optional path to a file containing a mapping from
+ obsolete PDB IDs to the PDB IDs of their replacements.
+ strict_error_check: If True, then the following will be treated as errors:
+ * If any template date is after the max_template_date.
+ * If any template has identical PDB ID to the query.
+ * If any template is a duplicate of the query.
+ * Any feature computation errors.
+ """
+ self._mmcif_dir = mmcif_dir
+ if not glob.glob(os.path.join(self._mmcif_dir, '*.cif')):
+ logging.error('Could not find CIFs in %s', self._mmcif_dir)
+ raise ValueError(f'Could not find CIFs in {self._mmcif_dir}')
+
+ try:
+ self._max_template_date = datetime.datetime.strptime(
+ max_template_date, '%Y-%m-%d')
+ except ValueError:
+ raise ValueError(
+ 'max_template_date must be set and have format YYYY-MM-DD.')
+ self._max_hits = max_hits
+ self._kalign_binary_path = kalign_binary_path
+ self._strict_error_check = strict_error_check
+
+ if release_dates_path:
+ logging.info(
+ 'Using precomputed release dates %s.',
+ release_dates_path)
+ self._release_dates = _parse_release_dates(release_dates_path)
+ else:
+ self._release_dates = {}
+
+ if obsolete_pdbs_path:
+ logging.info(
+ 'Using precomputed obsolete pdbs %s.',
+ obsolete_pdbs_path)
+ self._obsolete_pdbs = _parse_obsolete(obsolete_pdbs_path)
+ else:
+ self._obsolete_pdbs = {}
+
+ def get_templates(
+ self,
+ query_sequence: str,
+ query_pdb_code: Optional[str],
+ query_release_date: Optional[datetime.datetime],
+ hhr_hits: Sequence[parsers.HhrHit]) -> TemplateSearchResult:
+ """Computes the templates for given query sequence (more details above)."""
+ logging.info('Searching for template for: %s', query_pdb_code)
+
+ template_features = {}
+ for template_feature_name in TEMPLATE_FEATURES:
+ template_features[template_feature_name] = []
+
+ # Always use a max_template_date. Set to query_release_date minus 60 days
+ # if that's earlier.
+ template_cutoff_date = self._max_template_date
+ if query_release_date:
+ delta = datetime.timedelta(days=60)
+ if query_release_date - delta < template_cutoff_date:
+ template_cutoff_date = query_release_date - delta
+ assert template_cutoff_date < query_release_date
+ assert template_cutoff_date <= self._max_template_date
+
+ num_hits = 0
+ errors = []
+ warnings = []
+
+ for hit in sorted(hhr_hits, key=lambda x: x.sum_probs, reverse=True):
+ # We got all the templates we wanted, stop processing HHSearch
+ # hits.
+ if num_hits >= self._max_hits:
+ break
+
+ result = _process_single_hit(
+ query_sequence=query_sequence,
+ query_pdb_code=query_pdb_code,
+ hit=hit,
+ mmcif_dir=self._mmcif_dir,
+ max_template_date=template_cutoff_date,
+ release_dates=self._release_dates,
+ obsolete_pdbs=self._obsolete_pdbs,
+ strict_error_check=self._strict_error_check,
+ kalign_binary_path=self._kalign_binary_path)
+
+ if result.error:
+ errors.append(result.error)
+
+ # There could be an error even if there are some results, e.g. thrown by
+ # other unparsable chains in the same mmCIF file.
+ if result.warning:
+ warnings.append(result.warning)
+
+ if result.features is None:
+ logging.info('Skipped invalid hit %s, error: %s, warning: %s',
+ hit.name, result.error, result.warning)
+ else:
+ # Increment the hit counter, since we got features out of this
+ # hit.
+ num_hits += 1
+ for k in template_features:
+ template_features[k].append(result.features[k])
+
+ for name in template_features:
+ if num_hits > 0:
+ template_features[name] = np.stack(
+ template_features[name], axis=0).astype(TEMPLATE_FEATURES[name])
+ else:
+ # Make sure the feature has correct dtype even if empty.
+ template_features[name] = np.array(
+ [], dtype=TEMPLATE_FEATURES[name])
+
+ return TemplateSearchResult(
+ features=template_features, errors=errors, warnings=warnings)
diff --git a/reproduce/AlphaFold2-Chinese/docs/MindScience_Architecture.jpg b/reproduce/AlphaFold2-Chinese/docs/MindScience_Architecture.jpg
new file mode 100644
index 0000000..56993ce
Binary files /dev/null and b/reproduce/AlphaFold2-Chinese/docs/MindScience_Architecture.jpg differ
diff --git a/reproduce/AlphaFold2-Chinese/docs/MindScience_Architecture_en.jpg b/reproduce/AlphaFold2-Chinese/docs/MindScience_Architecture_en.jpg
new file mode 100644
index 0000000..79014d7
Binary files /dev/null and b/reproduce/AlphaFold2-Chinese/docs/MindScience_Architecture_en.jpg differ
diff --git a/reproduce/AlphaFold2-Chinese/docs/all_experiment_data.jpg b/reproduce/AlphaFold2-Chinese/docs/all_experiment_data.jpg
new file mode 100644
index 0000000..eace389
Binary files /dev/null and b/reproduce/AlphaFold2-Chinese/docs/all_experiment_data.jpg differ
diff --git a/reproduce/AlphaFold2-Chinese/docs/seq_21.jpg b/reproduce/AlphaFold2-Chinese/docs/seq_21.jpg
new file mode 100644
index 0000000..08f777a
Binary files /dev/null and b/reproduce/AlphaFold2-Chinese/docs/seq_21.jpg differ
diff --git a/reproduce/AlphaFold2-Chinese/docs/seq_64.gif b/reproduce/AlphaFold2-Chinese/docs/seq_64.gif
new file mode 100644
index 0000000..b39a706
Binary files /dev/null and b/reproduce/AlphaFold2-Chinese/docs/seq_64.gif differ
diff --git a/reproduce/AlphaFold2-Chinese/main.py b/reproduce/AlphaFold2-Chinese/main.py
new file mode 100644
index 0000000..2f1d530
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/main.py
@@ -0,0 +1,112 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""run script"""
+
+import time
+import os
+import json
+import argparse
+import numpy as np
+
+import mindspore.context as context
+from mindspore.common.tensor import Tensor
+from mindspore import load_checkpoint
+
+from data.feature.feature_extraction import process_features
+from data.tools.data_process import data_process
+from commons.generate_pdb import to_pdb, from_prediction
+from commons.utils import compute_confidence
+from model import AlphaFold
+from config import config, global_config
+
+parser = argparse.ArgumentParser(description='Inputs for run.py')
+parser.add_argument('--seq_length', help='padding sequence length')
+parser.add_argument('--input_fasta_path', help='Path of FASTA files folder directory to be predicted.')
+parser.add_argument('--msa_result_path', help='Path to save msa result.')
+parser.add_argument('--database_dir', help='Path of data to generate msa.')
+parser.add_argument('--database_envdb_dir', help='Path of expandable data to generate msa.')
+parser.add_argument('--hhsearch_binary_path', help='Path of hhsearch executable.')
+parser.add_argument('--pdb70_database_path', help='Path to pdb70.')
+parser.add_argument('--template_mmcif_dir', help='Path of template mmcif.')
+parser.add_argument('--max_template_date', help='Maximum template release date.')
+parser.add_argument('--kalign_binary_path', help='Path to kalign executable.')
+parser.add_argument('--obsolete_pdbs_path', help='Path to obsolete pdbs path.')
+parser.add_argument('--checkpoint_path', help='Path of the checkpoint.')
+parser.add_argument('--device_id', default=0, type=int, help='Device id to be used.')
+args = parser.parse_args()
+
+if __name__ == "__main__":
+ context.set_context(mode=context.GRAPH_MODE,
+ device_target="Ascend",
+ variable_memory_max_size="31GB",
+ device_id=args.device_id,
+ save_graphs=False)
+ model_name = "model_1"
+ model_config = config.model_config(model_name)
+ num_recycle = model_config.model.num_recycle
+ global_config = global_config.global_config(args.seq_length)
+ extra_msa_length = global_config.extra_msa_length
+ fold_net = AlphaFold(model_config, global_config)
+
+ load_checkpoint(args.checkpoint_path, fold_net)
+
+ seq_files = os.listdir(args.input_fasta_path)
+
+ for seq_file in seq_files:
+ t1 = time.time()
+ seq_name = seq_file.split('.')[0]
+ input_features = data_process(seq_name, args)
+ tensors, aatype, residue_index, ori_res_length = process_features(
+ raw_features=input_features, config=model_config, global_config=global_config)
+ prev_pos = Tensor(np.zeros([global_config.seq_length, 37, 3]).astype(np.float16))
+ prev_msa_first_row = Tensor(np.zeros([global_config.seq_length, 256]).astype(np.float16))
+ prev_pair = Tensor(np.zeros([global_config.seq_length, global_config.seq_length, 128]).astype(np.float16))
+ """
+ :param::@sequence_length
+ """
+ t2 = time.time()
+ for i in range(num_recycle+1):
+ tensors_i = [tensor[i] for tensor in tensors]
+ input_feats = [Tensor(tensor) for tensor in tensors_i]
+ final_atom_positions, final_atom_mask, predicted_lddt_logits,\
+ prev_pos, prev_msa_first_row, prev_pair = fold_net(*input_feats,
+ prev_pos,
+ prev_msa_first_row,
+ prev_pair)
+
+ t3 = time.time()
+
+ final_atom_positions = final_atom_positions.asnumpy()[:ori_res_length]
+ final_atom_mask = final_atom_mask.asnumpy()[:ori_res_length]
+ predicted_lddt_logits = predicted_lddt_logits.asnumpy()[:ori_res_length]
+
+ confidence = compute_confidence(predicted_lddt_logits)
+ unrelaxed_protein = from_prediction(final_atom_mask, aatype[0], final_atom_positions, residue_index[0])
+ pdb_file = to_pdb(unrelaxed_protein)
+
+ seq_length = aatype.shape[-1]
+ os.makedirs(f'./result/seq_{seq_name}_{seq_length}', exist_ok=True)
+
+ with open(os.path.join(f'./result/seq_{seq_name}_{seq_length}/', f'unrelaxed_model_{seq_name}.pdb'), 'w') as f:
+ f.write(pdb_file)
+ t4 = time.time()
+ timings = {"pre_process_time": round(t2 - t1, 2),
+ "model_time": round(t3 - t2, 2),
+ "pos_process_time": round(t4 - t3, 2),
+ "all_time": round(t4 - t1, 2),
+ "confidence": confidence}
+ print(timings)
+ with open(f'./result/seq_{seq_name}_{seq_length}/timings', 'w') as f:
+ f.write(json.dumps(timings))
diff --git a/reproduce/AlphaFold2-Chinese/module/basic_module.py b/reproduce/AlphaFold2-Chinese/module/basic_module.py
new file mode 100644
index 0000000..882a118
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/module/basic_module.py
@@ -0,0 +1,936 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""basic module"""
+
+import mindspore.nn as nn
+import mindspore.common.dtype as mstype
+import mindspore.numpy as mnp
+from mindspore.ops import operations as P
+from mindspore.ops import functional as F
+from mindspore.common.tensor import Tensor
+from mindspore import Parameter
+import numpy as np
+from commons.utils import mask_mean
+
+
+class Attention(nn.Cell):
+ '''attention module'''
+ def __init__(self, config, q_data_dim, m_data_dim, output_dim, batch_size=None):
+ super(Attention, self).__init__()
+ self.config = config
+ self.q_data_dim = q_data_dim
+ self.m_data_dim = m_data_dim
+ self.output_dim = output_dim
+ self.num_head = self.config.num_head
+ self.gating = self.config.gating
+ self.key_dim = self.config.get('key_dim', int(q_data_dim))
+ self.value_dim = self.config.get('value_dim', int(m_data_dim))
+ self.key_dim = self.key_dim // self.num_head
+ self.value_dim = self.value_dim // self.num_head
+ self.batch_size = batch_size
+ self.matmul = P.MatMul(transpose_b=True)
+ self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True)
+ self.softmax = nn.Softmax()
+ self.sigmoid = nn.Sigmoid()
+ self.batch_size = batch_size
+ self._init_parameter()
+
+ def _init_parameter(self):
+ '''init parameter'''
+ if self.batch_size:
+ self.linear_q_weights = Parameter(Tensor(np.zeros([self.batch_size, self.num_head * self.key_dim,
+ self.q_data_dim]), mstype.float32))
+ self.linear_k_weights = Parameter(Tensor(np.zeros([self.batch_size, self.num_head * self.key_dim,
+ self.m_data_dim]), mstype.float32))
+ self.linear_v_weights = Parameter(Tensor(np.zeros([self.batch_size, self.num_head * self.value_dim,
+ self.m_data_dim]), mstype.float32))
+ self.linear_output_weights = Parameter(Tensor(np.zeros([self.batch_size, self.output_dim,
+ self.num_head * self.value_dim]), mstype.float32))
+ self.o_biases = Parameter(Tensor(np.zeros([self.batch_size, self.output_dim]), mstype.float32))
+ if self.gating:
+ self.linear_gating_weights = Parameter(Tensor(np.zeros([self.batch_size, self.num_head * self.value_dim,
+ self.q_data_dim]), mstype.float32))
+ self.gating_biases = Parameter(Tensor(np.zeros((self.batch_size, self.num_head, self.value_dim)),
+ mstype.float32), name="gating_b")
+ else:
+ self.linear_q_weights = Parameter(Tensor(np.zeros([self.num_head * self.key_dim, self.q_data_dim]),
+ mstype.float32))
+ self.linear_k_weights = Parameter(Tensor(np.zeros([self.num_head * self.key_dim, self.m_data_dim]),
+ mstype.float32))
+ self.linear_v_weights = Parameter(Tensor(np.zeros([self.num_head * self.value_dim, self.m_data_dim]),
+ mstype.float32))
+ self.linear_output_weights = Parameter(Tensor(np.zeros([self.output_dim, self.num_head * self.value_dim]),
+ mstype.float32))
+ self.o_biases = Parameter(Tensor(np.zeros([self.output_dim]), mstype.float32))
+ if self.gating:
+ self.linear_gating_weights = Parameter(Tensor(np.zeros([self.num_head * self.value_dim,
+ self.q_data_dim]), mstype.float32))
+ self.gating_biases = Parameter(Tensor(np.zeros((self.num_head, self.value_dim)), mstype.float32),
+ name="gating_b")
+
+ def construct(self, q_data, m_data, bias, index=None, nonbatched_bias=None):
+ '''construct'''
+ if self.batch_size:
+ linear_q_weight = P.Gather()(self.linear_q_weights, index, 0)
+ linear_k_weight = P.Gather()(self.linear_k_weights, index, 0)
+ linear_v_weight = P.Gather()(self.linear_v_weights, index, 0)
+ linear_output_weight = P.Gather()(self.linear_output_weights, index, 0)
+ o_bias = P.Gather()(self.o_biases, index, 0)
+ linear_gating_weight = 0
+ gating_bias = 0
+ if self.gating:
+ linear_gating_weight = P.Gather()(self.linear_gating_weights, index, 0)
+ gating_bias = P.Gather()(self.gating_biases, index, 0)
+ else:
+ linear_q_weight = self.linear_q_weights
+ linear_k_weight = self.linear_k_weights
+ linear_v_weight = self.linear_v_weights
+ linear_output_weight = self.linear_output_weights
+ o_bias = self.o_biases
+ linear_gating_weight = 0
+ gating_bias = 0
+ if self.gating:
+ linear_gating_weight = self.linear_gating_weights
+ gating_bias = self.gating_biases
+
+ q_data, m_data, bias = q_data.astype(mstype.float32), m_data.astype(mstype.float32), bias.astype(mstype.float32)
+ dim_b, dim_q, dim_a = q_data.shape
+ _, dim_k, dim_c = m_data.shape
+ dim_h = self.num_head
+
+ q_data = P.Reshape()(q_data, (-1, dim_a))
+ m_data = P.Reshape()(m_data, (-1, dim_c))
+
+ q = self.matmul(q_data.astype(mstype.float16), linear_q_weight.astype(mstype.float16)). \
+ astype(mstype.float32) * self.key_dim ** (-0.5)
+ k = self.matmul(m_data.astype(mstype.float16), linear_k_weight.astype(mstype.float16))
+ v = self.matmul(m_data.astype(mstype.float16), linear_v_weight.astype(mstype.float16))
+
+ q = P.Reshape()(q, (dim_b, dim_q, dim_h, -1))
+ k = P.Reshape()(k, (dim_b, dim_k, dim_h, -1))
+ v = P.Reshape()(v, (dim_b, dim_k, dim_h, -1))
+
+ tmp_q = P.Reshape()(P.Transpose()(q.astype(mstype.float16), (0, 2, 1, 3)), (dim_b * dim_h, dim_q, -1))
+ tmp_k = P.Reshape()(P.Transpose()(k.astype(mstype.float16), (0, 2, 1, 3)), (dim_b * dim_h, dim_k, -1))
+ logits = P.Reshape()(self.batch_matmul_trans_b(tmp_q.astype(mstype.float16),
+ tmp_k.astype(mstype.float16)),
+ (dim_b, dim_h, dim_q, dim_k)) + bias.astype(mstype.float16)
+
+ if nonbatched_bias is not None:
+ logits += mnp.expand_dims(nonbatched_bias, axis=0)
+ weights = self.softmax(logits.astype(mstype.float32))
+ tmp_v = P.Reshape()(P.Transpose()(v, (0, 2, 3, 1)), (dim_b * dim_h, -1, dim_k))
+ tmp_weights = P.Reshape()(weights, (dim_b * dim_h, dim_q, -1))
+ weighted_avg = P.Transpose()(P.Reshape()(self.batch_matmul_trans_b(tmp_weights.astype(mstype.float16),
+ tmp_v.astype(mstype.float16)),
+ (dim_b, dim_h, dim_q, -1)),
+ (0, 2, 1, 3)).astype(mstype.float32)
+
+ if self.gating:
+ gate_values = P.Reshape()(
+ self.matmul(q_data.astype(mstype.float16), linear_gating_weight.astype(mstype.float16)),
+ (dim_b, dim_q, dim_h, -1)) + gating_bias[None, None, ...].astype(mstype.float16)
+ gate_values = self.sigmoid(gate_values.astype(mstype.float32))
+ weighted_avg = P.Reshape()(weighted_avg * gate_values, (dim_b * dim_q, -1))
+
+ weighted_avg = P.Reshape()(weighted_avg, (dim_b * dim_q, -1))
+ output = P.Reshape()(
+ self.matmul(weighted_avg.astype(mstype.float16), linear_output_weight.astype(mstype.float16)),
+ (dim_b, dim_q, -1)) + o_bias[None, ...].astype(mstype.float16)
+ return output
+
+
+class MSARowAttentionWithPairBias(nn.Cell):
+ '''MSA row attention'''
+ def __init__(self, config, msa_act_dim, pair_act_dim, batch_size=None, slice_num=0):
+ super(MSARowAttentionWithPairBias, self).__init__()
+ self.config = config
+ self.num_head = self.config.num_head
+ self.batch_size = batch_size
+ self.norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5)
+ self.matmul = P.MatMul(transpose_b=True)
+ self.attn_mod = Attention(self.config, msa_act_dim, msa_act_dim, msa_act_dim, batch_size)
+ self.msa_act_dim = msa_act_dim
+ self.pair_act_dim = pair_act_dim
+ self.batch_size = batch_size
+ self.slice_num = slice_num
+ self.idx = Tensor(0, mstype.int32)
+ self._init_parameter()
+
+ def _init_parameter(self):
+ '''init parameter'''
+ self.query_norm_gammas = Parameter(Tensor(np.zeros([self.batch_size, self.msa_act_dim,]), mstype.float32))
+ self.query_norm_betas = Parameter(Tensor(np.zeros([self.batch_size, self.msa_act_dim,]), mstype.float32))
+ self.feat_2d_norm_gammas = Parameter(Tensor(np.zeros([self.batch_size, self.pair_act_dim,]), mstype.float32))
+ self.feat_2d_norm_betas = Parameter(Tensor(np.zeros([self.batch_size, self.pair_act_dim,]), mstype.float32))
+ self.feat_2d_weights = Parameter(
+ Tensor(np.zeros([self.batch_size, self.num_head, self.pair_act_dim]), mstype.float32))
+
+ def construct(self, msa_act, msa_mask, pair_act, index):
+ '''construct'''
+ query_norm_gamma = P.Gather()(self.query_norm_gammas, index, 0)
+ query_norm_beta = P.Gather()(self.query_norm_betas, index, 0)
+ feat_2d_norm_gamma = P.Gather()(self.feat_2d_norm_gammas, index, 0)
+ feat_2d_norm_beta = P.Gather()(self.feat_2d_norm_betas, index, 0)
+ feat_2d_weight = P.Gather()(self.feat_2d_weights, index, 0)
+
+ q, k, _ = pair_act.shape
+ bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
+ msa_act, _, _ = self.norm(msa_act.astype(mstype.float32), query_norm_gamma.astype(mstype.float32),
+ query_norm_beta.astype(mstype.float32))
+ pair_act, _, _ = self.norm(pair_act.astype(mstype.float32), feat_2d_norm_gamma.astype(mstype.float32),
+ feat_2d_norm_beta.astype(mstype.float32))
+ pair_act = P.Reshape()(pair_act, (-1, pair_act.shape[-1]))
+ nonbatched_bias = P.Transpose()(
+ P.Reshape()(self.matmul(pair_act.astype(mstype.float16), feat_2d_weight.astype(mstype.float16)),
+ (q, k, self.num_head)), (2, 0, 1))
+
+ if self.slice_num:
+ msa_act_ori_shape = P.Shape()(msa_act)
+ slice_shape = (self.slice_num, -1) + msa_act_ori_shape[1:]
+ msa_act = P.Reshape()(msa_act, slice_shape).astype(mstype.float16)
+ bias_shape = P.Shape()(bias)
+ bias = P.Reshape()(bias, slice_shape[:2] + bias_shape[1:])
+ slice_idx = 0
+ slice_idx_tensor = self.idx
+ msa_act_tuple = ()
+
+ msa_act_slice = P.Gather()(msa_act, slice_idx_tensor, 0)
+ bias_slice = P.Gather()(bias, slice_idx_tensor, 0)
+ msa_act_slice = self.attn_mod(msa_act_slice, msa_act_slice, bias_slice, index, nonbatched_bias)
+ msa_act_slice = P.Reshape()(msa_act_slice, ((1,) + P.Shape()(msa_act_slice)))
+ msa_act_tuple = msa_act_tuple + (msa_act_slice,)
+ slice_idx += 1
+ slice_idx_tensor += 1
+
+ while slice_idx < self.slice_num:
+ msa_act_slice = P.Gather()(msa_act, slice_idx_tensor, 0)
+ msa_act_slice = F.depend(msa_act_slice, msa_act_tuple[-1])
+ bias_slice = P.Gather()(bias, slice_idx_tensor, 0)
+ msa_act_slice = self.attn_mod(msa_act_slice, msa_act_slice, bias_slice, index, nonbatched_bias)
+ msa_act_slice = P.Reshape()(msa_act_slice, ((1,) + P.Shape()(msa_act_slice)))
+ msa_act_tuple = msa_act_tuple + (msa_act_slice,)
+ slice_idx += 1
+ slice_idx_tensor += 1
+
+ msa_act = P.Concat()(msa_act_tuple)
+ msa_act = P.Reshape()(msa_act, msa_act_ori_shape)
+ return msa_act
+
+ msa_act = self.attn_mod(msa_act, msa_act, bias, index, nonbatched_bias)
+ return msa_act
+
+
+class MSAColumnAttention(nn.Cell):
+ '''MSA column attention'''
+ def __init__(self, config, msa_act_dim, batch_size=None, slice_num=0):
+ super(MSAColumnAttention, self).__init__()
+ self.config = config
+ self.query_norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5)
+ self.attn_mod = Attention(self.config, msa_act_dim, msa_act_dim, msa_act_dim, batch_size)
+ self.batch_size = batch_size
+ self.slice_num = slice_num
+ self.msa_act_dim = msa_act_dim
+ self.idx = Tensor(0, mstype.int32)
+ self._init_parameter()
+
+ def _init_parameter(self):
+ self.query_norm_gammas = Parameter(Tensor(np.zeros([self.batch_size, self.msa_act_dim]), mstype.float32))
+ self.query_norm_betas = Parameter(Tensor(np.zeros([self.batch_size, self.msa_act_dim]), mstype.float32))
+
+ def construct(self, msa_act, msa_mask, index):
+ '''construct'''
+ query_norm_gamma = P.Gather()(self.query_norm_gammas, index, 0)
+ query_norm_beta = P.Gather()(self.query_norm_betas, index, 0)
+ msa_act = mnp.swapaxes(msa_act, -2, -3)
+ msa_mask = mnp.swapaxes(msa_mask, -1, -2)
+ bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
+ msa_act, _, _ = self.query_norm(msa_act.astype(mstype.float32), query_norm_gamma.astype(mstype.float32),
+ query_norm_beta.astype(mstype.float32))
+ if self.slice_num:
+ msa_act_ori_shape = P.Shape()(msa_act)
+ slice_shape = (self.slice_num, -1) + msa_act_ori_shape[1:]
+ msa_act = P.Reshape()(msa_act, slice_shape).astype(mstype.float16)
+ bias_shape = P.Shape()(bias)
+ bias = P.Reshape()(bias, slice_shape[:2] + bias_shape[1:])
+
+ slice_idx = 0
+ slice_idx_tensor = self.idx
+ msa_act_tuple = ()
+
+ msa_act_slice = P.Gather()(msa_act, slice_idx_tensor, 0)
+ bias_slice = P.Gather()(bias, slice_idx_tensor, 0)
+ msa_act_slice = self.attn_mod(msa_act_slice, msa_act_slice, bias_slice, index)
+ msa_act_slice = P.Reshape()(msa_act_slice, ((1,) + P.Shape()(msa_act_slice)))
+ msa_act_tuple = msa_act_tuple + (msa_act_slice,)
+ slice_idx += 1
+ slice_idx_tensor += 1
+
+ while slice_idx < self.slice_num:
+ msa_act_slice = P.Gather()(msa_act, slice_idx_tensor, 0)
+ msa_act_slice = F.depend(msa_act_slice, msa_act_tuple[-1])
+ bias_slice = P.Gather()(bias, slice_idx_tensor, 0)
+ msa_act_slice = self.attn_mod(msa_act_slice, msa_act_slice, bias_slice, index)
+ msa_act_slice = P.Reshape()(msa_act_slice, ((1,) + P.Shape()(msa_act_slice)))
+ msa_act_tuple = msa_act_tuple + (msa_act_slice,)
+ slice_idx += 1
+ slice_idx_tensor += 1
+
+ msa_act = P.Concat()(msa_act_tuple)
+ msa_act = P.Reshape()(msa_act, msa_act_ori_shape)
+ msa_act = mnp.swapaxes(msa_act, -2, -3)
+ return msa_act
+
+ msa_act = self.attn_mod(msa_act, msa_act, bias, index)
+ msa_act = mnp.swapaxes(msa_act, -2, -3)
+ return msa_act
+
+
+class GlobalAttention(nn.Cell):
+ '''global attention'''
+ def __init__(self, config, key_dim, value_dim, output_dim, batch_size=None):
+ super(GlobalAttention, self).__init__()
+ self.config = config
+ self.key_dim = key_dim
+ self.ori_key_dim = key_dim
+ self.value_dim = value_dim
+ self.ori_value_dim = value_dim
+ self.num_head = self.config.num_head
+ self.key_dim = self.key_dim // self.num_head
+ self.value_dim = self.value_dim // self.num_head
+ self.output_dim = output_dim
+ self.matmul_trans_b = P.MatMul(transpose_b=True)
+ self.batch_matmul = P.BatchMatMul()
+ self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True)
+ self.matmul = P.MatMul()
+ self.softmax = nn.Softmax()
+ self.sigmoid = nn.Sigmoid()
+ self.gating = self.config.gating
+
+ self.batch_size = batch_size
+ self._init_parameter()
+
+ def _init_parameter(self):
+ '''init parameter'''
+ self.linear_q_weights = Parameter(
+ Tensor(np.zeros((self.batch_size, self.ori_key_dim, self.num_head, self.key_dim)), mstype.float32))
+ self.linear_k_weights = Parameter(
+ Tensor(np.zeros((self.batch_size, self.ori_value_dim, self.key_dim)), mstype.float32))
+ self.linear_v_weights = Parameter(
+ Tensor(np.zeros((self.batch_size, self.ori_value_dim, self.value_dim)), mstype.float32))
+ self.linear_output_weights = Parameter(
+ Tensor(np.zeros((self.batch_size, self.output_dim, self.num_head * self.value_dim)), mstype.float32))
+ self.o_biases = Parameter(Tensor(np.zeros((self.batch_size, self.output_dim)), mstype.float32))
+ if self.gating:
+ self.linear_gating_weights = Parameter(
+ Tensor(np.zeros((self.batch_size, self.num_head * self.value_dim, self.ori_key_dim)), mstype.float32))
+ self.gating_biases = Parameter(Tensor(np.zeros((self.batch_size, self.ori_key_dim)), mstype.float32))
+
+ def construct(self, q_data, m_data, q_mask, bias, index):
+ '''construct'''
+ q_weights = P.Gather()(self.linear_q_weights, index, 0)
+ k_weights = P.Gather()(self.linear_k_weights, index, 0)
+ v_weights = P.Gather()(self.linear_v_weights, index, 0)
+ output_weights = P.Gather()(self.linear_output_weights, index, 0)
+ output_bias = P.Gather()(self.o_biases, index, 0)
+ gating_weights = 0
+ gating_bias = 0
+ if self.gating:
+ gating_weights = P.Gather()(self.linear_gating_weights, index, 0)
+ gating_bias = P.Gather()(self.gating_biases, index, 0)
+ b, _, _ = m_data.shape
+ v_weights = v_weights[None, ...]
+ v_weights = mnp.broadcast_to(v_weights, (b, self.value_dim * self.num_head, self.value_dim))
+ v = self.batch_matmul(m_data.astype(mstype.float16), v_weights.astype(mstype.float16))
+ q_avg = mask_mean(q_mask, q_data, axis=1)
+ q_weights = P.Reshape()(q_weights, (-1, self.num_head * self.key_dim))
+ q = P.Reshape()(self.matmul(q_avg.astype(mstype.float16), q_weights.astype(mstype.float16)),
+ (-1, self.num_head, self.key_dim)) * (self.key_dim ** (-0.5))
+
+ k_weights = k_weights[None, ...]
+ k_weights = mnp.broadcast_to(k_weights, (b, self.value_dim * self.num_head, self.key_dim))
+
+ k = self.batch_matmul(m_data.astype(mstype.float16), k_weights.astype(mstype.float16))
+
+ bias = (1e9 * (q_mask[:, None, :, 0] - 1.))
+
+ logits = self.batch_matmul_trans_b(q.astype(mstype.float16), k.astype(mstype.float16)) + bias.astype(
+ mstype.float16)
+ weights = self.softmax(logits.astype(mstype.float32))
+ weighted_avg = self.batch_matmul(weights.astype(mstype.float16), v.astype(mstype.float16))
+
+ if self.gating:
+ # gate_values = self.linear_gating(q_data).astype(mstype.float32)
+ q_data_shape = P.Shape()(q_data)
+ if len(q_data_shape) != 2:
+ q_data = P.Reshape()(q_data, (-1, q_data_shape[-1]))
+ out_shape = q_data_shape[:-1] + (-1,)
+ gate_values = P.Reshape()(self.matmul_trans_b(q_data.astype(mstype.float16),
+ gating_weights.astype(mstype.float16)) +
+ gating_bias.astype(mstype.float16), out_shape)
+
+ gate_values = P.Reshape()(self.sigmoid(gate_values.astype(mstype.float32)),
+ (b, -1, self.num_head, self.value_dim))
+ weighted_avg = P.Reshape()(weighted_avg[:, None] * gate_values, (-1, self.num_head * self.value_dim))
+ weighted_avg_shape = P.Shape()(weighted_avg)
+ if len(weighted_avg_shape) != 2:
+ weighted_avg = P.Reshape()(weighted_avg, (-1, weighted_avg_shape[-1]))
+ output = P.Reshape()(self.matmul_trans_b(weighted_avg.astype(mstype.float16),
+ output_weights.astype(mstype.float16))
+ + output_bias.astype(mstype.float16), (b, -1, self.output_dim))
+
+ else:
+ weighted_avg = P.Reshape()(weighted_avg, (-1, self.num_head * self.value_dim))
+ # output = self.linear_gating(weighted_avg)
+ weighted_avg_shape = P.Shape()(weighted_avg)
+ if len(weighted_avg_shape) != 2:
+ weighted_avg = P.Reshape()(weighted_avg, (-1, weighted_avg_shape[-1]))
+ out_shape = weighted_avg_shape[:-1] + (-1,)
+ output = P.Reshape()(self.matmul_trans_b(weighted_avg.astype(mstype.float16),
+ output_weights.astype(mstype.float16)) +
+ output_bias.astype(mstype.float16), out_shape)
+ output = output[:, None]
+ return output
+
+
+class MSAColumnGlobalAttention(nn.Cell):
+ '''MSA column global attention'''
+ def __init__(self, config, msa_act_dim, batch_size=None, slice_num=0):
+ super(MSAColumnGlobalAttention, self).__init__()
+ self.config = config
+ self.attn_mod = GlobalAttention(self.config, msa_act_dim, msa_act_dim, msa_act_dim, batch_size)
+ self.query_norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5)
+ self.batch_size = batch_size
+ self.slice_num = slice_num
+ self.msa_act_dim = msa_act_dim
+ self.idx = Tensor(0, mstype.int32)
+ self._init_parameter()
+
+ def _init_parameter(self):
+ '''init parameter'''
+ self.query_norm_gammas = Parameter(Tensor(np.zeros((self.batch_size, self.msa_act_dim)), mstype.float32))
+ self.query_norm_betas = Parameter(Tensor(np.zeros((self.batch_size, self.msa_act_dim)), mstype.float32))
+
+ def construct(self, msa_act, msa_mask, index):
+ '''construct'''
+ query_norm_gamma = P.Gather()(self.query_norm_gammas, index, 0)
+ query_norm_beta = P.Gather()(self.query_norm_betas, index, 0)
+ msa_act = mnp.swapaxes(msa_act, -2, -3)
+ msa_mask = mnp.swapaxes(msa_mask, -1, -2)
+ bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
+ msa_act, _, _ = self.query_norm(msa_act.astype(mstype.float32),
+ query_norm_gamma.astype(mstype.float32),
+ query_norm_beta.astype(mstype.float32))
+ msa_mask = mnp.expand_dims(msa_mask, axis=-1)
+
+ if self.slice_num:
+ msa_act_ori_shape = P.Shape()(msa_act)
+ slice_shape = (self.slice_num, -1) + msa_act_ori_shape[1:]
+ msa_act = P.Reshape()(msa_act, slice_shape).astype(mstype.float16)
+ bias_shape = P.Shape()(bias)
+ bias = P.Reshape()(bias, slice_shape[:2] + bias_shape[1:])
+ msa_mask_shape = P.Shape()(msa_mask)
+ msa_mask = P.Reshape()(msa_mask, slice_shape[:2] + msa_mask_shape[1:])
+
+ slice_idx = 0
+ slice_idx_tensor = self.idx
+ msa_act_tuple = ()
+
+ msa_act_slice = P.Gather()(msa_act, slice_idx_tensor, 0)
+ msa_mask_slice = P.Gather()(msa_mask, slice_idx_tensor, 0)
+ bias_slice = P.Gather()(bias, slice_idx_tensor, 0)
+ msa_act_slice = self.attn_mod(msa_act_slice, msa_act_slice, msa_mask_slice, bias_slice, index)
+ msa_act_slice = P.Reshape()(msa_act_slice, ((1,) + P.Shape()(msa_act_slice)))
+ msa_act_tuple = msa_act_tuple + (msa_act_slice,)
+ slice_idx += 1
+ slice_idx_tensor += 1
+
+ while slice_idx < self.slice_num:
+ msa_act_slice = P.Gather()(msa_act, slice_idx_tensor, 0)
+ msa_act_slice = F.depend(msa_act_slice, msa_act_tuple[-1])
+ msa_mask_slice = P.Gather()(msa_mask, slice_idx_tensor, 0)
+ bias_slice = P.Gather()(bias, slice_idx_tensor, 0)
+
+ msa_act_slice = self.attn_mod(msa_act_slice, msa_act_slice, msa_mask_slice, bias_slice, index)
+ msa_act_slice = P.Reshape()(msa_act_slice, ((1,) + P.Shape()(msa_act_slice)))
+ msa_act_tuple = msa_act_tuple + (msa_act_slice,)
+ slice_idx += 1
+ slice_idx_tensor += 1
+
+ msa_act = P.Concat()(msa_act_tuple)
+ msa_act = P.Reshape()(msa_act, msa_act_ori_shape)
+ msa_act = mnp.swapaxes(msa_act, -2, -3)
+ return msa_act
+
+ msa_act = self.attn_mod(msa_act, msa_act, msa_mask, bias, index)
+ msa_act = mnp.swapaxes(msa_act, -2, -3)
+ return msa_act
+
+
+class Transition(nn.Cell):
+ '''transition'''
+ def __init__(self, config, layer_norm_dim, batch_size=None, slice_num=0):
+ super(Transition, self).__init__()
+ self.config = config
+ self.input_layer_norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5)
+ self.matmul = P.MatMul(transpose_b=True)
+ self.layer_norm_dim = layer_norm_dim
+ self.num_intermediate = int(layer_norm_dim * self.config.num_intermediate_factor)
+ self.batch_size = batch_size
+ self.slice_num = slice_num
+ self.relu = nn.ReLU()
+ self.idx = Tensor(0, mstype.int32)
+ self._init_parameter()
+
+ def _init_parameter(self):
+ '''init parameter'''
+ self.input_layer_norm_gammas = Parameter(
+ Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32))
+ self.input_layer_norm_betas = Parameter(
+ Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32))
+ self.transition1_weights = Parameter(
+ Tensor(np.zeros((self.batch_size, self.num_intermediate, self.layer_norm_dim)), mstype.float32))
+ self.transition1_biases = Parameter(Tensor(np.zeros((self.batch_size, self.num_intermediate)), mstype.float32))
+ self.transition2_weights = Parameter(
+ Tensor(np.zeros((self.batch_size, self.layer_norm_dim, self.num_intermediate)), mstype.float32))
+ self.transition2_biases = Parameter(Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32))
+
+ def construct(self, act, index):
+ '''construct'''
+ input_layer_norm_gamma = P.Gather()(self.input_layer_norm_gammas, index, 0)
+ input_layer_norm_beta = P.Gather()(self.input_layer_norm_betas, index, 0)
+ transition1_weight = P.Gather()(self.transition1_weights, index, 0)
+ transition1_bias = P.Gather()(self.transition1_biases, index, 0)
+ transition2_weight = P.Gather()(self.transition2_weights, index, 0)
+ transition2_bias = P.Gather()(self.transition2_biases, index, 0)
+ act, _, _ = self.input_layer_norm(act.astype(mstype.float32), input_layer_norm_gamma.astype(mstype.float32),
+ input_layer_norm_beta.astype(mstype.float32))
+ if self.slice_num:
+ act_ori_shape = P.Shape()(act)
+ slice_shape = (self.slice_num, -1) + act_ori_shape[1:]
+ act = P.Reshape()(act, slice_shape).astype(mstype.float16)
+
+ slice_idx = 0
+ slice_idx_tensor = self.idx
+ act_tuple = ()
+
+ act_slice = P.Gather()(act, slice_idx_tensor, 0)
+ act_shape = P.Shape()(act_slice)
+ if len(act_shape) != 2:
+ act_slice = P.Reshape()(act_slice, (-1, act_shape[-1]))
+ act_slice = self.relu(
+ P.BiasAdd()(self.matmul(act_slice.astype(mstype.float16), transition1_weight.astype(mstype.float16)),
+ transition1_bias.astype(mstype.float16)).astype(mstype.float32))
+ act_slice = P.BiasAdd()(
+ self.matmul(act_slice.astype(mstype.float16), transition2_weight.astype(mstype.float16)),
+ transition2_bias.astype(mstype.float16))
+ act_slice = P.Reshape()(act_slice, act_shape)
+ act_slice = P.Reshape()(act_slice, ((1,) + P.Shape()(act_slice)))
+ act_tuple = act_tuple + (act_slice,)
+ slice_idx += 1
+ slice_idx_tensor += 1
+
+ while slice_idx < self.slice_num:
+ act_slice = P.Gather()(act, slice_idx_tensor, 0)
+ act_slice = F.depend(act_slice, act_tuple[-1])
+ act_shape = P.Shape()(act_slice)
+ if len(act_shape) != 2:
+ act_slice = P.Reshape()(act_slice, (-1, act_shape[-1]))
+ act_slice = self.relu(P.BiasAdd()(
+ self.matmul(act_slice.astype(mstype.float16), transition1_weight.astype(mstype.float16)),
+ transition1_bias.astype(mstype.float16)).astype(mstype.float32))
+ act_slice = P.BiasAdd()(
+ self.matmul(act_slice.astype(mstype.float16), transition2_weight.astype(mstype.float16)),
+ transition2_bias.astype(mstype.float16))
+ act_slice = P.Reshape()(act_slice, act_shape)
+ act_slice = P.Reshape()(act_slice, ((1,) + P.Shape()(act_slice)))
+ act_tuple = act_tuple + (act_slice,)
+ slice_idx += 1
+ slice_idx_tensor += 1
+
+ act = P.Concat()(act_tuple)
+ act = P.Reshape()(act, act_ori_shape)
+ return act
+
+ act_shape = P.Shape()(act)
+ if len(act_shape) != 2:
+ act = P.Reshape()(act, (-1, act_shape[-1]))
+ act = self.relu(
+ P.BiasAdd()(self.matmul(act.astype(mstype.float16), transition1_weight.astype(mstype.float16)),
+ transition1_bias.astype(mstype.float16)).astype(mstype.float32))
+ act = P.BiasAdd()(self.matmul(act.astype(mstype.float16), transition2_weight.astype(mstype.float16)),
+ transition2_bias.astype(mstype.float16))
+ act = P.Reshape()(act, act_shape)
+ return act
+
+
+class OuterProductMean(nn.Cell):
+ '''outerproduct mean'''
+ def __init__(self, config, act_dim, num_output_channel, batch_size=None, slice_num=0):
+ super(OuterProductMean, self).__init__()
+ self.num_output_channel = num_output_channel
+ self.config = config
+ self.layer_norm_input = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5)
+ self.matmul_trans_b = P.MatMul(transpose_b=True)
+ self.matmul = P.MatMul()
+ self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True)
+ self.act_dim = act_dim
+ self.batch_size = batch_size
+ self.slice_num = slice_num
+ self.idx = Tensor(0, mstype.int32)
+ self._init_parameter()
+
+ def _init_parameter(self):
+ '''init parameter'''
+ self.layer_norm_input_gammas = Parameter(Tensor(np.zeros((self.batch_size, self.act_dim)), mstype.float32))
+ self.layer_norm_input_betas = Parameter(Tensor(np.zeros((self.batch_size, self.act_dim)), mstype.float32))
+ self.left_projection_weights = Parameter(
+ Tensor(np.zeros((self.batch_size, self.config.num_outer_channel, self.act_dim)), mstype.float32))
+ self.left_projection_biases = Parameter(
+ Tensor(np.zeros((self.batch_size, self.config.num_outer_channel)), mstype.float32))
+ self.right_projection_weights = Parameter(
+ Tensor(np.zeros((self.batch_size, self.config.num_outer_channel, self.act_dim)), mstype.float32))
+ self.right_projection_biases = Parameter(
+ Tensor(np.zeros((self.batch_size, self.config.num_outer_channel)), mstype.float32))
+ self.linear_output_weights = Parameter(Tensor(np.zeros(
+ (self.batch_size, self.num_output_channel, self.config.num_outer_channel *
+ self.config.num_outer_channel)), mstype.float32))
+ self.o_biases = Parameter(Tensor(np.zeros((self.batch_size, self.num_output_channel)), mstype.float32))
+
+ def construct(self, act, mask, index):
+ '''construct'''
+ layer_norm_input_gamma = P.Gather()(self.layer_norm_input_gammas, index, 0)
+ layer_norm_input_beta = P.Gather()(self.layer_norm_input_betas, index, 0)
+ left_projection_weight = P.Gather()(self.left_projection_weights, index, 0)
+ left_projection_bias = P.Gather()(self.left_projection_biases, index, 0)
+ right_projection_weight = P.Gather()(self.right_projection_weights, index, 0)
+ right_projection_bias = P.Gather()(self.right_projection_biases, index, 0)
+ linear_output_weight = P.Gather()(self.linear_output_weights, index, 0)
+ linear_output_bias = P.Gather()(self.o_biases, index, 0)
+ mask = mask[..., None]
+ act, _, _ = self.layer_norm_input(act.astype(mstype.float32),
+ layer_norm_input_gamma.astype(mstype.float32),
+ layer_norm_input_beta.astype(mstype.float32))
+ act_shape = P.Shape()(act)
+ if len(act_shape) != 2:
+ act = P.Reshape()(act, (-1, act_shape[-1]))
+ out_shape = act_shape[:-1] + (-1,)
+ left_act = mask * P.Reshape()(P.BiasAdd()(self.matmul_trans_b(act.astype(mstype.float16),
+ left_projection_weight.astype(mstype.float16)),
+ left_projection_bias.astype(mstype.float16)), out_shape)
+ right_act = mask * P.Reshape()(P.BiasAdd()(self.matmul_trans_b(act.astype(mstype.float16),
+ right_projection_weight.astype(mstype.float16)),
+ right_projection_bias.astype(mstype.float16)), out_shape)
+ _, d, e = right_act.shape
+ if self.slice_num:
+ left_act_shape = P.Shape()(left_act)
+ slice_shape = (left_act_shape[0],) + (self.slice_num, -1) + (left_act_shape[-1],)
+ left_act = P.Reshape()(left_act, slice_shape)
+ slice_idx = 0
+ slice_idx_tensor = self.idx
+ act_tuple = ()
+ left_act_slice = P.Gather()(left_act, slice_idx_tensor, 1)
+ a, b, c = left_act_slice.shape
+ left_act_slice = P.Reshape()(mnp.transpose(left_act_slice.astype(mstype.float16), [2, 1, 0]), (-1, a))
+ right_act = P.Reshape()(right_act, (a, -1))
+ act_slice = P.Reshape()(P.Transpose()(P.Reshape()(self.matmul(left_act_slice.astype(mstype.float16),
+ right_act.astype(mstype.float16)),
+ (c, b, d, e)), (2, 1, 0, 3)), (d, b, c * e))
+ act_slice_shape = P.Shape()(act_slice)
+ if len(act_shape) != 2:
+ act_slice = P.Reshape()(act_slice, (-1, act_slice_shape[-1]))
+ act_slice = P.Reshape()(P.BiasAdd()(self.matmul_trans_b(act_slice.astype(mstype.float16),
+ linear_output_weight.astype(mstype.float16)),
+ linear_output_bias.astype(mstype.float16)), (d, b, -1))
+ act_slice = mnp.transpose(act_slice.astype(mstype.float16), [1, 0, 2])
+ act_slice = P.Reshape()(act_slice, ((1,) + P.Shape()(act_slice)))
+ act_tuple = act_tuple + (act_slice,)
+ slice_idx += 1
+ slice_idx_tensor += 1
+ while slice_idx < self.slice_num:
+ left_act_slice = P.Gather()(left_act, slice_idx_tensor, 1)
+ left_act_slice = F.depend(left_act_slice, act_tuple[-1])
+ a, b, c = left_act_slice.shape
+ left_act_slice = P.Reshape()(mnp.transpose(left_act_slice.astype(mstype.float16), [2, 1, 0]), (-1, a))
+ right_act = P.Reshape()(right_act, (a, -1))
+ act_slice = P.Reshape()(P.Transpose()(P.Reshape()(self.matmul(left_act_slice.astype(mstype.float16),
+ right_act.astype(mstype.float16)),
+ (c, b, d, e)), (2, 1, 0, 3)), (d, b, c * e))
+ act_slice_shape = P.Shape()(act_slice)
+ if len(act_shape) != 2:
+ act_slice = P.Reshape()(act_slice, (-1, act_slice_shape[-1]))
+ act_slice = P.Reshape()(P.BiasAdd()(self.matmul_trans_b(act_slice.astype(mstype.float16),
+ linear_output_weight.astype(mstype.float16)),
+ linear_output_bias.astype(mstype.float16)), (d, b, -1))
+ act_slice = mnp.transpose(act_slice.astype(mstype.float16), [1, 0, 2])
+ act_slice = P.Reshape()(act_slice, ((1,) + P.Shape()(act_slice)))
+ act_tuple = act_tuple + (act_slice,)
+ slice_idx += 1
+ slice_idx_tensor += 1
+ act = P.Concat()(act_tuple)
+ act_shape = P.Shape()(act)
+ act = P.Reshape()(act, (-1, act_shape[-2], act_shape[-1]))
+ epsilon = 1e-3
+ tmp_mask = P.Transpose()(mask.astype(mstype.float16), (2, 1, 0))
+ norm = P.Transpose()(self.batch_matmul_trans_b(tmp_mask.astype(mstype.float16),
+ tmp_mask.astype(mstype.float16)),
+ (1, 2, 0)).astype(mstype.float32)
+ act /= epsilon + norm
+ return act
+
+ a, b, c = left_act.shape
+ left_act = P.Reshape()(mnp.transpose(left_act.astype(mstype.float16), [2, 1, 0]), (-1, a))
+ right_act = P.Reshape()(right_act, (a, -1))
+ act = P.Reshape()(P.Transpose()(P.Reshape()(self.matmul(left_act.astype(mstype.float16),
+ right_act.astype(mstype.float16)),
+ (c, b, d, e)), (2, 1, 0, 3)), (d, b, c * e))
+ act_shape = P.Shape()(act)
+ if len(act_shape) != 2:
+ act = P.Reshape()(act, (-1, act_shape[-1]))
+ act = P.Reshape()(P.BiasAdd()(self.matmul_trans_b(act.astype(mstype.float16),
+ linear_output_weight.astype(mstype.float16)),
+ linear_output_bias.astype(mstype.float16)), (d, b, -1))
+ act = mnp.transpose(act.astype(mstype.float16), [1, 0, 2])
+ epsilon = 1e-3
+ tmp_mask = P.Transpose()(mask.astype(mstype.float16), (2, 1, 0))
+ norm = P.Transpose()(self.batch_matmul_trans_b(tmp_mask.astype(mstype.float16),
+ tmp_mask.astype(mstype.float16)),
+ (1, 2, 0)).astype(mstype.float32)
+ act /= epsilon + norm
+ return act
+
+class TriangleMultiplication(nn.Cell):
+ '''triangle multiplication'''
+ def __init__(self, config, layer_norm_dim, batch_size):
+ super(TriangleMultiplication, self).__init__()
+ self.config = config
+ self.layer_norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5)
+ self.matmul = P.MatMul(transpose_b=True)
+ self.sigmoid = nn.Sigmoid()
+ self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True)
+ equation = ["ikc,jkc->ijc", "kjc,kic->ijc"]
+ if self.config.equation not in equation:
+ print("TriangleMultiplication Not Suppl")
+ if self.config.equation == "ikc,jkc->ijc":
+ self.equation = True
+ elif self.config.equation == "kjc,kic->ijc":
+ self.equation = False
+ else:
+ self.equation = None
+ self.batch_size = batch_size
+ self.layer_norm_dim = layer_norm_dim
+ self._init_parameter()
+
+ def _init_parameter(self):
+ '''init parameter'''
+ self.layer_norm_input_gammas = Parameter(
+ Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32))
+ self.layer_norm_input_betas = Parameter(
+ Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32))
+ self.left_projection_weights = Parameter(
+ Tensor(np.zeros((self.batch_size, self.config.num_intermediate_channel, self.layer_norm_dim)),
+ mstype.float32))
+ self.left_projection_biases = Parameter(
+ Tensor(np.zeros((self.batch_size, self.config.num_intermediate_channel)), mstype.float32))
+ self.right_projection_weights = Parameter(
+ Tensor(np.zeros((self.batch_size, self.config.num_intermediate_channel, self.layer_norm_dim)),
+ mstype.float32))
+ self.right_projection_biases = Parameter(
+ Tensor(np.zeros((self.batch_size, self.config.num_intermediate_channel)), mstype.float32))
+ self.left_gate_weights = Parameter(
+ Tensor(np.zeros((self.batch_size, self.config.num_intermediate_channel, self.layer_norm_dim)),
+ mstype.float32))
+ self.left_gate_biases = Parameter(
+ Tensor(np.zeros((self.batch_size, self.config.num_intermediate_channel)), mstype.float32))
+ self.right_gate_weights = Parameter(
+ Tensor(np.zeros((self.batch_size, self.config.num_intermediate_channel, self.layer_norm_dim)),
+ mstype.float32))
+ self.right_gate_biases = Parameter(
+ Tensor(np.zeros((self.batch_size, self.config.num_intermediate_channel)), mstype.float32))
+ self.center_layer_norm_gammas = Parameter(
+ Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32))
+ self.center_layer_norm_betas = Parameter(
+ Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32))
+ self.output_projection_weights = Parameter(
+ Tensor(np.zeros((self.batch_size, self.layer_norm_dim, self.layer_norm_dim)), mstype.float32))
+ self.output_projection_biases = Parameter(
+ Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32))
+ self.gating_linear_weights = Parameter(
+ Tensor(np.zeros((self.batch_size, self.layer_norm_dim, self.layer_norm_dim)), mstype.float32))
+ self.gating_linear_biases = Parameter(Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32))
+
+ def construct(self, act, mask, index):
+ '''construct'''
+ layer_norm_input_gamma = P.Gather()(self.layer_norm_input_gammas, index, 0)
+ layer_norm_input_beta = P.Gather()(self.layer_norm_input_betas, index, 0)
+ left_projection_weight = P.Gather()(self.left_projection_weights, index, 0)
+ left_projection_bias = P.Gather()(self.left_projection_biases, index, 0)
+ right_projection_weight = P.Gather()(self.right_projection_weights, index, 0)
+ right_projection_bias = P.Gather()(self.right_projection_biases, index, 0)
+ left_gate_weight = P.Gather()(self.left_gate_weights, index, 0)
+ left_gate_bias = P.Gather()(self.left_gate_biases, index, 0)
+ right_gate_weight = P.Gather()(self.right_gate_weights, index, 0)
+ right_gate_bias = P.Gather()(self.right_gate_biases, index, 0)
+ center_layer_norm_gamma = P.Gather()(self.center_layer_norm_gammas, index, 0)
+ center_layer_norm_beta = P.Gather()(self.center_layer_norm_betas, index, 0)
+ output_projection_weight = P.Gather()(self.output_projection_weights, index, 0)
+ output_projection_bias = P.Gather()(self.output_projection_biases, index, 0)
+ gating_linear_weight = P.Gather()(self.gating_linear_weights, index, 0)
+ gating_linear_bias = P.Gather()(self.gating_linear_biases, index, 0)
+
+ mask = mask[..., None]
+ act, _, _ = self.layer_norm(act.astype(mstype.float32),
+ layer_norm_input_gamma.astype(mstype.float32),
+ layer_norm_input_beta.astype(mstype.float32))
+ input_act = act
+ act_shape = P.Shape()(act)
+ if len(act_shape) != 2:
+ act = P.Reshape()(act, (-1, act_shape[-1]))
+ out_shape = act_shape[:-1] + (-1,)
+ left_projection = mask * P.Reshape()(
+ P.BiasAdd()(self.matmul(act.astype(mstype.float16), left_projection_weight.astype(mstype.float16)),
+ left_projection_bias.astype(mstype.float16)), out_shape)
+ left_projection = left_projection.astype(mstype.float16)
+ act = F.depend(act, left_projection)
+
+ left_gate_values = self.sigmoid(P.Reshape()(
+ P.BiasAdd()(self.matmul(act.astype(mstype.float16), left_gate_weight.astype(mstype.float16)),
+ left_gate_bias.astype(mstype.float16)), out_shape).astype(mstype.float32))
+ left_proj_act = left_projection * left_gate_values
+ act = F.depend(act, left_proj_act)
+
+ right_projection = mask * P.Reshape()(
+ P.BiasAdd()(self.matmul(act.astype(mstype.float16), right_projection_weight.astype(mstype.float16)),
+ right_projection_bias.astype(mstype.float16)), out_shape)
+ right_projection = right_projection.astype(mstype.float16)
+ act = F.depend(act, right_projection)
+
+ right_gate_values = self.sigmoid(P.Reshape()(
+ P.BiasAdd()(self.matmul(act.astype(mstype.float16), right_gate_weight.astype(mstype.float16)),
+ right_gate_bias.astype(mstype.float16)), out_shape).astype(mstype.float32))
+ right_proj_act = right_projection * right_gate_values
+ left_proj_act = F.depend(left_proj_act, right_proj_act)
+
+ if self.equation is not None:
+ if self.equation:
+ left_proj_act_tmp = P.Transpose()(left_proj_act.astype(mstype.float16), (2, 0, 1))
+ right_proj_act_tmp = P.Transpose()(right_proj_act.astype(mstype.float16), (2, 0, 1))
+ act = self.batch_matmul_trans_b(left_proj_act_tmp, right_proj_act_tmp)
+ act = P.Transpose()(act, (1, 2, 0)).astype(mstype.float32)
+ else:
+ left_proj_act_tmp = P.Transpose()(left_proj_act.astype(mstype.float16), (2, 1, 0))
+ right_proj_act_tmp = P.Transpose()(right_proj_act.astype(mstype.float16), (2, 1, 0))
+ act = self.batch_matmul_trans_b(left_proj_act_tmp, right_proj_act_tmp)
+ act = P.Transpose()(act, (2, 1, 0)).astype(mstype.float32)
+ act, _, _ = self.layer_norm(act.astype(mstype.float32),
+ center_layer_norm_gamma.astype(mstype.float32),
+ center_layer_norm_beta.astype(mstype.float32))
+ act_shape = P.Shape()(act)
+ if len(act_shape) != 2:
+ act = P.Reshape()(act, (-1, act_shape[-1]))
+ out_shape = act_shape[:-1] + (-1,)
+ act = P.Reshape()(
+ P.BiasAdd()(self.matmul(act.astype(mstype.float16), output_projection_weight.astype(mstype.float16)),
+ output_projection_bias.astype(mstype.float16)), out_shape)
+ input_act_shape = P.Shape()(input_act)
+ if len(input_act_shape) != 2:
+ input_act = P.Reshape()(input_act, (-1, input_act_shape[-1]))
+ out_shape = input_act_shape[:-1] + (-1,)
+ gate_values = self.sigmoid(P.Reshape()(
+ P.BiasAdd()(self.matmul(input_act.astype(mstype.float16), gating_linear_weight.astype(mstype.float16)),
+ gating_linear_bias.astype(mstype.float16)), out_shape).astype(mstype.float32))
+ act = act * gate_values
+ return act
+
+
+class TriangleAttention(nn.Cell):
+ '''triangle attention'''
+ def __init__(self, config, layer_norm_dim, batch_size=None, slice_num=0):
+ super(TriangleAttention, self).__init__()
+ self.config = config
+ self.orientation_is_per_column = (self.config.orientation == 'per_column')
+ self.init_factor = Tensor(1. / np.sqrt(layer_norm_dim), mstype.float32)
+ self.query_norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5)
+ self.matmul = P.MatMul(transpose_b=True)
+ self.attn_mod = Attention(self.config, layer_norm_dim, layer_norm_dim, layer_norm_dim, batch_size)
+ self.batch_size = batch_size
+ self.slice_num = slice_num
+ self.layer_norm_dim = layer_norm_dim
+ self.idx = Tensor(0, mstype.int32)
+ self._init_parameter()
+
+ def _init_parameter(self):
+ '''init parameter'''
+ self.query_norm_gammas = Parameter(Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32))
+ self.query_norm_betas = Parameter(Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32))
+ self.feat_2d_weights = Parameter(
+ Tensor(np.zeros((self.batch_size, self.config.num_head, self.layer_norm_dim)), mstype.float32))
+
+ def construct(self, pair_act, pair_mask, index):
+ '''construct'''
+ query_norm_gamma = P.Gather()(self.query_norm_gammas, index, 0)
+ query_norm_beta = P.Gather()(self.query_norm_betas, index, 0)
+ feat_2d_weight = P.Gather()(self.feat_2d_weights, index, 0)
+ if self.orientation_is_per_column:
+ pair_act = mnp.swapaxes(pair_act, -2, -3)
+ pair_mask = mnp.swapaxes(pair_mask, -1, -2)
+ bias = (1e9 * (pair_mask - 1.))[:, None, None, :]
+ pair_act, _, _ = self.query_norm(pair_act.astype(mstype.float32),
+ query_norm_gamma.astype(mstype.float32),
+ query_norm_beta.astype(mstype.float32))
+ q, k, _ = pair_act.shape
+ nonbatched_bias = self.matmul(P.Reshape()(pair_act.astype(mstype.float16), (-1, pair_act.shape[-1])),
+ feat_2d_weight.astype(mstype.float16))
+ nonbatched_bias = P.Transpose()(P.Reshape()(nonbatched_bias, (q, k, -1)), (2, 0, 1))
+ if self.slice_num:
+ pair_act_ori_shape = P.Shape()(pair_act)
+ slice_shape = (self.slice_num, -1) + pair_act_ori_shape[1:]
+ pair_act = P.Reshape()(pair_act, slice_shape).astype(mstype.float16)
+ bias_shape = P.Shape()(bias)
+ bias = P.Reshape()(bias, slice_shape[:2] + bias_shape[1:])
+
+ slice_idx = 0
+ slice_idx_tensor = self.idx
+ pair_act_tuple = ()
+
+ pair_act_slice = P.Gather()(pair_act, slice_idx_tensor, 0)
+ bias_slice = P.Gather()(bias, slice_idx_tensor, 0)
+ pair_act_slice = self.attn_mod(pair_act_slice, pair_act_slice, bias_slice, index, nonbatched_bias)
+ pair_act_slice = P.Reshape()(pair_act_slice, ((1,) + P.Shape()(pair_act_slice)))
+ pair_act_tuple = pair_act_tuple + (pair_act_slice,)
+ slice_idx += 1
+ slice_idx_tensor += 1
+
+ while slice_idx < self.slice_num:
+ pair_act_slice = P.Gather()(pair_act, slice_idx_tensor, 0)
+ pair_act_slice = F.depend(pair_act_slice, pair_act_tuple[-1])
+ bias_slice = P.Gather()(bias, slice_idx_tensor, 0)
+ pair_act_slice = self.attn_mod(pair_act_slice, pair_act_slice, bias_slice, index, nonbatched_bias)
+ pair_act_slice = P.Reshape()(pair_act_slice, ((1,) + P.Shape()(pair_act_slice)))
+ pair_act_tuple = pair_act_tuple + (pair_act_slice,)
+ slice_idx += 1
+ slice_idx_tensor += 1
+ pair_act = P.Concat()(pair_act_tuple)
+ pair_act = P.Reshape()(pair_act, pair_act_ori_shape)
+
+ if self.orientation_is_per_column:
+ pair_act = mnp.swapaxes(pair_act, -2, -3)
+ return pair_act
+
+ pair_act = self.attn_mod(pair_act, pair_act, bias, index, nonbatched_bias)
+ if self.orientation_is_per_column:
+ pair_act = mnp.swapaxes(pair_act, -2, -3)
+ return pair_act
diff --git a/reproduce/AlphaFold2-Chinese/module/evoformer_module.py b/reproduce/AlphaFold2-Chinese/module/evoformer_module.py
new file mode 100644
index 0000000..992fc52
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/module/evoformer_module.py
@@ -0,0 +1,304 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Evoformer module"""
+import numpy as np
+import mindspore.nn as nn
+import mindspore.common.dtype as mstype
+import mindspore.numpy as mnp
+from mindspore.ops import operations as P
+from mindspore.ops import functional as F
+from mindspore.common.tensor import Tensor
+from mindspore import Parameter
+
+from commons import residue_constants
+from commons.utils import dgram_from_positions, batch_make_transform_from_reference, batch_quat_affine, \
+ batch_invert_point, batch_rot_to_quat
+from module.basic_module import Attention, MSARowAttentionWithPairBias, MSAColumnAttention, MSAColumnGlobalAttention, \
+ Transition, OuterProductMean, TriangleMultiplication, TriangleAttention
+class EvoformerIteration(nn.Cell):
+ '''evoformer iteration'''
+ def __init__(self, config, msa_act_dim, pair_act_dim, is_extra_msa, batch_size, global_config):
+ super(EvoformerIteration, self).__init__()
+ self.config = config
+ self.is_extra_msa = is_extra_msa
+ if not self.is_extra_msa:
+ self.msa_row_attention_with_pair_bias = MSARowAttentionWithPairBias(
+ self.config.msa_row_attention_with_pair_bias, msa_act_dim, pair_act_dim, batch_size,
+ global_config.evoformer_iteration.msa_row_attention_with_pair_bias.slice_num)
+ self.attn_mod = MSAColumnAttention(self.config.msa_column_attention, msa_act_dim, batch_size,
+ global_config.evoformer_iteration.msa_column_attention.slice_num)
+ self.msa_transition = Transition(self.config.msa_transition, msa_act_dim, batch_size,
+ global_config.evoformer_iteration.msa_transition.slice_num)
+ self.outer_product_mean = OuterProductMean(self.config.outer_product_mean, msa_act_dim, pair_act_dim,
+ batch_size,
+ global_config.evoformer_iteration.outer_product_mean.slice_num)
+ self.triangle_attention_starting_node = \
+ TriangleAttention(self.config.triangle_attention_starting_node,
+ pair_act_dim, batch_size,
+ global_config.evoformer_iteration.triangle_attention_starting_node.slice_num)
+ self.triangle_attention_ending_node = \
+ TriangleAttention(self.config.triangle_attention_ending_node,
+ pair_act_dim, batch_size,
+ global_config.evoformer_iteration.triangle_attention_ending_node.slice_num)
+ self.pair_transition = Transition(self.config.pair_transition, pair_act_dim, batch_size,
+ global_config.evoformer_iteration.pair_transition.slice_num)
+ else:
+ self.msa_row_attention_with_pair_bias = MSARowAttentionWithPairBias(
+ self.config.msa_row_attention_with_pair_bias, msa_act_dim, pair_act_dim, batch_size,
+ global_config.extra_msa_stack.msa_row_attention_with_pair_bias.slice_num)
+ self.attn_mod = \
+ MSAColumnGlobalAttention(self.config.msa_column_attention, msa_act_dim, batch_size,
+ global_config.extra_msa_stack.msa_column_global_attention.slice_num)
+ self.msa_transition = Transition(self.config.msa_transition, msa_act_dim, batch_size,
+ global_config.extra_msa_stack.msa_transition.slice_num)
+ self.outer_product_mean = OuterProductMean(self.config.outer_product_mean, msa_act_dim, pair_act_dim,
+ batch_size,
+ global_config.extra_msa_stack.outer_product_mean.slice_num)
+ self.triangle_attention_starting_node = \
+ TriangleAttention(self.config.triangle_attention_starting_node,
+ pair_act_dim, batch_size,
+ global_config.extra_msa_stack.triangle_attention_starting_node.slice_num)
+ self.triangle_attention_ending_node = \
+ TriangleAttention(self.config.triangle_attention_ending_node,
+ pair_act_dim, batch_size,
+ global_config.extra_msa_stack.triangle_attention_ending_node.slice_num)
+ self.pair_transition = Transition(self.config.pair_transition, pair_act_dim, batch_size,
+ global_config.extra_msa_stack.pair_transition.slice_num)
+
+ self.triangle_multiplication_outgoing = TriangleMultiplication(self.config.triangle_multiplication_outgoing,
+ pair_act_dim, batch_size)
+ self.triangle_multiplication_incoming = TriangleMultiplication(self.config.triangle_multiplication_incoming,
+ pair_act_dim, batch_size)
+
+ def construct(self, msa_act, pair_act, msa_mask, pair_mask, index):
+ '''construct'''
+ msa_act = msa_act + self.msa_row_attention_with_pair_bias(msa_act, msa_mask, pair_act, index)
+ msa_act = msa_act + self.attn_mod(msa_act, msa_mask, index)
+ msa_act = msa_act + self.msa_transition(msa_act, index)
+ pair_act = pair_act + self.outer_product_mean(msa_act, msa_mask, index)
+ pair_act = pair_act + self.triangle_multiplication_outgoing(pair_act, pair_mask, index)
+ pair_act = pair_act + self.triangle_multiplication_incoming(pair_act, pair_mask, index)
+ pair_act = pair_act + self.triangle_attention_starting_node(pair_act, pair_mask, index)
+ pair_act = pair_act + self.triangle_attention_ending_node(pair_act, pair_mask, index)
+ pair_act = pair_act + self.pair_transition(pair_act, index)
+ return msa_act.astype(mstype.float16), pair_act.astype(mstype.float16)
+
+
+class TemplatePairStack(nn.Cell):
+ '''template pair stack'''
+ def __init__(self, config, global_config=None):
+ super(TemplatePairStack, self).__init__()
+ self.config = config
+ self.global_config = global_config
+ self.num_block = self.config.num_block
+ # self.seq_length = global_config.step_size
+
+ self.triangle_attention_starting_node = \
+ TriangleAttention(self.config.triangle_attention_starting_node,
+ 64, self.num_block,
+ global_config.template_pair_stack.triangle_attention_starting_node.slice_num)
+ self.triangle_attention_ending_node = \
+ TriangleAttention(self.config.triangle_attention_ending_node,
+ 64, self.num_block,
+ global_config.template_pair_stack.triangle_attention_ending_node.slice_num)
+ # Hard Code
+ self.pair_transition = Transition(self.config.pair_transition, 64, self.num_block,
+ global_config.template_pair_stack.pair_transition.slice_num)
+ self.triangle_multiplication_outgoing = TriangleMultiplication(self.config.triangle_multiplication_outgoing,
+ layer_norm_dim=64, batch_size=self.num_block)
+ self.triangle_multiplication_incoming = TriangleMultiplication(self.config.triangle_multiplication_incoming,
+ layer_norm_dim=64, batch_size=self.num_block)
+
+ def construct(self, pair_act, pair_mask, index):
+ if not self.num_block:
+ return pair_act
+
+ pair_act = pair_act + self.triangle_attention_starting_node(pair_act, pair_mask, index)
+ pair_act = pair_act + self.triangle_attention_ending_node(pair_act, pair_mask, index)
+ pair_act = pair_act + self.triangle_multiplication_outgoing(pair_act, pair_mask, index)
+ pair_act = pair_act + self.triangle_multiplication_incoming(pair_act, pair_mask, index)
+ pair_act = pair_act + self.pair_transition(pair_act, index)
+ return pair_act.astype(mstype.float16)
+
+
+class SingleTemplateEmbedding(nn.Cell):
+ '''single template embedding'''
+ def __init__(self, config, global_config=None):
+ super(SingleTemplateEmbedding, self).__init__()
+ self.config = config
+ # self.seq_length = global_config.step_size
+ self.num_channels = (self.config.template_pair_stack.triangle_attention_ending_node.value_dim)
+ self.embedding2d = nn.Dense(88, self.num_channels).to_float(mstype.float16)
+ self.template_pair_stack = TemplatePairStack(self.config.template_pair_stack, global_config)
+ self.num_bins = self.config.dgram_features.num_bins
+ self.min_bin = self.config.dgram_features.min_bin
+ self.max_bin = self.config.dgram_features.max_bin
+
+ self.one_hot = nn.OneHot(depth=22, axis=-1)
+ self.n, self.ca, self.c = [residue_constants.atom_order[a] for a in ('N', 'CA', 'C')]
+
+ self.use_template_unit_vector = self.config.use_template_unit_vector
+ layer_norm_dim = 64
+ self.output_layer_norm = nn.LayerNorm([layer_norm_dim,], epsilon=1e-5)
+
+ self.idx_num_block = Parameter(Tensor(0, mstype.int32), requires_grad=False)
+ self.idx_batch_loop = Parameter(Tensor(0, mstype.int32), requires_grad=False)
+ # self.num_block = Tensor(self.template_pair_stack.num_block, mstype.int32)
+ # self.batch_block = Tensor(4, mstype.int32)
+ self.num_block = self.template_pair_stack.num_block
+ self.batch_block = 4
+ self._act = Parameter(
+ Tensor(np.zeros([global_config.seq_length, global_config.seq_length, 64]).astype(np.float16)),
+ requires_grad=False)
+
+ def construct(self, query_embedding, mask_2d, template_aatype, template_all_atom_masks, template_all_atom_positions,
+ template_pseudo_beta_mask, template_pseudo_beta):
+ '''construct'''
+ num_res = template_aatype[0, ...].shape[0]
+ template_mask_2d_temp = P.Cast()(template_pseudo_beta_mask[:, :, None] * template_pseudo_beta_mask[:, None, :],
+ query_embedding.dtype)
+ template_dgram_temp = dgram_from_positions(template_pseudo_beta, self.num_bins, self.min_bin, self.max_bin)
+ template_dgram_temp = P.Cast()(template_dgram_temp, query_embedding.dtype)
+
+ to_concat_temp = (template_dgram_temp, template_mask_2d_temp[:, :, :, None])
+ aatype_temp = self.one_hot(template_aatype)
+ to_concat_temp = to_concat_temp + (mnp.tile(aatype_temp[:, None, :, :], (1, num_res, 1, 1)),
+ mnp.tile(aatype_temp[:, :, None, :], (1, 1, num_res, 1)))
+ rot_temp, trans_temp = batch_make_transform_from_reference(template_all_atom_positions[:, :, self.n],
+ template_all_atom_positions[:, :, self.ca],
+ template_all_atom_positions[:, :, self.c])
+
+ _, rotation_tmp, translation_tmp = batch_quat_affine(
+ batch_rot_to_quat(rot_temp, unstack_inputs=True), translation=trans_temp, rotation=rot_temp,
+ unstack_inputs=True)
+ points_tmp = mnp.expand_dims(translation_tmp, axis=-2)
+ affine_vec_tmp = batch_invert_point(points_tmp, rotation_tmp, translation_tmp, extra_dims=1)
+ inv_distance_scalar_tmp = P.Rsqrt()(1e-6 + mnp.sum(mnp.square(affine_vec_tmp), axis=1))
+ template_mask_tmp = (template_all_atom_masks[:, :, self.n] *
+ template_all_atom_masks[:, :, self.ca] *
+ template_all_atom_masks[:, :, self.c])
+ template_mask_2d_tmp = template_mask_tmp[:, :, None] * template_mask_tmp[:, None, :]
+
+ inv_distance_scalar_tmp = inv_distance_scalar_tmp * template_mask_2d_tmp.astype(inv_distance_scalar_tmp.dtype)
+ unit_vector_tmp = P.Transpose()((affine_vec_tmp * inv_distance_scalar_tmp[:, None, ...]), (0, 2, 3, 1))
+ template_mask_2d_tmp = P.Cast()(template_mask_2d_tmp, query_embedding.dtype)
+ if not self.use_template_unit_vector:
+ unit_vector_tmp = mnp.zeros_like(unit_vector_tmp)
+ to_concat_temp = to_concat_temp + (unit_vector_tmp, template_mask_2d_tmp[..., None],)
+ act_tmp = mnp.concatenate(to_concat_temp, axis=-1)
+ act_tmp = act_tmp * template_mask_2d_tmp[..., None]
+ act_tmp = self.embedding2d(act_tmp)
+
+ idx_batch_loop = self.idx_batch_loop
+ output = []
+ idx_batch_loop_int = 0
+ self.idx_num_block = 0
+
+ while idx_batch_loop_int < self.batch_block:
+ self.idx_num_block = 0
+ idx_num_block_int = 0
+ self._act = P.Gather()(act_tmp, idx_batch_loop, 0)
+ while idx_num_block_int < self.num_block:
+ self._act = self.template_pair_stack(self._act, mask_2d, self.idx_num_block)
+ self.idx_num_block += 1
+ idx_num_block_int += 1
+ temp_act = P.Reshape()(self._act, ((1,) + P.Shape()(self._act)))
+ output.append(temp_act)
+ idx_batch_loop += 1
+ idx_batch_loop_int += 1
+
+ act_tmp_loop = P.Concat()(output)
+ act_tmp = self.output_layer_norm(act_tmp_loop.astype(mstype.float32))
+ return act_tmp
+
+
+class TemplateEmbedding(nn.Cell):
+ '''template embedding'''
+ def __init__(self, config, slice_num, global_config=None):
+ super(TemplateEmbedding, self).__init__()
+ self.config = config
+ self.global_config = global_config
+ self.num_channels = (self.config.template_pair_stack.triangle_attention_ending_node.value_dim)
+ self.template_embedder = SingleTemplateEmbedding(self.config, self.global_config)
+ self.template_pointwise_attention = Attention(self.config.attention, q_data_dim=128, m_data_dim=64,
+ output_dim=128)
+ self.slice_num = slice_num
+ if slice_num == 0:
+ slice_num = 1
+ self._flat_query_slice = Parameter(
+ Tensor(np.zeros((int(global_config.seq_length * global_config.seq_length / slice_num), 1, 128)),
+ dtype=mstype.float16), requires_grad=False)
+ self._flat_templates_slice = Parameter(
+ Tensor(np.zeros((int(global_config.seq_length * global_config.seq_length / slice_num), 4, 64)),
+ dtype=mstype.float16), requires_grad=False)
+
+ def construct(self, query_embedding, template_aatype, template_all_atom_masks, template_all_atom_positions,
+ template_mask, template_pseudo_beta_mask, template_pseudo_beta, mask_2d):
+ '''construct'''
+ num_templates = template_mask.shape[0]
+ num_channels = self.num_channels
+ num_res = query_embedding.shape[0]
+ query_num_channels = query_embedding.shape[-1]
+ template_mask = P.Cast()(template_mask, query_embedding.dtype)
+
+ mask_2d = F.depend(mask_2d, query_embedding)
+ template_pair_representation = self.template_embedder(query_embedding, mask_2d, template_aatype,
+ template_all_atom_masks, template_all_atom_positions,
+ template_pseudo_beta_mask,
+ template_pseudo_beta)
+
+ template_pair_representation = template_pair_representation.astype(mstype.float32)
+ flat_query = mnp.reshape(query_embedding, [num_res * num_res, 1, query_num_channels])
+ flat_templates = mnp.reshape(
+ mnp.transpose(template_pair_representation.astype(mstype.float16), [1, 2, 0, 3]),
+ [num_res * num_res, num_templates, num_channels]).astype(mstype.float32)
+ bias = (1e9 * (template_mask[None, None, None, :] - 1.))
+ flat_query, flat_templates, bias = flat_query.astype(mstype.float32), flat_templates.astype(
+ mstype.float32), bias.astype(mstype.float32)
+
+ if self.slice_num:
+ slice_shape = (self.slice_num, -1)
+ flat_query_shape = P.Shape()(flat_query)
+ flat_query = P.Reshape()(flat_query, slice_shape + flat_query_shape[1:]).astype(mstype.float16)
+ flat_templates_shape = P.Shape()(flat_templates)
+ flat_templates = P.Reshape()(flat_templates, slice_shape + flat_templates_shape[1:]).astype(mstype.float16)
+ slice_idx = 0
+ embedding_tuple = ()
+ while slice_idx < self.slice_num:
+ self._flat_query_slice = flat_query[slice_idx]
+ self._flat_templates_slice = flat_templates[slice_idx]
+ embedding_slice = self.template_pointwise_attention(self._flat_query_slice, self._flat_templates_slice,
+ bias, index=None, nonbatched_bias=None)
+ embedding_slice = P.Reshape()(embedding_slice, ((1,) + P.Shape()(embedding_slice)))
+ embedding_tuple = embedding_tuple + (embedding_slice,)
+ slice_idx += 1
+ embedding = P.Concat()(embedding_tuple)
+
+ embedding = embedding.astype(mstype.float32)
+ embedding = mnp.reshape(embedding, [num_res, num_res, query_num_channels])
+ # No gradients if no templates.
+ template_mask = template_mask.astype(embedding.dtype)
+ embedding = embedding * (mnp.sum(template_mask) > 0.).astype(embedding.dtype)
+ return embedding
+
+ embedding = self.template_pointwise_attention(flat_query, flat_templates, bias, index=None,
+ nonbatched_bias=None)
+ embedding = embedding.astype(mstype.float32)
+ embedding = mnp.reshape(embedding, [num_res, num_res, query_num_channels])
+ # No gradients if no templates.
+ template_mask = template_mask.astype(embedding.dtype)
+ embedding = embedding * (mnp.sum(template_mask) > 0.).astype(embedding.dtype)
+ return embedding
diff --git a/reproduce/AlphaFold2-Chinese/module/model.py b/reproduce/AlphaFold2-Chinese/module/model.py
new file mode 100644
index 0000000..03ae69b
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/module/model.py
@@ -0,0 +1,235 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""AlphaFold Model"""
+
+import numpy as np
+
+import mindspore.nn as nn
+import mindspore.common.dtype as mstype
+import mindspore.numpy as mnp
+from mindspore.common.tensor import Tensor
+from mindspore import Parameter
+from mindspore.ops import functional as F
+
+from commons import residue_constants
+from commons.utils import get_chi_atom_indices, pseudo_beta_fn, dgram_from_positions, atom37_to_torsion_angles
+from module.evoformer_module import TemplateEmbedding, EvoformerIteration
+from module.structure_module import StructureModule, PredictedLDDTHead
+
+class AlphaFold(nn.Cell):
+ """AlphaFold Model"""
+ def __init__(self, config, global_config):
+ super(AlphaFold, self).__init__()
+ self.config = config.model.embeddings_and_evoformer
+ self.preprocess_1d = nn.Dense(22, self.config.msa_channel).to_float(mstype.float16)
+ self.preprocess_msa = nn.Dense(49, self.config.msa_channel).to_float(mstype.float16)
+ self.left_single = nn.Dense(22, self.config.pair_channel).to_float(mstype.float16)
+ self.right_single = nn.Dense(22, self.config.pair_channel).to_float(mstype.float16)
+ self.prev_pos_linear = nn.Dense(15, self.config.pair_channel).to_float(mstype.float16)
+ self.pair_activations = nn.Dense(65, self.config.pair_channel).to_float(mstype.float16)
+ self.prev_msa_first_row_norm = nn.LayerNorm([256,], epsilon=1e-5)
+ self.prev_pair_norm = nn.LayerNorm([128,], epsilon=1e-5)
+ self.one_hot = nn.OneHot(depth=self.config.max_relative_feature * 2 + 1, axis=-1)
+ self.extra_msa_activations = nn.Dense(25, self.config.extra_msa_channel).to_float(mstype.float16)
+ self.template_single_embedding = nn.Dense(57, self.config.msa_channel).to_float(mstype.float16)
+ self.template_projection = nn.Dense(self.config.msa_channel, self.config.msa_channel).to_float(mstype.float16)
+ self.single_activations = nn.Dense(self.config.msa_channel, self.config.seq_channel).to_float(mstype.float16)
+ self.relu = nn.ReLU()
+ self.recycle_pos = self.config.recycle_pos
+ self.recycle_features = self.config.recycle_features
+ self.template_enable = self.config.template.enabled
+ self.max_relative_feature = self.config.max_relative_feature
+ self.template_enabled = self.config.template.enabled
+ self.template_embed_torsion_angles = self.config.template.embed_torsion_angles
+ self.num_bins = self.config.prev_pos.num_bins
+ self.min_bin = self.config.prev_pos.min_bin
+ self.max_bin = self.config.prev_pos.max_bin
+ self.extra_msa_one_hot = nn.OneHot(depth=23, axis=-1)
+ self.template_aatype_one_hot = nn.OneHot(depth=22, axis=-1)
+ self.template_embedding = TemplateEmbedding(self.config.template,
+ global_config.template_embedding.slice_num,
+ global_config=global_config)
+ self.extra_msa_stack_iteration = EvoformerIteration(self.config.evoformer,
+ msa_act_dim=64,
+ pair_act_dim=128,
+ is_extra_msa=True,
+ batch_size=self.config.extra_msa_stack_num_block,
+ global_config=global_config)
+
+ self.evoformer_iteration = EvoformerIteration(self.config.evoformer,
+ msa_act_dim=256,
+ pair_act_dim=128,
+ is_extra_msa=False,
+ batch_size=self.config.evoformer_num_block,
+ global_config=global_config)
+
+ self.structure_module = StructureModule(config.model.heads.structure_module,
+ self.config.seq_channel,
+ self.config.pair_channel,
+ global_config=global_config)
+
+ self.module_lddt = PredictedLDDTHead(config.model.heads.predicted_lddt,
+ global_config,
+ self.config.seq_channel)
+ self._init_tensor(global_config)
+
+ def _init_tensor(self, global_config):
+ "initialization of tensors and parameters"
+ self.chi_atom_indices = Tensor(get_chi_atom_indices(), mstype.int32)
+ chi_angles_mask = list(residue_constants.chi_angles_mask)
+ chi_angles_mask.append([0.0, 0.0, 0.0, 0.0])
+ self.chi_angles_mask = Tensor(chi_angles_mask, mstype.float32)
+ self.mirror_psi_mask = Tensor(np.asarray([1., 1., -1., 1., 1., 1., 1.])[None, None, :, None], mstype.float32)
+ self.chi_pi_periodic = Tensor(residue_constants.chi_pi_periodic, mstype.float32)
+
+ indices0 = np.arange(4).reshape((-1, 1, 1, 1, 1)).astype("int64") # 4 batch
+ indices0 = indices0.repeat(global_config.seq_length, axis=1) # seq_length sequence length
+ indices0 = indices0.repeat(4, axis=2) # 4 chis
+ self.indices0 = Tensor(indices0.repeat(4, axis=3)) # 4 atoms
+
+ indices1 = np.arange(global_config.seq_length).reshape((1, -1, 1, 1, 1)).astype("int64")
+ indices1 = indices1.repeat(4, axis=0)
+ indices1 = indices1.repeat(4, axis=2)
+ self.indices1 = Tensor(indices1.repeat(4, axis=3))
+
+ self.idx_extra_msa_stack = Parameter(Tensor(0, mstype.int32), requires_grad=False)
+ self.extra_msa_stack_num_block = self.config.extra_msa_stack_num_block
+
+ self.idx_evoformer_block = Parameter(Tensor(0, mstype.int32), requires_grad=False)
+ self.evoformer_num_block = Tensor(self.config.evoformer_num_block, mstype.int32)
+
+ def construct(self, target_feat, msa_feat, msa_mask, seq_mask, aatype,
+ template_aatype, template_all_atom_masks, template_all_atom_positions,
+ template_mask, template_pseudo_beta_mask, template_pseudo_beta,
+ _, extra_msa, extra_has_deletion,
+ extra_deletion_value, extra_msa_mask,
+ atom14_atom_exists, atom37_atom_exists, residue_index,
+ prev_pos, prev_msa_first_row, prev_pair):
+ """construct"""
+
+ preprocess_1d = self.preprocess_1d(target_feat)
+ preprocess_msa = self.preprocess_msa(msa_feat)
+ msa_activations1 = mnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa
+
+ left_single = self.left_single(target_feat)
+ right_single = self.right_single(target_feat)
+
+ pair_activations = left_single[:, None] + right_single[None]
+ mask_2d = seq_mask[:, None] * seq_mask[None, :]
+
+ if self.recycle_pos:
+ prev_pseudo_beta = pseudo_beta_fn(aatype, prev_pos, None)
+ dgram = dgram_from_positions(prev_pseudo_beta, self.num_bins, self.min_bin, self.max_bin)
+ pair_activations += self.prev_pos_linear(dgram)
+ # return pair_activations, msa_activations1
+ prev_msa_first_row = F.depend(prev_msa_first_row, pair_activations)
+ if self.recycle_features:
+ prev_msa_first_row = self.prev_msa_first_row_norm(prev_msa_first_row)
+ msa_activations1 = mnp.concatenate(
+ (mnp.expand_dims(prev_msa_first_row + msa_activations1[0, ...], 0),
+ msa_activations1[1:, ...]), 0)
+ pair_activations += self.prev_pair_norm(prev_pair.astype(mstype.float32))
+
+ if self.max_relative_feature:
+ offset = residue_index[:, None] - residue_index[None, :]
+ rel_pos = self.one_hot(mnp.clip(offset + self.max_relative_feature, 0, 2 * self.max_relative_feature))
+ pair_activations += self.pair_activations(rel_pos)
+
+ template_pair_representation = 0
+ if self.template_enable:
+ template_pair_representation = self.template_embedding(pair_activations, template_aatype,
+ template_all_atom_masks, template_all_atom_positions,
+ template_mask, template_pseudo_beta_mask,
+ template_pseudo_beta, mask_2d)
+ pair_activations += template_pair_representation
+
+ msa_1hot = self.extra_msa_one_hot(extra_msa)
+ extra_msa_feat = mnp.concatenate((msa_1hot, extra_has_deletion[..., None], extra_deletion_value[..., None]),
+ axis=-1)
+ extra_msa_activations = self.extra_msa_activations(extra_msa_feat)
+ msa_act = extra_msa_activations
+ pair_act = pair_activations
+
+ msa_act = msa_act.astype(mstype.float32)
+ pair_act = pair_act.astype(mstype.float32)
+ extra_msa_mask = extra_msa_mask.astype(mstype.float32)
+ mask_2d = mask_2d.astype(mstype.float32)
+
+ self.idx_extra_msa_stack = 0
+ idx_extra_msa_stack_int = 0
+ while idx_extra_msa_stack_int < self.extra_msa_stack_num_block:
+ msa_act, pair_act = \
+ self.extra_msa_stack_iteration(msa_act, pair_act, extra_msa_mask, mask_2d, self.idx_extra_msa_stack)
+ self.idx_extra_msa_stack += 1
+ idx_extra_msa_stack_int += 1
+ msa_act = F.depend(msa_act, self.idx_extra_msa_stack)
+ pair_act = F.depend(pair_act, self.idx_extra_msa_stack)
+
+ msa_activations2 = None
+ if self.template_enabled and self.template_embed_torsion_angles:
+ num_templ, num_res = template_aatype.shape
+ aatype_one_hot = self.template_aatype_one_hot(template_aatype)
+ torsion_angles_sin_cos, alt_torsion_angles_sin_cos, torsion_angles_mask = atom37_to_torsion_angles(
+ template_aatype, template_all_atom_positions, template_all_atom_masks, self.chi_atom_indices,
+ self.chi_angles_mask, self.mirror_psi_mask, self.chi_pi_periodic, self.indices0, self.indices1)
+ template_features = mnp.concatenate([aatype_one_hot,
+ mnp.reshape(torsion_angles_sin_cos, [num_templ, num_res, 14]),
+ mnp.reshape(alt_torsion_angles_sin_cos, [num_templ, num_res, 14]),
+ torsion_angles_mask], axis=-1)
+ template_activations = self.template_single_embedding(template_features)
+ template_activations = self.relu(template_activations.astype(mstype.float32))
+ template_activations = self.template_projection(template_activations)
+ msa_activations2 = mnp.concatenate([msa_activations1, template_activations], axis=0)
+ torsion_angle_mask = torsion_angles_mask[:, :, 2]
+ torsion_angle_mask = torsion_angle_mask.astype(msa_mask.dtype)
+ msa_mask = mnp.concatenate([msa_mask, torsion_angle_mask], axis=0)
+
+ msa_activations2 = msa_activations2.astype(mstype.float16)
+ pair_activations = pair_act.astype(mstype.float16)
+ msa_mask = msa_mask.astype(mstype.float16)
+ mask_2d = mask_2d.astype(mstype.float16)
+ # return msa_activations2, pair_activations, msa_mask, mask_2d
+ self.idx_evoformer_block = self.idx_evoformer_block * 0
+ while self.idx_evoformer_block < self.evoformer_num_block:
+ msa_activations2, pair_activations = \
+ self.evoformer_iteration(msa_activations2,
+ pair_activations,
+ msa_mask,
+ mask_2d,
+ self.idx_evoformer_block)
+ self.idx_evoformer_block += 1
+
+ single_activations = self.single_activations(msa_activations2[0])
+ msa_first_row = msa_activations2[0]
+
+ # return single_activations, msa, msa_first_row
+ final_atom_positions, final_atom_mask, rp_structure_module = \
+ self.structure_module(single_activations,
+ pair_activations,
+ seq_mask,
+ aatype,
+ atom14_atom_exists,
+ atom37_atom_exists)
+
+ predicted_lddt_logits = self.module_lddt(rp_structure_module)
+
+ prev_pos = final_atom_positions.astype(mstype.float16)
+ prev_msa_first_row = msa_first_row.astype(mstype.float16)
+ prev_pair = pair_activations.astype(mstype.float16)
+
+ final_atom_positions = final_atom_positions.astype(mstype.float16)
+ final_atom_mask = final_atom_mask.astype(mstype.float16)
+
+ return final_atom_positions, final_atom_mask, predicted_lddt_logits, prev_pos, prev_msa_first_row, prev_pair
diff --git a/reproduce/AlphaFold2-Chinese/module/structure_module.py b/reproduce/AlphaFold2-Chinese/module/structure_module.py
new file mode 100644
index 0000000..c05affe
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/module/structure_module.py
@@ -0,0 +1,443 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""structure module"""
+import numpy as np
+import mindspore.ops as ops
+import mindspore.common.dtype as mstype
+import mindspore.numpy as mnp
+from mindspore import Parameter, ms_function, Tensor
+from mindspore import nn
+from commons import residue_constants
+from commons.utils import generate_new_affine, to_tensor, from_tensor, vecs_to_tensor, atom14_to_atom37, \
+ get_exp_atom_pos, get_exp_frames, pre_compose, scale_translation, to_tensor_new, l2_normalize, \
+ torsion_angles_to_frames, frames_and_literature_positions_to_atom14_pos, apply_to_point, _invert_point
+
+class InvariantPointAttention(nn.Cell):
+ """Invariant Point attention module."""
+
+ def __init__(self, config, global_config, pair_dim):
+ """Initialize.
+
+ Args:
+ config: Structure Module Config
+ global_config: Global Config of Model.
+ pair_dim: pair representation dimension.
+ """
+
+ super().__init__()
+
+ self._dist_epsilon = 1e-8
+ self.config = config
+ self.num_head = config.num_head
+ self.num_scalar_qk = config.num_scalar_qk
+ self.num_scalar_v = config.num_scalar_v
+ self.num_point_v = config.num_point_v
+ self.num_point_qk = config.num_point_qk
+ self.num_channel = config.num_channel
+ self.projection_num = self.num_head * self.num_scalar_v + self.num_head * self.num_point_v * 4 +\
+ self.num_head * pair_dim
+
+ self.global_config = global_config
+ self.q_scalar = nn.Dense(config.num_channel, self.num_head*self.num_scalar_qk).to_float(mstype.float16)
+ self.kv_scalar = nn.Dense(config.num_channel, self.num_head*(self.num_scalar_qk + self.num_scalar_v)
+ ).to_float(mstype.float16)
+ self.q_point_local = nn.Dense(config.num_channel, self.num_head * 3 * self.num_point_qk
+ ).to_float(mstype.float16)
+ self.kv_point_local = nn.Dense(config.num_channel, self.num_head * 3 * (self.num_point_qk + self.num_point_v)
+ ).to_float(mstype.float16)
+ self.soft_max = nn.Softmax()
+ self.soft_plus = ops.Softplus()
+ self.trainable_point_weights = Parameter(Tensor(np.ones((12,)), mstype.float32), name="trainable_point_weights")
+ self.attention_2d = nn.Dense(pair_dim, self.num_head).to_float(mstype.float16)
+ self.output_projection = nn.Dense(self.projection_num, self.num_channel, weight_init='zeros'
+ ).to_float(mstype.float16)
+ self.scalar_weights = np.sqrt(1.0 / (3 * 16))
+ self.point_weights = np.sqrt(1.0 / (3 * 18))
+ self.attention_2d_weights = np.sqrt(1.0 / 3)
+
+ def construct(self, inputs_1d, inputs_2d, mask, rotation, translation):
+ """Compute geometry-aware attention.
+
+ Args:
+ inputs_1d: (N, C) 1D input embedding that is the basis for the
+ scalar queries.
+ inputs_2d: (N, M, C') 2D input embedding, used for biases and values.
+ mask: (N, 1) mask to indicate which elements of inputs_1d participate
+ in the attention.
+ rotation: describe the orientation of every element in inputs_1d
+ translation: describe the position of every element in inputs_1d
+
+ Returns:
+ Transformation of the input embedding.
+ """
+
+ num_residues, _ = inputs_1d.shape
+
+ # Improve readability by removing a large number of 'self's.
+ num_head = self.num_head
+ num_scalar_qk = self.num_scalar_qk
+ num_point_qk = self.num_point_qk
+ num_scalar_v = self.num_scalar_v
+ num_point_v = self.num_point_v
+
+ # Construct scalar queries of shape:
+ q_scalar = self.q_scalar(inputs_1d)
+ q_scalar = mnp.reshape(q_scalar, [num_residues, num_head, num_scalar_qk])
+
+ # Construct scalar keys/values of shape:
+ # [num_target_residues, num_head, num_points]
+ kv_scalar = self.kv_scalar(inputs_1d)
+ kv_scalar = mnp.reshape(kv_scalar, [num_residues, num_head, num_scalar_v + num_scalar_qk])
+ k_scalar, v_scalar = mnp.split(kv_scalar, [num_scalar_qk], axis=-1)
+
+ # Construct query points of shape:
+ # [num_residues, num_head, num_point_qk]
+ # First construct query points in local frame.
+ q_point_local = self.q_point_local(inputs_1d)
+ q_point_local = mnp.stack(mnp.split(q_point_local, 3, axis=-1), axis=0)
+
+ # Project query points into global frame.
+ q_point_global = apply_to_point(rotation, translation, q_point_local)
+
+ # Reshape query point for later use.
+ q_point0 = mnp.reshape(q_point_global[0], (num_residues, num_head, num_point_qk))
+ q_point1 = mnp.reshape(q_point_global[1], (num_residues, num_head, num_point_qk))
+ q_point2 = mnp.reshape(q_point_global[2], (num_residues, num_head, num_point_qk))
+
+ # Construct key and value points.
+ # Key points have shape [num_residues, num_head, num_point_qk]
+ # Value points have shape [num_residues, num_head, num_point_v]
+
+ # Construct key and value points in local frame.
+ kv_point_local = self.kv_point_local(inputs_1d)
+
+ kv_point_local = mnp.split(kv_point_local, 3, axis=-1)
+ # Project key and value points into global frame.
+ kv_point_global = apply_to_point(rotation, translation, kv_point_local)
+
+ kv_point_global0 = mnp.reshape(kv_point_global[0], (num_residues, num_head, (num_point_qk + num_point_v)))
+ kv_point_global1 = mnp.reshape(kv_point_global[1], (num_residues, num_head, (num_point_qk + num_point_v)))
+ kv_point_global2 = mnp.reshape(kv_point_global[2], (num_residues, num_head, (num_point_qk + num_point_v)))
+
+ # Split key and value points.
+ k_point0, v_point0 = mnp.split(kv_point_global0, [num_point_qk,], axis=-1)
+ k_point1, v_point1 = mnp.split(kv_point_global1, [num_point_qk,], axis=-1)
+ k_point2, v_point2 = mnp.split(kv_point_global2, [num_point_qk,], axis=-1)
+
+ trainable_point_weights = self.soft_plus(self.trainable_point_weights)
+ point_weights = self.point_weights * mnp.expand_dims(trainable_point_weights, axis=1)
+
+ v_point = [mnp.swapaxes(v_point0, -2, -3), mnp.swapaxes(v_point1, -2, -3), mnp.swapaxes(v_point2, -2, -3)]
+ q_point = [mnp.swapaxes(q_point0, -2, -3), mnp.swapaxes(q_point1, -2, -3), mnp.swapaxes(q_point2, -2, -3)]
+ k_point = [mnp.swapaxes(k_point0, -2, -3), mnp.swapaxes(k_point1, -2, -3), mnp.swapaxes(k_point2, -2, -3)]
+
+ dist2 = mnp.square(q_point[0][:, :, None, :] - k_point[0][:, None, :, :]) + \
+ mnp.square(q_point[1][:, :, None, :] - k_point[1][:, None, :, :]) + \
+ mnp.square(q_point[2][:, :, None, :] - k_point[2][:, None, :, :])
+
+ attn_qk_point = -0.5 * mnp.sum(
+ point_weights[:, None, None, :] * dist2, axis=-1)
+
+ v = mnp.swapaxes(v_scalar, -2, -3)
+ q = mnp.swapaxes(self.scalar_weights * q_scalar, -2, -3)
+ k = mnp.swapaxes(k_scalar, -2, -3)
+ attn_qk_scalar = ops.matmul(q, mnp.swapaxes(k, -2, -1))
+ attn_logits = attn_qk_scalar + attn_qk_point
+
+ attention_2d = self.attention_2d(inputs_2d)
+ attention_2d = mnp.transpose(attention_2d, [2, 0, 1])
+ attention_2d = self.attention_2d_weights * attention_2d
+
+ attn_logits += attention_2d
+
+ mask_2d = mask * mnp.swapaxes(mask, -1, -2)
+ attn_logits -= 1e5 * (1. - mask_2d)
+
+ # [num_head, num_query_residues, num_target_residues]
+ attn = self.soft_max(attn_logits)
+
+ # [num_head, num_query_residues, num_head * num_scalar_v]
+ result_scalar = ops.matmul(attn, v)
+
+ result_point_global = [mnp.swapaxes(mnp.sum(attn[:, :, :, None] * v_point[0][:, None, :, :], axis=-2), -2, -3),
+ mnp.swapaxes(mnp.sum(attn[:, :, :, None] * v_point[1][:, None, :, :], axis=-2), -2, -3),
+ mnp.swapaxes(mnp.sum(attn[:, :, :, None] * v_point[2][:, None, :, :], axis=-2), -2, -3)
+ ]
+
+ result_point_global = [mnp.reshape(result_point_global[0], [num_residues, num_head * num_point_v]),
+ mnp.reshape(result_point_global[1], [num_residues, num_head * num_point_v]),
+ mnp.reshape(result_point_global[2], [num_residues, num_head * num_point_v])]
+ result_scalar = mnp.swapaxes(result_scalar, -2, -3)
+
+ result_scalar = mnp.reshape(result_scalar, [num_residues, num_head * num_scalar_v])
+
+ result_point_local = _invert_point(result_point_global, rotation, translation)
+
+ output_feature1 = result_scalar
+ output_feature20 = result_point_local[0]
+ output_feature21 = result_point_local[1]
+ output_feature22 = result_point_local[2]
+
+ output_feature3 = mnp.sqrt(self._dist_epsilon +
+ mnp.square(result_point_local[0]) +
+ mnp.square(result_point_local[1]) +
+ mnp.square(result_point_local[2]))
+
+ result_attention_over_2d = ops.matmul(mnp.swapaxes(attn, 0, 1), inputs_2d)
+ num_out = num_head * result_attention_over_2d.shape[-1]
+ output_feature4 = mnp.reshape(result_attention_over_2d, [num_residues, num_out])
+
+ final_act = mnp.concatenate([output_feature1, output_feature20, output_feature21,
+ output_feature22, output_feature3, output_feature4], axis=-1)
+ final_result = self.output_projection(final_act)
+ return final_result
+
+
+class MultiRigidSidechain(nn.Cell):
+ """Class to make side chain atoms."""
+
+ def __init__(self, config, global_config, single_repr_dim):
+ super().__init__()
+ self.config = config
+ self.global_config = global_config
+ self.input_projection = nn.Dense(single_repr_dim, config.num_channel, weight_init='normal'
+ ).to_float(mstype.float16)
+ self.input_projection_1 = nn.Dense(single_repr_dim, config.num_channel, weight_init='normal'
+ ).to_float(mstype.float16)
+ self.relu = nn.ReLU()
+ self.resblock1 = nn.Dense(config.num_channel, config.num_channel, weight_init='normal').to_float(mstype.float16)
+ self.resblock2 = nn.Dense(config.num_channel, config.num_channel, weight_init='zeros').to_float(mstype.float16)
+ self.resblock1_1 = nn.Dense(config.num_channel, config.num_channel, weight_init='normal'
+ ).to_float(mstype.float16)
+ self.resblock2_1 = nn.Dense(config.num_channel, config.num_channel, weight_init='zeros'
+ ).to_float(mstype.float16)
+ self.unnormalized_angles = nn.Dense(config.num_channel, 14, weight_init='normal').to_float(mstype.float16)
+ self.print = ops.Print()
+ self.restype_atom14_to_rigid_group = Tensor(residue_constants.restype_atom14_to_rigid_group)
+ self.restype_atom14_rigid_group_positions = Tensor(residue_constants.restype_atom14_rigid_group_positions)
+ self.restype_atom14_mask = Tensor(residue_constants.restype_atom14_mask)
+ self.restype_rigid_group_default_frame = Tensor(residue_constants.restype_rigid_group_default_frame)
+
+ def construct(self, rotation, translation, act, initial_act, aatype):
+ """Predict side chains using rotation and translation representations.
+
+ Args:
+ rotation: The rotation matrices.
+ translation: A translation matrices.
+ act: updated pair activations from structure module
+ initial_act: initial act representations (input of structure module)
+ aatype: Amino acid type representations
+
+ Returns:
+ angles, positions and new frames
+ """
+
+ act1 = self.input_projection(self.relu(act.astype(mstype.float32)))
+ init_act1 = self.input_projection_1(self.relu(initial_act.astype(mstype.float32)))
+ # Sum the activation list (equivalent to concat then Linear).
+ act = act1 + init_act1
+
+ # Mapping with some residual blocks.
+ # for _ in range(self.config.num_residual_block):
+ # resblock1
+ old_act = act
+ act = self.resblock1(self.relu(act.astype(mstype.float32)))
+ act = self.resblock2(self.relu(act.astype(mstype.float32)))
+ act += old_act
+ # resblock2
+ old_act = act
+ act = self.resblock1_1(self.relu(act.astype(mstype.float32)))
+ act = self.resblock2_1(self.relu(act.astype(mstype.float32)))
+ act += old_act
+
+ # Map activations to torsion angles. Shape: (num_res, 14).
+ num_res = act.shape[0]
+ unnormalized_angles = self.unnormalized_angles(self.relu(act.astype(mstype.float32)))
+
+ unnormalized_angles = mnp.reshape(unnormalized_angles, [num_res, 7, 2])
+
+ angles = l2_normalize(unnormalized_angles, axis=-1)
+
+ backb_to_global = [rotation[0][0], rotation[0][1], rotation[0][2],
+ rotation[1][0], rotation[1][1], rotation[1][2],
+ rotation[2][0], rotation[2][1], rotation[2][2],
+ translation[0], translation[1], translation[2]]
+
+ all_frames_to_global = torsion_angles_to_frames(aatype, backb_to_global, angles,
+ self.restype_rigid_group_default_frame)
+
+ pred_positions = frames_and_literature_positions_to_atom14_pos(aatype, all_frames_to_global,
+ self.restype_atom14_to_rigid_group,
+ self.restype_atom14_rigid_group_positions,
+ self.restype_atom14_mask)
+
+ atom_pos = pred_positions
+ frames = all_frames_to_global
+
+ return angles, unnormalized_angles, atom_pos, frames
+
+
+class FoldIteration(nn.Cell):
+ """A single iteration of the main structure module loop."""
+
+ def __init__(self, config, global_config, pair_dim, single_repr_dim):
+ super().__init__()
+ self.config = config
+ self.global_config = global_config
+ self.drop_out = nn.Dropout(keep_prob=0.9)
+ self.attention_layer_norm = nn.LayerNorm([config.num_channel,], epsilon=1e-5)
+ self.transition_layer_norm = nn.LayerNorm([config.num_channel,], epsilon=1e-5)
+ self.transition = nn.Dense(config.num_channel, config.num_channel, weight_init='normal'
+ ).to_float(mstype.float16)
+ self.transition_1 = nn.Dense(config.num_channel, config.num_channel, weight_init='normal'
+ ).to_float(mstype.float16)
+ self.transition_2 = nn.Dense(config.num_channel, config.num_channel, weight_init='normal'
+ ).to_float(mstype.float16)
+ self.relu = nn.ReLU()
+ self.affine_update = nn.Dense(config.num_channel, 6, weight_init='zeros').to_float(mstype.float16)
+ self.attention_module = InvariantPointAttention(self.config, self.global_config, pair_dim)
+ self.mu_side_chain = MultiRigidSidechain(config.sidechain, global_config, single_repr_dim)
+ self.print = ops.Print()
+
+ def construct(self, act, static_feat_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype):
+ '''constuct'''
+ # Attention
+ attn = self.attention_module(act, static_feat_2d, sequence_mask, rotation, translation)
+ act += attn
+ act = self.drop_out(act)
+ act = self.attention_layer_norm(act.astype(mstype.float32))
+ # Transition
+ input_act = act
+ act = self.transition(act)
+ act = self.relu(act.astype(mstype.float32))
+ act = self.transition_1(act)
+ act = self.relu(act.astype(mstype.float32))
+ act = self.transition_2(act)
+
+ act += input_act
+ act = self.drop_out(act)
+ act = self.transition_layer_norm(act.astype(mstype.float32))
+
+ # This block corresponds to
+ # Jumper et al. (2021) Alg. 23 "Backbone update"
+ # Affine update
+ affine_update = self.affine_update(act)
+
+ quaternion, rotation, translation = pre_compose(quaternion, rotation, translation, affine_update)
+ _, rotation1, translation1 = scale_translation(quaternion, translation, rotation, 10.0)
+
+ angles_sin_cos, unnormalized_angles_sin_cos, atom_pos, frames =\
+ self.mu_side_chain(rotation1, translation1, act, initial_act, aatype)
+
+ affine_output = to_tensor_new(quaternion, translation)
+
+ return act, quaternion, translation, rotation, affine_output, angles_sin_cos, unnormalized_angles_sin_cos, \
+ atom_pos, frames
+
+
+class StructureModule(nn.Cell):
+ """StructureModule as a network head."""
+
+ def __init__(self, config, single_repr_dim, pair_dim, global_config=None, compute_loss=True):
+ super(StructureModule, self).__init__()
+ self.config = config
+ self.global_config = global_config
+ self.compute_loss = compute_loss
+ self.fold_iteration = FoldIteration(self.config, global_config, pair_dim, single_repr_dim)
+ self.single_layer_norm = nn.LayerNorm([single_repr_dim,], epsilon=1e-5)
+ self.initial_projection = nn.Dense(single_repr_dim, self.config.num_channel).to_float(mstype.float16)
+ self.pair_layer_norm = nn.LayerNorm([pair_dim,], epsilon=1e-5)
+ self.num_layer = config.num_layer
+ self.indice0 = Tensor(
+ np.arange(global_config.seq_length).reshape((-1, 1, 1)).repeat(37, axis=1).astype("int32"))
+
+ @ms_function
+ def construct(self, single, pair, seq_mask, aatype, residx_atom37_to_atom14=None, atom37_atom_exists=None):
+ '''construct'''
+ sequence_mask = seq_mask[:, None]
+ act = self.single_layer_norm(single.astype(mstype.float32))
+ initial_act = act
+ act = self.initial_projection(act)
+ quaternion, rotation, translation = generate_new_affine(sequence_mask)
+ aff_to_tensor = to_tensor(quaternion, mnp.transpose(translation))
+ act_2d = self.pair_layer_norm(pair.astype(mstype.float32))
+ # folder iteration
+ quaternion, rotation, translation = from_tensor(aff_to_tensor)
+
+ act_new, atom_pos, _, _, _, _ =\
+ self.iteration_operation(act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype)
+ atom14_pred_positions = vecs_to_tensor(atom_pos)[-1]
+
+ atom37_pred_positions = atom14_to_atom37(atom14_pred_positions,
+ residx_atom37_to_atom14,
+ atom37_atom_exists,
+ self.indice0)
+
+ final_atom_positions = atom37_pred_positions
+ final_atom_mask = atom37_atom_exists
+ rp_structure_module = act_new
+ return final_atom_positions, final_atom_mask, rp_structure_module
+
+ def iteration_operation(self, act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act,
+ aatype):
+ '''iteration operation'''
+ affine_init = ()
+ angles_sin_cos_init = ()
+ um_angles_sin_cos_init = ()
+ atom_pos = ()
+ frames = ()
+
+ for _ in range(self.num_layer):
+ act, quaternion, translation, rotation, affine_output, angles_sin_cos, unnormalized_angles_sin_cos, \
+ atom_pos, frames = \
+ self.fold_iteration(act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype)
+ affine_init = affine_init + (affine_output[None, ...],)
+ angles_sin_cos_init = angles_sin_cos_init + (angles_sin_cos[None, ...],)
+ um_angles_sin_cos_init = um_angles_sin_cos_init + (unnormalized_angles_sin_cos[None, ...],)
+ atom_pos = get_exp_atom_pos(atom_pos)
+ frames = get_exp_frames(frames)
+ affine_output_new = mnp.concatenate(affine_init, axis=0)
+ angles_sin_cos_new = mnp.concatenate(angles_sin_cos_init, axis=0)
+ um_angles_sin_cos_new = mnp.concatenate(um_angles_sin_cos_init, axis=0)
+
+ return act, atom_pos, affine_output_new, angles_sin_cos_new, um_angles_sin_cos_new, frames
+
+
+class PredictedLDDTHead(nn.Cell):
+ """Head to predict the per-residue LDDT to be used as a confidence measure."""
+ def __init__(self, config, global_config, seq_channel):
+ super().__init__()
+ self.config = config
+ self.global_config = global_config
+ self.input_layer_norm = nn.LayerNorm([seq_channel,], epsilon=1e-5)
+ self.act_0 = nn.Dense(seq_channel, self.config.num_channels, weight_init='zeros'
+ ).to_float(mstype.float16)
+ self.act_1 = nn.Dense(self.config.num_channels, self.config.num_channels, weight_init='zeros'
+ ).to_float(mstype.float16)
+ self.logits = nn.Dense(self.config.num_channels, self.config.num_bins, weight_init='zeros'
+ ).to_float(mstype.float16)
+ self.relu = nn.ReLU()
+
+ def construct(self, rp_structure_module):
+ """Builds ExperimentallyResolvedHead module."""
+ act = rp_structure_module
+ act = self.input_layer_norm(act.astype(mstype.float32))
+ act = self.act_0(act)
+ act = self.relu(act.astype(mstype.float32))
+ act = self.act_1(act)
+ act = self.relu(act.astype(mstype.float32))
+ logits = self.logits(act)
+ return logits
diff --git a/reproduce/AlphaFold2-Chinese/requirements.txt b/reproduce/AlphaFold2-Chinese/requirements.txt
new file mode 100644
index 0000000..bfb18bd
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/requirements.txt
@@ -0,0 +1,5 @@
+absl-py==0.13.0
+biopython=1.79
+ml-collections==0.1.0
+scipy==1.7.0
+numpy==1.19.5
diff --git a/reproduce/AlphaFold2-Chinese/serving/fold_service/config.py b/reproduce/AlphaFold2-Chinese/serving/fold_service/config.py
new file mode 100644
index 0000000..91b8fe8
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/serving/fold_service/config.py
@@ -0,0 +1,34 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""config for serving mode"""
+
+import ml_collections
+
+config = ml_collections.ConfigDict({
+ "seq_length": 256,
+ "device_id": 0,
+ "port": 5500,
+ "ckpt_path": "/CHECKPOINT_PATH",
+ "input_fasta_path": "INPUT_FASTA_PATH",
+ "msa_result_path": "MSA_RESULT_PATH",
+ "database_dir": "DATABASE_DIR",
+ "database_envdb_dir": "DATABASE_ENVDB_DIR",
+ "hhsearch_binary_path": "HHSEARCH_BINARY_PATH",
+ "pdb70_database_path": 'PDB&)_DATABASE_PATH',
+ "template_mmcif_dir": 'TEMPLATE_MMCIF_DIR',
+ "max_template_date": "MAX_TEMPLATE_DATE",
+ "kalign_binary_path": 'KALIGN_BINARY_PATH',
+ "obsolete_pdbs_path": 'OBSOLETE_PDBS_PATH',
+})
diff --git a/reproduce/AlphaFold2-Chinese/serving/fold_service/servable_config.py b/reproduce/AlphaFold2-Chinese/serving/fold_service/servable_config.py
new file mode 100644
index 0000000..7d25345
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/serving/fold_service/servable_config.py
@@ -0,0 +1,104 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""serving config for mindspore serving"""
+
+import time
+import os
+import json
+import numpy as np
+
+from mindspore_serving.server import register
+import mindspore.context as context
+from mindspore.common.tensor import Tensor
+from mindspore import load_checkpoint
+
+from data.feature.feature_extraction import process_features
+from data.tools.data_process import data_process
+from commons.utils import compute_confidence
+from commons.generate_pdb import to_pdb, from_prediction
+from model import AlphaFold
+from config import config, global_config
+from fold_service.config import config as serving_config
+
+context.set_context(mode=context.GRAPH_MODE,
+ device_target="Ascend",
+ variable_memory_max_size="31GB",
+ device_id=serving_config.device_id,
+ save_graphs=False)
+model_name = "model_1"
+model_config = config.model_config(model_name)
+num_recycle = model_config.model.num_recycle
+global_config = global_config.global_config(serving_config.seq_length)
+extra_msa_length = global_config.extra_msa_length
+
+fold_net = AlphaFold(model_config, global_config)
+load_checkpoint(serving_config.ckpt_path, fold_net)
+
+def fold_model(input_fasta_path):
+ """defining fold model"""
+
+ seq_files = os.listdir(input_fasta_path)
+ for seq_file in seq_files:
+ print(seq_file)
+ t1 = time.time()
+ seq_name = seq_file.split('.')[0]
+
+ input_features = data_process(seq_name, serving_config)
+ tensors, aatype, residue_index, ori_res_length = process_features(
+ raw_features=input_features, config=model_config, global_config=global_config)
+ prev_pos = Tensor(np.zeros([global_config.seq_length, 37, 3]).astype(np.float16))
+ prev_msa_first_row = Tensor(np.zeros([global_config.seq_length, 256]).astype(np.float16))
+ prev_pair = Tensor(np.zeros([global_config.seq_length, global_config.seq_length, 128]).astype(np.float16))
+
+ t2 = time.time()
+ for i in range(num_recycle+1):
+ tensors_i = [tensor[i] for tensor in tensors]
+ input_feats = [Tensor(tensor) for tensor in tensors_i]
+ final_atom_positions, final_atom_mask, predicted_lddt_logits,\
+ prev_pos, prev_msa_first_row, prev_pair = fold_net(*input_feats,
+ prev_pos,
+ prev_msa_first_row,
+ prev_pair)
+
+ t3 = time.time()
+
+ final_atom_positions = final_atom_positions.asnumpy()[:ori_res_length]
+ final_atom_mask = final_atom_mask.asnumpy()[:ori_res_length]
+ predicted_lddt_logits = predicted_lddt_logits.asnumpy()[:ori_res_length]
+
+ confidence = compute_confidence(predicted_lddt_logits)
+ unrelaxed_protein = from_prediction(final_atom_mask, aatype[0], final_atom_positions, residue_index[0])
+ pdb_file = to_pdb(unrelaxed_protein)
+
+ seq_length = aatype.shape[-1]
+ os.makedirs(f'./result/seq_{seq_name}_{seq_length}', exist_ok=True)
+
+ with open(os.path.join(f'./result/seq_{seq_name}_{seq_length}/', f'unrelaxed_model_{seq_name}.pdb'), 'w') as f:
+ f.write(pdb_file)
+ t4 = time.time()
+ timings = {"pre_process_time": round(t2 - t1, 2),
+ "model_time": round(t3 - t2, 2),
+ "pos_process_time": round(t4 - t3, 2),
+ "all_time": round(t4 - t1, 2),
+ "confidence": confidence}
+ print(timings)
+ with open(f'./result/seq_{seq_name}_{seq_length}/timings', 'w') as f:
+ f.write(json.dumps(timings))
+ return True
+
+@register.register_method(output_names=["res"])
+def folding(input_fasta_path):
+ res = register.add_stage(fold_model, input_fasta_path, outputs_count=1)
+ return res
diff --git a/reproduce/AlphaFold2-Chinese/serving/serving_client.py b/reproduce/AlphaFold2-Chinese/serving/serving_client.py
new file mode 100644
index 0000000..e754cc4
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/serving/serving_client.py
@@ -0,0 +1,32 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""client script for serving mode"""
+
+import time
+
+from mindspore_serving.client import Client
+from fold_service.config import config as serving_config
+
+if __name__ == "__main__":
+
+ client = Client("127.0.0.1:" + str(serving_config.port), "fold_service", "folding")
+ instances = [{"input_fasta_path": serving_config.input_fasta_path}]
+
+ print("inferring...")
+ t1 = time.time()
+ result = client.infer(instances)
+ t2 = time.time()
+ print("finish inferring! Time costed:", t2 - t1)
+ print(result)
diff --git a/reproduce/AlphaFold2-Chinese/serving/serving_server.py b/reproduce/AlphaFold2-Chinese/serving/serving_server.py
new file mode 100644
index 0000000..87f18fd
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/serving/serving_server.py
@@ -0,0 +1,31 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""server script for serving mode"""
+
+import os
+from mindspore_serving import server
+from fold_service.config import config as serving_config
+
+def start():
+ servable_dir = os.path.dirname(os.path.realpath(__file__))
+
+ servable_config = server.ServableStartConfig(servable_directory=servable_dir, servable_name="fold_service",
+ device_ids=serving_config.device_id)
+ server.start_servables(servable_configs=servable_config)
+
+ server.start_grpc_server(address="127.0.0.1:" + str(serving_config.port))
+
+if __name__ == "__main__":
+ start()
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/__init__.py b/reproduce/AlphaFold2-Chinese/tests/st/__init__.py
new file mode 100644
index 0000000..6228b71
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/__init__.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/__init__.py
new file mode 100644
index 0000000..b3a552b
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""init"""
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/architecture/test_activation.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/architecture/test_activation.py
new file mode 100644
index 0000000..07d030d
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/architecture/test_activation.py
@@ -0,0 +1,85 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+""" test Activations """
+import pytest
+import numpy as np
+
+from mindspore import nn
+from mindspore import Tensor
+from mindspore import context
+from mindelec.architecture import get_activation
+
+context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
+
+
+class Net(nn.Cell):
+ def __init__(self):
+ super(Net, self).__init__()
+ self.srelu = get_activation("srelu")
+
+ def construct(self, x):
+ return self.srelu(x)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_srelu():
+ """test srelu activation"""
+ net = Net()
+ input_tensor = Tensor(np.array([[1.2, 0.1], [0.2, 3.2]], dtype=np.float32))
+ output = net(input_tensor)
+ print(input_tensor.asnumpy())
+ print(output.asnumpy())
+
+
+class Net1(nn.Cell):
+ """net"""
+ def __init__(self):
+ super(Net1, self).__init__()
+ self.sin = get_activation("sin")
+
+ def construct(self, x):
+ return self.sin(x)
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_sin():
+ """test sin activation"""
+ net = Net1()
+ input_tensor = Tensor(np.array([[1.2, 0.1], [0.2, 3.2]], dtype=np.float32))
+ output = net(input_tensor)
+ print(input_tensor.asnumpy())
+ print(output.asnumpy())
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_activation_type_error():
+ with pytest.raises(TypeError):
+ get_activation(1)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_get_activation():
+ activation = get_activation("softshrink")
+ assert isinstance(activation, nn.Cell)
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/architecture/test_block.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/architecture/test_block.py
new file mode 100644
index 0000000..28e968b
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/architecture/test_block.py
@@ -0,0 +1,224 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+""" test block """
+import pytest
+import numpy as np
+
+from mindspore import nn, context
+from mindspore import Tensor, Parameter
+
+from mindelec.architecture import LinearBlock, ResBlock
+from mindelec.architecture import InputScaleNet, FCSequential, MultiScaleFCCell
+
+context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
+
+
+class Net(nn.Cell):
+ """ Net definition """
+ def __init__(self,
+ input_channels,
+ output_channels,
+ weight='normal',
+ bias='zeros',
+ has_bias=True):
+ super(Net, self).__init__()
+ self.fc = LinearBlock(input_channels, output_channels, weight, bias, has_bias)
+
+ def construct(self, input_x):
+ return self.fc(input_x)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_linear():
+ """test linear block"""
+ weight = Tensor(np.random.randint(0, 255, [8, 64]).astype(np.float32))
+ bias = Tensor(np.random.randint(0, 255, [8]).astype(np.float32))
+ net = Net(64, 8, weight=weight, bias=bias)
+ input_data = Tensor(np.random.randint(0, 255, [128, 64]).astype(np.float32))
+ output = net(input_data)
+ print(output.asnumpy())
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_linear_nobias():
+ """test linear block with no bias"""
+ weight = Tensor(np.random.randint(0, 255, [8, 64]).astype(np.float32))
+ net = Net(64, 8, weight=weight, has_bias=False)
+ input_data = Tensor(np.random.randint(0, 255, [128, 64]).astype(np.float32))
+ output = net(input_data)
+ print(output.asnumpy())
+
+
+class Net1(nn.Cell):
+ """ Net definition """
+ def __init__(self,
+ input_channels,
+ output_channels,
+ weight='normal',
+ bias='zeros',
+ has_bias=True,
+ activation=None):
+ super(Net1, self).__init__()
+ self.fc = ResBlock(input_channels, output_channels, weight, bias, has_bias, activation)
+
+ def construct(self, input_x):
+ return self.fc(input_x)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_res():
+ """test res block"""
+ weight = Tensor(np.random.randint(0, 255, [8, 8]).astype(np.float32))
+ bias = Tensor(np.random.randint(0, 255, [8]).astype(np.float32))
+ net = Net1(8, 8, weight=weight, bias=bias)
+ input_data = Tensor(np.random.randint(0, 255, [128, 8]).astype(np.float32))
+ output = net(input_data)
+ print(output.asnumpy())
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_res_nobias():
+ """test res block with no bias"""
+ weight = Tensor(np.random.randint(0, 255, [8, 8]).astype(np.float32))
+ net = Net1(8, 8, weight=weight, has_bias=False)
+ input_data = Tensor(np.random.randint(0, 255, [128, 8]).astype(np.float32))
+ output = net(input_data)
+ print(output.asnumpy())
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_res_activation():
+ """test res block with activation"""
+ weight = Tensor(np.random.randint(0, 255, [8, 8]).astype(np.float32))
+ bias = Tensor(np.random.randint(0, 255, [8]).astype(np.float32))
+ net = Net1(8, 8, weight=weight, bias=bias, activation='sin')
+ input_data = Tensor(np.random.randint(0, 255, [128, 8]).astype(np.float32))
+ output = net(input_data)
+ print(output.asnumpy())
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_res_channel_error():
+ with pytest.raises(ValueError):
+ ResBlock(3, 6)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_input_scale():
+ """test input scale cell"""
+ inputs = np.random.uniform(size=(16, 3)) + 3.0
+ inputs = Tensor(inputs.astype(np.float32))
+ input_scale = [1.0, 2.0, 4.0]
+ input_center = [3.5, 3.5, 3.5]
+ net = InputScaleNet(input_scale, input_center)
+ output = net(inputs).asnumpy()
+
+ assert np.all(output[:, 0] <= 0.5) and np.all(output[:, 0] >= -0.5)
+ assert np.all(output[:, 1] <= 1.0) and np.all(output[:, 0] >= -1.0)
+ assert np.all(output[:, 2] <= 2.0) and np.all(output[:, 0] >= -2.0)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_fc_sequential():
+ """test fc sequential cell"""
+ inputs = np.ones((16, 3))
+ inputs = Tensor(inputs.astype(np.float32))
+ net = FCSequential(3, 3, 5, 32, weight_init="ones", bias_init="zeros")
+ output = net(inputs).asnumpy()
+ target = np.ones((16, 3)) * -31.998459
+ assert np.allclose(output, target, rtol=5e-2)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_mulscale_without_latent():
+ """test multi-scale net without latent vector"""
+ inputs = np.ones((16, 3)) + 3.0
+ inputs = Tensor(inputs.astype(np.float32))
+ input_scale = [1.0, 2.0, 4.0]
+ input_center = [3.5, 3.5, 3.5]
+ net = MultiScaleFCCell(3, 3, 5, 32,
+ weight_init="ones", bias_init="zeros",
+ input_scale=input_scale, input_center=input_center)
+ output = net(inputs).asnumpy()
+ target = np.ones((16, 3)) * -61.669254
+ assert np.allclose(output, target, rtol=5e-2)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_mulscale_with_latent():
+ """test multi-scale net with latent vector and input scale"""
+ inputs = np.ones((64, 3)) + 3.0
+ inputs = Tensor(inputs.astype(np.float32))
+ num_scenarios = 4
+ latent_size = 16
+ latent_init = np.ones((num_scenarios, latent_size)).astype(np.float32)
+ latent_vector = Parameter(Tensor(latent_init), requires_grad=True)
+ input_scale = [1.0, 2.0, 4.0]
+ input_center = [3.5, 3.5, 3.5]
+ net = MultiScaleFCCell(3, 3, 5, 32,
+ weight_init="ones", bias_init="zeros",
+ input_scale=input_scale, input_center=input_center, latent_vector=latent_vector)
+ output = net(inputs).asnumpy()
+ target = np.ones((64, 3)) * -57.8849
+ assert np.allclose(output, target, rtol=5e-2)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_mulscale_with_latent_noscale():
+ """test multi-scale net with latent vector"""
+ inputs = np.ones((64, 3))
+ inputs = Tensor(inputs.astype(np.float32))
+ num_scenarios = 4
+ latent_size = 16
+ latent_init = np.ones((num_scenarios, latent_size)).astype(np.float32)
+ latent_vector = Parameter(Tensor(latent_init), requires_grad=True)
+ net = MultiScaleFCCell(3, 3, 5, 32,
+ weight_init="ones", bias_init="zeros", latent_vector=latent_vector)
+ output = net(inputs).asnumpy()
+ target = np.ones((64, 3)) * -105.62799
+ assert np.allclose(output, target, rtol=5e-2)
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/architecture/test_mlt.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/architecture/test_mlt.py
new file mode 100644
index 0000000..84b909b
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/architecture/test_mlt.py
@@ -0,0 +1,43 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+""" test block """
+import pytest
+import numpy as np
+
+from mindspore import context
+from mindspore import Tensor
+from mindelec.architecture import MTLWeightedLossCell
+
+context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_linear():
+ net = MTLWeightedLossCell(num_losses=2)
+ input_data = Tensor(np.array([1.0, 1.0]).astype(np.float32))
+ output = net(input_data)
+ print(output.asnumpy())
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_mlt_num_losses_error():
+ with pytest.raises(TypeError):
+ MTLWeightedLossCell('a')
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/common/test_lr_scheduler.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/common/test_lr_scheduler.py
new file mode 100644
index 0000000..729d7bb
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/common/test_lr_scheduler.py
@@ -0,0 +1,92 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+""" test metrics """
+import pytest
+from mindspore import context
+from mindspore.common.tensor import Tensor
+from mindspore.common import dtype as mstype
+from mindelec.common import LearningRate, get_poly_lr
+
+context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_learning_rate():
+ """test LearningRate"""
+ learning_rate = LearningRate(0.1, 0.001, 0, 10, 0.5)
+ res = learning_rate(Tensor(10000, mstype.int32))
+ assert res == 0.001
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_learning_rate_power_value_error():
+ with pytest.raises(ValueError):
+ LearningRate(0.1, 0.001, 0, 10, -0.5)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_learning_rate_warmup_steps_type_error():
+ """test TypeError cases"""
+ with pytest.raises(TypeError):
+ LearningRate(0.1, 0.001, 0.1, 10, 0.5)
+ with pytest.raises(ValueError):
+ LearningRate(0.0, 0.001, 0, 10, 0.5)
+ with pytest.raises(ValueError):
+ LearningRate(0.1, -0.001, 0, 10, 0.5)
+ with pytest.raises(ValueError):
+ LearningRate(0.1, 0.001, 0, 0, 0.5)
+ with pytest.raises(ValueError):
+ LearningRate(0.1, 0.001, -1, 10, 0.5)
+ with pytest.raises(ValueError):
+ LearningRate(0.1, 0.001, 0, -10, 0.5)
+ with pytest.raises(ValueError):
+ LearningRate(0.1, 0.001, 1, -10, 0.5)
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_get_poly_lr():
+ """test get_poly_lr"""
+ res = get_poly_lr(100, 0.001, 0.1, 0.0001, 1000, 10000, 0.5)
+ assert res.shape == (9900,)
+ with pytest.raises(ValueError):
+ get_poly_lr(-1, 0.001, 0.1, 0.0001, 1000, 10000, 0.5)
+ with pytest.raises(ValueError):
+ get_poly_lr(100, 0.0, 0.1, 0.0001, 1000, 10000, 0.5)
+ with pytest.raises(ValueError):
+ get_poly_lr(100, 0.001, 0.1, 0.0, 1000, 10000, 0.5)
+ with pytest.raises(ValueError):
+ get_poly_lr(100, 0.001, 0.1, 0.0001, 1000, 0, 0.5)
+ with pytest.raises(ValueError):
+ get_poly_lr(100, 0.001, 0.1, 0.0001, 1000, 10000, -0.5)
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_get_poly_lr1():
+ """test get_poly_lr"""
+ res = get_poly_lr(100, 0.001, 0.1, 0.0001, 0, 10000, 0.5)
+ assert res.shape == (9900,)
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/common/test_metrics.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/common/test_metrics.py
new file mode 100644
index 0000000..8cfcd87
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/common/test_metrics.py
@@ -0,0 +1,39 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+""" test metrics """
+import pytest
+import numpy as np
+
+import mindspore
+from mindspore import Tensor
+from mindspore import context
+from mindelec.common import L2
+
+context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_l2():
+ """test l2"""
+ x = Tensor(np.array([0.1, 0.2, 0.6, 0.9]), mindspore.float32)
+ y = Tensor(np.array([0.1, 0.25, 0.7, 0.9]), mindspore.float32)
+ metric = L2()
+ metric.clear()
+ metric.update(x, y)
+ result = metric.eval()
+ assert result == 0.09543302997807275
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/config.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/config.py
new file mode 100644
index 0000000..c6bfcd6
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/config.py
@@ -0,0 +1,82 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+sampling and dataset settings
+"""
+
+from easydict import EasyDict as edict
+
+ds_config = edict({
+ 'train': edict({
+ 'batch_size': 100,
+ 'shuffle': True,
+ 'drop_remainder': True,
+ }),
+ 'eval': edict({
+ 'batch_size': 100,
+ 'shuffle': False,
+ 'drop_remainder': False,
+ }),
+})
+
+src_sampling_config = edict({
+ 'domain': edict({
+ 'random_sampling': True,
+ 'size': 400,
+ 'sampler': 'uniform'
+ }),
+ 'IC': edict({
+ 'random_sampling': True,
+ 'size': 100,
+ 'sampler': 'uniform',
+ }),
+ 'time': edict({
+ 'random_sampling': True,
+ 'size': 10,
+ 'sampler': 'uniform',
+ }),
+})
+
+no_src_sampling_config = edict({
+ 'domain': edict({
+ 'random_sampling': True,
+ 'size': 100,
+ 'sampler': 'uniform'
+ }),
+ 'IC': edict({
+ 'random_sampling': True,
+ 'size': 100,
+ 'sampler': 'uniform',
+ }),
+ 'time': edict({
+ 'random_sampling': True,
+ 'size': 10,
+ 'sampler': 'uniform',
+ }),
+})
+
+bc_sampling_config = edict({
+ 'BC': edict({
+ 'random_sampling': True,
+ 'size': 100,
+ 'sampler': 'uniform',
+ 'with_normal': True
+ }),
+ 'time': edict({
+ 'random_sampling': True,
+ 'size': 10,
+ 'sampler': 'uniform',
+ }),
+})
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_boundary.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_boundary.py
new file mode 100644
index 0000000..fd4698e
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_boundary.py
@@ -0,0 +1,110 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""test geometry module: geometry with time cases"""
+
+import copy
+import pytest
+from easydict import EasyDict as edict
+
+from mindelec.geometry import create_config_from_edict
+from mindelec.geometry import Rectangle, GeometryWithTime, TimeDomain
+from mindelec.data import Boundary, BoundaryBC, BoundaryIC
+
+reset_geom_time_config = edict({
+ 'domain': edict({
+ 'random_sampling': False,
+ 'size': [10, 20],
+ }),
+ 'BC': edict({
+ 'random_sampling': True,
+ 'size': 10,
+ 'with_normal': True,
+ }),
+ 'IC': edict({
+ 'random_sampling': True,
+ 'size': 10,
+ }),
+ 'time': edict({
+ 'random_sampling': False,
+ 'size': 10,
+ })
+})
+
+reset_geom_time_config2 = edict({
+ 'domain': edict({
+ 'random_sampling': False,
+ 'size': [10, 20],
+ }),
+ 'BC': edict({
+ 'random_sampling': False,
+ 'size': 10,
+ 'with_normal': False,
+ }),
+ 'IC': edict({
+ 'random_sampling': False,
+ 'size': 10,
+ }),
+ 'time': edict({
+ 'random_sampling': False,
+ 'size': 10,
+ })
+})
+
+def check_rect_with_time_set_config(config):
+ """check_rect_with_time_set_config"""
+ rect = Rectangle("rect", [-1.0, -0.5], [1.0, 0.5])
+ time = TimeDomain("time", 0.0, 1.0)
+ rect_with_time = GeometryWithTime(rect, time)
+
+ with pytest.raises(TypeError):
+ Boundary(0)
+ with pytest.raises(ValueError):
+ Boundary(rect_with_time)
+
+ config1 = copy.deepcopy(config)
+ config1.pop('BC')
+ rect_with_time.set_sampling_config(create_config_from_edict(config1))
+ with pytest.raises(ValueError):
+ BoundaryBC(rect_with_time)
+
+ config2 = copy.deepcopy(config)
+ config2.pop('IC')
+ rect_with_time.set_sampling_config(create_config_from_edict(config2))
+ with pytest.raises(ValueError):
+ BoundaryIC(rect_with_time)
+
+ rect_with_time.set_sampling_config(create_config_from_edict(config))
+
+ bc = BoundaryBC(rect_with_time)
+ for i in range(20):
+ print(bc[i])
+
+ config3 = copy.deepcopy(config)
+ if not config3.IC.random_sampling:
+ config3.IC.size = [4, 4]
+ rect_with_time.set_sampling_config(create_config_from_edict(config3))
+ ic = BoundaryIC(rect_with_time)
+ for i in range(20):
+ print(ic[i])
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_rect_with_time_set_config():
+ """test_check_rect_with_time_set_config"""
+ check_rect_with_time_set_config(reset_geom_time_config)
+ check_rect_with_time_set_config(reset_geom_time_config2)
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_data_base.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_data_base.py
new file mode 100644
index 0000000..e51e62f
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_data_base.py
@@ -0,0 +1,195 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Test data_base."""
+import pytest
+import numpy as np
+from mindelec.data import ExistedDataConfig
+from mindelec.data.data_base import Data
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_existed_data_config_name_error():
+ with pytest.raises(TypeError):
+ ExistedDataConfig(1, "/home/a.npy", "data")
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_existed_data_config_data_dir_error():
+ with pytest.raises(TypeError):
+ ExistedDataConfig("a", 1, "data")
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_existed_data_config_column_list_error():
+ with pytest.raises(TypeError):
+ input_path = "./input_data.npy"
+ input_data = np.random.randn(10, 3)
+ np.save(input_path, input_data)
+ ExistedDataConfig("a", input_path, 1)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_existed_data_config_constraint_type_error():
+ with pytest.raises(TypeError):
+ input_path = "./input_data.npy"
+ input_data = np.random.randn(10, 3)
+ np.save(input_path, input_data)
+ ExistedDataConfig("a", input_path, "data", contraint_type="a")
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_existed_data_config_data_format_typeerror():
+ with pytest.raises(TypeError):
+ input_path = "./input_data.npy"
+ input_data = np.random.randn(10, 3)
+ np.save(input_path, input_data)
+ ExistedDataConfig("a", input_path, "data", 1)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_existed_data_config_data_format_valueerror():
+ with pytest.raises(ValueError):
+ input_path = "./input_data.npy"
+ input_data = np.random.randn(10, 3)
+ np.save(input_path, input_data)
+ ExistedDataConfig("a", input_path, "data", "csv")
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_existed_data_config_data_constraint_type_valueerror():
+ with pytest.raises(TypeError):
+ input_path = "./input_data.npy"
+ input_data = np.random.randn(10, 3)
+ np.save(input_path, input_data)
+ ExistedDataConfig("a", input_path, "data", constraint_type='1')
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_existed_data_config_data_constraint_type_valueerror1():
+ with pytest.raises(TypeError):
+ input_path = "./input_data.npy"
+ input_data = np.random.randn(10, 3)
+ np.save(input_path, input_data)
+ ExistedDataConfig("a", input_path, "data", constraint_type='test')
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_existed_data_config_random_merge_error():
+ with pytest.raises(TypeError):
+ input_path = "./input_data.npy"
+ input_data = np.random.randn(10, 3)
+ np.save(input_path, input_data)
+ ExistedDataConfig("a", input_path, "data", random_merge=1)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_data_name_type_error():
+ with pytest.raises(TypeError):
+ Data(1)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_data_columns_list_type_error():
+ with pytest.raises(TypeError):
+ Data(columns_list=1)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_data_constraint_type_type_error():
+ with pytest.raises(TypeError):
+ Data(constraint_type=1)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_data_constraint_type_type_error1():
+ with pytest.raises(TypeError):
+ Data(constraint_type="labe")
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_data_set_constraint_type_type_error():
+ with pytest.raises(TypeError):
+ data = Data()
+ data.set_constraint_type("test")
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_data_create_dataset_nie_error():
+ with pytest.raises(NotImplementedError):
+ data = Data()
+ data.create_dataset()
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_data_get_item_nie_error():
+ with pytest.raises(NotImplementedError):
+ data = Data()
+ print(data[0])
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_data_len_nie_error():
+ with pytest.raises(NotImplementedError):
+ data = Data()
+ len(data)
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_dataset.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_dataset.py
new file mode 100644
index 0000000..ede8a72
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_dataset.py
@@ -0,0 +1,112 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""test dataset module"""
+import pytest
+from easydict import EasyDict as edict
+from mindelec.data import Dataset
+from mindelec.geometry import create_config_from_edict
+from mindelec.geometry import Disk, Rectangle, TimeDomain, GeometryWithTime
+from mindelec.data import BoundaryBC, BoundaryIC
+from config import ds_config, src_sampling_config, no_src_sampling_config, bc_sampling_config
+
+ic_bc_config = edict({
+ 'domain': edict({
+ 'random_sampling': False,
+ 'size': [10, 20],
+ }),
+ 'BC': edict({
+ 'random_sampling': True,
+ 'size': 10,
+ 'with_normal': True,
+ }),
+ 'IC': edict({
+ 'random_sampling': True,
+ 'size': 10,
+ }),
+ 'time': edict({
+ 'random_sampling': False,
+ 'size': 10,
+ })
+})
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_dataset_allnone():
+ with pytest.raises(ValueError):
+ Dataset()
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_dataset():
+ """test dataset"""
+ disk = Disk("src", (0.0, 0.0), 0.2)
+ rectangle = Rectangle("rect", (-1, -1), (1, 1))
+ diff = rectangle - disk
+ time = TimeDomain("time", 0.0, 1.0)
+
+ # check datalist
+ rect_with_time = GeometryWithTime(rectangle, time)
+ rect_with_time.set_sampling_config(create_config_from_edict(ic_bc_config))
+ bc = BoundaryBC(rect_with_time)
+ ic = BoundaryIC(rect_with_time)
+ dataset = Dataset(dataset_list=bc)
+
+ dataset.set_constraint_type("Equation")
+
+ c_type1 = {bc: "Equation", ic: "Equation"}
+ with pytest.raises(ValueError):
+ dataset.set_constraint_type(c_type1)
+
+ no_src_region = GeometryWithTime(diff, time)
+ no_src_region.set_name("no_src")
+ no_src_region.set_sampling_config(create_config_from_edict(no_src_sampling_config))
+ src_region = GeometryWithTime(disk, time)
+ src_region.set_name("src")
+ src_region.set_sampling_config(create_config_from_edict(src_sampling_config))
+ boundary = GeometryWithTime(rectangle, time)
+ boundary.set_name("bc")
+ boundary.set_sampling_config(create_config_from_edict(bc_sampling_config))
+
+ geom_dict = ['1', '2']
+ with pytest.raises(TypeError):
+ Dataset(geom_dict)
+
+ geom_dict = {src_region: ["test"]}
+ with pytest.raises(KeyError):
+ Dataset(geom_dict)
+
+ geom_dict = {src_region: ["domain", "IC"],
+ no_src_region: ["domain", "IC"],
+ boundary: ["BC"]}
+ dataset = Dataset(geom_dict)
+
+ with pytest.raises(ValueError):
+ print(dataset[0])
+
+ with pytest.raises(ValueError):
+ len(dataset)
+
+ with pytest.raises(ValueError):
+ dataset.get_columns_list()
+
+ with pytest.raises(ValueError):
+ dataset.create_dataset(batch_size=ds_config.train.batch_size,
+ shuffle=ds_config.train.shuffle,
+ prebatched_data=True,
+ drop_remainder=False)
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_equation.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_equation.py
new file mode 100644
index 0000000..d88d71c
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_equation.py
@@ -0,0 +1,77 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""test geometry module: geometry with time cases"""
+
+import copy
+import pytest
+from easydict import EasyDict as edict
+
+from mindelec.geometry import create_config_from_edict
+from mindelec.geometry import Rectangle, GeometryWithTime, TimeDomain
+from mindelec.data import Equation
+
+reset_geom_time_config = edict({
+ 'domain': edict({
+ 'random_sampling': False,
+ 'size': [5, 2],
+ }),
+ 'time': edict({
+ 'random_sampling': False,
+ 'size': 10,
+ })
+})
+
+reset_geom_time_config2 = edict({
+ 'domain': edict({
+ 'random_sampling': True,
+ 'size': 10,
+ }),
+ 'time': edict({
+ 'random_sampling': False,
+ 'size': 10,
+ })
+})
+
+def check_rect_with_time_set_config(config):
+ """check_rect_with_time_set_config"""
+ rect = Rectangle("rect", [-1.0, -0.5], [1.0, 0.5])
+ time = TimeDomain("time", 0.0, 1.0)
+ rect_with_time = GeometryWithTime(rect, time)
+
+ with pytest.raises(TypeError):
+ Equation(0)
+ with pytest.raises(ValueError):
+ Equation(rect_with_time)
+
+ config1 = copy.deepcopy(config)
+ config1.pop('domain')
+ rect_with_time.set_sampling_config(create_config_from_edict(config1))
+ with pytest.raises(KeyError):
+ Equation(rect_with_time)
+
+ rect_with_time.set_sampling_config(create_config_from_edict(config))
+ eq = Equation(rect_with_time)
+ for i in range(20):
+ print(eq[i])
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_rect_with_time_set_config():
+ """test_check_rect_with_time_set_config"""
+ check_rect_with_time_set_config(reset_geom_time_config)
+ check_rect_with_time_set_config(reset_geom_time_config2)
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_existed_data.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_existed_data.py
new file mode 100644
index 0000000..4de5521
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_existed_data.py
@@ -0,0 +1,77 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Test existed_data."""
+import pytest
+import numpy as np
+from mindelec.data import ExistedDataset, ExistedDataConfig
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_existed_data_value_error():
+ with pytest.raises(ValueError):
+ ExistedDataset()
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_existed_data_cnfig_type_error():
+ with pytest.raises(TypeError):
+ ExistedDataset(data_config=1)
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_existed_config():
+ """test various errors"""
+ with pytest.raises(ValueError):
+ input_path = "./input_data.npy"
+ label_path = "./label.npy"
+ input_data = np.random.randn(10, 3)
+ output = np.random.randn(10, 3)
+ np.save(input_path, input_data)
+ np.save(label_path, output)
+
+ config = ExistedDataConfig(name="existed_data",
+ data_dir=[input_path, label_path],
+ columns_list=["inputs", "label"],
+ constraint_type="Equation",
+ data_format="npz")
+ dataset = ExistedDataset(data_config=config)
+
+ input_path = "./input_data.npy"
+ input_data = np.random.randn(10, 3)
+ np.save(input_path, input_data)
+
+ dataset = ExistedDataset(name="existed_data",
+ data_dir=[input_path],
+ columns_list=["inputs"],
+ constraint_type="Equation",
+ data_format="npy")
+ for i in range(20):
+ print(dataset[i])
+
+ dataset = ExistedDataset(name="existed_data",
+ data_dir=[input_path],
+ columns_list=["inputs"],
+ constraint_type="Equation",
+ data_format="npy",
+ random_merge=False)
+ for i in range(20):
+ print(dataset[i])
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_src_td.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_src_td.py
new file mode 100644
index 0000000..916e3cf
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/data/test_src_td.py
@@ -0,0 +1,84 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""test dataset module, call CSG ant GeometryWithTime"""
+import pytest
+import numpy as np
+from easydict import EasyDict as edict
+
+from mindelec.data import Dataset
+from mindelec.geometry import Rectangle, TimeDomain, GeometryWithTime, create_config_from_edict
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_dataset():
+ """check dataset"""
+ sampling_config = edict({
+ 'domain': edict({
+ 'random_sampling': False,
+ 'size': [10, 10],
+ }),
+ 'BC': edict({
+ 'random_sampling': False,
+ 'size': 100,
+ 'with_normal': True,
+ }),
+ 'IC': edict({
+ 'random_sampling': False,
+ 'size': [10, 10],
+ }),
+ 'time': edict({
+ 'random_sampling': False,
+ 'size': 40,
+ }),
+ })
+
+ rect_space = Rectangle("rectangle", coord_min=[0, 0], coord_max=[10, 10])
+ time = TimeDomain("time", 0.0, 20)
+ grid = GeometryWithTime(rect_space, time)
+ grid.set_sampling_config(create_config_from_edict(sampling_config))
+ grid.set_name("grid")
+ geom_dict = {grid: ["domain", "BC", "IC"]}
+ dataset = Dataset(geom_dict)
+
+ def preprocess_fn(*data):
+ bc_data = data[1]
+ bc_normal = np.ones(bc_data.shape)
+ return data[0], data[1], bc_normal, data[2], data[3]
+
+ colunms_map = {"grid_domain_points": ["grid_domain_points"],
+ "grid_BC_points": ["grid_BC_points", "grid_BC_tangent"],
+ "grid_BC_normal": "grid_BC_normal",
+ "grid_IC_points": "grid_IC_points"}
+ train_data = dataset.create_dataset(batch_size=8192, shuffle=False, drop_remainder=False,
+ preprocess_fn=preprocess_fn,
+ input_output_columns_map=colunms_map)
+ dataset.set_constraint_type({dataset.all_datasets[0]: "Equation",
+ dataset.all_datasets[1]: "BC",
+ dataset.all_datasets[2]: "IC"})
+ print("get merged data: {}".format(dataset[5]))
+ for sub_data in dataset.all_datasets:
+ print("get data: {}".format(sub_data[5]))
+
+ dataset_iter = train_data.create_dict_iterator(num_epochs=1)
+ np.set_printoptions(threshold=np.inf)
+ for _ in range(1):
+ for data in dataset_iter:
+ for k, v in data.items():
+ print("key: ", k)
+ print(v)
+ break
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_1d.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_1d.py
new file mode 100644
index 0000000..0b15b90
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_1d.py
@@ -0,0 +1,133 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""test geometry module: 1d cases"""
+import pytest
+from easydict import EasyDict as edict
+
+from mindelec.geometry import create_config_from_edict
+from mindelec.geometry import Interval, TimeDomain
+
+line_config_out = edict({
+ 'domain': edict({
+ 'random_sampling': True,
+ 'size': 100,
+ 'sampler': 'uniform'
+ }),
+ 'BC': edict({
+ 'random_sampling': True,
+ 'size': 10,
+ 'sampler': 'uniform',
+ }),
+})
+
+
+line_config_out2 = edict({
+ 'domain': edict({
+ 'random_sampling': True,
+ 'size': 100,
+ 'sampler': 'uniform'
+ }),
+ 'BC': edict({
+ 'random_sampling': False,
+ 'size': 10,
+ 'sampler': 'uniform',
+ }),
+})
+
+def check_line_interval_case1(line_config):
+ """check_line_interval_case1"""
+ try:
+ Interval("line", 'test', 1.0, sampling_config=create_config_from_edict(line_config))
+ except ValueError:
+ return
+ line = Interval("line", -1.0, 1.0, sampling_config=create_config_from_edict(line_config))
+
+ domain = line.sampling(geom_type="domain")
+ bc = line.sampling(geom_type="BC")
+ try:
+ line.sampling(geom_type="other")
+ except ValueError:
+ print("get ValueError when sampling other data")
+
+ # uniform sampling
+ if "BC" in line_config.keys():
+ line_config.BC = None
+ if "domain" in line_config.keys():
+ line_config.domain.random_sampling = False
+ line.set_sampling_config(create_config_from_edict(line_config))
+ domain = line.sampling(geom_type="domain")
+ try:
+ line.sampling(geom_type="BC")
+ except KeyError:
+ print("get ValueError when sampling BC data")
+
+ # lhs, halton, sobol
+ for samplers in ["lhs", "halton", "sobol"]:
+ if "domain" in line_config.keys():
+ line_config.domain.random_sampling = True
+ line_config.domain.sampler = samplers
+ line.set_sampling_config(create_config_from_edict(line_config))
+ domain = line.sampling(geom_type="domain")
+ print(domain, bc)
+
+time_config = edict({
+ 'domain': edict({
+ 'random_sampling': True,
+ 'size': 100,
+ 'sampler': 'lhs'
+ })
+})
+
+
+def check_time_interval(line_config):
+ """check_time_interval"""
+ try:
+ create_config_from_edict({"test": "test"})
+ except ValueError:
+ return
+
+ line = TimeDomain("time", 0.0, 1.0, sampling_config=create_config_from_edict(line_config))
+ domain = line.sampling(geom_type="domain")
+ try:
+ line.sampling(geom_type="BC")
+ except KeyError:
+ print("get ValueError when sampling BC data")
+ print(domain)
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_time_interval():
+ """test_check_time_interval"""
+ check_time_interval(time_config)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_line_interval_case1():
+ """test_check_time_interval"""
+ check_line_interval_case1(line_config_out)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_line_interval_case2():
+ """test_check_time_interval"""
+ check_line_interval_case1(line_config_out2)
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_2d.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_2d.py
new file mode 100644
index 0000000..76999a2
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_2d.py
@@ -0,0 +1,247 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""test geometry module: 2d cases"""
+import pytest
+from easydict import EasyDict as edict
+
+from mindelec.geometry import create_config_from_edict
+from mindelec.geometry import Rectangle, Disk
+
+disk_random = edict({
+ 'domain': edict({
+ 'random_sampling': True,
+ 'size': 1000,
+ 'sampler': 'uniform'
+ }),
+ 'BC': edict({
+ 'random_sampling': True,
+ 'size': 200,
+ 'sampler': 'uniform',
+ 'with_normal': False,
+ }),
+})
+
+disk_mesh = edict({
+ 'domain': edict({
+ 'random_sampling': False,
+ 'size': [100, 180],
+ }),
+ 'BC': edict({
+ 'random_sampling': False,
+ 'size': [20, 10],
+ 'with_normal': False,
+ }),
+})
+
+disk_mesh_wrong_meshsize = edict({
+ 'domain': edict({
+ 'random_sampling': False,
+ 'size': [100, 180, 200],
+ }),
+ 'BC': edict({
+ 'random_sampling': False,
+ 'size': 200,
+ 'with_normal': False,
+ }),
+})
+
+disk_mesh_nodomain = edict({
+ 'BC': edict({
+ 'random_sampling': False,
+ 'size': 200,
+ 'with_normal': False,
+ }),
+})
+
+disk_mesh_nobc = edict({
+ 'domain': edict({
+ 'random_sampling': True,
+ 'size': 1000,
+ 'sampler': 'uniform'
+ }),
+})
+
+
+def check_disk_random(disk_config):
+ """check_disk_random"""
+ with pytest.raises(ValueError):
+ disk = Disk("disk", (-1.0, 0), -2.0)
+ with pytest.raises(ValueError):
+ disk = Disk("disk", (-1.0, 0, 3), 2.0)
+
+ disk = Disk("disk", (-1.0, 0), 2.0)
+ for samplers in ["uniform", "lhs", "halton", "sobol"]:
+ print("check random sampler: {}".format(samplers))
+ if "domain" in disk_config.keys():
+ disk_config.domain.sampler = samplers
+ if "BC" in disk_config.keys():
+ disk_config.BC.sampler = samplers
+
+ try:
+ disk.sampling(geom_type="domain")
+ except ValueError:
+ return
+
+ disk.set_sampling_config(create_config_from_edict(disk_config))
+ with pytest.raises(ValueError):
+ disk.sampling(geom_type="test")
+
+ domain = disk.sampling(geom_type="domain")
+
+ bc = disk.sampling(geom_type="BC")
+ disk.sampling_config.bc.with_normal = True
+ bc, bc_normal = disk.sampling(geom_type="BC")
+ print(bc, bc_normal, domain)
+
+
+def check_disk_mesh(disk_config):
+ """check_disk_mesh"""
+ disk = Disk("disk", (-1.0, 0), 2.0, sampling_config=create_config_from_edict(disk_config))
+ domain = disk.sampling(geom_type="domain")
+
+ bc = disk.sampling(geom_type="BC")
+ disk.sampling_config.bc.with_normal = True
+ bc, bc_normal = disk.sampling(geom_type="BC")
+ print(bc, bc_normal, domain)
+
+
+rectangle_random = edict({
+ 'domain': edict({
+ 'random_sampling': True,
+ 'size': 1000,
+ 'sampler': 'uniform'
+ }),
+ 'BC': edict({
+ 'random_sampling': True,
+ 'size': 200,
+ 'sampler': 'uniform',
+ 'with_normal': True,
+ }),
+})
+
+rectangle_mesh = edict({
+ 'domain': edict({
+ 'random_sampling': False,
+ 'size': [50, 25],
+ }),
+ 'BC': edict({
+ 'random_sampling': False,
+ 'size': 300,
+ 'with_normal': True,
+ }),
+})
+
+
+def check_rectangle_random(config):
+ """check_rectangle_random"""
+ rectangle = Rectangle("rectangle", (-3.0, 1), (1, 2))
+ for samplers in ["uniform", "lhs", "halton", "sobol"]:
+ print("check random sampler: {}".format(samplers))
+ if "domain" in config.keys():
+ config.domain.sampler = samplers
+ if "BC" in config.keys():
+ config.BC.sampler = samplers
+ config.BC.with_normal = True
+ rectangle.set_sampling_config(create_config_from_edict(config))
+ domain = rectangle.sampling(geom_type="domain")
+ bc, bc_normal = rectangle.sampling(geom_type="BC")
+
+ if "BC" in config.keys():
+ config.BC.with_normal = False
+ rectangle.set_sampling_config(create_config_from_edict(config))
+ bc = rectangle.sampling(geom_type="BC")
+ print(bc, bc_normal, domain)
+
+
+def check_rectangle_mesh(config):
+ """check_rectangle_mesh"""
+ rectangle = Rectangle("rectangle", (-3.0, 1), (1, 2))
+ if "BC" in config.keys():
+ config.BC.with_normal = True
+ rectangle.set_sampling_config(create_config_from_edict(config))
+ domain = rectangle.sampling(geom_type="domain")
+ bc, bc_normal = rectangle.sampling(geom_type="BC")
+
+ if "BC" in config.keys():
+ config.BC.with_normal = False
+ rectangle.set_sampling_config(create_config_from_edict(config))
+ bc = rectangle.sampling(geom_type="BC")
+ print(bc, bc_normal, domain)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_disk_random():
+ """test_check_disk_random"""
+ check_disk_random(disk_random)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_disk_mesh():
+ """test_check_disk_mesh"""
+ check_disk_mesh(disk_mesh)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_disk_mesh_wrong_meshsize_error():
+ """test_check_disk_mesh"""
+ with pytest.raises(ValueError):
+ check_disk_mesh(disk_mesh_wrong_meshsize)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_disk_mesh_nodomain_error():
+ """test_check_disk_mesh"""
+ with pytest.raises(KeyError):
+ check_disk_mesh(disk_mesh_nodomain)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_disk_mesh_nobc_error():
+ """test_check_disk_mesh"""
+ with pytest.raises(KeyError):
+ check_disk_mesh(disk_mesh_nobc)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_rectangle_random():
+ """test_check_rectangle_random"""
+ check_rectangle_random(rectangle_random)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_rectangle_mesh():
+ """test_check_rectangle_mesh"""
+ check_rectangle_mesh(rectangle_mesh)
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_base.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_base.py
new file mode 100644
index 0000000..40eabd0
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_base.py
@@ -0,0 +1,264 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""test geometry module: base classes"""
+import pytest
+import numpy as np
+from easydict import EasyDict as edict
+
+from mindelec.geometry import Geometry, PartSamplingConfig, SamplingConfig, create_config_from_edict
+
+def check_create_config_from_edict():
+ try:
+ sampling_config = ["geom", "IC", "BC"]
+ config = create_config_from_edict(sampling_config)
+ print("check config: {}".format(config))
+ except TypeError:
+ print("get sampling config type error")
+
+
+def check_part_sampling_config(size, random_sampling, sampler, random_merge, with_normal):
+ try:
+ config = PartSamplingConfig(size, random_sampling, sampler, random_merge, with_normal)
+ print("check config: {}".format(config.__dict__))
+ except TypeError:
+ print("get TypeError")
+
+
+temp_config = edict({
+ 'domain': edict({
+ 'random_sampling': True,
+ 'size': 100,
+ }),
+ 'BC': edict({
+ 'random_sampling': True,
+ 'size': 100,
+ 'sampler': 'uniform',
+ 'random_merge': True,
+ }),
+})
+
+
+def check_sampling_config_case1():
+ """check_sampling_config_case1"""
+ with pytest.raises(TypeError):
+ SamplingConfig("test")
+ with pytest.raises(KeyError):
+ SamplingConfig({"test": "test"})
+ with pytest.raises(TypeError):
+ SamplingConfig({"domain": "test"})
+ with pytest.raises(TypeError):
+ part_sampling_config_dict = {"domain": PartSamplingConfig("test", False, True)}
+ SamplingConfig(part_sampling_config_dict)
+
+ part_sampling_config_dict = {"domain": PartSamplingConfig([100, 100], False, "uniform", True, True),
+ "BC": PartSamplingConfig(100, True, "uniform", True, True)}
+ sampling_config_tmp = SamplingConfig(part_sampling_config_dict)
+ for attr, config in sampling_config_tmp.__dict__.items():
+ if config is not None:
+ print("check sampling config: {}: {}".format(attr, config.__dict__))
+
+
+def check_sampling_config_case2(config_in):
+ """check_sampling_config_case2"""
+ sampling_config_tmp = create_config_from_edict(config_in)
+ for attr, config in sampling_config_tmp.__dict__.items():
+ if config is not None:
+ print("check sampling config: {}: {}".format(attr, config.__dict__))
+
+
+def check_sampling_config_case3(config_in):
+ """check_sampling_config_case3"""
+ config_in.ttime = config_in.domain
+ sampling_config_tmp = create_config_from_edict(config_in)
+ for attr, config in sampling_config_tmp.__dict__.items():
+ if config is not None:
+ print("check sampling config: {}: {}".format(attr, config.__dict__))
+
+
+def check_geometry_case1():
+ """check_geometry_case1"""
+ with pytest.raises(ValueError):
+ Geometry("geom", 2, 0.0, 1.0)
+ with pytest.raises(TypeError):
+ Geometry("geom", 2, [0, 0], [1, 1], sampling_config="test")
+
+ geom = Geometry("geom", 1, 0.0, 1.0)
+ with pytest.raises(TypeError):
+ geom.set_sampling_config('test')
+
+ with pytest.raises(NotImplementedError):
+ geom.sampling()
+
+ for attr, config in geom.__dict__.items():
+ if config is not None:
+ print("check sampling config: {}: {}".format(attr, config))
+
+ try:
+ geom = Geometry("geom", 1, 1.0, 0.0)
+ for attr, config in geom.__dict__.items():
+ if config is not None:
+ print("check sampling config: {}: {}".format(attr, config))
+ except ValueError:
+ print("get ValueError")
+
+
+sampling_config2 = create_config_from_edict(temp_config)
+
+
+def check_geometry_case2(sampling_config_tmp):
+ """check_geometry_case2"""
+ geom = Geometry("geom", 1, 0.0, 1.0, sampling_config=sampling_config_tmp)
+ geom.set_name("geom_name")
+ for attr, config in geom.__dict__.items():
+ if config is not None:
+ print("check sampling config: {}: {}".format(attr, config))
+
+
+def check_geometry_case3(config_in):
+ """check_geometry_case3"""
+ geom = Geometry("geom", 1, 0.0, 1.0)
+ for attr, configs in geom.__dict__.items():
+ if configs is not None:
+ print("check sampling config: {}: {}".format(attr, configs))
+
+ geom.set_name("geom_name")
+ geom.set_sampling_config(create_config_from_edict(config_in))
+ for attr, config in geom.__dict__.items():
+ if attr == "sampling_config" and config is not None:
+ print("check sampling config after set: {}: {}".format(attr, config.__dict__))
+ for attrs, configs in config.__dict__.items():
+ if configs is not None:
+ print("check sampling config: {}: {}".format(attrs, configs.__dict__))
+ else:
+ if config is not None:
+ print("check sampling config after set: {}: {}".format(attr, config))
+
+
+def check_geometry_case4():
+ """check_geometry_case4"""
+ try:
+ geom = Geometry(10, 1, 1.0, 0.0)
+ for attr, config in geom.__dict__.items():
+ if config is not None:
+ print("check sampling config: {}: {}".format(attr, config))
+ except TypeError:
+ print("get geom name type error")
+
+ geom = Geometry("geom", 1, 0.0, 1.0)
+ try:
+ geom.set_name("geom_name")
+ except TypeError:
+ print("get set geom name type error")
+ geom.set_name("geom_name")
+
+ try:
+ geom = Geometry("geom", 1.0, 1.0, 2.0)
+ for attr, config in geom.__dict__.items():
+ if config is not None:
+ print("check sampling config: {}: {}".format(attr, config))
+ except TypeError:
+ print("get geom dim type error")
+
+ try:
+ geom = Geometry("geom", 1, 1.0, 0.0)
+ except ValueError:
+ print("get geom coord value error")
+
+ try:
+ geom = Geometry("geom", 1, {"min": 0.0}, {"max": 1.0})
+ except TypeError:
+ print("get geom coord type error")
+
+ try:
+ geom = Geometry("geom", 1, 0.0, 1.0, dtype=np.uint32)
+ except TypeError:
+ print("get geom data type error")
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_part_sampling_config():
+ """test_check_part_sampling_config"""
+ check_part_sampling_config(100, True, "uniform", True, True)
+ check_part_sampling_config(100, False, "sobol", True, True)
+ check_part_sampling_config([100, 100], False, "sobol", True, True)
+ check_part_sampling_config([100, 100], True, "uniform", True, True)
+ check_part_sampling_config(100, False, "lhs", True, True)
+ check_part_sampling_config(100, False, "halton", True, True)
+ check_create_config_from_edict()
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_sampling_config_case1():
+ """test_check_sampling_config_case1"""
+ check_sampling_config_case1()
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_sampling_config_case2():
+ """test_check_sampling_config_case2"""
+ check_sampling_config_case2(temp_config)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_sampling_config_case3():
+ """test_check_sampling_config_case3"""
+ check_sampling_config_case3(temp_config)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_geometry_case1():
+ """test_check_geometry_case1"""
+ check_geometry_case1()
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_geometry_case2():
+ """test_check_geometry_case2"""
+ check_geometry_case2(sampling_config2)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_geometry_case3():
+ """test_check_geometry_case3"""
+ check_geometry_case3(temp_config)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_geometry_case4():
+ """test_check_geometry_case4"""
+ check_geometry_case4()
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_csg.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_csg.py
new file mode 100644
index 0000000..6fec379
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_csg.py
@@ -0,0 +1,240 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""test geometry module: CSG classes"""
+import pytest
+from easydict import EasyDict as edict
+from mindelec.geometry import create_config_from_edict
+from mindelec.geometry import Rectangle, Disk, Interval
+from mindelec.geometry import CSGIntersection, CSGDifference, CSGUnion, CSGXOR, CSG
+
+sampling_config_csg = edict({
+ 'domain': edict({
+ 'random_sampling': True,
+ 'size': 1000,
+ 'sampler': 'uniform'
+ }),
+ 'BC': edict({
+ 'random_sampling': True,
+ 'size': 200,
+ 'sampler': 'uniform',
+ 'with_normal': True,
+ }),
+})
+
+sampling_config_csg2 = edict({
+ 'domain': edict({
+ 'random_sampling': True,
+ 'size': 1000,
+ 'sampler': 'uniform'
+ }),
+ 'BC': edict({
+ 'random_sampling': True,
+ 'size': 200,
+ 'sampler': 'uniform',
+ 'with_normal': False,
+ }),
+})
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_csg_union():
+ """test check union"""
+ disk = Disk("disk", (1.2, 0.5), 0.8)
+ rect = Rectangle("rect", (-1.0, 0), (1, 1))
+
+ union = CSGUnion(rect, disk, create_config_from_edict(sampling_config_csg))
+ union = rect | disk
+ for samplers in ["uniform", "lhs", "halton", "sobol"]:
+ print("check random sampler: {}".format(samplers))
+ if "domain" in sampling_config_csg.keys():
+ sampling_config_csg.domain.sampler = samplers
+ if "BC" in sampling_config_csg.keys():
+ sampling_config_csg.BC.sampler = samplers
+
+ union.set_sampling_config(create_config_from_edict(sampling_config_csg2))
+ bc = union.sampling(geom_type="BC")
+
+ union.set_sampling_config(create_config_from_edict(sampling_config_csg))
+ domain = union.sampling(geom_type="domain")
+ bc, bc_normal = union.sampling(geom_type="BC")
+ print(bc, bc_normal, domain)
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_csg_difference():
+ """test_check_csg_difference"""
+ disk = Disk("disk", (1.2, 0.5), 0.8)
+ rect = Rectangle("rect", (-1.0, 0), (1, 1))
+
+ difference = CSGDifference(rect, disk, create_config_from_edict(sampling_config_csg))
+ difference = rect - disk
+ for samplers in ["uniform", "lhs", "halton", "sobol"]:
+ print("check random sampler: {}".format(samplers))
+ if "domain" in sampling_config_csg.keys():
+ sampling_config_csg.domain.sampler = samplers
+ if "BC" in sampling_config_csg.keys():
+ sampling_config_csg.BC.sampler = samplers
+
+ difference.set_sampling_config(create_config_from_edict(sampling_config_csg2))
+ bc = difference.sampling(geom_type="BC")
+
+ difference.set_sampling_config(create_config_from_edict(sampling_config_csg))
+ domain = difference.sampling(geom_type="domain")
+ bc, bc_normal = difference.sampling(geom_type="BC")
+ print(bc, bc_normal, domain)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_csg_intersection():
+ """test_check_csg_intersection"""
+ disk = Disk("disk", (1.2, 0.5), 0.8)
+ rect = Rectangle("rect", (-1.0, 0), (1, 1))
+
+ intersec = CSGIntersection(rect, disk, create_config_from_edict(sampling_config_csg))
+ intersec = rect & disk
+ for samplers in ["uniform", "lhs", "halton", "sobol"]:
+ print("check random sampler: {}".format(samplers))
+ if "domain" in sampling_config_csg.keys():
+ sampling_config_csg.domain.sampler = samplers
+ if "BC" in sampling_config_csg.keys():
+ sampling_config_csg.BC.sampler = samplers
+
+ intersec.set_sampling_config(create_config_from_edict(sampling_config_csg2))
+ bc = intersec.sampling(geom_type="BC")
+
+ intersec.set_sampling_config(create_config_from_edict(sampling_config_csg))
+ domain = intersec.sampling(geom_type="domain")
+ bc, bc_normal = intersec.sampling(geom_type="BC")
+ print(bc, bc_normal, domain)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_csg_xor():
+ """test_check_csg_xor"""
+ disk = Disk("disk", (1.2, 0.5), 0.8)
+ rect = Rectangle("rect", (-1.0, 0), (1, 1))
+
+ xor = CSGXOR(rect, disk, create_config_from_edict(sampling_config_csg))
+ xor = rect ^ disk
+ for samplers in ["uniform", "lhs", "halton", "sobol"]:
+ print("check random sampler: {}".format(samplers))
+ if "domain" in sampling_config_csg.keys():
+ sampling_config_csg.domain.sampler = samplers
+ if "BC" in sampling_config_csg.keys():
+ sampling_config_csg.BC.sampler = samplers
+
+ xor.set_sampling_config(create_config_from_edict(sampling_config_csg2))
+ bc = xor.sampling(geom_type="BC")
+
+ xor.set_sampling_config(create_config_from_edict(sampling_config_csg))
+ domain = xor.sampling(geom_type="domain")
+ bc, bc_normal = xor.sampling(geom_type="BC")
+ print(bc, bc_normal, domain)
+
+
+no_src_sampling_config = edict({
+ 'domain': edict({
+ 'random_sampling': True,
+ 'size': 200,
+ 'sampler': 'uniform'
+ }),
+})
+
+no_src_sampling_config1 = edict({
+ 'BC': edict({
+ 'random_sampling': True,
+ 'size': 200,
+ 'sampler': 'uniform',
+ 'with_normal': False
+ }),
+})
+
+no_src_sampling_config2 = edict({
+ 'domain': edict({
+ 'random_sampling': False,
+ 'size': [10, 10]
+ }),
+})
+
+no_src_sampling_config3 = edict({
+ 'BC': edict({
+ 'random_sampling': False,
+ 'size': [10, 10]
+ }),
+})
+
+no_src_sampling_config4 = edict({
+ 'IC': edict({
+ 'random_sampling': False,
+ 'size': [10, 10]
+ }),
+})
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_point_src_csg():
+ """test_check_point_src_csg"""
+ src_region = Disk("src", (0.0, 0.0), 0.2)
+ rectangle = Rectangle("rect", (-1, -1), (1, 1))
+ line = Interval("line", -1, 1)
+
+ with pytest.raises(TypeError):
+ no_src_region = CSG("test", rectangle, 2, [0, 0], [1, 1])
+ with pytest.raises(TypeError):
+ no_src_region = CSG("test", 2, rectangle, [0, 0], [1, 1])
+ with pytest.raises(ValueError):
+ no_src_region = CSG("test", rectangle, line, [0, 0], [1, 1])
+
+ no_src_region = rectangle - src_region
+ no_src_region.set_name("no_src")
+
+ with pytest.raises(ValueError):
+ no_src_region.set_sampling_config(None)
+
+ with pytest.raises(TypeError):
+ no_src_region.set_sampling_config("test")
+
+ with pytest.raises(ValueError):
+ no_src_region.set_sampling_config(create_config_from_edict(no_src_sampling_config2))
+
+ with pytest.raises(ValueError):
+ no_src_region.set_sampling_config(create_config_from_edict(no_src_sampling_config3))
+
+ with pytest.raises(ValueError):
+ no_src_region.set_sampling_config(create_config_from_edict(no_src_sampling_config4))
+
+ no_src_region.set_sampling_config(create_config_from_edict(no_src_sampling_config))
+ with pytest.raises(KeyError):
+ no_src_region.sampling(geom_type="BC")
+
+ no_src_region.set_sampling_config(create_config_from_edict(no_src_sampling_config1))
+ with pytest.raises(KeyError):
+ no_src_region.sampling(geom_type="domain")
+ with pytest.raises(ValueError):
+ no_src_region.sampling(geom_type="test")
+
+ no_src_region.sampling(geom_type="BC")
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_nd.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_nd.py
new file mode 100644
index 0000000..9df2743
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_nd.py
@@ -0,0 +1,142 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""test geometry module: nd cases"""
+import pytest
+from easydict import EasyDict as edict
+
+from mindelec.geometry import create_config_from_edict
+from mindelec.geometry import Cuboid
+
+cuboid_random = edict({
+ 'domain': edict({
+ 'random_sampling': True,
+ 'size': 1000,
+ 'sampler': 'uniform'
+ }),
+ 'BC': edict({
+ 'random_sampling': True,
+ 'size': 200,
+ 'sampler': 'uniform',
+ 'with_normal': False,
+ }),
+})
+
+cuboid_random2 = edict({
+ 'BC': edict({
+ 'random_sampling': True,
+ 'size': 200,
+ 'sampler': 'uniform',
+ 'with_normal': False,
+ }),
+})
+
+cuboid_mesh = edict({
+ 'domain': edict({
+ 'random_sampling': False,
+ 'size': [20, 30, 10],
+ }),
+ 'BC': edict({
+ 'random_sampling': False,
+ 'size': 900,
+ 'with_normal': True,
+ }),
+})
+
+
+cuboid_mesh2 = edict({
+ 'domain': edict({
+ 'random_sampling': False,
+ 'size': [20, 30],
+ }),
+ 'BC': edict({
+ 'random_sampling': False,
+ 'size': 900,
+ 'with_normal': True,
+ }),
+})
+
+def check_cuboid_random(cuboid_config):
+ """check_cuboid_random"""
+
+ cuboid = Cuboid("Cuboid", (-3, -1, 0), [-1, 2, 1])
+ for samplers in ["uniform", "lhs", "halton", "sobol"]:
+ print("check random sampler: {}".format(samplers))
+ if "domain" in cuboid_config.keys():
+ cuboid_config.domain.sampler = samplers
+ if "BC" in cuboid_config.keys():
+ cuboid_config.BC.sampler = samplers
+
+ try:
+ cuboid.set_sampling_config("test")
+ except TypeError:
+ print("set_sampling_config TypeError")
+
+ try:
+ cuboid.sampling(geom_type="domain")
+ except ValueError:
+ print("sampling ValueError")
+
+ cuboid.set_sampling_config(create_config_from_edict(cuboid_config))
+ domain = cuboid.sampling(geom_type="domain")
+ bc = cuboid.sampling(geom_type="BC")
+ assert domain.shape == (1000, 3)
+ assert bc.shape == (199, 3)
+
+
+def check_cuboid_mesh(cuboid_config):
+ """check_cuboid_mesh"""
+ cuboid = Cuboid("Cuboid", (-3, -1, 0), [-1, 2, 1], sampling_config=create_config_from_edict(cuboid_config))
+ domain = cuboid.sampling(geom_type="domain")
+ bc, bc_normal = cuboid.sampling(geom_type="BC")
+ assert domain.shape == (6000, 3)
+ assert bc.shape == (556, 3)
+ assert bc_normal.shape == (556, 3)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_cuboid_random():
+ """test_check_cuboid_random"""
+ check_cuboid_random(cuboid_random)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_cuboid_random_nodomain_error():
+ """test_check_cuboid_random"""
+ with pytest.raises(KeyError):
+ check_cuboid_random(cuboid_random2)
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_cuboid_mesh():
+ """test_check_cuboid_mesh"""
+ check_cuboid_mesh(cuboid_mesh)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_cuboid_mesh_meshsize_error():
+ """test_check_cuboid_mesh"""
+ with pytest.raises(ValueError):
+ check_cuboid_mesh(cuboid_mesh2)
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_td.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_td.py
new file mode 100644
index 0000000..aa1c3c4
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/geometry/test_geometry_td.py
@@ -0,0 +1,293 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""test geometry module: geometry with time cases"""
+
+import copy
+import pytest
+from easydict import EasyDict as edict
+
+from mindelec.geometry import create_config_from_edict
+from mindelec.geometry import Rectangle, GeometryWithTime, TimeDomain
+
+rectangle_random = edict({
+ 'domain': edict({
+ 'random_sampling': True,
+ 'size': 1000,
+ 'sampler': 'uniform'
+ }),
+ 'BC': edict({
+ 'random_sampling': True,
+ 'size': 200,
+ 'sampler': 'uniform',
+ 'with_normal': False,
+ }),
+})
+
+rectangle_mesh = edict({
+ 'domain': edict({
+ 'random_sampling': False,
+ 'size': [10, 20],
+ }),
+ 'BC': edict({
+ 'random_sampling': False,
+ 'size': 50,
+ 'with_normal': False,
+ }),
+})
+
+time_random = edict({
+ 'domain': edict({
+ 'random_sampling': True,
+ 'size': 10,
+ 'sampler': 'lhs'
+ })
+})
+
+time_mesh = edict({
+ 'domain': edict({
+ 'random_sampling': False,
+ 'size': 10,
+ })
+})
+
+time_mesh2 = edict({
+ 'domain': edict({
+ 'random_sampling': False,
+ 'size': 10000,
+ })
+})
+
+time_mesh3 = edict({
+ 'IC': edict({
+ 'random_sampling': False,
+ 'size': 10000,
+ })
+})
+
+reset_geom_time_config = edict({
+ 'domain': edict({
+ 'random_sampling': False,
+ 'size': [10, 20],
+ }),
+ 'BC': edict({
+ 'random_sampling': False,
+ 'size': 50,
+ 'with_normal': True,
+ }),
+ 'IC': edict({
+ 'random_sampling': False,
+ 'size': [10, 10],
+ }),
+ 'time': edict({
+ 'random_sampling': False,
+ 'size': 10,
+ })
+})
+
+reset_geom_time_config2 = edict({
+ 'domain': edict({
+ 'random_sampling': False,
+ 'size': [10, 20],
+ }),
+ 'BC': edict({
+ 'random_sampling': False,
+ 'size': 50,
+ 'with_normal': True,
+ }),
+ 'IC': edict({
+ 'random_sampling': False,
+ 'size': [10, 10],
+ }),
+ 'time': edict({
+ 'random_sampling': True,
+ 'size': 10,
+ })
+})
+
+reset_geom_time_config3 = edict({
+ 'domain': edict({
+ 'random_sampling': True,
+ 'size': 200,
+ }),
+ 'BC': edict({
+ 'random_sampling': True,
+ 'size': 50,
+ 'with_normal': True,
+ }),
+ 'IC': edict({
+ 'random_sampling': False,
+ 'size': [10, 10],
+ }),
+ 'time': edict({
+ 'random_sampling': True,
+ 'size': 10,
+ })
+})
+
+reset_geom_time_config4 = edict({
+ 'domain': edict({
+ 'random_sampling': True,
+ 'size': 200,
+ }),
+ 'BC': edict({
+ 'random_sampling': True,
+ 'size': 50,
+ 'with_normal': True,
+ }),
+ 'IC': edict({
+ 'random_sampling': True,
+ 'size': 100,
+ }),
+ 'time': edict({
+ 'random_sampling': False,
+ 'size': 10,
+ })
+})
+
+reset_geom_time_config5 = edict({
+ 'time': edict({
+ 'random_sampling': False,
+ 'size': 10,
+ })
+})
+
+def check_rect_with_time_init_config(rect_config, time_config):
+ """check_rect_with_time_init_config"""
+ rect = Rectangle("rect", [-1.0, -0.5], [1.0, 0.5], sampling_config=create_config_from_edict(rect_config))
+ time = TimeDomain("time", 0.0, 1.0, sampling_config=create_config_from_edict(time_config))
+ rect_with_time = GeometryWithTime(rect, time)
+
+ # check info
+ print("check rect_with_time initial config: {}".format(rect_with_time.__dict__))
+ if rect_with_time.sampling_config is not None:
+ for key, value in rect_with_time.sampling_config.__dict__.items():
+ if value is not None:
+ print(" get attr: {}, value: {}".format(key, value.__dict__))
+
+ # sampling
+ config = rect_with_time.sampling_config
+ if config is None:
+ raise ValueError
+ if config.domain is not None:
+ domain = rect_with_time.sampling(geom_type="domain")
+ print("check domain points: {}".format(domain.shape))
+ if config.bc is not None:
+ if config.bc.with_normal:
+ bc, bc_normal = rect_with_time.sampling(geom_type="BC")
+ print("check bc points: {}, bc_normal: {}".format(bc.shape, bc_normal.shape))
+ else:
+ bc = rect_with_time.sampling(geom_type="BC")
+ print("check bc points: {}".format(bc.shape))
+ if config.ic is not None:
+ ic = rect_with_time.sampling(geom_type="IC")
+ print("check ic points: {}".format(ic.shape))
+
+
+def check_rect_with_time_set_config(config):
+ """check_rect_with_time_set_config"""
+ rect = Rectangle("rect", [-1.0, -0.5], [1.0, 0.5])
+ time = TimeDomain("time", 0.0, 1.0)
+ try:
+ GeometryWithTime(rect, time, create_config_from_edict(config))
+ except ValueError:
+ print("create_config_from_edict ValueError")
+
+ rect_with_time = GeometryWithTime(rect, time)
+ try:
+ rect_with_time.sampling(geom_type="domain")
+ except ValueError:
+ print("sampling ValueError")
+
+ rect_with_time.set_sampling_config(create_config_from_edict(config))
+
+ try:
+ rect_with_time.sampling(geom_type="test")
+ except ValueError:
+ print("sampling ValueError")
+ # sampling
+ config = rect_with_time.sampling_config
+ if config.domain is not None:
+ domain = rect_with_time.sampling(geom_type="domain")
+ print("check domain points: {}".format(domain.shape))
+ else:
+ try:
+ rect_with_time.sampling(geom_type="domain")
+ except ValueError:
+ print("sampling KeyError")
+ if config.bc is not None:
+ if config.bc.with_normal:
+ bc, bc_normal = rect_with_time.sampling(geom_type="BC")
+ print("check bc points: {}, bc_normal: {}".format(bc.shape, bc_normal.shape))
+ normal = copy.deepcopy(bc)
+ normal[:, :2] = bc[:, :2] + bc_normal[:, :]
+ else:
+ bc = rect_with_time.sampling(geom_type="BC")
+ print("check bc points: {}".format(bc.shape))
+ else:
+ try:
+ rect_with_time.sampling(geom_type="BC")
+ except ValueError:
+ print("sampling KeyError")
+ if config.ic is not None:
+ ic = rect_with_time.sampling(geom_type="IC")
+ print("check ic points: {}".format(ic.shape))
+ else:
+ try:
+ rect_with_time.sampling(geom_type="IC")
+ except ValueError:
+ print("sampling KeyError")
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_rect_with_time_init_config():
+ """test_check_rect_with_time_init_config"""
+ check_rect_with_time_init_config(rectangle_random, time_random)
+ check_rect_with_time_init_config(rectangle_mesh, time_random)
+ check_rect_with_time_init_config(rectangle_mesh, time_mesh2)
+ check_rect_with_time_init_config(rectangle_random, time_mesh)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_rect_with_time_init_config_error():
+ """test_check_rect_with_time_init_config"""
+ with pytest.raises(ValueError):
+ check_rect_with_time_init_config(rectangle_mesh, time_mesh3)
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_rect_with_time_set_config():
+ """test_check_rect_with_time_set_config"""
+ check_rect_with_time_set_config(reset_geom_time_config)
+ check_rect_with_time_set_config(reset_geom_time_config2)
+ check_rect_with_time_set_config(reset_geom_time_config3)
+ check_rect_with_time_set_config(reset_geom_time_config4)
+ check_rect_with_time_set_config(reset_geom_time_config4)
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_check_rect_with_time_set_config2():
+ """test_check_rect_with_time_set_config"""
+ with pytest.raises(ValueError):
+ check_rect_with_time_set_config(rectangle_random)
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/loss/test_constraints.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/loss/test_constraints.py
new file mode 100644
index 0000000..2e794d4
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/loss/test_constraints.py
@@ -0,0 +1,39 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Test constraints"""
+import pytest
+from mindelec.data import Dataset
+from mindelec.loss import Constraints
+from mindelec.geometry import Rectangle
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_constraints_dataset_type_error():
+ with pytest.raises(TypeError):
+ Constraints(1, 1)
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_existed_data_config_type_error():
+ rectangle = Rectangle("rect", (-1, -1), (1, 1))
+ geom_dict = {rectangle: ["domain"]}
+ dataset = Dataset(geom_dict)
+ with pytest.raises(TypeError):
+ Constraints(dataset, 1)
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/net_with_loss/test_netwithloss.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/net_with_loss/test_netwithloss.py
new file mode 100644
index 0000000..00cdec4
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/net_with_loss/test_netwithloss.py
@@ -0,0 +1,217 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""
+test net_with_loss
+"""
+import os
+from easydict import EasyDict as edict
+import numpy as np
+import pytest
+
+import mindspore.nn as nn
+import mindspore.ops as ops
+from mindspore import Tensor
+from mindspore import context, ms_function
+from mindspore.common import set_seed
+from mindelec.geometry import Rectangle, create_config_from_edict
+from mindelec.data import Dataset, ExistedDataConfig
+from mindelec.loss import Constraints, NetWithLoss, NetWithEval
+from mindelec.architecture import ResBlock, LinearBlock
+from mindelec.solver import Problem
+from mindelec.common import L2
+from mindelec.operators import Grad
+
+set_seed(1)
+np.random.seed(1)
+context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
+
+path = os.getcwd()
+data_config = edict({
+ 'data_dir': [path+'/inputs.npy', path+'/label.npy'], # absolute dir
+ 'columns_list': ['input_data', 'label'],
+ 'data_format': 'npy',
+ 'constraint_type': 'Equation',
+ 'name': 'exist'
+})
+
+if not os.path.exists(data_config['data_dir'][0]):
+ data_in = np.ones((32, 2), dtype=np.float32)
+ np.save(path+"/inputs.npy", data_in)
+
+if not os.path.exists(data_config['data_dir'][1]):
+ data_label = np.ones((32, 1), dtype=np.float32)
+ np.save(path+"/label.npy", data_label)
+
+
+rectangle_config = edict({
+ 'domain': edict({
+ 'random_sampling': True,
+ 'size': 2500,
+ 'sampler': 'uniform'
+ }),
+ 'BC': edict({
+ 'random_sampling': True,
+ 'size': 200,
+ 'sampler': 'uniform'
+ })
+})
+
+rectangle_config1 = edict({
+ 'domain': edict({
+ 'random_sampling': True,
+ 'size': 2500,
+ 'sampler': 'uniform'
+ }),
+})
+
+
+class Net(nn.Cell):
+ """net definition"""
+ def __init__(self, input_dim, output_dim, hidden_layer=128, activation="sin"):
+ super(Net, self).__init__()
+ self.resblock = ResBlock(hidden_layer, hidden_layer, activation=activation)
+ self.fc1 = LinearBlock(input_dim, hidden_layer)
+ self.fc2 = LinearBlock(hidden_layer, output_dim)
+
+ def construct(self, *inputs):
+ x = inputs[0]
+ out = self.fc1(x)
+ out = self.resblock(out)
+ out = self.fc2(out)
+ return out
+
+
+class RectPde(Problem):
+ """rectangle pde problem"""
+ def __init__(self, domain_name=None, bc_name=None, label_name=None, net=None):
+ super(RectPde, self).__init__()
+ self.domain_name = domain_name
+ self.bc_name = bc_name
+ self.label_name = label_name
+ self.type = "Equation"
+ self.jacobian = Grad(net)
+
+ @ms_function
+ def governing_equation(self, *output, **kwargs):
+ u = output[0]
+ data = kwargs[self.domain_name]
+ u_x = self.jacobian(data, 0, 0, u)
+ return u_x
+
+ @ms_function
+ def boundary_condition(self, *output, **kwargs):
+ u = output[0]
+ x = kwargs[self.bc_name][:, 0]
+ y = kwargs[self.bc_name][:, 1]
+ return u - ops.sin(x) * ops.cos(y)
+
+ @ms_function
+ def constraint_function(self, *output, **kwargs):
+ u = output[0]
+ label = kwargs[self.label_name]
+ return u - label
+
+
+class RectPde1(Problem):
+ """rectangle pde problem with no boundary condition"""
+ def __init__(self, domain_name, net):
+ super(RectPde1, self).__init__()
+ self.domain_name = domain_name
+ self.type = "Equation"
+ self.jacobian = Grad(net)
+
+ @ms_function
+ def governing_equation(self, *output, **kwargs):
+ u = output[0]
+ data = kwargs[self.domain_name]
+ u_x = self.jacobian(data, 0, 0, u)
+ return u_x
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_netwithloss():
+ """test netwithloss function"""
+ model = Net(2, 1, 128, "sin")
+ rect_space = Rectangle("rectangle", coord_min=[-1.0, -1.0], coord_max=[1.0, 1.0],
+ sampling_config=create_config_from_edict(rectangle_config))
+
+ geom_dict = {rect_space: ["domain", "BC"]}
+ dataset = Dataset(geom_dict)
+ dataset.create_dataset(batch_size=4, shuffle=True)
+ prob_dict = {rect_space.name: RectPde(domain_name="rectangle_domain_points", bc_name="rectangle_BC_points",
+ net=model)}
+ train_constraints = Constraints(dataset, prob_dict)
+ metrics = {'l2': L2(), 'distance': nn.MAE()}
+ train_input_map = {'rectangle_domain': ['rectangle_domain_points'], 'rectangle_BC': ['rectangle_BC_points']}
+ loss_network = NetWithLoss(model, train_constraints, metrics, train_input_map)
+
+ domain_points = Tensor(np.ones([32, 2]).astype(np.float32))
+ bc_points = Tensor(np.ones([32, 2]).astype(np.float32))
+ bc_normal = Tensor(np.ones([32, 2]).astype(np.float32))
+ out = loss_network(domain_points, bc_points, bc_normal)
+ print(out)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_netwithloss1():
+ """test netwithloss function1"""
+ model = Net(2, 1, 128, "sin")
+ rect_space = Rectangle("rectangle", coord_min=[-1.0, -1.0], coord_max=[1.0, 1.0],
+ sampling_config=create_config_from_edict(rectangle_config1))
+ geom_dict = {rect_space: ["domain"]}
+ dataset = Dataset(geom_dict)
+ dataset.create_dataset(batch_size=4, shuffle=True)
+ prob_dict = {rect_space.name: RectPde1(domain_name="rectangle_domain_points", net=model)}
+ train_constraints = Constraints(dataset, prob_dict)
+ metrics = {'l2': L2(), 'distance': nn.MAE()}
+ train_input_map = {'rectangle_domain': ['rectangle_domain_points']}
+ loss_network = NetWithLoss(model, train_constraints, metrics, train_input_map)
+
+ domain_points = Tensor(np.ones([32, 2]).astype(np.float32))
+ out = loss_network(domain_points)
+ print(out)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_netwitheval():
+ """test netwitheval function"""
+ model = Net(2, 1, 128, "sin")
+ src_domain = ExistedDataConfig(name="src_domain",
+ data_dir=[path+"/inputs.npy", path+"/label.npy"],
+ columns_list=["inputs", "label"],
+ data_format="npy",
+ constraint_type="Label",
+ random_merge=False)
+ test_prob_dict = {src_domain.name: RectPde(domain_name=src_domain.name+"_inputs",
+ label_name=src_domain.name+"_label", net=model)}
+ test_dataset = Dataset(existed_data_list=[src_domain])
+ test_dataset.create_dataset(batch_size=4, shuffle=False)
+ test_constraints = Constraints(test_dataset, test_prob_dict)
+ metrics = {'l2': L2(), 'distance': nn.MAE()}
+ loss_network = NetWithEval(model, test_constraints, metrics)
+
+ data = Tensor(np.ones([32, 2]).astype(np.float32))
+ label = Tensor(np.ones([32, 1]).astype(np.float32))
+ out = loss_network(data, label)
+ print(out)
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/src/config.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/src/config.py
new file mode 100644
index 0000000..07ae2a5
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/src/config.py
@@ -0,0 +1,30 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+network config setting, will be used in train.py and eval.py
+"""
+
+# config
+config = {
+ 'base_channels': 8,
+ 'input_channels': 4,
+ 'epochs': 2000,
+ 'batch_size': 8,
+ 'save_epoch': 100,
+ 'lr': 0.01,
+ 'lr_decay_milestones': 5,
+ 'eval_interval': 20,
+ 'patch_shape': [25, 50, 25],
+}
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/src/dataset.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/src/dataset.py
new file mode 100644
index 0000000..35ac727
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/src/dataset.py
@@ -0,0 +1,82 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""dataset generation and loading"""
+
+import numpy as np
+from mindspore.common import set_seed
+
+from mindelec.data import Dataset, ExistedDataConfig
+
+np.random.seed(0)
+set_seed(0)
+
+PATCH_DIM = [25, 50, 25]
+NUM_SAMPLE = 10000
+INPUT_PATH = ""
+DATA_CONFIG_PATH = "./data_config.npy"
+SAVE_DATA_PATH = "./"
+
+
+def generate_data(input_path):
+ """generate training data and data configuration"""
+ space_temp = np.load(input_path)
+
+ print("data load finish")
+ print("random cropping...")
+ space_data = np.ones((NUM_SAMPLE,
+ PATCH_DIM[0],
+ PATCH_DIM[1],
+ PATCH_DIM[2],
+ space_temp.shape[-1])).astype(np.float32)
+ rand_pos = np.random.randint(low=0, high=(space_temp.shape[0] - PATCH_DIM[0])*\
+ (space_temp.shape[1] - PATCH_DIM[1])*\
+ (space_temp.shape[2] - PATCH_DIM[2]), size=NUM_SAMPLE)
+ for i, pos in enumerate(rand_pos):
+ z = pos % (space_temp.shape[2] - PATCH_DIM[2])
+ y = (pos // (space_temp.shape[2] - PATCH_DIM[2])) % (space_temp.shape[1] - PATCH_DIM[1])
+ x = (pos // (space_temp.shape[2] - PATCH_DIM[2])) // (space_temp.shape[1] - PATCH_DIM[1])
+ space_data[i] = space_temp[x : x+PATCH_DIM[0], y : y+PATCH_DIM[1], z : z+PATCH_DIM[2], :]
+ print("random crop finished")
+
+ space_data[:, :, :, :, 2] = np.log10(space_data[:, :, :, :, 2] + 1.0)
+ data_config = np.ones(4)
+ for i in range(4):
+ data_config[i] = np.max(np.abs(space_data[:, :, :, :, i]))
+ space_data[:, :, :, :, i] = space_data[:, :, :, :, i] / data_config[i]
+
+ length = space_data.shape[0] // 10
+ test_data = space_data[:length]
+ train_data = space_data[length:]
+
+ space_data = space_data.transpose((0, 4, 1, 2, 3))
+
+ np.save(DATA_CONFIG_PATH, data_config)
+ np.save(SAVE_DATA_PATH+"/train_data.npy", train_data)
+ np.save(SAVE_DATA_PATH+"/test_data.npy", test_data)
+ print("AE train and test data and data config is saved")
+
+
+def create_dataset(input_path, label_path, batch_size=8, shuffle=True):
+ electromagnetic = ExistedDataConfig(name="electromagnetic",
+ data_dir=[input_path, label_path],
+ columns_list=["inputs", "label"],
+ data_format=input_path.split('.')[-1])
+ dataset = Dataset(existed_data_list=[electromagnetic])
+ data_loader = dataset.create_dataset(batch_size=batch_size, shuffle=shuffle)
+
+ return data_loader
+
+if __name__ == "__main__":
+ generate_data(INPUT_PATH)
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/src/lr_generator.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/src/lr_generator.py
new file mode 100644
index 0000000..a70c1aa
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/src/lr_generator.py
@@ -0,0 +1,30 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""learning rate generator"""
+
+
+def step_lr_generator(step_size, epochs, lr, lr_decay_milestones):
+ """generate step decayed learnig rate"""
+
+ total_steps = epochs * step_size
+
+ milestones = [int(total_steps * i / lr_decay_milestones) for i in range(1, lr_decay_milestones)]
+ milestones.append(total_steps)
+ learning_rates = [lr*0.5**i for i in range(0, lr_decay_milestones - 1)]
+ learning_rates.append(lr*0.5**(lr_decay_milestones - 1))
+
+ print("total_steps: %s, milestones: %s, learning_rates: %s " %(total_steps, milestones, learning_rates))
+
+ return milestones, learning_rates
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/src/metric.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/src/metric.py
new file mode 100644
index 0000000..e77b6b3
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/src/metric.py
@@ -0,0 +1,49 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""metrics"""
+
+import numpy as np
+import mindspore.nn as nn
+from mindspore.ops import functional as F
+
+class MyMSELoss(nn.LossBase):
+ """mse loss function"""
+ def construct(self, base, target):
+ bs, _, _, _, _ = F.shape(target)
+ x = F.square(base - target)
+ return 2*bs*self.get_loss(x)
+
+
+class EvalMetric(nn.Metric):
+ """eval metric"""
+
+ def __init__(self, length):
+ super(EvalMetric, self).__init__()
+ self.clear()
+ self.length = length
+
+ def clear(self):
+ self.error_sum_l2_error = 0
+ self.error_mean_l1_error = 0
+
+ def update(self, *inputs):
+ test_predict = self._convert_data(inputs[0])
+ test_label = self._convert_data(inputs[1])
+
+ for i in range(len(test_label)):
+ self.error_mean_l1_error += np.mean(np.abs(test_label[i] - test_predict[i]))
+
+ def eval(self):
+ return {'mean_l1_error': self.error_mean_l1_error / self.length}
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/src/model.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/src/model.py
new file mode 100644
index 0000000..b996dca
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/src/model.py
@@ -0,0 +1,298 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""EncoderDecoder for MindElec."""
+
+import mindspore.nn as nn
+import mindspore.ops as ops
+
+
+class EncoderDecoder(nn.Cell):
+ """
+ EncoderDecoder architecture for MindElec.
+
+ Args:
+ input_dim (int): input channel.
+ target_shape (list or tuple): Output DWH shape.
+ base_channels (int): base channel, all intermediate layers' channels are multiple of this value.
+ decoding (bool): Enable Decoder, True if the reconstructed input is need,
+ for example, the training. Default: False
+
+ Returns:
+ Tensor, output tensor, compressed encodings (encoding=False) or reconstructed input (encoding=True).
+
+ Examples:
+ >>> training encoder: Encoder_Decoder(input_dim=4, target_shape=[25, 50, 25], base_channels=8, decoding=True)
+ >>> applying encoder(data compression): Encoder_Decoder(input_dim=4,
+ ... target_shape=[25, 50, 25],
+ ... base_channels=8,
+ ... decoding=False)
+ """
+
+ def __init__(self, input_dim, target_shape, base_channels=8, decoding=False):
+ super(EncoderDecoder, self).__init__()
+ self.decoding = decoding
+ self.encoder = Encoder(input_dim, base_channels)
+ if self.decoding:
+ self.decoder = Decoder(input_dim, target_shape, base_channels)
+
+ def construct(self, x):
+ encoding = self.encoder(x)
+ if self.decoding:
+ output = self.decoder(encoding)
+ else:
+ output = encoding
+ return output
+
+
+class Encoder(nn.Cell):
+ """
+ Encoder architecture.
+
+ Args:
+ input_dim (int): input channel.
+ base_channels (int): base channel, all intermediate layers' channels are multiple of this value.
+
+ Returns:
+ Tensor, output tensor.
+
+ Examples:
+ >>> Encoder(input_dim=4, base_channels=8)
+ """
+
+ def __init__(self, input_dim, base_channels):
+ super(Encoder, self).__init__()
+ print("BASE_CHANNELS: %d" %base_channels)
+ self.input_dim = input_dim
+ self.channels = base_channels
+
+ self.conv0 = nn.Conv3d(self.input_dim, self.channels*2, kernel_size=3, pad_mode='pad', padding=1)
+ self.conv0_1 = nn.Conv3d(self.channels*2, self.channels*2, kernel_size=3, pad_mode='pad', padding=1)
+ self.conv1 = nn.Conv3d(self.channels*2, self.channels*4, kernel_size=3, pad_mode='pad', padding=1)
+ self.conv1_1 = nn.Conv3d(self.channels*4, self.channels*4, kernel_size=3, pad_mode='pad', padding=1)
+ self.conv2 = nn.Conv3d(self.channels*4, self.channels*8, kernel_size=3, pad_mode='pad', padding=1)
+ self.conv2_1 = nn.Conv3d(self.channels*8, self.channels*8, kernel_size=3, pad_mode='pad', padding=1)
+ self.conv3 = nn.Conv3d(self.channels*8, self.channels*16, kernel_size=(2, 3, 2))
+ self.conv4 = nn.Conv2d(self.channels*16, self.channels*32, kernel_size=(1, 3), pad_mode='pad', padding=0)
+
+ self.bn0 = nn.BatchNorm3d(self.channels*2)
+ self.bn1 = nn.BatchNorm3d(self.channels*4)
+ self.bn2 = nn.BatchNorm3d(self.channels*8)
+ self.bn3 = nn.BatchNorm3d(self.channels*16)
+ self.bn4 = nn.BatchNorm2d(self.channels*32)
+
+ self.down1 = ops.MaxPool3D(kernel_size=(2, 2, 2), strides=(2, 2, 2))
+ self.down2 = ops.MaxPool3D(kernel_size=(2, 2, 2), strides=(2, 2, 2))
+ self.down3 = ops.MaxPool3D(kernel_size=(2, 2, 2), strides=(2, 2, 2))
+ self.down4 = nn.MaxPool2d(kernel_size=(2, 6), stride=(2, 6))
+
+ self.down_1_1 = ops.MaxPool3D(kernel_size=(4, 5, 4), strides=(4, 5, 4))
+ self.down_1 = nn.MaxPool2d(kernel_size=(3, 5*3))
+
+ self.down_2_1 = ops.MaxPool3D(kernel_size=(3, 4, 3), strides=(3, 4, 3))
+ self.down_2 = nn.MaxPool2d(kernel_size=(2, 3*2))
+
+ self.down_3 = nn.MaxPool2d(kernel_size=(3, 18))
+
+ self.act = nn.Sigmoid()
+
+ self.concat = ops.Concat(axis=1)
+ self.expand_dims = ops.ExpandDims()
+
+ def construct(self, x):
+ """forward"""
+ bs = x.shape[0]
+
+ x = self.conv0(x)
+ x = self.conv0_1(x)
+ x = self.bn0(x)
+ x = self.act(x)
+ x = self.down1(x)
+ x_1 = self.down_1_1(x)
+ x_1 = self.down_1(x_1.view(bs, x_1.shape[1], x_1.shape[2], -1))
+
+ x = self.conv1(x)
+ x = self.conv1_1(x)
+ x = self.bn1(x)
+ x = self.act(x)
+ x = self.down2(x)
+ x_2 = self.down_2_1(x)
+ x_2 = self.down_2(x_2.view(bs, x_2.shape[1], x_2.shape[2], -1))
+
+ x = self.conv2(x)
+ x = self.conv2_1(x)
+ x = self.bn2(x)
+ x = self.act(x)
+ x = self.down3(x)
+ x_3 = self.down_3(x.view(bs, x.shape[1], x.shape[2], -1))
+
+ x = self.act(self.bn3(self.conv3(x)))
+ x = x.view((bs, x.shape[1], x.shape[2], -1))
+ x = self.down4(x)
+
+ x = self.act(self.bn4(self.conv4(x)))
+ x = self.concat((x, x_1, x_2, x_3))
+ x = self.expand_dims(x, 3)
+
+ return x
+
+
+class Decoder(nn.Cell):
+ """
+ Decoder architecture.
+
+ Args:
+ output_dim (int): output channel.
+ target_shape (list or tuple): Output DWH shape.
+ base_channels (int): base channel, all intermediate layers' channels are multiple of this value.
+
+ Returns:
+ Tensor, output tensor.
+
+ Examples:
+ >>> Decoder(output_dim=4, target_shape=[25, 50, 25], base_channels=8)
+ """
+
+ def __init__(self, output_dim, target_shape, base_channels):
+ super(Decoder, self).__init__()
+
+ self.output_dim = output_dim
+ self.base_channels = base_channels
+ self.up0 = Up((32 + 8 + 4 + 2) * self.base_channels,
+ self.base_channels * 32,
+ [1, 1, 1],
+ [x // 8 for x in target_shape],
+ pad=True)
+ self.up1 = Up(self.base_channels * 32,
+ self.base_channels * 16,
+ [x // 8 for x in target_shape],
+ [x // 4 for x in target_shape],
+ pad=False)
+ self.up2 = Up(self.base_channels * 16,
+ self.base_channels* 4,
+ [x // 4 for x in target_shape],
+ [x // 2 for x in target_shape],
+ pad=False)
+ self.up3 = Up(self.base_channels * 4,
+ self.output_dim,
+ [x // 2 for x in target_shape],
+ target_shape,
+ pad=False)
+
+ def construct(self, x):
+ x = self.up0(x)
+ x = self.up1(x)
+ x = self.up2(x)
+ x = self.up3(x)
+
+ return x
+
+
+class DoubleConvTranspose(nn.Cell):
+ """
+ DoubleConvTranspose architecture
+
+ Args:
+ input_dim (int): Input channel.
+ out_channel (int): Output channel.
+ mid_channels (int): Mid channels. Default: None.
+
+ Returns:
+ Tensor, output tensor.
+
+ Examples:
+ >>> DoubleConvTranspose(input_dim=4, out_channels=8)
+ """
+
+ def __init__(self, input_dim, out_channels, mid_channels=None):
+ super(DoubleConvTranspose, self).__init__()
+ if not mid_channels:
+ mid_channels = out_channels
+ self.conv = nn.Conv3dTranspose(input_dim, out_channels, kernel_size=3)
+ self.conv1 = nn.Conv3dTranspose(out_channels, out_channels, kernel_size=3)
+ self.bn = nn.BatchNorm3d(out_channels)
+ self.relu = nn.ReLU()
+ self.act = nn.Sigmoid()
+
+ def construct(self, x):
+ x = self.conv(x)
+ x = self.conv1(x)
+ x = self.bn(x)
+ x = self.act(x)
+ return x
+
+
+def calculate_difference(input_shape, target_shape):
+ """calculate difference of the target shape and the output shape of previous Conv3dTranspose
+ and the according padding sizes"""
+
+ target_shape = target_shape
+
+ # calculating output shape of Conv3dTranspose
+ input_shape = [(dim - 1)*2 - 2*0 + 1*(2 - 1) + 0 + 1 for dim in input_shape]
+
+ diff = [x - y for x, y in zip(target_shape, input_shape)]
+ paddings = [(0, 0), (0, 0)]
+ for i in range(3):
+ paddings.append((diff[i], 0))
+ return tuple(diff), tuple(paddings)
+
+
+class Up(nn.Cell):
+ """
+ Upscaling then apply Conv3dTranspose to decode
+
+ Args:
+ in_channel (int): Input channel.
+ out_channel (int): Output channel.
+ input_shape (list or tuple): Input DWH shape. Default: None.
+ target_shape (list or tuple): Output DWH shape. Default: None.
+ pad (bool): Enable manual padding, only needed in the first layer of the Decoder. Default: True.
+
+ Returns:
+ Tensor, output tensor.
+
+ Examples:
+ >>> Up(8,4, input_shape=[10, 10, 10], target_shape=[20, 20, 20], pad=False)
+
+ """
+
+ def __init__(self, input_dim, out_channels, input_shape=None, target_shape=None, pad=True):
+ super(Up, self).__init__()
+ self.pad = pad
+
+ diff, paddings = calculate_difference(input_shape, target_shape)
+ if self.pad:
+ self.pad = ops.Pad(paddings)
+ self.up = nn.Conv3dTranspose(input_dim, input_dim // 2, kernel_size=2, stride=2, pad_mode='pad', padding=0)
+
+ else:
+ self.up = nn.Conv3dTranspose(input_dim,
+ input_dim // 2,
+ kernel_size=2,
+ stride=2,
+ pad_mode='pad',
+ padding=0,
+ output_padding=diff)
+
+ self.conv = DoubleConvTranspose(input_dim // 2, out_channels)
+
+
+ def construct(self, x):
+ x = self.up(x)
+ if self.pad:
+ x = self.pad(x)
+ x = self.conv(x)
+
+ return x
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/test_data_compression.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/test_data_compression.py
new file mode 100644
index 0000000..3ce60ab
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_data_compression/test_data_compression.py
@@ -0,0 +1,137 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""ae_train"""
+
+import os
+import time
+import pytest
+
+import mindspore.nn as nn
+import mindspore.common.initializer as weight_init
+from mindspore import context
+from mindspore.common import set_seed
+from mindspore.train.callback import LossMonitor, Callback
+
+from mindelec.solver import Solver
+
+from src.dataset import create_dataset
+from src.model import EncoderDecoder
+from src.lr_generator import step_lr_generator
+from src.metric import MyMSELoss, EvalMetric
+from src.config import config
+
+train_data_path = "/home/workspace/mindspore_dataset/mindelec_data/ae_data/train_data.npy"
+test_data_path = "/home/workspace/mindspore_dataset/mindelec_data/ae_data/test_data.npy"
+set_seed(0)
+
+print("pid:", os.getpid())
+context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
+
+class TimeMonitor(Callback):
+ """
+ Monitor the time in training.
+ """
+
+ def __init__(self, data_size=None):
+ super(TimeMonitor, self).__init__()
+ self.data_size = data_size
+ self.epoch_time = time.time()
+ self.per_step_time = 0
+ self._tmp = None
+
+ def epoch_begin(self, run_context):
+ """
+ Record time at the begin of epoch.
+ """
+ self.epoch_time = time.time()
+ self._tmp = run_context
+
+ def epoch_end(self, run_context):
+ """
+ Print process cost time at the end of epoch.
+ """
+ epoch_seconds = (time.time() - self.epoch_time) * 1000
+ step_size = self.data_size
+ cb_params = run_context.original_args()
+ if hasattr(cb_params, "batch_num"):
+ batch_num = cb_params.batch_num
+ if isinstance(batch_num, int) and batch_num > 0:
+ step_size = cb_params.batch_num
+
+ self.per_step_time = epoch_seconds / step_size
+ print("epoch time: {:5.3f} ms, per step time: {:5.3f} ms".format(epoch_seconds, self.per_step_time), flush=True)
+
+ def get_step_time(self,):
+ return self.per_step_time
+
+
+def init_weight(net):
+ for _, cell in net.cells_and_names():
+ if isinstance(cell, (nn.Conv3d, nn.Dense)):
+ cell.weight.set_data(weight_init.initializer(weight_init.HeNormal(),
+ cell.weight.shape,
+ cell.weight.dtype))
+
+@pytest.mark.level1
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_auto_encoder():
+ """training"""
+
+ model_net = EncoderDecoder(config["input_channels"], config["patch_shape"], config["base_channels"], decoding=True)
+ init_weight(net=model_net)
+
+ train_dataset = create_dataset(input_path=train_data_path,
+ label_path=train_data_path,
+ batch_size=config["batch_size"],
+ shuffle=True)
+
+ eval_dataset = create_dataset(input_path=test_data_path,
+ label_path=test_data_path,
+ batch_size=config["batch_size"],
+ shuffle=False)
+
+
+ step_size = train_dataset.get_dataset_size()
+ milestones, learning_rates = step_lr_generator(step_size,
+ config["epochs"],
+ config["lr"],
+ config["lr_decay_milestones"])
+
+ optimizer = nn.Adam(model_net.trainable_params(),
+ learning_rate=nn.piecewise_constant_lr(milestones, learning_rates))
+
+ loss_net = MyMSELoss()
+ eval_step_size = eval_dataset.get_dataset_size() * config["batch_size"]
+ evl_error_mrc = EvalMetric(eval_step_size)
+
+ solver = Solver(model_net,
+ train_input_map={'train': ['train_input_data']},
+ test_input_map={'test': ['test_input_data']},
+ optimizer=optimizer,
+ metrics={'evl_mrc': evl_error_mrc,},
+ amp_level="O2",
+ loss_fn=loss_net)
+
+ time_cb = TimeMonitor()
+ solver.model.train(20, train_dataset, callbacks=[LossMonitor(), time_cb], dataset_sink_mode=False)
+ res = solver.model.eval(eval_dataset, dataset_sink_mode=False)
+ per_step_time = time_cb.get_step_time()
+ l1_error = res['evl_mrc']['mean_l1_error']
+ print('test_res:', f'l1_error: {l1_error:.10f} ')
+ print(f'per step time: {per_step_time:.10f} ')
+ assert l1_error <= 0.03
+ assert per_step_time <= 30
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_frequency_domain_maxwell/src/callback.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_frequency_domain_maxwell/src/callback.py
new file mode 100644
index 0000000..12b7e8b
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_frequency_domain_maxwell/src/callback.py
@@ -0,0 +1,124 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""
+call back functions
+"""
+import time
+
+import numpy as np
+
+from mindspore.train.callback import Callback
+from mindspore import Tensor
+import mindspore.common.dtype as mstype
+
+
+class PredictCallback(Callback):
+ """
+ Evaluate the model during training.
+
+ Args:
+ model (Cell): A testing network.
+ predict_ds (Dataset): Dataset to predictuate the model.
+ predict_interval (int): Specifies how many epochs to train before prediction.
+ input_data (Array): Input test dataset.
+ label (Array): Label data.
+ batch_size (int): batch size for prediction
+ """
+
+ def __init__(self, model, predict_interval, input_data, label, batch_size=8192):
+ super(PredictCallback, self).__init__()
+ self.model = model
+ self.input_data = input_data
+ self.label = label
+ self.predict_interval = predict_interval
+ self.batch_size = min(batch_size, len(input_data))
+ self.l2_error = 1.0
+ print("check test dataset shape: {}, {}".format(self.input_data.shape, self.label.shape))
+
+ def epoch_end(self, run_context):
+ """
+ Evaluate the model at the end of epoch.
+
+ Args:
+ run_context (RunContext): Context of the train running.
+ """
+ cb_params = run_context.original_args()
+ if cb_params.cur_epoch_num % self.predict_interval == 0:
+ print("================================Start Evaluation================================")
+
+ test_input_data = self.input_data.reshape(-1, 2)
+ label = self.label.reshape(-1, 1)
+ index = 0
+ prediction = np.zeros(label.shape)
+ time_beg = time.time()
+ while index < len(test_input_data):
+ index_end = min(index + self.batch_size, len(test_input_data))
+ test_batch = Tensor(test_input_data[index: index_end, :], dtype=mstype.float32)
+ predict = self.model(test_batch)
+ predict = predict.asnumpy()
+ prediction[index: index_end, :] = predict[:, :]
+ index = index_end
+ print("Total prediction time: {} s".format(time.time() - time_beg))
+ error = label - prediction
+ l2_error = np.sqrt(np.sum(np.square(error[:, 0]))) / np.sqrt(np.sum(np.square(label[:, 0])))
+ print("l2_error: ", l2_error)
+ print("=================================End Evaluation=================================")
+ self.l2_error = l2_error
+
+ def get_l2_error(self):
+ return self.l2_error
+
+
+class TimeMonitor(Callback):
+ """
+ Monitor the time in training.
+
+ Args:
+ data_size (int): Iteration steps to run one epoch of the whole dataset.
+ """
+ def __init__(self, data_size=None):
+ super(TimeMonitor, self).__init__()
+ self.data_size = data_size
+ self.epoch_time = time.time()
+ self.per_step_time = 0
+
+ def epoch_begin(self, run_context):
+ """
+ Set begin time at the beginning of epoch.
+
+ Args:
+ run_context (RunContext): Context of the train running.
+ """
+ run_context.original_args()
+ self.epoch_time = time.time()
+
+ def epoch_end(self, run_context):
+ """
+ Print process cost time at the end of epoch.
+ """
+ epoch_seconds = (time.time() - self.epoch_time) * 1000
+ step_size = self.data_size
+ cb_params = run_context.original_args()
+ if hasattr(cb_params, "batch_num"):
+ batch_num = cb_params.batch_num
+ if isinstance(batch_num, int) and batch_num > 0:
+ step_size = cb_params.batch_num
+
+ self.per_step_time = epoch_seconds / step_size
+ print("epoch time: {:5.3f} ms, per step time: {:5.3f} ms".format(epoch_seconds, self.per_step_time), flush=True)
+
+ def get_step_time(self):
+ return self.per_step_time
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_frequency_domain_maxwell/src/config.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_frequency_domain_maxwell/src/config.py
new file mode 100644
index 0000000..033b9fc
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_frequency_domain_maxwell/src/config.py
@@ -0,0 +1,44 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+network config setting, will be used in train.py and eval.py
+"""
+from easydict import EasyDict as ed
+
+# config
+rectangle_sampling_config = ed({
+ 'domain': ed({
+ 'random_sampling': False,
+ 'size': [100, 100],
+ }),
+ 'BC': ed({
+ 'random_sampling': True,
+ 'size': 128,
+ 'with_normal': False,
+ })
+})
+
+# config
+helmholtz_2d_config = ed({
+ "name": "Helmholtz2D",
+ "columns_list": ["input", "label"],
+ "epochs": 10,
+ "batch_size": 128,
+ "lr": 0.001,
+ "coord_min": [0.0, 0.0],
+ "coord_max": [1.0, 1.0],
+ "axis_size": 101,
+ "wave_number": 2
+})
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_frequency_domain_maxwell/src/dataset.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_frequency_domain_maxwell/src/dataset.py
new file mode 100644
index 0000000..c841d80
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_frequency_domain_maxwell/src/dataset.py
@@ -0,0 +1,42 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""
+dataset
+"""
+import numpy as np
+
+# prepare test input and label
+def test_data_prepare(config):
+ """create test dataset"""
+ coord_min = config["coord_min"]
+ coord_max = config["coord_max"]
+ axis_size = config["axis_size"]
+ wave_number = config.get("wave_number", 2.0)
+
+ # input
+ axis_x = np.linspace(coord_min[0], coord_max[0], num=axis_size, endpoint=True)
+ axis_y = np.linspace(coord_min[1], coord_max[1], num=axis_size, endpoint=True)
+ mesh_x, mesh_y = np.meshgrid(axis_y, axis_x)
+ input_data = np.hstack((mesh_x.flatten()[:, None], mesh_y.flatten()[:, None])).astype(np.float32)
+
+ # label
+ label = np.zeros((axis_size, axis_size, 1))
+ for i in range(axis_size):
+ for j in range(axis_size):
+ label[i, j, 0] = np.sin(wave_number * axis_x[j])
+
+ label = label.reshape(-1, 1).astype(np.float32)
+
+ return input_data, label
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_frequency_domain_maxwell/src/model.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_frequency_domain_maxwell/src/model.py
new file mode 100644
index 0000000..791826d
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_frequency_domain_maxwell/src/model.py
@@ -0,0 +1,55 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""
+feedforward neural network
+"""
+
+import mindspore.nn as nn
+from mindelec.architecture import get_activation, LinearBlock
+
+
+class FFNN(nn.Cell):
+ """
+ Full-connect networks.
+
+ Args:
+ input_dim (int): the input dimensions.
+ output_dim (int): the output dimensions.
+ hidden_layer (int): number of hidden layers.
+ activation (str or Cell): activation functions.
+ """
+
+ def __init__(self, input_dim, output_dim, hidden_layer=64, activation="sin"):
+ super(FFNN, self).__init__()
+ self.activation = get_activation(activation)
+ self.fc1 = LinearBlock(input_dim, hidden_layer)
+ self.fc2 = LinearBlock(hidden_layer, hidden_layer)
+ self.fc3 = LinearBlock(hidden_layer, hidden_layer)
+ self.fc4 = LinearBlock(hidden_layer, hidden_layer)
+ self.fc5 = LinearBlock(hidden_layer, output_dim)
+
+ def construct(self, *inputs):
+ """fc network"""
+ x = inputs[0]
+ out = self.fc1(x)
+ out = self.activation(out)
+ out = self.fc2(out)
+ out = self.activation(out)
+ out = self.fc3(out)
+ out = self.activation(out)
+ out = self.fc4(out)
+ out = self.activation(out)
+ out = self.fc5(out)
+ return out
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_frequency_domain_maxwell/test_frequency_domain_maxwell.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_frequency_domain_maxwell/test_frequency_domain_maxwell.py
new file mode 100644
index 0000000..65c3d50
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_frequency_domain_maxwell/test_frequency_domain_maxwell.py
@@ -0,0 +1,144 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""
+train
+"""
+import os
+import pytest
+import numpy as np
+
+import mindspore.nn as nn
+import mindspore.ops as ops
+from mindspore import context, ms_function
+from mindspore.common import set_seed
+from mindspore.train.callback import LossMonitor
+from mindspore.train.loss_scale_manager import DynamicLossScaleManager
+
+from mindelec.solver import Solver, Problem
+from mindelec.geometry import Rectangle, create_config_from_edict
+from mindelec.common import L2
+from mindelec.data import Dataset
+from mindelec.operators import SecondOrderGrad as Hessian
+from mindelec.loss import Constraints
+
+from src.config import rectangle_sampling_config, helmholtz_2d_config
+from src.model import FFNN
+from src.dataset import test_data_prepare
+from src.callback import PredictCallback, TimeMonitor
+
+set_seed(0)
+np.random.seed(0)
+
+print("pid:", os.getpid())
+context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend")
+
+
+# define problem
+class Helmholtz2D(Problem):
+ """2D Helmholtz equation"""
+ def __init__(self, domain_name, bc_name, net, wavenumber=2):
+ super(Helmholtz2D, self).__init__()
+ self.domain_name = domain_name
+ self.bc_name = bc_name
+ self.type = "Equation"
+ self.wave_number = wavenumber
+ self.grad_xx = Hessian(net, input_idx1=0, input_idx2=0, output_idx=0)
+ self.grad_yy = Hessian(net, input_idx1=1, input_idx2=1, output_idx=0)
+ self.reshape = ops.Reshape()
+
+ @ms_function
+ def governing_equation(self, *output, **kwargs):
+ """governing equation"""
+ u = output[0]
+ x = kwargs[self.domain_name][:, 0]
+ y = kwargs[self.domain_name][:, 1]
+ x = self.reshape(x, (-1, 1))
+ y = self.reshape(y, (-1, 1))
+
+ u_xx = self.grad_xx(kwargs[self.domain_name])
+ u_yy = self.grad_yy(kwargs[self.domain_name])
+
+ return u_xx + u_yy + self.wave_number**2 * u
+
+ @ms_function
+ def boundary_condition(self, *output, **kwargs):
+ """boundary condition"""
+ u = output[0]
+ x = kwargs[self.bc_name][:, 0]
+ y = kwargs[self.bc_name][:, 1]
+ x = self.reshape(x, (-1, 1))
+ y = self.reshape(y, (-1, 1))
+
+ test_label = ops.sin(self.wave_number * x)
+ return 100 * (u - test_label)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_frequency_domain_maxwell():
+ """train process"""
+ net = FFNN(input_dim=2, output_dim=1, hidden_layer=64)
+
+ # define geometry
+ geom_name = "rectangle"
+ rect_space = Rectangle(geom_name,
+ coord_min=helmholtz_2d_config["coord_min"],
+ coord_max=helmholtz_2d_config["coord_max"],
+ sampling_config=create_config_from_edict(rectangle_sampling_config))
+ geom_dict = {rect_space: ["domain", "BC"]}
+
+ # create dataset for train and test
+ train_dataset = Dataset(geom_dict)
+ train_data = train_dataset.create_dataset(batch_size=helmholtz_2d_config.get("batch_size", 128),
+ shuffle=True, drop_remainder=False)
+ test_input, test_label = test_data_prepare(helmholtz_2d_config)
+
+ # define problem and constraints
+ train_prob_dict = {geom_name: Helmholtz2D(domain_name=geom_name + "_domain_points",
+ bc_name=geom_name + "_BC_points",
+ net=net,
+ wavenumber=helmholtz_2d_config.get("wavenumber", 2)),
+ }
+ train_constraints = Constraints(train_dataset, train_prob_dict)
+
+ # optimizer
+ optim = nn.Adam(net.trainable_params(), learning_rate=helmholtz_2d_config.get("lr", 1e-4))
+
+ # solver
+ solver = Solver(net,
+ optimizer=optim,
+ mode="PINNs",
+ train_constraints=train_constraints,
+ test_constraints=None,
+ amp_level="O2",
+ metrics={'l2': L2(), 'distance': nn.MAE()},
+ loss_scale_manager=DynamicLossScaleManager()
+ )
+
+ # train
+ time_cb = TimeMonitor()
+ loss_cb = PredictCallback(model=net, predict_interval=10, input_data=test_input, label=test_label)
+ solver.train(epoch=helmholtz_2d_config.get("epochs", 10),
+ train_dataset=train_data,
+ callbacks=[time_cb, LossMonitor(), loss_cb])
+ per_step_time = time_cb.get_step_time()
+ l2_error = loss_cb.get_l2_error()
+
+ print(f'l2 error: {l2_error:.10f}')
+ print(f'per step time: {per_step_time:.10f}')
+ assert l2_error <= 0.05
+ assert per_step_time <= 10.0
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/src/config.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/src/config.py
new file mode 100644
index 0000000..9f7d747
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/src/config.py
@@ -0,0 +1,31 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+network config setting, will be used in train.py and eval.py
+"""
+from easydict import EasyDict as ed
+
+# config
+config = ed({
+ "epochs": 500,
+ "batch_size": 8,
+ "lr": 0.0001,
+ "t_solution": 162,
+ "x_solution": 50,
+ "y_solution": 50,
+ "z_solution": 8,
+ "save_checkpoint_epochs": 5,
+ "keep_checkpoint_max": 20
+})
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/src/dataset.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/src/dataset.py
new file mode 100644
index 0000000..accc7a4
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/src/dataset.py
@@ -0,0 +1,35 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""
+dataset
+"""
+import numpy as np
+from mindelec.data import Dataset, ExistedDataConfig
+
+
+def create_dataset(data_path, batch_size=8, shuffle=True, drop_remainder=True, is_train=True):
+ """create dataset"""
+ input_path = data_path + "inputs.npy"
+ label_path = data_path + "label.npy"
+ electromagnetic = ExistedDataConfig(name="electromagnetic",
+ data_dir=[input_path, label_path],
+ columns_list=["inputs", "label"],
+ data_format="npy")
+ dataset = Dataset(existed_data_list=[electromagnetic])
+ data_loader = dataset.create_dataset(batch_size=batch_size, shuffle=shuffle, drop_remainder=drop_remainder)
+ scale = None
+ if not is_train:
+ scale = np.load(data_path+"scale.npy")
+ return data_loader, scale
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/src/loss.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/src/loss.py
new file mode 100644
index 0000000..fdc2358
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/src/loss.py
@@ -0,0 +1,80 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""
+loss
+"""
+
+import numpy as np
+import mindspore.nn as nn
+from mindspore.ops import functional as F
+from src.config import config
+
+
+class MyMSELoss(nn.LossBase):
+ """mse loss function"""
+ def construct(self, base, target):
+ bs, _, _, _, _ = F.shape(target)
+ x = F.square(base - target)
+ return 2*bs*self.get_loss(x)
+
+
+class EvaLMetric(nn.Metric):
+ """eval metric"""
+ def __init__(self, length, scale, batch_size):
+ super(EvaLMetric, self).__init__()
+ self.clear()
+ self.length = length
+ self.batch_size = batch_size
+ self.t = config.t_solution
+ self.x = config.x_solution
+ self.y = config.y_solution
+ self.z = config.z_solution
+ self.predict_real = np.zeros((self.length*self.t, self.x, self.y, self.z, 6), dtype=np.float32)
+ self.label_real = np.zeros((self.length*self.t, self.x, self.y, self.z, 6), dtype=np.float32)
+ self.scale = scale
+ self.iter_idx = 0
+
+ def clear(self):
+ """clear"""
+ self.iter_idx = 0
+
+ def update(self, *inputs):
+ """update"""
+ y_pred = self._convert_data(inputs[0])
+ y = self._convert_data(inputs[1])
+
+ predict, label = y_pred, y
+ self.predict_real[self.iter_idx*self.batch_size: self.iter_idx*self.batch_size + label.shape[0]] = predict
+ self.label_real[self.iter_idx*self.batch_size: self.iter_idx*self.batch_size + label.shape[0]] = label
+ self.iter_idx += 1
+
+ def eval(self):
+ """eval"""
+ predict_real = np.reshape(self.predict_real, (self.length, self.t, self.x, self.y, self.z, 6))
+ label_real = np.reshape(self.label_real, (self.length, self.t, self.x, self.y, self.z, 6))
+ l2_time = 0.0
+ for i in range(self.length):
+ predict_real_temp = predict_real[i:i+1]
+ label_real_temp = label_real[i:i+1]
+ for j in range(self.t):
+ predict_real_temp[0, j, :, :, :, 0] = predict_real_temp[0, j, :, :, :, 0] * self.scale[0][j]
+ predict_real_temp[0, j, :, :, :, 1] = predict_real_temp[0, j, :, :, :, 1] * self.scale[1][j]
+ predict_real_temp[0, j, :, :, :, 2] = predict_real_temp[0, j, :, :, :, 2] * self.scale[2][j]
+ predict_real_temp[0, j, :, :, :, 3] = predict_real_temp[0, j, :, :, :, 3] * self.scale[3][j]
+ predict_real_temp[0, j, :, :, :, 4] = predict_real_temp[0, j, :, :, :, 4] * self.scale[4][j]
+ predict_real_temp[0, j, :, :, :, 5] = predict_real_temp[0, j, :, :, :, 5] * self.scale[5][j]
+ l2_time += (np.sqrt(np.sum(np.square(label_real_temp - predict_real_temp))) /
+ np.sqrt(np.sum(np.square(label_real_temp))))
+ return {'l2_error': l2_time / self.length}
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/src/maxwell_model.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/src/maxwell_model.py
new file mode 100644
index 0000000..371f4d9
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/src/maxwell_model.py
@@ -0,0 +1,176 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""
+maxwell 3d model
+"""
+
+import mindspore.nn as nn
+from mindspore.ops import operations as P
+import mindspore.ops.functional as F
+from mindelec.architecture import get_activation
+
+
+class Maxwell3D(nn.Cell):
+ """maxwell3d"""
+ def __init__(self, output_dim):
+ super(Maxwell3D, self).__init__()
+
+ self.output_dim = output_dim
+ width = 64
+ self.net0 = ModelHead(4, width)
+ self.net1 = ModelHead(4, width)
+ self.net2 = ModelHead(4, width)
+ self.net3 = ModelHead(4, width)
+ self.net4 = ModelHead(4, width)
+
+ self.fc0 = nn.Dense(width+33, 128)
+ self.net = ModelOut(128, output_dim, (2, 2, 1), (2, 2, 1))
+ self.cat = P.Concat(axis=-1)
+
+ def construct(self, x):
+ """forward"""
+ x_location = x[..., :4]
+ x_media = x[..., 4:]
+ out1 = self.net0(x_location)
+ out2 = self.net1(2*x_location)
+ out3 = self.net2(4*x_location)
+ out4 = self.net3(8*x_location)
+ out5 = self.net4(16.0*x_location)
+ out = out1 + out2 + out3 + out4 + out5
+ out = self.cat((out, x_media))
+ out = self.fc0(out)
+ out = self.net(out)
+ return out
+
+
+class ModelHead(nn.Cell):
+ """model_head"""
+ def __init__(self, input_dim, output_dim):
+ super(ModelHead, self).__init__()
+ self.output_dim = output_dim
+ self.fc0 = nn.Dense(input_dim, output_dim)
+ self.fc1 = nn.Dense(output_dim, output_dim)
+ self.act0 = get_activation('srelu')
+ self.act1 = get_activation('srelu')
+
+ def construct(self, x):
+ """forward"""
+ x = self.fc0(x)
+ x = self.act0(x)
+ x = self.fc1(x)
+ x = self.act1(x)
+
+ return x
+
+
+class ModelOut(nn.Cell):
+ """model out"""
+ def __init__(self, input_dim, output_dim, kernel_size=2, strides=2):
+ super(ModelOut, self).__init__()
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.base_channels = 64
+ self.inc = DoubleConv(self.input_dim, self.base_channels)
+ self.down1 = Down(self.base_channels, self.base_channels * 2, kernel_size, strides)
+ self.down2 = Down(self.base_channels * 2, self.base_channels * 4, kernel_size, strides)
+ self.down3 = Down(self.base_channels * 4, self.base_channels * 8, kernel_size, strides)
+ self.down4 = Down(self.base_channels * 8, self.base_channels * 16, kernel_size, strides)
+ self.up1 = Up(self.base_channels * 16, self.base_channels * 8, kernel_size, strides)
+ self.up2 = Up(self.base_channels * 8, self.base_channels * 4, kernel_size, strides)
+ self.up3 = Up(self.base_channels * 4, self.base_channels * 2, kernel_size, strides)
+ self.up4 = Up(self.base_channels * 2, self.base_channels, kernel_size, strides)
+
+ self.fc1 = nn.Dense(self.base_channels+128, 64)
+ self.fc2 = nn.Dense(64, output_dim)
+ self.relu = nn.ReLU()
+ self.transpose = P.Transpose()
+ self.cat = P.Concat(axis=1)
+
+ def construct(self, x):
+ """forward"""
+ x0 = self.transpose(x, (0, 4, 1, 2, 3))
+ x1 = self.inc(x0)
+ x2 = self.down1(x1)
+ x3 = self.down2(x2)
+ x4 = self.down3(x3)
+ x5 = self.down4(x4)
+ x = self.up1(x5, x4)
+ x = self.up2(x, x3)
+ x = self.up3(x, x2)
+ x = self.up4(x, x1)
+ x = self.cat((x, x0))
+ x = self.transpose(x, (0, 2, 3, 4, 1))
+ x = self.fc1(x)
+ x = self.relu(x)
+ x = self.fc2(x)
+
+ return x
+
+
+class DoubleConv(nn.Cell):
+ """double conv"""
+ def __init__(self, input_dim, out_channels, mid_channels=None):
+ super().__init__()
+ if not mid_channels:
+ mid_channels = out_channels
+ self.double_conv = nn.SequentialCell(
+ nn.Conv3d(input_dim, mid_channels, kernel_size=3),
+ nn.BatchNorm3d(mid_channels),
+ nn.ReLU(),
+ nn.Conv3d(mid_channels, out_channels, kernel_size=3),
+ nn.BatchNorm3d(out_channels),
+ nn.ReLU()
+ )
+
+ def construct(self, x):
+ """forward"""
+ return self.double_conv(x)
+
+
+class Down(nn.Cell):
+ """down"""
+ def __init__(self, input_dim, out_channels, kernel_size=2, strides=2):
+ super().__init__()
+ self.conv = DoubleConv(input_dim, out_channels)
+ self.maxpool = P.MaxPool3D(kernel_size=kernel_size, strides=strides)
+
+ def construct(self, x):
+ """forward"""
+ x = self.maxpool(x)
+ return self.conv(x)
+
+
+class Up(nn.Cell):
+ """up"""
+ def __init__(self, input_dim, out_channels, kernel_size=2, strides=2):
+ super().__init__()
+ self.up = nn.Conv3dTranspose(input_dim, input_dim // 2, kernel_size=kernel_size, stride=strides)
+ self.conv = DoubleConv(input_dim, out_channels)
+ self.cat = P.Concat(axis=1)
+
+ def construct(self, x1, x2):
+ """forward"""
+ x1 = self.up(x1)
+
+ _, _, h1, w1, c1 = F.shape(x1)
+ _, _, h2, w2, c2 = F.shape(x2)
+ diff_z = c2 - c1
+ diff_y = w2 - w1
+ diff_x = h2 - h1
+
+ x1 = P.Pad(((0, 0), (0, 0), (diff_x // 2, diff_x - diff_x // 2), (diff_y // 2, diff_y - diff_y // 2),
+ (diff_z // 2, diff_z - diff_z // 2)))(x1)
+ x = self.cat((x2, x1))
+ return self.conv(x)
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/src/sample.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/src/sample.py
new file mode 100644
index 0000000..998dc88
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/src/sample.py
@@ -0,0 +1,37 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""
+sample fake data for train and test
+"""
+import os
+import numpy as np
+
+
+def generate_data(train_path, test_path):
+ """generate fake data"""
+ if not os.path.exists(train_path):
+ os.makedirs(train_path)
+ if not os.path.exists(test_path):
+ os.makedirs(test_path)
+
+ inputs = np.ones((162, 50, 50, 8, 37), dtype=np.float32)
+ label = np.ones((162, 50, 50, 8, 6), dtype=np.float32)
+ np.save(train_path + "inputs.npy", inputs)
+ np.save(train_path + "label.npy", label)
+
+ scale = np.ones((6, 162), dtype=np.float32)
+ np.save(test_path + "inputs.npy", inputs)
+ np.save(test_path + "label.npy", label)
+ np.save(test_path + "scale.npy", scale)
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/test_full_em.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/test_full_em.py
new file mode 100644
index 0000000..53ced86
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_full_em/test_full_em.py
@@ -0,0 +1,152 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""
+train
+"""
+import os
+import time
+import numpy as np
+import pytest
+
+import mindspore.nn as nn
+import mindspore.common.initializer as weight_init
+from mindspore.common import set_seed
+from mindspore import Tensor
+from mindspore import context
+from mindspore.train.callback import Callback, LossMonitor
+from mindspore.train.loss_scale_manager import DynamicLossScaleManager
+from mindelec.solver import Solver
+from src.dataset import create_dataset
+from src.loss import MyMSELoss, EvaLMetric
+from src.maxwell_model import Maxwell3D
+from src.config import config
+from src.sample import generate_data
+
+set_seed(0)
+np.random.seed(0)
+
+train_data_path = "./train_data_em/"
+test_data_path = "./test_data_em/"
+
+print("pid:", os.getpid())
+context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
+
+
+class TimeMonitor(Callback):
+ """
+ Monitor the time in training.
+ """
+
+ def __init__(self, data_size=None):
+ super(TimeMonitor, self).__init__()
+ self.data_size = data_size
+ self.epoch_time = time.time()
+ self.per_step_time = 0
+
+ def epoch_begin(self, run_context):
+ """
+ Record time at the begin of epoch.
+ """
+ run_context.original_args()
+ self.epoch_time = time.time()
+
+ def epoch_end(self, run_context):
+ """
+ Print process cost time at the end of epoch.
+ """
+ epoch_seconds = (time.time() - self.epoch_time) * 1000
+ step_size = self.data_size
+ cb_params = run_context.original_args()
+ if hasattr(cb_params, "batch_num"):
+ batch_num = cb_params.batch_num
+ if isinstance(batch_num, int) and batch_num > 0:
+ step_size = cb_params.batch_num
+
+ self.per_step_time = epoch_seconds / step_size
+ print("epoch time: {:5.3f} ms, per step time: {:5.3f} ms".format(epoch_seconds, self.per_step_time), flush=True)
+
+ def get_step_time(self,):
+ return self.per_step_time
+
+
+def get_lr(lr_init, steps_per_epoch, total_epochs):
+ """get lr"""
+ lr_each_step = []
+ total_steps = steps_per_epoch * total_epochs
+ for i in range(total_steps):
+ epoch = i // steps_per_epoch
+ lr_local = lr_init
+ if epoch <= 15:
+ lr_local = lr_init
+ elif epoch <= 45:
+ lr_local = lr_init * 0.5
+ elif epoch <= 300:
+ lr_local = lr_init * 0.25
+ elif epoch <= 600:
+ lr_local = lr_init * 0.125
+ lr_each_step.append(lr_local)
+ learning_rate = np.array(lr_each_step).astype(np.float32)
+ print(learning_rate)
+ return learning_rate
+
+
+def init_weight(net):
+ """init weight"""
+ for _, cell in net.cells_and_names():
+ if isinstance(cell, (nn.Conv3d, nn.Dense)):
+ cell.weight.set_data(weight_init.initializer(weight_init.HeNormal(),
+ cell.weight.shape,
+ cell.weight.dtype))
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_full_em():
+ """train"""
+ generate_data(train_data_path, test_data_path)
+ train_dataset, _ = create_dataset(train_data_path, batch_size=config.batch_size, shuffle=True)
+ test_dataset, config_scale = create_dataset(test_data_path, batch_size=config.batch_size,
+ shuffle=False, drop_remainder=False, is_train=False)
+ model_net = Maxwell3D(6)
+ init_weight(net=model_net)
+ train_step_size = train_dataset.get_dataset_size()
+ lr = get_lr(config.lr, train_step_size, config.epochs)
+ optimizer = nn.Adam(model_net.trainable_params(), learning_rate=Tensor(lr))
+ loss_net = MyMSELoss()
+ loss_scale = DynamicLossScaleManager()
+
+ tets_step_size = test_dataset.get_dataset_size()
+ test_batch_size = test_dataset.get_batch_size()
+ data_length = tets_step_size * test_batch_size // config.t_solution
+ evl_error_mrc = EvaLMetric(data_length, config_scale, test_batch_size)
+ solver = Solver(model_net,
+ optimizer=optimizer,
+ loss_scale_manager=loss_scale,
+ amp_level="O2",
+ keep_batchnorm_fp32=False,
+ loss_fn=loss_net,
+ metrics={"evl_mrc": evl_error_mrc})
+
+ time_cb = TimeMonitor()
+ solver.model.train(5, train_dataset, callbacks=[LossMonitor(), time_cb], dataset_sink_mode=False)
+ res = solver.model.eval(test_dataset, dataset_sink_mode=False)
+ per_step_time = time_cb.get_step_time()
+ l2_s11 = res['evl_mrc']['l2_error']
+ print('test_res:', f'l2_error: {l2_s11:.10f} ')
+ print(f'per step time: {per_step_time:.10f} ')
+ assert l2_s11 <= 0.05
+ assert per_step_time <= 150
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/pretrain.json b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/pretrain.json
new file mode 100644
index 0000000..9104c66
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/pretrain.json
@@ -0,0 +1,44 @@
+{
+ "Description" : [ "PINNs for solve Maxwell's equations" ],
+
+ "Case" : "2D_Mur_Src_Gauss_Mscale_MTL_PIAD",
+ "coord_min" : [0, 0],
+ "coord_max" : [1, 1],
+ "src_pos" : [0.4975, 0.4975],
+ "SrcFrq": 1e+9,
+ "range_t" : 4e-9,
+ "input_center": [0.5, 0.5, 2.0e-9],
+ "input_scale": [2.0, 2.0, 5.0e+8],
+ "output_scale": [37.67303, 37.67303, 0.1],
+ "src_radius": 0.01,
+ "input_size" : 3,
+ "output_size" : 3,
+ "residual" : true,
+ "num_scales" : 4,
+ "layers" : 7,
+ "neurons" : 64,
+ "amp_factor" : 2,
+ "scale_factor" : 2,
+ "save_ckpt" : true,
+ "load_ckpt" : false,
+ "save_ckpt_path" : "./ckpt",
+ "load_ckpt_path" : "",
+ "train_with_eval": true,
+ "test_data_path" : "./benchmark/",
+ "lr" : 0.001,
+ "milestones" : [2000],
+ "lr_gamma" : 0.25,
+ "train_epoch" : 50,
+ "train_batch_size" : 1024,
+ "test_batch_size" : 8192,
+ "predict_interval" : 500,
+ "vision_path" : "./vision",
+ "summary_path" : "./summary",
+
+ "EPS_candidates": [1, 3, 5],
+ "MU_candidates": [1, 3, 5],
+ "num_scenarios": 9,
+ "latent_vector_size": 16,
+ "latent_reg": 1.0,
+ "latent_init_std": 1.0
+}
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/src/__init__.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/src/__init__.py
new file mode 100644
index 0000000..638e227
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/src/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+init
+"""
+from .dataset import get_test_data, create_random_dataset
+from .maxwell import Maxwell2DMur
+from .lr_scheduler import MultiStepLR
+
+
+__all__ = [
+ "create_random_dataset",
+ "get_test_data",
+ "Maxwell2DMur",
+ "MultiStepLR",
+]
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/src/dataset.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/src/dataset.py
new file mode 100644
index 0000000..219a844
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/src/dataset.py
@@ -0,0 +1,57 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+create dataset
+"""
+import numpy as np
+from mindelec.data import Dataset
+from mindelec.geometry import Disk, Rectangle, TimeDomain, GeometryWithTime
+from mindelec.geometry import create_config_from_edict
+
+from .sampling_config import no_src_sampling_config, src_sampling_config, bc_sampling_config
+
+def get_test_data(test_data_path):
+ """load label_dataed data for evaluation"""
+ # check data
+ paths = [test_data_path + '/input.npy', test_data_path + '/output.npy']
+ input_data = np.load(paths[0])
+ label_data = np.load(paths[1])
+ return input_data, label_data
+
+def create_random_dataset(config):
+ """create training dataset by online sampling"""
+ radius = config["src_radius"]
+ origin = config["src_pos"]
+
+ disk = Disk("src", origin, radius)
+ rect = Rectangle("rect", config["coord_min"], config["coord_max"])
+ diff = rect - disk
+ interval = TimeDomain("time", 0.0, config["range_t"])
+ no_src = GeometryWithTime(diff, interval)
+ no_src.set_name("no_src")
+ no_src.set_sampling_config(create_config_from_edict(no_src_sampling_config))
+ src = GeometryWithTime(disk, interval)
+ src.set_name("src")
+ src.set_sampling_config(create_config_from_edict(src_sampling_config))
+ bc = GeometryWithTime(rect, interval)
+ bc.set_name("bc")
+ bc.set_sampling_config(create_config_from_edict(bc_sampling_config))
+
+ geom_dict = {src: ["domain", "IC"],
+ no_src: ["domain", "IC"],
+ bc: ["BC"]}
+
+ dataset = Dataset(geom_dict)
+ return dataset
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/src/lr_scheduler.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/src/lr_scheduler.py
new file mode 100644
index 0000000..7029722
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/src/lr_scheduler.py
@@ -0,0 +1,73 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ===========================================================================
+"""Learning rate scheduler."""
+from collections import Counter
+import numpy as np
+
+class _LRScheduler:
+ """
+ Basic class for learning rate scheduler
+ """
+
+ def __init__(self, lr, max_epoch, steps_per_epoch):
+ self.base_lr = lr
+ self.steps_per_epoch = steps_per_epoch
+ self.total_steps = int(max_epoch * steps_per_epoch)
+
+ def get_lr(self):
+ # Compute learning rate using chainable form of the scheduler
+ raise NotImplementedError
+
+
+class MultiStepLR(_LRScheduler):
+ """
+ Multi-step learning rate scheduler
+
+ Decays the learning rate by gamma once the number of epoch reaches one of the milestones.
+
+ Args:
+ lr (float): Initial learning rate which is the lower boundary in the cycle.
+ milestones (list): List of epoch indices. Must be increasing.
+ gamma (float): Multiplicative factor of learning rate decay.
+ steps_per_epoch (int): The number of steps per epoch to train for.
+ max_epoch (int): The number of epochs to train for.
+
+ Outputs:
+ numpy.ndarray, shape=(1, steps_per_epoch*max_epoch)
+
+ Example:
+ >>> # Assuming optimizer uses lr = 0.05 for all groups
+ >>> # lr = 0.05 if epoch < 30
+ >>> # lr = 0.005 if 30 <= epoch < 80
+ >>> # lr = 0.0005 if epoch >= 80
+ >>> scheduler = MultiStepLR(lr=0.1, milestones=[30,80], gamma=0.1, steps_per_epoch=5000, max_epoch=90)
+ >>> lr = scheduler.get_lr()
+ """
+
+ def __init__(self, lr, milestones, gamma, steps_per_epoch, max_epoch):
+ self.milestones = Counter(milestones)
+ self.gamma = gamma
+ super(MultiStepLR, self).__init__(lr, max_epoch, steps_per_epoch)
+
+ def get_lr(self):
+ lr_each_step = []
+ current_lr = self.base_lr
+ for i in range(self.total_steps):
+ cur_ep = i // self.steps_per_epoch
+ if i % self.steps_per_epoch == 0 and cur_ep in self.milestones:
+ current_lr = current_lr * self.gamma
+ lr = current_lr
+ lr_each_step.append(lr)
+ return np.array(lr_each_step).astype(np.float32)
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/src/maxwell.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/src/maxwell.py
new file mode 100644
index 0000000..acafe2f
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/src/maxwell.py
@@ -0,0 +1,184 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+#pylint: disable=W0613
+"""
+2D maxwell problem with Mur bc
+"""
+import numpy as np
+
+import mindspore.numpy as ms_np
+from mindspore import ms_function
+from mindspore import ops
+from mindspore import Tensor
+import mindspore.common.dtype as ms_type
+
+from mindelec.solver import Problem
+from mindelec.common import MU, EPS, PI
+from mindelec.operators import Grad
+
+
+class Maxwell2DMur(Problem):
+ r"""
+ The 2D Maxwell's equations with 2nd-order Mur absorbed boundary condition.
+
+ Args:
+ network (Cell): The solving network.
+ config (dict): Setting information.
+ domain_column (str): The corresponding column name of data which governed by maxwell's equation.
+ bc_column (str): The corresponding column name of data which governed by boundary condition.
+ bc_normal (str): The column name of normal direction vector corresponding to specified boundary.
+ ic_column (str): The corresponding column name of data which governed by initial condition.
+ """
+ def __init__(self, network, config, domain_column=None, bc_column=None, ic_column=None):
+ super(Maxwell2DMur, self).__init__()
+ self.domain_column = domain_column
+ self.bc_column = bc_column
+ self.ic_column = ic_column
+ self.network = network
+
+ # operations
+ self.gradient = Grad(self.network)
+ self.reshape = ops.Reshape()
+ self.cast = ops.Cast()
+ self.mul = ops.Mul()
+ self.cast = ops.Cast()
+ self.split = ops.Split(1, 3)
+ self.concat = ops.Concat(1)
+ self.sqrt = ops.Sqrt()
+
+ # gauss-type pulse source
+ self.pi = Tensor(PI, ms_type.float32)
+ self.src_frq = config.get("src_frq", 1e+9)
+ self.tau = Tensor((2.3 ** 0.5) / (PI * self.src_frq), ms_type.float32)
+ self.amp = Tensor(1.0, ms_type.float32)
+ self.t0 = Tensor(3.65 * self.tau, ms_type.float32)
+
+ # src space
+ self.src_x0 = Tensor(config["src_pos"][0], ms_type.float32)
+ self.src_y0 = Tensor(config["src_pos"][1], ms_type.float32)
+ self.src_sigma = Tensor(config["src_radius"] / 4.0, ms_type.float32)
+ self.src_coord_min = config["coord_min"]
+ self.src_coord_max = config["coord_max"]
+
+ input_scales = config.get("input_scales", [1.0, 1.0, 2.5e+8])
+ output_scales = config.get("output_scales", [37.67303, 37.67303, 0.1])
+ self.s_x = Tensor(input_scales[0], ms_type.float32)
+ self.s_y = Tensor(input_scales[1], ms_type.float32)
+ self.s_t = Tensor(input_scales[2], ms_type.float32)
+ self.s_ex = Tensor(output_scales[0], ms_type.float32)
+ self.s_ey = Tensor(output_scales[1], ms_type.float32)
+ self.s_hz = Tensor(output_scales[2], ms_type.float32)
+
+ # set up eps, mu candidates
+ eps_candidates = np.array(config["eps_list"], dtype=np.float32) * EPS
+ mu_candidates = np.array(config["mu_list"], dtype=np.float32) * MU
+ self.epsilon_x = Tensor(eps_candidates, ms_type.float32).view((-1, 1))
+ self.epsilon_y = Tensor(eps_candidates, ms_type.float32).view((-1, 1))
+ self.mu_z = Tensor(mu_candidates, ms_type.float32).view((-1, 1))
+ self.light_speed = 1.0 / ops.Sqrt()(ops.Mul()(self.epsilon_x, self.mu_z))
+
+ def smooth_src(self, x, y, t):
+ source = self.amp * ops.exp(- ((t - self.t0) / self.tau)**2)
+ gauss = 1 / (2 * self.pi * self.src_sigma**2) * \
+ ops.exp(- ((x - self.src_x0)**2 + (y - self.src_y0)**2) / (2 * (self.src_sigma**2)))
+ return self.mul(source, gauss)
+
+ @ms_function
+ def governing_equation(self, *output, **kwargs):
+ """maxwell equation of TE mode wave"""
+ out = output[0]
+ data = kwargs[self.domain_column]
+ x = self.reshape(data[:, 0], (-1, 1))
+ y = self.reshape(data[:, 1], (-1, 1))
+ t = self.reshape(data[:, 2], (-1, 1))
+
+ dex_dxyt = self.gradient(data, None, 0, out)
+ _, dex_dy, dex_dt = self.split(dex_dxyt)
+ dey_dxyt = self.gradient(data, None, 1, out)
+ dey_dx, _, dey_dt = self.split(dey_dxyt)
+ dhz_dxyt = self.gradient(data, None, 2, out)
+ dhz_dx, dhz_dy, dhz_dt = self.split(dhz_dxyt)
+
+ dex_dy = self.cast(dex_dy, ms_type.float32)
+ dex_dt = self.cast(dex_dt, ms_type.float32)
+ dey_dx = self.cast(dey_dx, ms_type.float32)
+ dey_dt = self.cast(dey_dt, ms_type.float32)
+ dhz_dx = self.cast(dhz_dx, ms_type.float32)
+ dhz_dy = self.cast(dhz_dy, ms_type.float32)
+ dhz_dt = self.cast(dhz_dt, ms_type.float32)
+
+ loss_a1 = (self.s_hz * dhz_dy) / (self.s_ex * self.s_t * self.epsilon_x)
+ loss_a2 = dex_dt / self.s_t
+
+ loss_b1 = -(self.s_hz * dhz_dx) / (self.s_ey * self.s_t * self.epsilon_y)
+ loss_b2 = dey_dt / self.s_t
+
+ loss_c1 = (self.s_ey * dey_dx - self.s_ex * dex_dy) / (self.s_hz * self.s_t * self.mu_z)
+ loss_c2 = - dhz_dt / self.s_t
+
+ source = self.smooth_src(x, y, t) / (self.s_hz * self.s_t * self.mu_z)
+
+ pde_res1 = loss_a1 - loss_a2
+ pde_res2 = loss_b1 - loss_b2
+ pde_res3 = loss_c1 - loss_c2 - source
+ pde_r = ops.Concat(1)((pde_res1, pde_res2, pde_res3))
+ return pde_r
+
+ @ms_function
+ def boundary_condition(self, *output, **kwargs):
+ """2nd-order mur boundary condition"""
+ u = output[0]
+ data = kwargs[self.bc_column]
+
+ coord_min = self.src_coord_min
+ coord_max = self.src_coord_max
+ batch_size, _ = data.shape
+ bc_attr = ms_np.zeros(shape=(batch_size, 4))
+ bc_attr[:, 0] = ms_np.where(ms_np.isclose(data[:, 0], coord_min[0]), 1.0, 0.0)
+ bc_attr[:, 1] = ms_np.where(ms_np.isclose(data[:, 0], coord_max[0]), 1.0, 0.0)
+ bc_attr[:, 2] = ms_np.where(ms_np.isclose(data[:, 1], coord_min[1]), 1.0, 0.0)
+ bc_attr[:, 3] = ms_np.where(ms_np.isclose(data[:, 1], coord_max[1]), 1.0, 0.0)
+
+ dex_dxyt = self.gradient(data, None, 0, u)
+ _, dex_dy, _ = self.split(dex_dxyt)
+ dey_dxyt = self.gradient(data, None, 1, u)
+ dey_dx, _, _ = self.split(dey_dxyt)
+ dhz_dxyt = self.gradient(data, None, 2, u)
+ dhz_dx, dhz_dy, dhz_dt = self.split(dhz_dxyt)
+
+ dex_dy = self.cast(dex_dy, ms_type.float32)
+ dey_dx = self.cast(dey_dx, ms_type.float32)
+ dhz_dx = self.cast(dhz_dx, ms_type.float32)
+ dhz_dy = self.cast(dhz_dy, ms_type.float32)
+ dhz_dt = self.cast(dhz_dt, ms_type.float32)
+
+ bc_r1 = dhz_dx / self.s_x - dhz_dt / (self.light_speed * self.s_x) + \
+ self.s_ex * self.light_speed * self.epsilon_x / (2 * self.s_hz * self.s_x) * dex_dy # x=0
+ bc_r2 = dhz_dx / self.s_x + dhz_dt / (self.light_speed * self.s_x) - \
+ self.s_ex * self.light_speed * self.epsilon_x / (2 * self.s_hz * self.s_x) * dex_dy # x=L
+ bc_r3 = dhz_dy / self.s_y - dhz_dt / (self.light_speed * self.s_y) - \
+ self.s_ey * self.light_speed * self.epsilon_y / (2 * self.s_hz * self.s_y) * dey_dx # y=0
+ bc_r4 = dhz_dy / self.s_y + dhz_dt / (self.light_speed * self.s_y) + \
+ self.s_ey * self.light_speed * self.epsilon_y / (2 * self.s_hz * self.s_y) * dey_dx # y=L
+
+ bc_r_all = self.concat((bc_r1, bc_r2, bc_r3, bc_r4))
+ bc_r = self.mul(bc_r_all, bc_attr)
+ return bc_r
+
+ @ms_function
+ def initial_condition(self, *output, **kwargs):
+ """initial condition: u = 0"""
+ net_out = output[0]
+ return net_out
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/src/sampling_config.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/src/sampling_config.py
new file mode 100644
index 0000000..6d17871
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/src/sampling_config.py
@@ -0,0 +1,52 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""sampling information"""
+import copy
+from easydict import EasyDict as edict
+
+
+src_sampling_config = edict({
+ 'domain': edict({
+ 'random_sampling': True,
+ 'size': 262144,
+ 'sampler': 'uniform'
+ }),
+ 'IC': edict({
+ 'random_sampling': True,
+ 'size': 262144,
+ 'sampler': 'uniform',
+ }),
+ 'time': edict({
+ 'random_sampling': True,
+ 'size': 262144,
+ 'sampler': 'uniform',
+ }),
+})
+
+no_src_sampling_config = copy.deepcopy(src_sampling_config)
+
+bc_sampling_config = edict({
+ 'BC': edict({
+ 'random_sampling': True,
+ 'size': 262144,
+ 'sampler': 'uniform',
+ 'with_normal': False
+ }),
+ 'time': edict({
+ 'random_sampling': True,
+ 'size': 262144,
+ 'sampler': 'uniform',
+ }),
+})
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/test_incremental_learning.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/test_incremental_learning.py
new file mode 100644
index 0000000..d46ee8b
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_incremental_learning/test_incremental_learning.py
@@ -0,0 +1,192 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""reconstruct process."""
+import os
+import json
+import math
+import pytest
+import numpy as np
+
+from mindspore.common import set_seed
+from mindspore import context, Tensor, nn, Parameter
+from mindspore.train import DynamicLossScaleManager
+from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
+from mindspore.train.serialization import load_checkpoint, load_param_into_net
+import mindspore.common.dtype as ms_type
+from mindspore.common.initializer import HeUniform
+
+
+from mindelec.loss import Constraints
+from mindelec.solver import Solver, LossAndTimeMonitor
+from mindelec.common import L2
+from mindelec.architecture import MultiScaleFCCell, MTLWeightedLossCell
+
+from src.dataset import create_random_dataset
+from src.lr_scheduler import MultiStepLR
+from src.maxwell import Maxwell2DMur
+
+set_seed(123456)
+np.random.seed(123456)
+
+context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend", save_graphs_path="./solver")
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_incremental_learning():
+ """pretraining process"""
+ print("pid:", os.getpid())
+ mode = "pretrain"
+ config = json.load(open("./pretrain.json"))
+ preprocess_config(config)
+ elec_train_dataset = create_random_dataset(config)
+ train_dataset = elec_train_dataset.create_dataset(batch_size=config["batch_size"],
+ shuffle=True,
+ prebatched_data=True,
+ drop_remainder=True)
+ epoch_steps = len(elec_train_dataset)
+ print("check train dataset size: ", len(elec_train_dataset))
+
+ # load ckpt
+ if config.get("load_ckpt", False):
+ param_dict = load_checkpoint(config["load_ckpt_path"])
+ if mode == "pretrain":
+ loaded_ckpt_dict = param_dict
+ else:
+ loaded_ckpt_dict = {}
+ latent_vector_ckpt = 0
+ for name in param_dict:
+ if name == "model.latent_vector":
+ latent_vector_ckpt = param_dict[name].data.asnumpy()
+ elif "network" in name and "moment" not in name:
+ loaded_ckpt_dict[name] = param_dict[name]
+
+ # initialize latent vector
+ num_scenarios = config["num_scenarios"]
+ latent_size = config["latent_vector_size"]
+ if mode == "pretrain":
+ latent_init = np.random.randn(num_scenarios, latent_size) / np.sqrt(latent_size)
+ else:
+ latent_norm = np.mean(np.linalg.norm(latent_vector_ckpt, axis=1))
+ print("check mean latent vector norm: ", latent_norm)
+ latent_init = np.zeros((num_scenarios, latent_size))
+ latent_vector = Parameter(Tensor(latent_init, ms_type.float32), requires_grad=True)
+
+ network = MultiScaleFCCell(config["input_size"],
+ config["output_size"],
+ layers=config["layers"],
+ neurons=config["neurons"],
+ residual=config["residual"],
+ weight_init=HeUniform(negative_slope=math.sqrt(5)),
+ act="sin",
+ num_scales=config["num_scales"],
+ amp_factor=config["amp_factor"],
+ scale_factor=config["scale_factor"],
+ input_scale=config["input_scale"],
+ input_center=config["input_center"],
+ latent_vector=latent_vector
+ )
+
+ network = network.to_float(ms_type.float16)
+ network.input_scale.to_float(ms_type.float32)
+
+ if config.get("enable_mtl", True):
+ mtl_cell = MTLWeightedLossCell(num_losses=elec_train_dataset.num_dataset)
+ else:
+ mtl_cell = None
+
+ # define problem
+ train_prob = {}
+ for dataset in elec_train_dataset.all_datasets:
+ train_prob[dataset.name] = Maxwell2DMur(network=network, config=config,
+ domain_column=dataset.name + "_points",
+ ic_column=dataset.name + "_points",
+ bc_column=dataset.name + "_points")
+ print("check problem: ", train_prob)
+ train_constraints = Constraints(elec_train_dataset, train_prob)
+
+ # optimizer
+ if mode == "pretrain":
+ params = network.trainable_params() + mtl_cell.trainable_params()
+ if config.get("load_ckpt", False):
+ load_param_into_net(network, loaded_ckpt_dict)
+ load_param_into_net(mtl_cell, loaded_ckpt_dict)
+ else:
+ if config.get("finetune_model"):
+ model_params = network.trainable_params()
+ else:
+ model_params = [param for param in network.trainable_params()
+ if ("bias" not in param.name and "weight" not in param.name)]
+ params = model_params + mtl_cell.trainable_params() if mtl_cell else model_params
+ load_param_into_net(network, loaded_ckpt_dict)
+
+ lr_scheduler = MultiStepLR(config["lr"], config["milestones"], config["lr_gamma"],
+ epoch_steps, config["train_epoch"])
+ optimizer = nn.Adam(params, learning_rate=Tensor(lr_scheduler.get_lr()))
+
+ # problem solver
+ solver = Solver(network,
+ optimizer=optimizer,
+ mode="PINNs",
+ train_constraints=train_constraints,
+ test_constraints=None,
+ metrics={'l2': L2(), 'distance': nn.MAE()},
+ loss_fn='smooth_l1_loss',
+ loss_scale_manager=DynamicLossScaleManager(),
+ mtl_weighted_cell=mtl_cell,
+ latent_vector=latent_vector,
+ latent_reg=config["latent_reg"]
+ )
+
+ loss_time_callback = LossAndTimeMonitor(epoch_steps)
+ callbacks = [loss_time_callback]
+ if config["save_ckpt"]:
+ config_ck = CheckpointConfig(save_checkpoint_steps=10, keep_checkpoint_max=2)
+ prefix = 'pretrain_maxwell_frq1e9' if mode == "pretrain" else 'reconstruct_maxwell_frq1e9'
+ ckpoint_cb = ModelCheckpoint(prefix=prefix, directory=config["save_ckpt_path"], config=config_ck)
+ callbacks += [ckpoint_cb]
+
+ solver.train(config["train_epoch"], train_dataset, callbacks=callbacks, dataset_sink_mode=True)
+ assert loss_time_callback.get_loss() <= 2.0
+ assert loss_time_callback.get_step_time() <= 115.0
+
+
+def preprocess_config(config):
+ """preprocess to get the coefficients of electromagnetic field for each scenario"""
+ eps_candidates = config["EPS_candidates"]
+ mu_candidates = config["MU_candidates"]
+ config["num_scenarios"] = len(eps_candidates) * len(mu_candidates)
+ batch_size_single_scenario = config["train_batch_size"]
+ config["batch_size"] = batch_size_single_scenario * config["num_scenarios"]
+ eps_list = []
+ for eps in eps_candidates:
+ eps_list.extend([eps] * (batch_size_single_scenario * len(mu_candidates)))
+ mu_list = []
+ for mu in mu_candidates:
+ mu_list.extend([mu] * batch_size_single_scenario)
+ mu_list = mu_list * (len(eps_candidates))
+
+ exp_name = "_" + config["Case"] + '_num_scenarios_' + str(config["num_scenarios"]) \
+ + "_latent_reg_" + str(config["latent_reg"])
+ if config["save_ckpt"]:
+ config["save_ckpt_path"] += exp_name
+
+ config["vision_path"] += exp_name
+ config["summary_path"] += exp_name
+ print("check config: {}".format(config))
+ config["eps_list"] = eps_list
+ config["mu_list"] = mu_list
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/dataset/Butterfly_antenna/data_input.npy b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/dataset/Butterfly_antenna/data_input.npy
new file mode 100644
index 0000000..01b3618
Binary files /dev/null and b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/dataset/Butterfly_antenna/data_input.npy differ
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/dataset/Butterfly_antenna/data_label.npy b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/dataset/Butterfly_antenna/data_label.npy
new file mode 100644
index 0000000..c2832f0
Binary files /dev/null and b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/dataset/Butterfly_antenna/data_label.npy differ
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/dataset/Phone/data_input.npy b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/dataset/Phone/data_input.npy
new file mode 100644
index 0000000..c294967
Binary files /dev/null and b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/dataset/Phone/data_input.npy differ
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/dataset/Phone/data_label.npy b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/dataset/Phone/data_label.npy
new file mode 100644
index 0000000..5715764
Binary files /dev/null and b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/dataset/Phone/data_label.npy differ
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/src/dataset.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/src/dataset.py
new file mode 100644
index 0000000..00f41cd
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/src/dataset.py
@@ -0,0 +1,130 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""
+dataset
+"""
+import os
+import shutil
+import numpy as np
+from mindelec.data import Dataset, ExistedDataConfig
+
+
+def custom_normalize(data):
+ """
+ get normalize data
+ """
+ print("Custom normalization is called")
+ ori_shape = data.shape
+ data = data.reshape(ori_shape[0], -1)
+ data = np.transpose(data)
+ mean = np.mean(data, axis=1)
+ data = data - mean[:, None]
+ std = np.std(data, axis=1)
+ std += (np.abs(std) < 0.0000001)
+ data = data / std[:, None]
+ data = np.transpose(data)
+ data = data.reshape(ori_shape)
+ return data
+
+
+def create_dataset(opt):
+ """
+ load data
+ """
+ data_input_path = opt.input_path
+ data_label_path = opt.label_path
+
+ data_input = np.load(data_input_path)
+ data_label = np.load(data_label_path)
+
+ frequency = data_label[0, :, 0]
+ data_label = data_label[:, :, 1]
+
+ print(data_input.shape)
+ print(data_label.shape)
+ print("data load finish")
+
+ data_input = custom_normalize(data_input)
+
+ config_data_prepare = {}
+
+ config_data_prepare["scale_input"] = 0.5 * np.max(np.abs(data_input), axis=0)
+ config_data_prepare["scale_S11"] = 0.5 * np.max(np.abs(data_label))
+
+ data_input[:, :] = data_input[:, :] / config_data_prepare["scale_input"]
+ data_label[:, :] = data_label[:, :] / config_data_prepare["scale_S11"]
+
+ permutation = np.random.permutation(data_input.shape[0])
+ data_input = data_input[permutation]
+ data_label = data_label[permutation]
+
+ length = data_input.shape[0] // 10
+ train_input, train_label = data_input[length:], data_label[length:]
+ eval_input, eval_label = data_input[:length], data_label[:length]
+
+ print(np.shape(train_input))
+ print(np.shape(train_label))
+ print(np.shape(eval_input))
+ print(np.shape(eval_label))
+
+ if not os.path.exists('./data_prepare'):
+ os.mkdir('./data_prepare')
+ else:
+ shutil.rmtree('./data_prepare')
+ os.mkdir('./data_prepare')
+
+ train_input = train_input.astype(np.float32)
+ np.save('./data_prepare/train_input', train_input)
+ train_label = train_label.astype(np.float32)
+ np.save('./data_prepare/train_label', train_label)
+ eval_input = eval_input.astype(np.float32)
+ np.save('./data_prepare/eval_input', eval_input)
+ eval_label = eval_label.astype(np.float32)
+ np.save('./data_prepare/eval_label', eval_label)
+
+ electromagnetic_train = ExistedDataConfig(name="electromagnetic_train",
+ data_dir=['./data_prepare/train_input.npy',
+ './data_prepare/train_label.npy'],
+ columns_list=["inputs", "label"],
+ data_format="npy")
+ electromagnetic_eval = ExistedDataConfig(name="electromagnetic_eval",
+ data_dir=['./data_prepare/eval_input.npy',
+ './data_prepare/eval_label.npy'],
+ columns_list=["inputs", "label"],
+ data_format="npy")
+ train_batch_size = opt.batch_size
+ eval_batch_size = len(eval_input)
+
+ train_dataset = Dataset(existed_data_list=[electromagnetic_train])
+ train_loader = train_dataset.create_dataset(batch_size=train_batch_size, shuffle=True)
+
+ eval_dataset = Dataset(existed_data_list=[electromagnetic_eval])
+ eval_loader = eval_dataset.create_dataset(batch_size=eval_batch_size, shuffle=False)
+
+ data = {
+ "train_loader": train_loader,
+ "eval_loader": eval_loader,
+
+ "train_data": train_input,
+ "train_label": train_label,
+ "eval_data": eval_input,
+ "eval_label": eval_label,
+
+ "train_data_length": len(train_label),
+ "eval_data_length": len(eval_label),
+ "frequency": frequency,
+ }
+
+ return data, config_data_prepare
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/src/loss.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/src/loss.py
new file mode 100644
index 0000000..97b7bdf
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/src/loss.py
@@ -0,0 +1,98 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""
+loss
+"""
+
+import os
+import shutil
+import mindspore.nn as nn
+import matplotlib.pyplot as plt
+import numpy as np
+import cv2
+
+
+class EvalMetric(nn.Metric):
+ """
+ eval metric
+ """
+
+ def __init__(self, scale_s11, length, frequency, show_pic_number, file_path):
+ super(EvalMetric, self).__init__()
+ self.clear()
+ self.scale_s11 = scale_s11
+ self.length = length
+ self.frequency = frequency
+ self.show_pic_number = show_pic_number
+ self.file_path = file_path
+ self.show_pic_id = np.random.choice(length, self.show_pic_number, replace=False)
+
+ def clear(self):
+ """
+ clear error
+ """
+ self.error_sum_l2_error = 0
+ self.error_sum_loss_error = 0
+ self.pic_res = None
+
+ def update(self, *inputs):
+ """
+ update error
+ """
+ if not os.path.exists(self.file_path):
+ os.mkdir(self.file_path)
+ else:
+ shutil.rmtree(self.file_path)
+ os.mkdir(self.file_path)
+
+ y_pred = self._convert_data(inputs[0])
+ y_label = self._convert_data(inputs[1])
+
+ test_predict, test_label = y_pred, y_label
+ test_predict[:, :] = test_predict[:, :] * self.scale_s11
+ test_label[:, :] = test_label[:, :] * self.scale_s11
+ self.pic_res = []
+
+ for i in range(len(test_label)):
+ predict_real_temp = test_predict[i]
+ label_real_temp = test_label[i]
+ l2_error_temp = np.sqrt(np.sum(np.square(label_real_temp - predict_real_temp))) / \
+ np.sqrt(np.sum(np.square(label_real_temp)))
+ self.error_sum_l2_error += l2_error_temp
+ self.error_sum_loss_error += np.mean((label_real_temp - predict_real_temp) ** 2)
+
+ s11_label, s11_predict = label_real_temp, predict_real_temp
+ plt.figure(dpi=250)
+ plt.plot(self.frequency, s11_predict, '-', label='AI Model', linewidth=2)
+ plt.plot(self.frequency, s11_label, '--', label='CST', linewidth=1)
+ plt.title('s11(dB)')
+ plt.xlabel('frequency(GHz) l2_s11:' + str(l2_error_temp)[:10])
+ plt.ylabel('dB')
+ plt.legend()
+ plt.savefig(self.file_path + '/' + str(i) + '_' + str(l2_error_temp)[:10] + '.jpg')
+ plt.close()
+ if i in self.show_pic_id:
+ self.pic_res.append(cv2.imread(
+ self.file_path + '/' + str(i) + '_' + str(l2_error_temp)[:10] + '.jpg'))
+
+ self.pic_res = np.array(self.pic_res).astype(np.float32)
+
+ def eval(self):
+ """
+ compute final error
+ """
+ return {'l2_error': self.error_sum_l2_error / self.length,
+ 'loss_error': self.error_sum_loss_error / self.length,
+ 'pic_res': self.pic_res}
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/src/maxwell_model.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/src/maxwell_model.py
new file mode 100644
index 0000000..5b5aecf
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/src/maxwell_model.py
@@ -0,0 +1,47 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""
+maxwell S11 model
+"""
+
+import mindspore.nn as nn
+
+
+class S11Predictor(nn.Cell):
+ """
+ maxwell S11 model define
+ """
+ def __init__(self, input_dimension):
+ super(S11Predictor, self).__init__()
+ self.fc1 = nn.Dense(input_dimension, 128)
+ self.fc2 = nn.Dense(128, 128)
+ self.fc3 = nn.Dense(128, 128)
+ self.fc4 = nn.Dense(128, 128)
+ self.fc5 = nn.Dense(128, 128)
+ self.fc6 = nn.Dense(128, 128)
+ self.fc7 = nn.Dense(128, 1001)
+ self.relu = nn.ReLU()
+
+ def construct(self, x):
+ """forward"""
+ x0 = x
+ x1 = self.relu(self.fc1(x0))
+ x2 = self.relu(self.fc2(x1))
+ x3 = self.relu(self.fc3(x1 + x2))
+ x4 = self.relu(self.fc4(x1 + x2 + x3))
+ x5 = self.relu(self.fc5(x1 + x2 + x3 + x4))
+ x6 = self.relu(self.fc6(x1 + x2 + x3 + x4 + x5))
+ x = self.fc7(x1 + x2 + x3 + x4 + x5 + x6)
+ return x
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/test_parameterization.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/test_parameterization.py
new file mode 100644
index 0000000..e0c61d9
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_parameterization/test_parameterization.py
@@ -0,0 +1,166 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""
+test parameterization
+"""
+
+import time
+
+import pytest
+import numpy as np
+import mindspore.nn as nn
+from mindspore.common import set_seed
+import mindspore.common.dtype as mstype
+from mindspore import context
+from mindspore.train.callback import Callback
+
+from easydict import EasyDict as edict
+from mindelec.solver import Solver
+from mindelec.vision import MonitorTrain, MonitorEval
+
+from src.dataset import create_dataset
+from src.maxwell_model import S11Predictor
+from src.loss import EvalMetric
+
+set_seed(123456)
+np.random.seed(123456)
+context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
+
+opt = edict({
+ 'epochs': 100,
+ 'print_interval': 50,
+ 'batch_size': 8,
+ 'save_epoch': 10,
+ 'lr': 0.0001,
+ 'input_dim': 3,
+ 'device_num': 1,
+ 'device_target': "Ascend",
+ 'checkpoint_dir': './ckpt/',
+ 'save_graphs_path': './graph_result/',
+ 'input_path': './dataset/Butterfly_antenna/data_input.npy',
+ 'label_path': './dataset/Butterfly_antenna/data_label.npy',
+})
+
+
+class TimeMonitor(Callback):
+ """
+ Monitor the time in training.
+ """
+
+ def __init__(self, data_size=None):
+ super(TimeMonitor, self).__init__()
+ self.data_size = data_size
+ self.epoch_time = time.time()
+ self.per_step_time = 0
+ self.t = 0
+
+ def epoch_begin(self, run_context):
+ """
+ Record time at the begin of epoch.
+ """
+ self.t = run_context
+ self.epoch_time = time.time()
+
+ def epoch_end(self, run_context):
+ """
+ Print process cost time at the end of epoch.
+ """
+ epoch_seconds = (time.time() - self.epoch_time) * 1000
+ step_size = self.data_size
+ cb_params = run_context.original_args()
+ if hasattr(cb_params, "batch_num"):
+ batch_num = cb_params.batch_num
+ if isinstance(batch_num, int) and batch_num > 0:
+ step_size = cb_params.batch_num
+
+ self.per_step_time = epoch_seconds / step_size
+
+ def get_step_time(self,):
+ return self.per_step_time
+
+
+def get_lr(data):
+ """
+ get_lr
+ """
+ num_milestones = 10
+ if data['train_data_length'] % opt.batch_size == 0:
+ iter_number = int(data['train_data_length'] / opt.batch_size)
+ else:
+ iter_number = int(data['train_data_length'] / opt.batch_size) + 1
+ iter_number = opt.epochs * iter_number
+ milestones = [int(iter_number * i / num_milestones) for i in range(1, num_milestones)]
+ milestones.append(iter_number)
+ learning_rates = [opt.lr * 0.5 ** i for i in range(0, num_milestones - 1)]
+ learning_rates.append(opt.lr * 0.5 ** (num_milestones - 1))
+ return milestones, learning_rates
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_parameterization():
+ """
+ test parameterization
+ """
+ data, config_data = create_dataset(opt)
+
+ model_net = S11Predictor(opt.input_dim)
+ model_net.to_float(mstype.float16)
+
+ milestones, learning_rates = get_lr(data)
+
+ optim = nn.Adam(model_net.trainable_params(),
+ learning_rate=nn.piecewise_constant_lr(milestones, learning_rates))
+
+ eval_error_mrc = EvalMetric(scale_s11=config_data["scale_S11"],
+ length=data["eval_data_length"],
+ frequency=data["frequency"],
+ show_pic_number=4,
+ file_path='./eval_res')
+
+ solver = Solver(network=model_net,
+ mode="Data",
+ optimizer=optim,
+ metrics={'eval_mrc': eval_error_mrc},
+ loss_fn=nn.MSELoss())
+
+ monitor_train = MonitorTrain(per_print_times=1,
+ summary_dir='./summary_dir_train')
+
+ monitor_eval = MonitorEval(summary_dir='./summary_dir_eval',
+ model=solver,
+ eval_ds=data["eval_loader"],
+ eval_interval=opt.print_interval,
+ draw_flag=True)
+
+ time_monitor = TimeMonitor()
+ callbacks_train = [monitor_train, time_monitor, monitor_eval]
+
+ solver.model.train(epoch=opt.epochs,
+ train_dataset=data["train_loader"],
+ callbacks=callbacks_train,
+ dataset_sink_mode=True)
+
+ loss_print, l2_s11_print = monitor_eval.loss_final, monitor_eval.l2_s11_final
+ per_step_time = time_monitor.get_step_time()
+
+ print('loss_mse:', loss_print)
+ print('l2_s11:', l2_s11_print)
+ print('per_step_time:', per_step_time)
+
+ assert l2_s11_print <= 1.0
+ assert per_step_time <= 10.0
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/src/config.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/src/config.py
new file mode 100644
index 0000000..691ec49
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/src/config.py
@@ -0,0 +1,26 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+network config setting, will be used in train.py and eval.py
+"""
+
+# config
+config = {
+ 'input_channels': 496,
+ 'epochs': 20,
+ 'batch_size': 8,
+ 'lr': 0.0001,
+ 'lr_decay_milestones': 10,
+}
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/src/dataset.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/src/dataset.py
new file mode 100644
index 0000000..9562362
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/src/dataset.py
@@ -0,0 +1,79 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""dataset utilities"""
+
+import os
+import numpy as np
+
+from mindelec.data import Dataset, ExistedDataConfig
+
+INPUT_PATH = ""
+LABEL_PATH = ""
+DATA_CONFIG_PATH = "./data_config.npz"
+SAVE_DATA_PATH = "./"
+
+def custom_normalize(dataset, mean=None, std=None):
+ """ custom normalization """
+ ori_shape = dataset.shape
+ dataset = dataset.reshape(ori_shape[0], -1)
+ dataset = np.transpose(dataset)
+ if mean is None:
+ mean = np.mean(dataset, axis=1)
+ dataset = dataset - mean[:, None]
+ if std is None:
+ std = np.std(dataset, axis=1)
+ std += (np.abs(std) < 0.0000001)
+ dataset = dataset / std[:, None]
+ dataset = np.transpose(dataset)
+ dataset = dataset.reshape(ori_shape)
+ return dataset, mean, std
+
+def generate_data(input_path, label_path):
+ """generate dataset for s11 parameter prediction"""
+
+ data_input = np.load(input_path)
+ if os.path.exists(DATA_CONFIG_PATH):
+ data_config = np.load(DATA_CONFIG_PATH)
+ mean = data_config["mean"]
+ std = data_config["std"]
+ data_input, mean, std = custom_normalize(data_input)
+ data_label = np.load(label_path)
+
+ print(data_input.shape)
+ print(data_label.shape)
+
+ data_input = data_input.transpose((0, 4, 1, 2, 3))
+ data_label[:, :] = np.log10(-data_label[:, :] + 1.0)
+ scale_s11 = 0.5 * np.max(np.abs(data_label[:, :]))
+ data_label[:, :] = data_label[:, :] / scale_s11
+
+ np.savez(DATA_CONFIG_PATH, scale_s11=scale_s11, mean=mean, std=std)
+ np.save(os.path.join(SAVE_DATA_PATH, 'data_input.npy'), data_input)
+ np.save(os.path.join(SAVE_DATA_PATH, 'data_label.npy'), data_label)
+ print("data saved in target path")
+
+
+def create_dataset(input_path, label_path, batch_size=8, shuffle=True):
+ electromagnetic = ExistedDataConfig(name="electromagnetic",
+ data_dir=[input_path, label_path],
+ columns_list=["inputs", "label"],
+ data_format=input_path.split('.')[-1])
+ dataset = Dataset(existed_data_list=[electromagnetic])
+ data_loader = dataset.create_dataset(batch_size=batch_size, shuffle=shuffle)
+
+ return data_loader
+
+if __name__ == "__main__":
+ generate_data(input_path=INPUT_PATH, label_path=LABEL_PATH)
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/src/lr_generator.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/src/lr_generator.py
new file mode 100644
index 0000000..9a2920a
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/src/lr_generator.py
@@ -0,0 +1,29 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""learning rate generator"""
+
+def step_lr_generator(step_size, epochs, lr, lr_decay_milestones):
+ """generate step decayed learning rate"""
+
+ total_steps = epochs * step_size
+
+ milestones = [int(total_steps * i / lr_decay_milestones) for i in range(1, lr_decay_milestones)]
+ milestones.append(total_steps)
+ learning_rates = [lr*0.5**i for i in range(0, lr_decay_milestones - 1)]
+ learning_rates.append(lr*0.5**(lr_decay_milestones - 1))
+
+ print("total_steps: %s, milestones: %s, learning_rates: %s " %(total_steps, milestones, learning_rates))
+
+ return milestones, learning_rates
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/src/metric.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/src/metric.py
new file mode 100644
index 0000000..104a618
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/src/metric.py
@@ -0,0 +1,108 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""metrics"""
+
+import os
+import shutil
+import mindspore.nn as nn
+from mindspore.ops import functional as F
+import matplotlib.pyplot as plt
+import numpy as np
+import cv2
+
+class MyMSELoss(nn.LossBase):
+ """mse loss function"""
+ def construct(self, base, target):
+ x = F.square(base - target)
+ return self.get_loss(x)
+
+
+class EvalMetric(nn.Metric):
+ """
+ eval metric
+ """
+
+ def __init__(self, scale_s11, length, frequency, show_pic_number, file_path):
+ super(EvalMetric, self).__init__()
+ self.clear()
+ self.scale_s11 = scale_s11
+ self.length = length
+ self.frequency = frequency
+ self.show_pic_number = show_pic_number
+ self.file_path = file_path
+ self.show_pic_id = np.random.choice(length, self.show_pic_number, replace=False)
+
+ if not os.path.exists(self.file_path):
+ os.mkdir(self.file_path)
+ else:
+ shutil.rmtree(self.file_path)
+ os.mkdir(self.file_path)
+
+ def clear(self):
+ """
+ clear error
+ """
+ self.error_sum_l2_error = 0
+ self.error_sum_loss_error = 0
+ self.pic_res = None
+ self.index = 0
+
+ def update(self, *inputs):
+ """
+ update error
+ """
+
+ y_pred = self._convert_data(inputs[0])
+ y_label = self._convert_data(inputs[1])
+
+ test_predict, test_label = y_pred, y_label
+ test_predict[:, :] = test_predict[:, :] * self.scale_s11
+ test_label[:, :] = test_label[:, :] * self.scale_s11
+ test_predict[:, :] = 1.0 - np.power(10, test_predict[:, :])
+ test_label[:, :] = 1.0 - np.power(10, test_label[:, :])
+ self.pic_res = []
+
+ for i in range(len(test_label)):
+ self.index += 1
+ predict_real_temp = test_predict[i]
+ label_real_temp = test_label[i]
+ l2_error_temp = np.sqrt(np.sum(np.square(label_real_temp - predict_real_temp))) / \
+ np.sqrt(np.sum(np.square(label_real_temp)))
+ self.error_sum_l2_error += l2_error_temp
+ self.error_sum_loss_error += np.mean((label_real_temp - predict_real_temp) ** 2)
+
+ s11_label, s11_predict = label_real_temp, predict_real_temp
+ plt.figure(dpi=250)
+ plt.plot(self.frequency, s11_predict, '-', label='AI Model', linewidth=2)
+ plt.plot(self.frequency, s11_label, '--', label='CST', linewidth=1)
+ plt.title('s11(dB)')
+ plt.xlabel('frequency(GHz) l2_s11:' + str(l2_error_temp)[:10])
+ plt.ylabel('dB')
+ plt.legend()
+ plt.savefig(self.file_path + '/' + str(self.index) + '_' + str(l2_error_temp)[:10] + '.jpg')
+ plt.close()
+ if i in self.show_pic_id:
+ self.pic_res.append(cv2.imread(
+ self.file_path + '/' + str(i) + '_' + str(l2_error_temp)[:10] + '.jpg'))
+
+ self.pic_res = np.array(self.pic_res).astype(np.float32)
+
+ def eval(self):
+ """
+ compute final error
+ """
+ return {'l2_error': self.error_sum_l2_error / self.length,
+ 'loss_error': self.error_sum_loss_error / self.length,
+ 'pic_res': self.pic_res}
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/src/model.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/src/model.py
new file mode 100644
index 0000000..6e61654
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/src/model.py
@@ -0,0 +1,89 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""S parameter prediction model"""
+
+import mindspore.nn as nn
+import mindspore.ops as ops
+
+
+class S11Predictor(nn.Cell):
+ """
+ S11Predictor architecture for MindElec.
+
+ Args:
+ input_dim (int): input channel.
+
+ Returns:
+ Tensor, output tensor.
+
+ Examples:
+ >>> S11Predictor(input_dim=128)
+ """
+
+ def __init__(self, input_dim):
+ super(S11Predictor, self).__init__()
+ # input shape is [20, 40, 3]
+ self.conv1 = nn.Conv3d(input_dim, 512, kernel_size=(3, 3, 1))
+ self.conv2 = nn.Conv3d(512, 512, kernel_size=(3, 3, 1))
+ self.conv3 = nn.Conv3d(512, 512, kernel_size=(3, 3, 1))
+ self.conv4 = nn.Conv3d(512, 512, kernel_size=(2, 1, 3), pad_mode='pad', padding=0)
+
+ self.down1 = ops.MaxPool3D(kernel_size=(2, 3, 1), strides=(2, 3, 1))
+ self.down2 = ops.MaxPool3D(kernel_size=(2, 3, 1), strides=(2, 3, 1))
+ self.down3 = ops.MaxPool3D(kernel_size=(2, 3, 1), strides=(2, 3, 1))
+
+ self.down_1_1 = ops.MaxPool3D(kernel_size=(1, 13, 1), strides=(1, 13, 1))
+ self.down_1_2 = nn.MaxPool2d(kernel_size=(10, 3))
+
+ self.down_2 = nn.MaxPool2d((5, 4*3))
+
+ self.fc1 = nn.Dense(1536, 2048)
+ self.fc2 = nn.Dense(2048, 2048)
+ self.fc3 = nn.Dense(2048, 1001)
+
+ self.concat = ops.Concat(axis=1)
+ self.relu = nn.ReLU()
+
+
+ def construct(self, x):
+ """forward"""
+ bs = x.shape[0]
+
+ x = self.conv1(x)
+ x = self.relu(x)
+ x = self.down1(x)
+ x_1 = self.down_1_1(x)
+ x_1 = self.down_1_2(x_1.view(bs, x_1.shape[1], x_1.shape[2], -1)).view((bs, -1))
+
+ x = self.conv2(x)
+ x = self.relu(x)
+ x = self.down2(x)
+ x_2 = self.down_2(x.view(bs, x.shape[1], x.shape[2], -1)).view((bs, -1))
+
+ x = self.conv3(x)
+ x = self.relu(x)
+ x = self.down3(x)
+
+ x = self.conv4(x)
+ x = self.relu(x).view((bs, -1))
+
+ x = self.concat([x, x_1, x_2])
+ x = self.relu(x).view(bs, -1)
+
+ x = self.relu(self.fc1(x))
+ x = self.relu(self.fc2(x))
+ x = self.fc3(x)
+
+ return x
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/test_s_parameter.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/test_s_parameter.py
new file mode 100644
index 0000000..b831f10
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_s_parameter/test_s_parameter.py
@@ -0,0 +1,144 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""s parameter prediction model training test"""
+import time
+import pytest
+import numpy as np
+
+import mindspore.nn as nn
+from mindspore import context
+from mindspore.common import set_seed
+import mindspore.common.initializer as weight_init
+from mindspore.train.callback import LossMonitor, Callback
+
+from mindelec.solver import Solver
+
+from src.dataset import create_dataset
+from src.model import S11Predictor
+from src.lr_generator import step_lr_generator
+from src.config import config
+from src.metric import MyMSELoss, EvalMetric
+
+set_seed(0)
+np.random.seed(0)
+
+context.set_context(mode=context.GRAPH_MODE,
+ save_graphs=False,
+ device_target="Ascend")
+
+
+class TimeMonitor(Callback):
+ """
+ Monitor the time in training.
+ """
+
+ def __init__(self, data_size=None):
+ super(TimeMonitor, self).__init__()
+ self.data_size = data_size
+ self.epoch_time = time.time()
+ self.per_step_time = 0
+ self._tmp = None
+
+ def epoch_begin(self, run_context):
+ """
+ Record time at the begin of epoch.
+ """
+ self.epoch_time = time.time()
+ self._tmp = run_context
+
+ def epoch_end(self, run_context):
+ """
+ Print process cost time at the end of epoch.
+ """
+ epoch_seconds = (time.time() - self.epoch_time) * 1000
+ step_size = self.data_size
+ cb_params = run_context.original_args()
+ if hasattr(cb_params, "batch_num"):
+ batch_num = cb_params.batch_num
+ if isinstance(batch_num, int) and batch_num > 0:
+ step_size = cb_params.batch_num
+
+ self.per_step_time = epoch_seconds / step_size
+ print("epoch time: {:5.3f} ms, per step time: {:5.3f} ms".format(epoch_seconds, self.per_step_time), flush=True)
+
+ def get_step_time(self,):
+ return self.per_step_time
+
+def init_weight(net):
+ """init_weight"""
+ for _, cell in net.cells_and_names():
+ if isinstance(cell, (nn.Conv3d, nn.Dense)):
+ cell.weight.set_data(weight_init.initializer(weight_init.HeNormal(),
+ cell.weight.shape,
+ cell.weight.dtype))
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_s_predictor_train():
+ """training"""
+ train_input_path = "input.npy"
+ train_label_path = "label.npy"
+ train_input = np.ones((100, 496, 20, 40, 3), np.float32)
+ train_label = np.ones((100, 1001), np.float32)
+ np.save(train_input_path, train_input)
+ np.save(train_label_path, train_label)
+
+ model_net = S11Predictor(input_dim=config["input_channels"])
+ init_weight(net=model_net)
+
+ train_dataset = create_dataset(input_path=train_input_path,
+ label_path=train_label_path,
+ batch_size=config["batch_size"],
+ shuffle=True)
+
+ step_size = train_dataset.get_dataset_size()
+ milestones, learning_rates = step_lr_generator(step_size,
+ config["epochs"],
+ config["lr"],
+ config["lr_decay_milestones"])
+ optimizer = nn.Adam(model_net.trainable_params(),
+ learning_rate=nn.piecewise_constant_lr(milestones, learning_rates))
+
+ loss_net = MyMSELoss()
+
+ eval_step_size = train_dataset.get_dataset_size() * config["batch_size"]
+ evl_error_mrc = EvalMetric(scale_s11=1,
+ length=eval_step_size,
+ frequency=np.linspace(0, 4*10**8, 1001),
+ show_pic_number=4,
+ file_path='./eval_result')
+
+ solver = Solver(model_net,
+ train_input_map={'train': ['train_input_data']},
+ test_input_map={'test': ['test_input_data']},
+ optimizer=optimizer,
+ metrics={'evl_mrc': evl_error_mrc,},
+ amp_level="O2",
+ loss_fn=loss_net)
+
+ time_cb = TimeMonitor()
+ solver.model.train(config["epochs"],
+ train_dataset,
+ callbacks=[LossMonitor(), time_cb],
+ dataset_sink_mode=False)
+ res = solver.model.eval(train_dataset, dataset_sink_mode=False)
+ per_step_time = time_cb.get_step_time()
+ l1_error = res['evl_mrc']['l2_error']
+ print('test_res:', f'l1_error: {l1_error:.10f} ')
+ print(f'per step time: {per_step_time:.10f} ')
+ assert l1_error <= 0.01
+ assert per_step_time <= 100
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/config.json b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/config.json
new file mode 100644
index 0000000..f3b622e
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/config.json
@@ -0,0 +1,37 @@
+{
+ "Description" : [ "PINNs for solve Maxwell's equations" ],
+
+ "Case" : "2D_Mur_Src_Gauss_Mscale_MTL",
+ "random_sampling" : true,
+ "coord_min" : [0, 0],
+ "coord_max" : [1, 1],
+ "src_pos" : [0.4975, 0.4975],
+ "src_frq": 1e+9,
+ "range_t" : 4e-9,
+ "input_scale": [1.0, 1.0, 2.5e+8],
+ "output_scale": [37.67303, 37.67303, 0.1],
+ "src_radius": 0.01,
+ "input_size" : 3,
+ "output_size" : 3,
+ "residual" : true,
+ "num_scales" : 4,
+ "layers" : 7,
+ "neurons" : 64,
+ "amp_factor" : 10.0,
+ "scale_factor" : 2.0,
+ "save_ckpt" : true,
+ "load_ckpt" : false,
+ "save_ckpt_path" : "./ckpt",
+ "load_ckpt_path" : "",
+ "train_data_path" : "./input/",
+ "test_data_path" : "./input/benchmark/",
+ "lr" : 0.002,
+ "milestones" : [300],
+ "lr_gamma" : 0.1,
+ "train_epoch" : 300,
+ "train_batch_size" : 8192,
+ "test_batch_size" : 32768,
+ "predict_interval" : 6000,
+ "vision_path" : "./vision",
+ "summary_path" : "./summary"
+}
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/src/__init__.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/src/__init__.py
new file mode 100644
index 0000000..dcbda75
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/src/__init__.py
@@ -0,0 +1,32 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+init
+"""
+from .dataset import create_train_dataset, get_test_data, create_random_dataset
+from .maxwell import Maxwell2DMur
+from .lr_scheduler import MultiStepLR
+from .callback import PredictCallback
+from .utils import visual_result
+
+__all__ = [
+ "create_train_dataset",
+ "create_random_dataset",
+ "get_test_data",
+ "Maxwell2DMur",
+ "MultiStepLR",
+ "PredictCallback",
+ "visual_result"
+]
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/src/callback.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/src/callback.py
new file mode 100644
index 0000000..7a4f401
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/src/callback.py
@@ -0,0 +1,127 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+call back functions
+"""
+import time
+import copy
+
+import numpy as np
+from mindspore.train.callback import Callback
+from mindspore.train.summary import SummaryRecord
+from mindspore import Tensor
+import mindspore.common.dtype as mstype
+
+class PredictCallback(Callback):
+ """
+ Monitor the prediction accuracy in training.
+
+ Args:
+ model (Cell): Prediction network cell.
+ inputs (Array): Input data of prediction.
+ label (Array): Label data of prediction.
+ config (dict): config info of prediction.
+ visual_fn (dict): Visualization function. Default: None.
+ """
+ def __init__(self, model, inputs, label, config, visual_fn=None):
+ super(PredictCallback, self).__init__()
+ self.model = model
+ self.inputs = inputs
+ self.label = label
+ self.label_shape = label.shape
+ self.visual_fn = visual_fn
+ self.vision_path = config.get("vision_path", "./vision")
+ self.summary_dir = config.get("summary_path", "./summary")
+
+ self.output_size = config.get("output_size", 3)
+ self.input_size = config.get("input_size", 3)
+ self.output_scale = np.array(config["output_scale"], dtype=np.float32)
+ self.predict_interval = config.get("predict_interval", 10)
+ self.batch_size = config.get("test_batch_size", 8192*4)
+
+ self.dx = inputs[0, 1, 0, 0] - inputs[0, 0, 0, 0]
+ self.dy = inputs[0, 0, 1, 1] - inputs[0, 0, 0, 1]
+ self.dt = inputs[1, 0, 0, 2] - inputs[0, 0, 0, 2]
+ print("check yee delta: {}, {}, {}".format(self.dx, self.dy, self.dt))
+
+ self.ex_inputs = copy.deepcopy(inputs)
+ self.ey_inputs = copy.deepcopy(inputs)
+ self.hz_inputs = copy.deepcopy(inputs)
+ self.ex_inputs = self.ex_inputs.reshape(-1, self.input_size)
+ self.ey_inputs = self.ey_inputs.reshape(-1, self.input_size)
+ self.hz_inputs = self.hz_inputs.reshape(-1, self.input_size)
+ self.ex_inputs[:, 1] += self.dy / 2.0
+ self.ex_inputs[:, 2] += self.dt / 2.0
+ self.ey_inputs[:, 0] += self.dx / 2.0
+ self.ey_inputs[:, 2] += self.dt / 2.0
+ self.inputs_each = [self.ex_inputs, self.ey_inputs, self.hz_inputs]
+ self._step_counter = 0
+ self.l2_error = (1.0, 1.0, 1.0)
+
+ def __enter__(self):
+ self.summary_record = SummaryRecord(self.summary_dir)
+ return self
+
+ def __exit__(self, *exc_args):
+ self.summary_record.close()
+
+ def epoch_end(self, run_context):
+ """
+ Evaluate the model at the end of epoch.
+
+ Args:
+ run_context (RunContext): Context of the train running.
+ """
+ cb_params = run_context.original_args()
+ if cb_params.cur_epoch_num % self.predict_interval == 0:
+ # predict each quantity
+ index = 0
+ prediction_each = np.zeros(self.label_shape)
+ prediction_each = prediction_each.reshape((-1, self.output_size))
+ time_beg = time.time()
+ while index < len(self.inputs_each[0]):
+ index_end = min(index + self.batch_size, len(self.inputs_each[0]))
+ for i in range(self.output_size):
+ test_batch = Tensor(self.inputs_each[i][index: index_end, :], mstype.float32)
+ predict = self.model(test_batch)
+ predict = predict.asnumpy()
+ prediction_each[index: index_end, i] = predict[:, i] * self.output_scale[i]
+ index = index_end
+ print("==================================================================================================")
+ print("predict total time: {} s".format(time.time() - time_beg))
+ prediction = prediction_each.reshape(self.label_shape)
+ if self.visual_fn is not None:
+ self.visual_fn(self.inputs, self.label, prediction, path=self.vision_path,
+ name="epoch" + str(cb_params.cur_epoch_num))
+
+ label = self.label.reshape((-1, self.output_size))
+ prediction = prediction.reshape((-1, self.output_size))
+ self.l2_error = self._calculate_error(label, prediction)
+
+ def _calculate_error(self, label, prediction):
+ """calculate l2-error to evaluate accuracy"""
+ self._step_counter += 1
+ error = label - prediction
+ l2_error_ex = np.sqrt(np.sum(np.square(error[..., 0]))) / np.sqrt(np.sum(np.square(label[..., 0])))
+ l2_error_ey = np.sqrt(np.sum(np.square(error[..., 1]))) / np.sqrt(np.sum(np.square(label[..., 1])))
+ l2_error_hz = np.sqrt(np.sum(np.square(error[..., 2]))) / np.sqrt(np.sum(np.square(label[..., 2])))
+ print("l2_error, Ex: ", l2_error_ex, ", Ey: ", l2_error_ey, ", Hz: ", l2_error_hz)
+ self.summary_record.add_value('scalar', 'l2_ex', Tensor(l2_error_ex))
+ self.summary_record.add_value('scalar', 'l2_ey', Tensor(l2_error_ey))
+ self.summary_record.add_value('scalar', 'l2_hz', Tensor(l2_error_hz))
+ return l2_error_ex, l2_error_ey, l2_error_hz
+
+ def get_l2_error(self):
+ return self.l2_error
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/src/dataset.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/src/dataset.py
new file mode 100644
index 0000000..d807a4a
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/src/dataset.py
@@ -0,0 +1,95 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+create dataset
+"""
+import numpy as np
+
+from mindelec.data import Dataset, ExistedDataConfig
+from mindelec.geometry import Disk, Rectangle, TimeDomain, GeometryWithTime
+from mindelec.geometry import create_config_from_edict
+
+from .sampling_config import no_src_sampling_config, src_sampling_config, bc_sampling_config
+
+def get_test_data(test_data_path):
+ """load labeled data for evaluation"""
+ # check data
+ paths = [test_data_path + '/input.npy', test_data_path + '/output.npy']
+ inputs = np.load(paths[0])
+ label = np.load(paths[1])
+ return inputs, label
+
+def create_train_dataset(train_data_path):
+ """create training dataset from existed data files"""
+ src_ic = ExistedDataConfig(name="src_ic",
+ data_dir=[train_data_path + "elec_src_ic.npy"],
+ columns_list=["points"],
+ data_format="npy",
+ constraint_type="IC",
+ random_merge=False)
+ boundary = ExistedDataConfig(name="boundary",
+ data_dir=[train_data_path + "elec_no_src_bc.npy"],
+ columns_list=["points"],
+ data_format="npy",
+ constraint_type="BC",
+ random_merge=True)
+ no_src_ic = ExistedDataConfig(name="no_src_ic",
+ data_dir=[train_data_path + "elec_no_src_ic.npy"],
+ columns_list=["points"],
+ data_format="npy",
+ constraint_type="IC",
+ random_merge=True)
+ src_domain = ExistedDataConfig(name="src_domain",
+ data_dir=[train_data_path + "elec_src_domain.npy"],
+ columns_list=["points"],
+ data_format="npy",
+ constraint_type="Equation",
+ random_merge=True)
+ no_src_domain = ExistedDataConfig(name="no_src_domain",
+ data_dir=[train_data_path + "elec_no_src_domain.npy"],
+ columns_list=["points"],
+ data_format="npy",
+ constraint_type="Equation",
+ random_merge=True)
+ dataset = Dataset(existed_data_list=[no_src_domain, no_src_ic, boundary, src_domain, src_ic])
+ return dataset
+
+def create_random_dataset(config):
+ """create training dataset by online sampling"""
+ disk_radius = config["src_radius"]
+ disk_origin = config["src_pos"]
+ coord_min = config["coord_min"]
+ coord_max = config["coord_max"]
+
+ disk = Disk("src", disk_origin, disk_radius)
+ rectangle = Rectangle("rect", coord_min, coord_max)
+ diff = rectangle - disk
+ time_interval = TimeDomain("time", 0.0, config["range_t"])
+ no_src_region = GeometryWithTime(diff, time_interval)
+ no_src_region.set_name("no_src")
+ no_src_region.set_sampling_config(create_config_from_edict(no_src_sampling_config))
+ src_region = GeometryWithTime(disk, time_interval)
+ src_region.set_name("src")
+ src_region.set_sampling_config(create_config_from_edict(src_sampling_config))
+ boundary = GeometryWithTime(rectangle, time_interval)
+ boundary.set_name("bc")
+ boundary.set_sampling_config(create_config_from_edict(bc_sampling_config))
+
+ geom_dict = {src_region: ["domain", "IC"],
+ no_src_region: ["domain", "IC"],
+ boundary: ["BC"]}
+
+ dataset = Dataset(geom_dict)
+ return dataset
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/src/lr_scheduler.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/src/lr_scheduler.py
new file mode 100644
index 0000000..9813001
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/src/lr_scheduler.py
@@ -0,0 +1,73 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ===========================================================================
+"""Learning rate scheduler."""
+from collections import Counter
+import numpy as np
+
+class _LRScheduler():
+ """
+ Basic class for learning rate scheduler
+ """
+
+ def __init__(self, lr, max_epoch, steps_per_epoch):
+ self.base_lr = lr
+ self.steps_per_epoch = steps_per_epoch
+ self.total_steps = int(max_epoch * steps_per_epoch)
+
+ def get_lr(self):
+ # Compute learning rate using chainable form of the scheduler
+ raise NotImplementedError
+
+
+class MultiStepLR(_LRScheduler):
+ """
+ Multi-step learning rate scheduler
+
+ Decays the learning rate by gamma once the number of epoch reaches one of the milestones.
+
+ Args:
+ lr (float): Initial learning rate which is the lower boundary in the cycle.
+ milestones (list): List of epoch indices. Must be increasing.
+ gamma (float): Multiplicative factor of learning rate decay.
+ steps_per_epoch (int): The number of steps per epoch to train for.
+ max_epoch (int): The number of epochs to train for.
+
+ Outputs:
+ numpy.ndarray, shape=(1, steps_per_epoch*max_epoch)
+
+ Example:
+ >>> # Assuming optimizer uses lr = 0.05 for all groups
+ >>> # lr = 0.05 if epoch < 30
+ >>> # lr = 0.005 if 30 <= epoch < 80
+ >>> # lr = 0.0005 if epoch >= 80
+ >>> scheduler = MultiStepLR(lr=0.1, milestones=[30,80], gamma=0.1, steps_per_epoch=5000, max_epoch=90)
+ >>> lr = scheduler.get_lr()
+ """
+
+ def __init__(self, lr, milestones, gamma, steps_per_epoch, max_epoch):
+ self.milestones = Counter(milestones)
+ self.gamma = gamma
+ super(MultiStepLR, self).__init__(lr, max_epoch, steps_per_epoch)
+
+ def get_lr(self):
+ lr_each_step = []
+ current_lr = self.base_lr
+ for i in range(self.total_steps):
+ cur_ep = i // self.steps_per_epoch
+ if i % self.steps_per_epoch == 0 and cur_ep in self.milestones:
+ current_lr = current_lr * self.gamma
+ lr = current_lr
+ lr_each_step.append(lr)
+ return np.array(lr_each_step).astype(np.float32)
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/src/maxwell.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/src/maxwell.py
new file mode 100644
index 0000000..639bfbe
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/src/maxwell.py
@@ -0,0 +1,178 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+#pylint: disable=W0613
+"""
+2D maxwell problem with Mur bc
+"""
+import mindspore.numpy as ms_np
+from mindspore import ms_function
+from mindspore import ops
+from mindspore import Tensor
+import mindspore.common.dtype as mstype
+
+from mindelec.solver import Problem
+from mindelec.common import MU, EPS, LIGHT_SPEED, PI
+from mindelec.operators import Grad
+
+
+class Maxwell2DMur(Problem):
+ r"""
+ The 2D Maxwell's equations with 2nd-order Mur absorbed boundary condition.
+
+ Args:
+ model (Cell): The solving network.
+ config (dict): Setting information.
+ domain_name (str): The corresponding column name of data which governed by maxwell's equation.
+ bc_name (str): The corresponding column name of data which governed by boundary condition.
+ bc_normal (str): The column name of normal direction vector corresponding to specified boundary.
+ ic_name (str): The corresponding column name of data which governed by initial condition.
+ """
+ def __init__(self, model, config, domain_name=None, bc_name=None, bc_normal=None, ic_name=None):
+ super(Maxwell2DMur, self).__init__()
+ self.domain_name = domain_name
+ self.bc_name = bc_name
+ self.bc_normal = bc_normal
+ self.ic_name = ic_name
+ self.model = model
+ self.grad = Grad(self.model)
+ self.reshape = ops.Reshape()
+ self.cast = ops.Cast()
+ self.mul = ops.Mul()
+ self.cast = ops.Cast()
+ self.split = ops.Split(1, 3)
+ self.concat = ops.Concat(1)
+
+ # constants
+ self.pi = Tensor(PI, mstype.float32)
+ self.eps_x = Tensor(EPS, mstype.float32)
+ self.eps_y = Tensor(EPS, mstype.float32)
+ self.mu_z = Tensor(MU, mstype.float32)
+ self.light_speed = Tensor(LIGHT_SPEED, mstype.float32)
+
+ # gauss-type pulse source
+ self.src_frq = config.get("src_frq", 1e+9)
+ self.tau = Tensor((2.3 ** 0.5) / (PI * self.src_frq), mstype.float32)
+ self.amp = Tensor(1.0, mstype.float32)
+ self.t0 = Tensor(3.65 * self.tau, mstype.float32)
+
+ # src space
+ self.x0 = Tensor(config["src_pos"][0], mstype.float32)
+ self.y0 = Tensor(config["src_pos"][1], mstype.float32)
+ self.sigma = Tensor(config["src_radius"] / 4.0, mstype.float32)
+ self.coord_min = config["coord_min"]
+ self.coord_max = config["coord_max"]
+
+ input_scale = config.get("input_scale", [1.0, 1.0, 2.5e+8])
+ output_scale = config.get("output_scale", [37.67303, 37.67303, 0.1])
+ self.s_x = Tensor(input_scale[0], mstype.float32)
+ self.s_y = Tensor(input_scale[1], mstype.float32)
+ self.s_t = Tensor(input_scale[2], mstype.float32)
+ self.s_ex = Tensor(output_scale[0], mstype.float32)
+ self.s_ey = Tensor(output_scale[1], mstype.float32)
+ self.s_hz = Tensor(output_scale[2], mstype.float32)
+
+ def smooth_src(self, x, y, t):
+ source = self.amp * ops.exp(- ((t - self.t0) / self.tau)**2)
+ gauss = 1 / (2 * self.pi * self.sigma**2) * \
+ ops.exp(- ((x - self.x0)**2 + (y - self.y0)**2) / (2 * (self.sigma**2)))
+ return self.mul(source, gauss)
+
+ @ms_function
+ def governing_equation(self, *output, **kwargs):
+ """maxwell equation of TE mode wave"""
+ u = output[0]
+ data = kwargs[self.domain_name]
+ x = self.reshape(data[:, 0], (-1, 1))
+ y = self.reshape(data[:, 1], (-1, 1))
+ t = self.reshape(data[:, 2], (-1, 1))
+
+ dex_dxyt = self.grad(data, None, 0, u)
+ _, dex_dy, dex_dt = self.split(dex_dxyt)
+ dey_dxyt = self.grad(data, None, 1, u)
+ dey_dx, _, dey_dt = self.split(dey_dxyt)
+ dhz_dxyt = self.grad(data, None, 2, u)
+ dhz_dx, dhz_dy, dhz_dt = self.split(dhz_dxyt)
+
+ dex_dy = self.cast(dex_dy, mstype.float32)
+ dex_dt = self.cast(dex_dt, mstype.float32)
+ dey_dx = self.cast(dey_dx, mstype.float32)
+ dey_dt = self.cast(dey_dt, mstype.float32)
+ dhz_dx = self.cast(dhz_dx, mstype.float32)
+ dhz_dy = self.cast(dhz_dy, mstype.float32)
+ dhz_dt = self.cast(dhz_dt, mstype.float32)
+
+ loss_a1 = (self.s_hz * dhz_dy) / (self.s_ex * self.s_t * self.eps_x)
+ loss_a2 = dex_dt / self.s_t
+
+ loss_b1 = -(self.s_hz * dhz_dx) / (self.s_ey * self.s_t * self.eps_y)
+ loss_b2 = dey_dt / self.s_t
+
+ loss_c1 = (self.s_ey * dey_dx - self.s_ex * dex_dy) / (self.s_hz * self.s_t * self.mu_z)
+ loss_c2 = - dhz_dt / self.s_t
+
+ src = self.smooth_src(x, y, t) / (self.s_hz * self.s_t * self.mu_z)
+
+ pde_r1 = loss_a1 - loss_a2
+ pde_r2 = loss_b1 - loss_b2
+ pde_r3 = loss_c1 - loss_c2 - src
+ pde_r = ops.Concat(1)((pde_r1, pde_r2, pde_r3))
+ return pde_r
+
+ @ms_function
+ def boundary_condition(self, *output, **kwargs):
+ """2nd-order mur boundary condition"""
+ u = output[0]
+ data = kwargs[self.bc_name]
+
+ coord_min = self.coord_min
+ coord_max = self.coord_max
+ batch_size, _ = data.shape
+ attr = ms_np.zeros(shape=(batch_size, 4))
+ attr[:, 0] = ms_np.where(ms_np.isclose(data[:, 0], coord_min[0]), 1.0, 0.0)
+ attr[:, 1] = ms_np.where(ms_np.isclose(data[:, 0], coord_max[0]), 1.0, 0.0)
+ attr[:, 2] = ms_np.where(ms_np.isclose(data[:, 1], coord_min[1]), 1.0, 0.0)
+ attr[:, 3] = ms_np.where(ms_np.isclose(data[:, 1], coord_max[1]), 1.0, 0.0)
+
+ dex_dxyt = self.grad(data, None, 0, u)
+ _, dex_dy, _ = self.split(dex_dxyt)
+ dey_dxyt = self.grad(data, None, 1, u)
+ dey_dx, _, _ = self.split(dey_dxyt)
+ dhz_dxyt = self.grad(data, None, 2, u)
+ dhz_dx, dhz_dy, dhz_dt = self.split(dhz_dxyt)
+
+ dex_dy = self.cast(dex_dy, mstype.float32)
+ dey_dx = self.cast(dey_dx, mstype.float32)
+ dhz_dx = self.cast(dhz_dx, mstype.float32)
+ dhz_dy = self.cast(dhz_dy, mstype.float32)
+ dhz_dt = self.cast(dhz_dt, mstype.float32)
+
+ bc_r1 = dhz_dx / self.s_x - dhz_dt / (self.light_speed * self.s_x) + \
+ self.s_ex * self.light_speed * self.eps_x / (2 * self.s_hz * self.s_x) * dex_dy # x=0
+ bc_r2 = dhz_dx / self.s_x + dhz_dt / (self.light_speed * self.s_x) - \
+ self.s_ex * self.light_speed * self.eps_x / (2 * self.s_hz * self.s_x) * dex_dy # x=L
+ bc_r3 = dhz_dy / self.s_y - dhz_dt / (self.light_speed * self.s_y) - \
+ self.s_ey * self.light_speed * self.eps_y / (2 * self.s_hz * self.s_y) * dey_dx # y=0
+ bc_r4 = dhz_dy / self.s_y + dhz_dt / (self.light_speed * self.s_y) + \
+ self.s_ey * self.light_speed * self.eps_y / (2 * self.s_hz * self.s_y) * dey_dx # y=L
+
+ bc_r_all = self.concat((bc_r1, bc_r2, bc_r3, bc_r4))
+ bc_r = self.mul(bc_r_all, attr)
+ return bc_r
+
+ @ms_function
+ def initial_condition(self, *output, **kwargs):
+ """initial condition: u = 0"""
+ u = output[0]
+ return u
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/src/sampling_config.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/src/sampling_config.py
new file mode 100644
index 0000000..3e2ec10
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/src/sampling_config.py
@@ -0,0 +1,67 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""sampling information"""
+from easydict import EasyDict as edict
+
+
+src_sampling_config = edict({
+ 'domain': edict({
+ 'random_sampling': True,
+ 'size': 65536,
+ 'sampler': 'uniform'
+ }),
+ 'IC': edict({
+ 'random_sampling': True,
+ 'size': 65536,
+ 'sampler': 'uniform',
+ }),
+ 'time': edict({
+ 'random_sampling': True,
+ 'size': 65536,
+ 'sampler': 'uniform',
+ }),
+})
+
+no_src_sampling_config = edict({
+ 'domain': edict({
+ 'random_sampling': True,
+ 'size': 65536,
+ 'sampler': 'uniform'
+ }),
+ 'IC': edict({
+ 'random_sampling': True,
+ 'size': 65536,
+ 'sampler': 'uniform',
+ }),
+ 'time': edict({
+ 'random_sampling': True,
+ 'size': 65536,
+ 'sampler': 'uniform',
+ }),
+})
+
+bc_sampling_config = edict({
+ 'BC': edict({
+ 'random_sampling': True,
+ 'size': 65536,
+ 'sampler': 'uniform',
+ 'with_normal': False
+ }),
+ 'time': edict({
+ 'random_sampling': True,
+ 'size': 65536,
+ 'sampler': 'uniform',
+ }),
+})
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/src/utils.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/src/utils.py
new file mode 100644
index 0000000..0be4d14
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/src/utils.py
@@ -0,0 +1,173 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""visualization of field quantities"""
+import os
+
+import copy
+import io
+import cv2
+import PIL
+import numpy as np
+import matplotlib.gridspec as gridspec
+import matplotlib.pyplot as plt
+from mpl_toolkits.axes_grid1 import make_axes_locatable, axes_size
+
+
+plt.rcParams['figure.dpi'] = 300
+
+def visual_result(input_data, label, predict, path, name):
+ """visulization of original field and normalized field"""
+ input_data = copy.deepcopy(input_data)
+ label = copy.deepcopy(label)
+ predict = copy.deepcopy(predict)
+ visual(input_data, label, predict, path, name)
+
+ # normalize
+ ex_min = label[:, :, :, 0].min()
+ ex_max = label[:, :, :, 0].max()
+ ey_min = label[:, :, :, 1].min()
+ ey_max = label[:, :, :, 1].max()
+ hz_min = label[:, :, :, 2].min()
+ hz_max = label[:, :, :, 2].max()
+ if ex_min == ex_max:
+ ex_min = ex_min - 1
+ ex_max = ex_max + 1
+ if ey_min == ey_max:
+ ey_min = ey_min - 1
+ ey_max = ey_max + 1
+ if hz_min == hz_max:
+ hz_min = hz_min - 1
+ hz_max = hz_max + 1
+
+ label[:, :, :, 0] = 2 * (label[:, :, :, 0] - np.mean([ex_max, ex_min])) / (ex_max - ex_min)
+ label[:, :, :, 1] = 2 * (label[:, :, :, 1] - np.mean([ey_max, ey_min])) / (ey_max - ey_min)
+ label[:, :, :, 2] = 2 * (label[:, :, :, 2] - np.mean([hz_max, hz_min])) / (hz_max - hz_min)
+
+ predict[:, :, :, 0] = 2 * (predict[:, :, :, 0] - np.mean([ex_max, ex_min])) / (ex_max - ex_min)
+ predict[:, :, :, 1] = 2 * (predict[:, :, :, 1] - np.mean([ey_max, ey_min])) / (ey_max - ey_min)
+ predict[:, :, :, 2] = 2 * (predict[:, :, :, 2] - np.mean([hz_max, hz_min])) / (hz_max - hz_min)
+ visual(input_data, label, predict, path, str(name) + "_normlize")
+
+def visual(input_data, label, predict, path, name):
+ """visulization of ex/ey/hz"""
+ [sample_t, sample_x, sample_y, _] = np.shape(input_data)
+
+ # 将label、predict归一化
+ ex_vmin, ex_vmax = np.percentile(label[:, :, :, 0], [0.5, 99.5])
+ ey_vmin, ey_vmax = np.percentile(label[:, :, :, 1], [0.5, 99.5])
+ hz_vmin, hz_vmax = np.percentile(label[:, :, :, 2], [0.5, 99.5])
+
+ vmin_list = [ex_vmin, ey_vmin, hz_vmin]
+ vmax_list = [ex_vmax, ey_vmax, hz_vmax]
+
+ mean_abs_ex_label = 1.0
+ mean_abs_ey_label = 1.0
+ mean_abs_hz_label = 1.0
+
+ output_names = ["Ex", "Ey", "Hz"]
+
+ if not os.path.isdir(path):
+ os.makedirs(path)
+
+ fourcc = cv2.VideoWriter_fourcc('D', 'I', 'V', 'X')
+ fps = 10
+ size = (1920, 1440)
+ video = cv2.VideoWriter(os.path.join(path, "EH_" + str(name) + ".avi"), fourcc, fps, size)
+
+ t_set = []
+ if sample_t < 100:
+ t_set = np.arange(sample_t, dtype=np.int32)
+ else:
+ for t in range(sample_t):
+ if t % int(sample_t / 20) == 0 or t == sample_t - 1:
+ t_set.append(t)
+
+ for t in t_set:
+ ex_label = label[t, :, :, 0]
+ ey_label = label[t, :, :, 1]
+ hz_label = label[t, :, :, 2]
+
+ ex_predict = predict[t, :, :, 0]
+ ey_predict = predict[t, :, :, 1]
+ hz_predict = predict[t, :, :, 2]
+
+ ex_label_2d = np.reshape(np.array(ex_label), (sample_x, sample_y))
+ ey_label_2d = np.reshape(np.array(ey_label), (sample_x, sample_y))
+ hz_label_2d = np.reshape(np.array(hz_label), (sample_x, sample_y))
+
+ ex_predict_2d = np.reshape(np.array(ex_predict), (sample_x, sample_y))
+ ey_predict_2d = np.reshape(np.array(ey_predict), (sample_x, sample_y))
+ hz_predict_2d = np.reshape(np.array(hz_predict), (sample_x, sample_y))
+
+ ex_error_2d = np.abs(ex_predict_2d - ex_label_2d) / mean_abs_ex_label
+ ey_error_2d = np.abs(ey_predict_2d - ey_label_2d) / mean_abs_ey_label
+ hz_error_2d = np.abs(hz_predict_2d - hz_label_2d) / mean_abs_hz_label
+
+ label_2d = [ex_label_2d, ey_label_2d, hz_label_2d]
+ predict_2d = [ex_predict_2d, ey_predict_2d, hz_predict_2d]
+ error_2d = [ex_error_2d, ey_error_2d, hz_error_2d]
+
+ lpe_2d = [label_2d, predict_2d, error_2d]
+ lpe_names = ["label", "predict", "error"]
+
+ fig = plt.figure()
+
+ gs = gridspec.GridSpec(3, 3)
+
+ title = "t={:d}".format(t)
+ plt.suptitle(title, fontsize=14)
+
+ gs_idx = int(0)
+
+ for i, data_2d in enumerate(lpe_2d):
+ for j, data in enumerate(data_2d):
+ ax = fig.add_subplot(gs[gs_idx])
+ gs_idx += 1
+
+ if lpe_names[i] == "error":
+ img = ax.imshow(data.T, vmin=0, vmax=1,
+ cmap=plt.get_cmap("jet"), origin='lower')
+ else:
+ img = ax.imshow(data.T, vmin=vmin_list[j], vmax=vmax_list[j],
+ cmap=plt.get_cmap("jet"), origin='lower')
+
+ ax.set_title(output_names[j] + " " + lpe_names[i], fontsize=4)
+ plt.xticks(size=4)
+ plt.yticks(size=4)
+
+ aspect = 20
+ pad_fraction = 0.5
+ divider = make_axes_locatable(ax)
+ width = axes_size.AxesY(ax, aspect=1/aspect)
+ pad = axes_size.Fraction(pad_fraction, width)
+ cax = divider.append_axes("right", size=width, pad=pad)
+ cb = plt.colorbar(img, cax=cax)
+ cb.ax.tick_params(labelsize=4)
+
+ gs.tight_layout(fig, pad=0.4, w_pad=0.4, h_pad=0.4)
+
+ # save image to memory buffer
+ buffer_ = io.BytesIO()
+ fig.savefig(buffer_, format="jpg")
+ buffer_.seek(0)
+ image = PIL.Image.open(buffer_)
+
+ video.write(np.asarray(image))
+
+ buffer_.close()
+
+ plt.close()
+
+ video.release()
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/test_time_domain_maxwell.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/test_time_domain_maxwell.py
new file mode 100644
index 0000000..e1ed748
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/networks/test_time_domain_maxwell/test_time_domain_maxwell.py
@@ -0,0 +1,126 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""train process"""
+import json
+import math
+import pytest
+import numpy as np
+
+from mindspore.common import set_seed
+from mindspore import context, Tensor, nn
+from mindspore.train import DynamicLossScaleManager
+from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
+from mindspore.train.serialization import load_checkpoint, load_param_into_net
+import mindspore.common.dtype as mstype
+from mindspore.common.initializer import HeUniform
+
+from mindelec.loss import Constraints
+from mindelec.solver import Solver, LossAndTimeMonitor
+from mindelec.common import L2
+from mindelec.architecture import MultiScaleFCCell, MTLWeightedLossCell
+
+from src.dataset import create_train_dataset, create_random_dataset
+from src.maxwell import Maxwell2DMur
+from src.lr_scheduler import MultiStepLR
+
+set_seed(123456)
+np.random.seed(123456)
+
+context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend", save_graphs_path="./graph")
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_time_domain_maxwell():
+ """training process"""
+ config = json.load(open("./config.json"))
+ print("check config: {}".format(config))
+ if config["random_sampling"]:
+ elec_train_dataset = create_random_dataset(config)
+ else:
+ elec_train_dataset = create_train_dataset(config["train_data_path"])
+ train_dataset = elec_train_dataset.create_dataset(batch_size=config["train_batch_size"],
+ shuffle=True,
+ prebatched_data=True,
+ drop_remainder=True)
+ steps_per_epoch = len(elec_train_dataset)
+ print("check train dataset size: ", len(elec_train_dataset))
+
+ # define network
+ model = MultiScaleFCCell(config["input_size"],
+ config["output_size"],
+ layers=config["layers"],
+ neurons=config["neurons"],
+ input_scale=config["input_scale"],
+ residual=config["residual"],
+ weight_init=HeUniform(negative_slope=math.sqrt(5)),
+ act="sin",
+ num_scales=config["num_scales"],
+ amp_factor=config["amp_factor"],
+ scale_factor=config["scale_factor"]
+ )
+
+ model.to_float(mstype.float16)
+ model.input_scale.to_float(mstype.float32)
+ mtl = MTLWeightedLossCell(num_losses=elec_train_dataset.num_dataset)
+
+ # define problem
+ train_prob = {}
+ for dataset in elec_train_dataset.all_datasets:
+ train_prob[dataset.name] = Maxwell2DMur(model=model, config=config,
+ domain_name=dataset.name + "_points",
+ ic_name=dataset.name + "_points",
+ bc_name=dataset.name + "_points")
+ print("check problem: ", train_prob)
+ train_constraints = Constraints(elec_train_dataset, train_prob)
+
+ # optimizer
+ params = model.trainable_params() + mtl.trainable_params()
+ lr_scheduler = MultiStepLR(config["lr"], config["milestones"], config["lr_gamma"],
+ steps_per_epoch, config["train_epoch"])
+ lr = lr_scheduler.get_lr()
+ optim = nn.Adam(params, learning_rate=Tensor(lr))
+
+ if config["load_ckpt"]:
+ param_dict = load_checkpoint(config["load_ckpt_path"])
+ load_param_into_net(model, param_dict)
+ load_param_into_net(mtl, param_dict)
+ # define solver
+ solver = Solver(model,
+ optimizer=optim,
+ mode="PINNs",
+ train_constraints=train_constraints,
+ test_constraints=None,
+ metrics={'l2': L2(), 'distance': nn.MAE()},
+ loss_fn='smooth_l1_loss',
+ loss_scale_manager=DynamicLossScaleManager(init_loss_scale=2 ** 10, scale_window=2000),
+ mtl_weighted_cell=mtl,
+ )
+
+ loss_time_callback = LossAndTimeMonitor(steps_per_epoch)
+ callbacks = [loss_time_callback]
+ if config["save_ckpt"]:
+ config_ck = CheckpointConfig(save_checkpoint_steps=10,
+ keep_checkpoint_max=2)
+ ckpoint_cb = ModelCheckpoint(prefix='ckpt_maxwell_frq1e9',
+ directory=config["save_ckpt_path"], config=config_ck)
+ callbacks += [ckpoint_cb]
+
+ solver.train(config["train_epoch"], train_dataset, callbacks=callbacks, dataset_sink_mode=True)
+
+ assert loss_time_callback.get_loss() <= 2.5
+ assert loss_time_callback.get_step_time() <= 75.0
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/operators/test_derivatives.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/operators/test_derivatives.py
new file mode 100644
index 0000000..13500dd
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/operators/test_derivatives.py
@@ -0,0 +1,235 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#pylint: disable=W0235
+"""
+test derivatives
+"""
+import pytest
+import numpy as np
+
+from mindelec.operators import Grad, SecondOrderGrad, Jacobian, Hessian
+from mindspore import Tensor, ops
+from mindspore import dtype as mstype
+from mindspore import context
+from mindspore import nn
+context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
+
+
+def func(x):
+ return x * x * x
+
+
+class Net(nn.Cell):
+ def __init__(self):
+ super(Net, self).__init__()
+ self.w = Tensor(np.array([[1, 5, 2], [8, 4, 6], [7, 6, 9], [5, 1, 5]], np.float32))
+ self.matmul = ops.MatMul()
+
+ def construct(self, a, b):
+ x = ops.Concat(1)((a, b))
+ return self.matmul(self.w, x)
+
+
+class Net1(nn.Cell):
+ def __init__(self):
+ super(Net1, self).__init__()
+
+ def construct(self, x, y):
+ out = x * x * x * y * 5 + 3 * y * x * x + 15 * x * y
+ return out.sum()
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_grad_error():
+ """test grad error"""
+ # check argnum error
+ with pytest.raises(TypeError):
+ Grad(func, "a")
+ # check model error
+ with pytest.raises(TypeError):
+ Grad("a", 0)
+
+ x = Tensor(np.array([[1.0, -2.0], [-3.0, 4.0]]).astype(np.float32))
+ x1 = Tensor(np.array([[[1.0, -2.0], [-3.0, 4.0]]]).astype(np.float32))
+ out = func(x)
+ grad = Grad(func)
+ # check net input type error
+ with pytest.raises(TypeError):
+ grad(0.0, 0, 0, out)
+ # check input index type error
+ with pytest.raises(TypeError):
+ grad(x, [0], 0, out)
+ # check output index type error
+ with pytest.raises(TypeError):
+ grad(x, 0, 1.2, out)
+ # check net output type error
+ with pytest.raises(TypeError):
+ grad(x, 0, 0, (1,))
+ # check net input value error
+ with pytest.raises(ValueError):
+ grad(x1, 0, 0, out)
+ # check input index value error
+ with pytest.raises(ValueError):
+ grad(x, 7, 0, out)
+ # check output index value error
+ with pytest.raises(ValueError):
+ grad(x, 0, 7, out)
+ # check net output value error
+ with pytest.raises(ValueError):
+ grad(x, 0, 0, func(x1))
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_grad():
+ """test grad"""
+ x = Tensor(np.array([[1.0, -2.0], [-3.0, 4.0]]).astype(np.float32))
+ out = func(x)
+ grad = Grad(func)
+ output = grad(x, 0, 0, out).asnumpy()
+ res = np.array([[3.0], [27.0]], np.float32)
+ assert (output == res).any()
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_grad_two_input():
+ """test_grad_two_input"""
+ a = Tensor(np.array([[1, 3], [5, 9], [8, 2]], np.float32))
+ b = Tensor(np.array([[4, 6], [7, 2], [2, 1]], np.float32))
+ net = Net()
+ out = net(a, b)
+ grad = Grad(net)
+ output = grad(a, b, 0, 0, out).asnumpy()
+ res = np.array([[21.0], [16.0], [22.0]], np.float32)
+ assert (output == res).any()
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_second_order_grad_error():
+ """test_second_order_grad_error"""
+ # check input_idx1 type error
+ with pytest.raises(TypeError):
+ SecondOrderGrad(func, "a", 1, 0)
+ # check input_idx2 type error
+ with pytest.raises(TypeError):
+ SecondOrderGrad(func, 1, 1.0, 0)
+ # check output_idx type error
+ with pytest.raises(TypeError):
+ SecondOrderGrad(func, 1, 1, True)
+ # check model error
+ with pytest.raises(TypeError):
+ Grad("a", 0, 0, 0)
+
+ x = Tensor(np.array([[1.0, -2.0], [-3.0, 4.0]]).astype(np.float32))
+ x1 = Tensor(np.array([[[1.0, -2.0], [-3.0, 4.0]]]).astype(np.float32))
+ # check net input type error
+ with pytest.raises(TypeError):
+ second_order_grad = SecondOrderGrad(func, 0, 0, 0)
+ second_order_grad(0.0)
+ # check net input value error
+ with pytest.raises(ValueError):
+ second_order_grad = SecondOrderGrad(func, 0, 0, 0)
+ second_order_grad(x1)
+ # check input index 1 value error
+ with pytest.raises(ValueError):
+ second_order_grad = SecondOrderGrad(func, 7, 0, 0)
+ second_order_grad(x)
+ # check input index 2 value error
+ with pytest.raises(ValueError):
+ second_order_grad = SecondOrderGrad(func, 0, 7, 0)
+ second_order_grad(x)
+ # check output index value error
+ with pytest.raises(ValueError):
+ second_order_grad = SecondOrderGrad(func, 0, 0, 7)
+ second_order_grad(x)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_second_order_grad():
+ """test_second_order_grad"""
+ x = Tensor(np.array([[1.0, -2.0], [-3.0, 4.0]]).astype(np.float32))
+ second_order_grad = SecondOrderGrad(func, 0, 0, 0)
+ output = second_order_grad(x).asnumpy()
+ res = np.array([[6.0], [-18.0]], np.float32)
+ assert output.any() == res.any()
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_class_jacobian_type_error():
+ with pytest.raises(TypeError):
+ Jacobian(Net1(), "a", 1)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_class_jacobian():
+ """test Jacobian"""
+ a = Tensor(np.array([[1, 3], [5, 9], [8, 2]], np.float32))
+ b = Tensor(np.array([[4, 6], [7, 2], [2, 1]], np.float32))
+ jac = Jacobian(Net(), 0, 0)
+ output = jac(a, b).asnumpy()
+ res = Tensor([[[[1, 0], [5, 0], [2, 0]], [[0, 1], [0, 5], [0, 2]],
+ [[0, 0], [0, 0], [0, 0]], [[0, 0], [0, 0], [0, 0]]],
+ [[[8, 0], [4, 0], [6, 0]], [[0, 8], [0, 4], [0, 6]],
+ [[0, 0], [0, 0], [0, 0]], [[0, 0], [0, 0], [0, 0]]],
+ [[[7, 0], [6, 0], [9, 0]], [[0, 7], [0, 6], [0, 9]],
+ [[0, 0], [0, 0], [0, 0]], [[0, 0], [0, 0], [0, 0]]],
+ [[[5, 0], [1, 0], [5, 0]], [[0, 5], [0, 1], [0, 5]],
+ [[0, 0], [0, 0], [0, 0]], [[0, 0], [0, 0], [0, 0]]]], mstype.float32).asnumpy()
+ assert output.any() == res.any()
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_class_hessian_type_error():
+ with pytest.raises(TypeError):
+ Hessian(Net1(), "a", 1)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_class_hessian():
+ """test Hessian"""
+ a = Tensor(np.array([[1, 3], [5, 9], [8, 2]], np.float32))
+ b = Tensor(np.array([[4, 6], [7, 2], [2, 1]], np.float32))
+ hes = Hessian(Net1(), 0, 1)
+ output = hes(a, b).asnumpy()
+ res = Tensor([[[[36, 0], [0, 0], [0, 0]], [[0, 168], [0, 0], [0, 0]]],
+ [[[0, 0], [420, 0], [0, 0]], [[0, 0], [0, 1284], [0, 0]]],
+ [[[0, 0], [0, 0], [1023, 0]], [[0, 0], [0, 0], [0, 87]]]], mstype.float32).asnumpy()
+ assert output.any() == res.any()
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/solver/test_solver.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/solver/test_solver.py
new file mode 100644
index 0000000..e8f8f24
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/solver/test_solver.py
@@ -0,0 +1,104 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#pylint: disable=W0622
+"""
+test parameterization
+"""
+
+import pytest
+import numpy as np
+import mindspore.nn as nn
+from mindspore.common import set_seed
+import mindspore.common.dtype as mstype
+from mindspore import context, Tensor
+
+from mindelec.solver import Solver, Problem, LossAndTimeMonitor
+from mindelec.data import ExistedDataConfig, Dataset
+from mindelec.loss import Constraints
+from mindelec.common import L2
+
+set_seed(0)
+np.random.seed(0)
+
+context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
+
+
+class NetWithoutLoss(nn.Cell):
+ """define network"""
+ def __init__(self, input_dim, output_dim):
+ super(NetWithoutLoss, self).__init__()
+ self.fc1 = nn.Dense(input_dim, 64)
+ self.fc2 = nn.Dense(64, output_dim)
+
+ def construct(self, *input):
+ x = input[0]
+ out = self.fc1(x)
+ out = self.fc2(out)
+ return out
+
+
+class RectPde(Problem):
+ def __init__(self, domain_name):
+ self.domain_name = domain_name
+
+ def governing_equation(self, *output, **kwargs):
+ u = output[0]
+ x = kwargs[self.domain_name]
+ return u - x
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_solver():
+ """
+ test solver
+ """
+ input_path = "./input.npy"
+ label_path = "./label.npy"
+ input = np.random.randn(1000, 3)
+ output = np.random.randn(1000, 3)
+ np.save(input_path, input)
+ np.save(label_path, output)
+
+ exist_train = ExistedDataConfig(name="existed_data",
+ data_dir=[input_path, label_path],
+ columns_list=["inputs", "label"],
+ constraint_type="Equation",
+ data_format="npy")
+
+ dataset = Dataset(existed_data_list=[exist_train])
+ train_dataset = dataset.create_dataset(batch_size=500, shuffle=True)
+ steps_per_epoch = len(dataset)
+ prob_dict = {exist_train.name: RectPde(domain_name="existed_data_inputs")}
+ train_constraints = Constraints(dataset, prob_dict)
+
+ model = NetWithoutLoss(3, 3)
+ optim = nn.Adam(model.trainable_params(), learning_rate=1e-4)
+
+ solver = Solver(network=model,
+ mode="PINNs",
+ optimizer=optim,
+ train_constraints=train_constraints,
+ train_input_map={"existed_data": ["existed_data_inputs"]},
+ metrics={'l2': L2(), 'distance': nn.MAE()})
+
+ loss_time_callback = LossAndTimeMonitor(steps_per_epoch)
+ solver.train(5, train_dataset, callbacks=[loss_time_callback])
+
+ pred_input = Tensor(np.random.randn(20, 3), mstype.float32)
+ pred_output = solver.predict(pred_input)
+ assert pred_output.shape == (20, 3)
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/test_mindelec.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/test_mindelec.py
new file mode 100644
index 0000000..bba4aa2
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/test_mindelec.py
@@ -0,0 +1,24 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Test mindelec."""
+import pytest
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_empty():
+ assert 1 < 2
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/vision/test_body.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/vision/test_body.py
new file mode 100644
index 0000000..06af229
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/vision/test_body.py
@@ -0,0 +1,34 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Visualization of the results 3D VTK form"""
+
+import pytest
+import numpy as np
+from mindelec.vision import vtk_structure
+
+def vtk_structure_all():
+ """vtk_structure_all"""
+ grid = np.random.rand(20, 10, 10, 10, 4).astype(np.float32)
+ eh = np.random.rand(20, 10, 10, 10, 6).astype(np.float32)
+ vtk_structure(grid, eh, './result_vtk')
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_body():
+ """test_body"""
+ vtk_structure_all()
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/vision/test_plane.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/vision/test_plane.py
new file mode 100644
index 0000000..b832139
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/vision/test_plane.py
@@ -0,0 +1,41 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Visualization of the results in 2D image form"""
+
+import numpy as np
+import pytest
+from mindelec.vision import plot_s11, plot_eh
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_plane_plot_s11():
+ """test plane plot S11"""
+ s11 = np.random.rand(1001, 2).astype(np.float32)
+ s11[:, 0] = np.linspace(0, 4 * 10 ** 9, 1001)
+ s11 = s11.astype(np.float32)
+ plot_s11(s11, './result_s11', 's11')
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_plane_plot_eh():
+ """test plane plot eh"""
+ eh = np.random.rand(20, 10, 10, 10, 6).astype(np.float32)
+ plot_eh(eh, './result_eh', 5, 300)
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/vision/test_print_scatter.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/vision/test_print_scatter.py
new file mode 100644
index 0000000..cfad2e2
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/vision/test_print_scatter.py
@@ -0,0 +1,32 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Visualization of the results in graph 1d 2d"""
+
+import os
+import numpy as np
+import pytest
+from mindelec.vision import print_graph_1d, print_graph_2d
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_print_scatter():
+ """test print scatter"""
+ print_graph_1d("output.jpg", np.ones(10), "./graph_1d")
+ print_graph_2d("output.jpg", np.ones(10), np.ones(10), "./graph_2d")
+ assert os.path.exists("./graph_1d/output.jpg")
+ assert os.path.exists("./graph_2d/output.jpg")
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindelec/vision/test_video.py b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/vision/test_video.py
new file mode 100644
index 0000000..48140db
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindelec/vision/test_video.py
@@ -0,0 +1,43 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Visualization of the results 3D VTK form"""
+
+import os
+import pytest
+import numpy as np
+from mindelec.vision import plot_eh, image_to_video
+
+
+def image_to_video_temp():
+ """image to video test"""
+ path_image = './image'
+ eh = np.random.rand(5, 10, 10, 10, 6).astype(np.float32)
+ plot_eh(eh, path_image, 5, 300)
+
+ path_video = './result_video'
+ video_name = 'video.avi'
+ fps = 20
+ image_to_video(path_image, path_video, video_name, fps)
+
+ assert os.path.exists(path_video)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend_training
+@pytest.mark.platform_x86_ascend_training
+@pytest.mark.env_onecard
+def test_image_to_video():
+ """test image to video"""
+ image_to_video_temp()
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindsponge/test_covid/__init__.py b/reproduce/AlphaFold2-Chinese/tests/st/mindsponge/test_covid/__init__.py
new file mode 100644
index 0000000..6228b71
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindsponge/test_covid/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindsponge/test_covid/min/test_case_covid_min.py b/reproduce/AlphaFold2-Chinese/tests/st/mindsponge/test_covid/min/test_case_covid_min.py
new file mode 100644
index 0000000..6c756c5
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindsponge/test_covid/min/test_case_covid_min.py
@@ -0,0 +1,64 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Test case covid min"""
+
+import time
+import numpy as np
+import pytest
+from mindspore import context, Tensor
+from mindsponge.md.simulation import Simulation
+
+class ArgsOpt():
+ """ArgsOpt"""
+ def __init__(self):
+ self.amber_parm = '/home/workspace/mindspore_dataset/mindsponge_data/min1/s1ace2.parm7'
+ self.box = ''
+ self.c = '/home/workspace/mindspore_dataset/mindsponge_data/min1/s1ace2_min1.rst7'
+ self.checkpoint = ''
+ self.device_id = 0
+ self.i = '/home/workspace/mindspore_dataset/mindsponge_data/min1/min1.in'
+ self.o = ''
+ self.r = ''
+ self.u = False
+ self.x = ''
+
+@pytest.mark.level0
+@pytest.mark.platform_x86_gpu_training
+@pytest.mark.env_onecard
+def test_case_poly():
+ """test_case_covid_min"""
+ context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
+ args_opt = ArgsOpt()
+ simulation = Simulation(args_opt)
+ for i in range(10):
+ temperature, total_potential_energy, sigma_of_bond_ene, sigma_of_angle_ene, sigma_of_dihedral_ene, \
+ nb14_lj_energy_sum, nb14_cf_energy_sum, lj_energy_sum, ee_ene, _, _, _, _ = \
+ simulation(Tensor(i), Tensor(0))
+ if i == 0:
+ print(temperature, total_potential_energy, sigma_of_bond_ene, sigma_of_angle_ene, sigma_of_dihedral_ene, \
+ nb14_lj_energy_sum, nb14_cf_energy_sum, lj_energy_sum, ee_ene)
+ start = time.time()
+ assert np.allclose(round(float(temperature.asnumpy()), 3), 0.000, rtol=0.1)
+ assert np.allclose(round(float(total_potential_energy.asnumpy()), 3), 39327864.000, rtol=0.1)
+ assert np.allclose(round(float(sigma_of_bond_ene.asnumpy()), 3), 418.748, rtol=0.1)
+ assert np.allclose(round(float(sigma_of_angle_ene.asnumpy()), 3), 1351.111, rtol=0.1)
+ assert np.allclose(round(float(sigma_of_dihedral_ene.asnumpy()), 3), 9382.757, rtol=0.1)
+ assert np.allclose(round(float(nb14_lj_energy_sum.asnumpy()), 3), 3714.295, rtol=0.1)
+ assert np.allclose(round(float(nb14_cf_energy_sum.asnumpy()), 3), 36175.125, rtol=0.1)
+ assert np.allclose(round(float(lj_energy_sum.asnumpy()), 3), 39634900.000, rtol=0.1)
+ assert np.allclose(round(float(ee_ene.asnumpy()), 3), -358078.625, rtol=0.1)
+ end = time.time()
+
+ assert ((end - start) / 9) < 0.1
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindsponge/test_covid/pres/test_case_covid_pres.py b/reproduce/AlphaFold2-Chinese/tests/st/mindsponge/test_covid/pres/test_case_covid_pres.py
new file mode 100644
index 0000000..7000f47
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindsponge/test_covid/pres/test_case_covid_pres.py
@@ -0,0 +1,67 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Test case covid pres"""
+
+import time
+import numpy as np
+import pytest
+from mindspore import context, Tensor
+import mindspore.common.dtype as mstype
+from mindsponge.md.npt import NPT as Simulation
+
+
+class ArgsOpt():
+ """ArgsOpt"""
+ def __init__(self):
+ self.amber_parm = '/home/workspace/mindspore_dataset/mindsponge_data/pres/s1ace2.parm7'
+ self.box = ''
+ self.c = '/home/workspace/mindspore_dataset/mindsponge_data/pres/s1ace2_heat.rst7'
+ self.checkpoint = ''
+ self.device_id = 0
+ self.i = '/home/workspace/mindspore_dataset/mindsponge_data/pres/pres.in'
+ self.o = ''
+ self.r = ''
+ self.u = False
+ self.x = ''
+
+@pytest.mark.level0
+@pytest.mark.platform_x86_gpu_training
+@pytest.mark.env_onecard
+def test_case_poly():
+ """test_case_covid_min"""
+ context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
+ args_opt = ArgsOpt()
+ simulation = Simulation(args_opt)
+ for i in range(1, 11):
+ print_step = 1 if i % simulation.ntwx == 0 or i == 1 or i == simulation.md_info.step_limit else 0
+ update_step = 1 if (i != 1 and i % simulation.update_interval == 0) else 0
+ temperature, total_potential_energy, sigma_of_bond_ene, sigma_of_angle_ene, sigma_of_dihedral_ene, \
+ nb14_lj_energy_sum, nb14_cf_energy_sum, lj_energy_sum, ee_ene, _, _, _, _, _, _, _, _ = \
+ simulation(Tensor(i), Tensor(print_step), Tensor(update_step, mstype.int32))
+ if i == 1:
+ start = time.time()
+ print(temperature, total_potential_energy, sigma_of_bond_ene, sigma_of_angle_ene, sigma_of_dihedral_ene, \
+ nb14_lj_energy_sum, nb14_cf_energy_sum, lj_energy_sum, ee_ene)
+ assert np.allclose(round(float(temperature.asnumpy()), 3), 298.406, rtol=0.1)
+ assert np.allclose(round(float(total_potential_energy.asnumpy()), 3), -320432.750, rtol=0.1)
+ assert np.allclose(round(float(sigma_of_bond_ene.asnumpy()), 3), 4228.548, rtol=0.1)
+ assert np.allclose(round(float(sigma_of_angle_ene.asnumpy()), 3), 6081.921, rtol=0.1)
+ assert np.allclose(round(float(sigma_of_dihedral_ene.asnumpy()), 3), 10484.753, rtol=0.1)
+ assert np.allclose(round(float(nb14_lj_energy_sum.asnumpy()), 3), 2990.386, rtol=0.1)
+ assert np.allclose(round(float(nb14_cf_energy_sum.asnumpy()), 3), 34394.328, rtol=0.1)
+ assert np.allclose(round(float(lj_energy_sum.asnumpy()), 3), 36317.559, rtol=0.1)
+ assert np.allclose(round(float(ee_ene.asnumpy()), 3), -414930.250, rtol=0.1)
+ end = time.time()
+ assert ((end - start) / 9) < 0.1
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindsponge/test_mct/__init__.py b/reproduce/AlphaFold2-Chinese/tests/st/mindsponge/test_mct/__init__.py
new file mode 100644
index 0000000..6228b71
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindsponge/test_mct/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindsponge/test_mct/test_case_mct.py b/reproduce/AlphaFold2-Chinese/tests/st/mindsponge/test_mct/test_case_mct.py
new file mode 100644
index 0000000..7d7f270
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindsponge/test_mct/test_case_mct.py
@@ -0,0 +1,527 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Test case mct"""
+
+import time
+import numpy as np
+import pytest
+from mindspore import context, load_checkpoint, ops, nn, Tensor
+
+import mindspore.numpy as msnp
+import mindspore.common.dtype as mstype
+from mindspore.common.parameter import Parameter
+from mindspore.ops import functional as F
+from mindspore.ops import operations as P
+from mindspore.ops import constexpr
+from mindspore.ops import composite as C
+
+from mindsponge import Angle
+from mindsponge import Bond
+from mindsponge import Dihedral
+from mindsponge import LennardJonesInformation
+from mindsponge import NonBond14
+from mindsponge import ParticleMeshEwald
+from mindsponge import LangevinLiujian
+from mindsponge import MdInformation
+from mindsponge import NeighborList
+from mindsponge.md.cybertron.meta_dynamics import Bias
+from mindsponge.md.cybertron.units import units
+from mindsponge.md.cybertron.models import MolCT
+from mindsponge.md.cybertron.readouts import AtomwiseReadout
+from mindsponge.md.cybertron.cybertron import Cybertron
+
+standard_normal = ops.StandardNormal()
+zeros = ops.Zeros()
+
+WALL_P = 9e08
+WALL_POTENTIAL = np.zeros(200, dtype=np.float32)
+WALL_POTENTIAL[0] = WALL_P
+WALL_POTENTIAL[1] = WALL_P
+WALL_POTENTIAL[2] = WALL_P
+WALL_POTENTIAL[-1] = WALL_P
+WALL_POTENTIAL[-2] = WALL_P
+WALL_POTENTIAL[-3] = WALL_P
+SMIN = 0
+SMAX = 8
+DS = 0.04
+OMEGA = 50
+SIGMA = 0.005
+DDT = 0.001
+T = 300
+ALPHA = 0.5
+GAMMA = 6
+KAPPA = 4
+UPPER_BOUND_INDEX = 190
+LOWER_BOUND_INDEX = 10
+WALL_FACTOR = 0.1
+
+@constexpr
+def get_full_tensor(shape, fill_value, dtype=np.float32):
+ '''get_full_tensor'''
+ return msnp.full(shape, fill_value, dtype)
+
+
+class Controller:
+ '''controller'''
+
+ def __init__(self, args_opt):
+ self.input_file = args_opt.i
+ self.initial_coordinates_file = args_opt.c
+ self.amber_parm = args_opt.amber_parm
+ self.restrt = args_opt.r
+ self.mdcrd = args_opt.x
+ self.mdout = args_opt.o
+ self.mdbox = args_opt.box
+ self.meta = args_opt.meta
+ self.with_box = args_opt.with_box
+ self.np_iter = args_opt.np_iter
+ self.command_set = {}
+ self.md_task = None
+ self.commands_from_in_file()
+
+ def commands_from_in_file(self):
+ '''command from in file'''
+ file = open(self.input_file, 'r')
+ ct = file.readlines()
+ file.close()
+ self.md_task = ct[0].strip()
+ for val in ct:
+ if "=" in val:
+ assert len(val.strip().split("=")) == 2
+ flag, value = val.strip().split("=")
+ value = value.replace(",", '')
+ flag = flag.replace(" ", "")
+ if flag not in self.command_set:
+ self.command_set[flag] = value
+ else:
+ print("ERROR COMMAND FILE")
+
+class SimulationCybertron(nn.Cell):
+ '''simulation'''
+
+ def __init__(self, args_opt, network=None):
+ super().__init__()
+ self.control = Controller(args_opt)
+ if self.control.meta:
+ self.meta = Tensor([1], mstype.int32)
+ else:
+ self.meta = Tensor([0], mstype.int32)
+ self.md_info = MdInformation(self.control)
+ self.bond = Bond(self.control)
+ self.angle = Angle(self.control)
+ self.dihedral = Dihedral(self.control)
+ self.nb14 = NonBond14(self.control, self.dihedral, self.md_info.atom_numbers)
+ self.nb_info = NeighborList(self.control, self.md_info.atom_numbers, self.md_info.box_length)
+ self.lj_info = LennardJonesInformation(self.control, self.md_info.nb.cutoff, self.md_info.sys.box_length)
+ self.liujian_info = LangevinLiujian(self.control, self.md_info.atom_numbers)
+ self.pme_method = ParticleMeshEwald(self.control, self.md_info)
+ self.bond_energy_sum = Tensor(0, mstype.int32)
+ self.angle_energy_sum = Tensor(0, mstype.int32)
+ self.dihedral_energy_sum = Tensor(0, mstype.int32)
+ self.nb14_lj_energy_sum = Tensor(0, mstype.int32)
+ self.nb14_cf_energy_sum = Tensor(0, mstype.int32)
+ self.lj_energy_sum = Tensor(0, mstype.int32)
+ self.ee_ene = Tensor(0, mstype.int32)
+ self.total_energy = Tensor(0, mstype.int32)
+ # Init scalar
+ self.ntwx = self.md_info.ntwx
+ self.atom_numbers = self.md_info.atom_numbers
+ self.residue_numbers = self.md_info.residue_numbers
+ self.bond_numbers = self.bond.bond_numbers
+ self.angle_numbers = self.angle.angle_numbers
+ self.dihedral_numbers = self.dihedral.dihedral_numbers
+ self.nb14_numbers = self.nb14.nb14_numbers
+ self.nxy = self.nb_info.nxy
+ self.grid_numbers = self.nb_info.grid_numbers
+ self.max_atom_in_grid_numbers = self.nb_info.max_atom_in_grid_numbers
+ self.max_neighbor_numbers = self.nb_info.max_neighbor_numbers
+ self.excluded_atom_numbers = self.nb_info.excluded_atom_numbers
+ self.refresh_count = Parameter(Tensor(self.nb_info.refresh_count, mstype.int32), requires_grad=False)
+ self.refresh_interval = self.nb_info.refresh_interval
+ self.skin = self.nb_info.skin
+ self.cutoff = self.nb_info.cutoff
+ self.cutoff_square = self.nb_info.cutoff_square
+ self.cutoff_with_skin = self.nb_info.cutoff_with_skin
+ self.half_cutoff_with_skin = self.nb_info.half_cutoff_with_skin
+ self.cutoff_with_skin_square = self.nb_info.cutoff_with_skin_square
+ self.half_skin_square = self.nb_info.half_skin_square
+ self.beta = self.pme_method.beta
+ self.fftx = self.pme_method.fftx
+ self.ffty = self.pme_method.ffty
+ self.fftz = self.pme_method.fftz
+ self.random_seed = self.liujian_info.random_seed
+ self.dt = self.liujian_info.dt
+ self.half_dt = self.liujian_info.half_dt
+ self.exp_gamma = self.liujian_info.exp_gamma
+
+ self.tmp_forces = Tensor(np.zeros((self.atom_numbers, 3)), dtype=mstype.float32)
+
+ self.bias_potential = Parameter(Tensor(WALL_POTENTIAL, mstype.float32), requires_grad=True)
+ self.grid_num = 200
+ self.wall_potential = WALL_P
+
+ self.meta_interval = 5
+
+ self.wall_factor = WALL_FACTOR
+ self.upper_bound_index = UPPER_BOUND_INDEX
+ self.lower_bound_index = LOWER_BOUND_INDEX
+ self.kappa = KAPPA
+ self.smin = (Tensor(SMIN, mstype.float32),)
+ self.smax = SMAX
+ self.t = T
+ self.alpha = ALPHA
+ self.gamma = GAMMA
+ self.ds = DS
+ self.ddt = DDT
+ self.sum = ops.ReduceSum()
+ self.omega = OMEGA
+ self.sigma = Tensor(SIGMA, mstype.float32)
+ self.exp = ops.Exp()
+ self.square = ops.Square()
+ self.sqrt = ops.Sqrt()
+ self.zeros = ops.Zeros()
+ self.ones = ops.Ones()
+ self.norm = nn.Norm()
+ self.add = ops.Add()
+ self.cast = ops.Cast()
+ self.cv_list = Tensor(np.arange(SMIN, SMAX, DS, dtype=np.float32)[0:self.grid_num], dtype=mstype.float32)
+ self.init_tensor()
+ self.op_define()
+ self.update = False
+ self.constant_random_force = Tensor(np.zeros([self.atom_numbers, 3], np.float32), mstype.float32)
+ self.max_vel = 20
+ self.hsigmoid = nn.HSigmoid()
+ self.one_hill = Tensor([1], mstype.int32)
+ self.sqrt2 = Tensor(np.sqrt(2), mstype.float32)
+ self.kb = units.boltzmann()
+ self.kbt = self.kb * self.t
+ self.beta = 1.0 / self.kbt
+ self.wt_factor = -1.0 / (self.gamma - 1.0) * self.beta
+
+ self.network = network
+ self.index_add = ops.IndexAdd(axis=-1)
+ self.bias = Bias
+ self.keep_sum = P.ReduceSum(keep_dims=True)
+ self.grad = C.GradOperation()
+ self.squeeze = P.Squeeze(0)
+ self.file = None
+
+ def init_tensor(self):
+ '''init tensor'''
+ self.hills = Parameter(Tensor(np.zeros(self.grid_num), mstype.float32), requires_grad=False)
+ self.crd = Parameter(
+ Tensor(np.float32(np.asarray(self.md_info.coordinate).reshape([self.atom_numbers, 3])), mstype.float32),
+ requires_grad=False)
+ self.crd_to_uint_crd_cof = Tensor(np.asarray(self.md_info.pbc.crd_to_uint_crd_cof, np.float32), mstype.float32)
+ self.uint_dr_to_dr_cof = Parameter(
+ Tensor(np.asarray(self.md_info.pbc.uint_dr_to_dr_cof, np.float32), mstype.float32), requires_grad=False)
+ self.box_length = Tensor(self.md_info.box_length, mstype.float32)
+ self.virtual_box_length = Tensor([0., 0., 0.], mstype.float32)
+ self.charge = Parameter(Tensor(np.asarray(self.md_info.h_charge, dtype=np.float32), mstype.float32),
+ requires_grad=False)
+ self.old_crd = Parameter(Tensor(np.zeros([self.atom_numbers, 3], dtype=np.float32), mstype.float32),
+ requires_grad=False)
+ self.last_crd = Parameter(Tensor(np.zeros([self.atom_numbers, 3], dtype=np.float32), mstype.float32),
+ requires_grad=False)
+ self.uint_crd = Parameter(Tensor(np.zeros([self.atom_numbers, 3], dtype=np.uint32), mstype.uint32),
+ requires_grad=False)
+ self.mass_inverse = Tensor(self.md_info.h_mass_inverse, mstype.float32)
+ self.res_start = Tensor(self.md_info.h_res_start, mstype.int32)
+ self.res_end = Tensor(self.md_info.h_res_end, mstype.int32)
+ self.mass = Tensor(self.md_info.h_mass, mstype.float32)
+ self.velocity = Parameter(Tensor(self.md_info.velocity, mstype.float32), requires_grad=False)
+ self.acc = Parameter(Tensor(np.zeros([self.atom_numbers, 3], np.float32), mstype.float32), requires_grad=False)
+ self.grid_n = Tensor(self.nb_info.grid_n, mstype.int32)
+ self.grid_length_inverse = Tensor(self.nb_info.grid_length_inverse, mstype.float32)
+ self.bucket = Parameter(Tensor(
+ np.asarray(self.nb_info.bucket, np.int32).reshape([self.grid_numbers, self.max_atom_in_grid_numbers]),
+ mstype.int32), requires_grad=False)
+ self.atom_numbers_in_grid_bucket = Parameter(Tensor(self.nb_info.atom_numbers_in_grid_bucket, mstype.int32),
+ requires_grad=False)
+ self.atom_in_grid_serial = Parameter(Tensor(np.zeros([self.nb_info.atom_numbers,], np.int32), mstype.int32),
+ requires_grad=False)
+ self.pointer = Parameter(
+ Tensor(np.asarray(self.nb_info.pointer, np.int32).reshape([self.grid_numbers, 125]), mstype.int32),
+ requires_grad=False)
+ self.nl_atom_numbers = Parameter(Tensor(np.zeros([self.atom_numbers,], np.int32), mstype.int32),
+ requires_grad=False)
+ self.nl_atom_serial = Parameter(
+ Tensor(np.zeros([self.atom_numbers, self.max_neighbor_numbers], np.int32), mstype.int32),
+ requires_grad=False)
+
+ self.excluded_list_start = Tensor(np.asarray(self.nb_info.excluded_list_start, np.int32), mstype.int32)
+ self.excluded_list = Tensor(np.asarray(self.nb_info.excluded_list, np.int32), mstype.int32)
+ self.excluded_numbers = Tensor(np.asarray(self.nb_info.excluded_numbers, np.int32), mstype.int32)
+ self.need_refresh_flag = Tensor(np.asarray([0], np.int32), mstype.int32)
+ self.sqrt_mass = Tensor(self.liujian_info.h_sqrt_mass, mstype.float32)
+ self.rand_state = Parameter(Tensor(self.liujian_info.rand_state, mstype.float32))
+ self.zero_fp_tensor = Tensor(np.asarray([0,], np.float32))
+
+ def op_define(self):
+ '''op define'''
+ self.mdtemp = P.MDTemperature(self.residue_numbers, self.atom_numbers)
+ self.setup_random_state = P.MDIterationSetupRandState(self.atom_numbers, self.random_seed)
+
+ self.md_iteration_leap_frog_liujian = P.MDIterationLeapFrogLiujian(self.atom_numbers, self.half_dt, self.dt,
+ self.exp_gamma)
+
+ self.neighbor_list_update_init = P.NeighborListUpdate(grid_numbers=self.grid_numbers,
+ atom_numbers=self.atom_numbers, not_first_time=0,
+ nxy=self.nxy,
+ excluded_atom_numbers=self.excluded_atom_numbers,
+ cutoff_square=self.cutoff_square,
+ half_skin_square=self.half_skin_square,
+ cutoff_with_skin=self.cutoff_with_skin,
+ half_cutoff_with_skin=self.half_cutoff_with_skin,
+ cutoff_with_skin_square=self.cutoff_with_skin_square,
+ refresh_interval=self.refresh_interval,
+ cutoff=self.cutoff, skin=self.skin,
+ max_atom_in_grid_numbers=self.max_atom_in_grid_numbers,
+ max_neighbor_numbers=self.max_neighbor_numbers)
+
+ self.random_force = Tensor(np.zeros([self.atom_numbers, 3], np.float32), mstype.float32)
+
+ def update_hills(self, index, value):
+ hills = ops.TensorScatterAdd()(self.hills,
+ F.expand_dims(index, -1),
+ F.expand_dims(self.cast(value, mstype.float32), -1))
+ return hills
+
+ def simulation_caculate_cybertron_force(self, positions, step, atom_types=None):
+ """simulation_caculate_cybertron_force"""
+ forces = -1 * self.grad(self.network)(positions,
+ atom_types,
+ None,
+ None)
+ cv = self.norm(self.add(self.last_crd[11], -self.last_crd[14]))
+ cv_index = self.cast((cv - self.smin) / self.ds, mstype.int32)
+ cv_index = cv_index * (cv_index >= 0)
+ cv_index = cv_index * (cv_index < self.grid_num) + (self.grid_num - 1) * (cv_index >= self.grid_num)
+ self.hills = self.update_hills(cv_index, step % self.meta_interval == 0)
+ bias_cell = self.bias(self.hills,
+ smin=self.smin,
+ smax=self.smax,
+ ds=self.ds,
+ omega=self.omega,
+ sigma=self.sigma,
+ dt=self.ddt,
+ t=self.t,
+ alpha=self.alpha,
+ gamma=self.gamma,
+ wall_potential=self.wall_potential,
+ kappa=self.kappa,
+ upper_bound=self.upper_bound_index,
+ lower_bound=self.lower_bound_index,
+ factor=self.wall_factor)
+ entropy_force = self.grad(bias_cell)(self.last_crd)
+ tforces = P.AddN()([self.squeeze(forces), -self.meta * entropy_force])
+ return tforces
+
+ def simulation_caculate_cybertron_energy(self, positions, atom_types=None):
+ energy = self.network(positions, atom_types, None, None)
+ energy = self.squeeze(energy)
+ return energy
+
+ def simulation_temperature(self):
+ '''caculate temperature'''
+ res_ek_energy = self.mdtemp(self.res_start, self.res_end, self.velocity, self.mass)
+ temperature = P.ReduceSum()(res_ek_energy)
+ return temperature
+
+ def simulation_mditeration_leapfrog_liujian(self, inverse_mass, sqrt_mass_inverse, crd, frc, rand_state,
+ random_frc):
+ '''simulation leap frog iteration liujian'''
+ crd = self.md_iteration_leap_frog_liujian(inverse_mass, sqrt_mass_inverse, self.velocity, crd, frc, self.acc,
+ rand_state, random_frc)
+
+ vel = F.depend(self.velocity, crd)
+ vel = (self.hsigmoid(vel * 3 / self.max_vel) - 0.5) * 2 * self.max_vel
+ acc = F.depend(self.acc, crd)
+ return vel, crd, acc
+
+ def main_print(self, *args):
+ """compute the temperature"""
+ _, temperature, total_potential_energy, _, _, _, _, _, _, _ = list(args)
+
+ temperature = temperature.asnumpy()
+ total_potential_energy = total_potential_energy.asnumpy()
+ cv = self.norm(self.add(self.last_crd[11], -self.last_crd[14]))
+ biasp = self.sum(
+ self.dt * self.hills * self.omega * self.exp(-self.square(cv - self.cv_list) / 2 / self.square(self.sigma)))
+ return cv, biasp
+
+ def main_initial(self):
+ """main initial"""
+ if self.control.mdout:
+ self.file = open(self.control.mdout, 'w')
+ self.file.write("_steps_ _TEMP_ _TOT_POT_ENE_ _CVariable_ _Bias_Potential_\n")
+ if self.control.mdcrd:
+ self.datfile = open(self.control.mdcrd, 'wb')
+
+ def main_destroy(self):
+ """main destroy"""
+ if self.file is not None:
+ self.file.close()
+ print("Save .out file successfully!")
+ if self.datfile is not None:
+ self.datfile.close()
+ print("Save .dat file successfully!")
+
+ def construct(self, step, print_step):
+ '''construct'''
+ self.last_crd = self.crd
+ if step == 0:
+ res = self.neighbor_list_update_init(self.atom_numbers_in_grid_bucket, self.bucket, self.crd,
+ self.virtual_box_length, self.grid_n, self.grid_length_inverse,
+ self.atom_in_grid_serial, self.old_crd, self.crd_to_uint_crd_cof,
+ self.uint_crd, self.pointer, self.nl_atom_numbers, self.nl_atom_serial,
+ self.uint_dr_to_dr_cof, self.excluded_list_start, self.excluded_list,
+ self.excluded_numbers, self.need_refresh_flag, self.refresh_count)
+ self.nl_atom_numbers = F.depend(self.nl_atom_numbers, res)
+ self.nl_atom_serial = F.depend(self.nl_atom_serial, res)
+ self.uint_dr_to_dr_cof = F.depend(self.uint_dr_to_dr_cof, res)
+ self.old_crd = F.depend(self.old_crd, res)
+ self.atom_numbers_in_grid_bucket = F.depend(self.atom_numbers_in_grid_bucket, res)
+ self.bucket = F.depend(self.bucket, res)
+ self.atom_in_grid_serial = F.depend(self.atom_in_grid_serial, res)
+ self.pointer = F.depend(self.pointer, res)
+
+ positions = F.expand_dims(self.crd, 0)
+ force = self.simulation_caculate_cybertron_force(positions, step)
+ bond_energy_sum = self.zero_fp_tensor
+ angle_energy_sum = self.zero_fp_tensor
+ dihedral_energy_sum = self.zero_fp_tensor
+ nb14_lj_energy_sum = self.zero_fp_tensor
+ nb14_cf_energy_sum = self.zero_fp_tensor
+ lj_energy_sum = self.zero_fp_tensor
+ ee_ene = self.zero_fp_tensor
+ total_energy = self.simulation_caculate_cybertron_energy(positions)
+
+ temperature = self.simulation_temperature()
+ self.rand_state = self.setup_random_state()
+ self.velocity, self.crd, _ = self.simulation_mditeration_leapfrog_liujian(self.mass_inverse,
+ self.sqrt_mass, self.crd, force,
+ self.rand_state,
+ self.random_force)
+
+ res = self.ds
+ self.nl_atom_numbers = F.depend(self.nl_atom_numbers, res)
+ self.nl_atom_serial = F.depend(self.nl_atom_serial, res)
+ else:
+
+ positions = F.expand_dims(self.crd, 0)
+ force = self.simulation_caculate_cybertron_force(positions, step)
+ if print_step == 0:
+ bond_energy_sum = self.zero_fp_tensor
+ angle_energy_sum = self.zero_fp_tensor
+ dihedral_energy_sum = self.zero_fp_tensor
+ nb14_lj_energy_sum = self.zero_fp_tensor
+ nb14_cf_energy_sum = self.zero_fp_tensor
+ lj_energy_sum = self.zero_fp_tensor
+ ee_ene = self.zero_fp_tensor
+ total_energy = self.simulation_caculate_cybertron_energy(positions)
+ else:
+ bond_energy_sum = self.zero_fp_tensor
+ angle_energy_sum = self.zero_fp_tensor
+ dihedral_energy_sum = self.zero_fp_tensor
+ nb14_lj_energy_sum = self.zero_fp_tensor
+ nb14_cf_energy_sum = self.zero_fp_tensor
+ lj_energy_sum = self.zero_fp_tensor
+ ee_ene = self.zero_fp_tensor
+ total_energy = self.zero_fp_tensor
+ temperature = self.simulation_temperature()
+ self.velocity, self.crd, _ = self.simulation_mditeration_leapfrog_liujian(self.mass_inverse,
+ self.sqrt_mass, self.crd, force,
+ self.rand_state,
+ self.random_force)
+
+ res = self.ds
+ self.nl_atom_numbers = F.depend(self.nl_atom_numbers, res)
+ self.nl_atom_serial = F.depend(self.nl_atom_serial, res)
+ return temperature, total_energy, bond_energy_sum, angle_energy_sum, dihedral_energy_sum, nb14_lj_energy_sum, \
+ nb14_cf_energy_sum, lj_energy_sum, ee_ene, res
+
+class ArgsOpt():
+ """ArgsOpt"""
+ def __init__(self):
+ self.amber_parm = '/home/workspace/mindspore_dataset/mindsponge_data/ai/cba.prmtop'
+ self.box = ''
+ self.c = '/home/workspace/mindspore_dataset/mindsponge_data/ai/cba_its_mw0_trans.rst7'
+ self.checkpoint = ''
+ self.device_id = 0
+ self.i = '/home/workspace/mindspore_dataset/mindsponge_data/ai/md.in'
+ self.o = ''
+ self.r = ''
+ self.u = False
+ self.x = ''
+ self.meta = 0
+ self.with_box = 1
+ self.np_iter = 0
+
+
+@pytest.mark.level1
+@pytest.mark.platform_x86_gpu_training
+@pytest.mark.env_onecard
+def test_case_mct():
+ """test_case_mct for test"""
+ args_opt = ArgsOpt()
+ args_opt.initial_coordinates_file = args_opt.c
+ context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
+ atom_types = Tensor([6, 1, 6, 1, 6, 1, 1, 6, 1, 6, 1, 6, 1, 1, 8])
+ mod = MolCT(
+ min_rbf_dis=0.1,
+ max_rbf_dis=10,
+ num_rbf=128,
+ rbf_sigma=0.2,
+ n_interactions=3,
+ dim_feature=128,
+ n_heads=8,
+ max_cycles=1,
+ use_time_embedding=True,
+ fixed_cycles=True,
+ self_dis=0.1,
+ unit_length='A',
+ use_feed_forward=False,
+ )
+ scales = 3.0
+ readout = AtomwiseReadout(n_in=mod.dim_feature, n_interactions=mod.n_interactions, activation=mod.activation,
+ n_out=1, mol_scale=scales, unit_energy='kcal/mol')
+ net = Cybertron(mod, atom_types=atom_types, full_connect=True, readout=readout, unit_dis='A',
+ unit_energy='kcal/mol')
+
+ param_file = '/home/workspace/mindspore_dataset/mindsponge_data/ai/cba_kcal_mol_A_MolCT-best.ckpt'
+ load_checkpoint(param_file, net=net)
+
+ simulation = SimulationCybertron(args_opt, network=net)
+ compiler_time = 0
+ simulation.main_initial()
+ for steps in range(simulation.md_info.step_limit):
+ print_step = steps % simulation.ntwx
+ if steps == simulation.md_info.step_limit - 1:
+ print_step = 0
+ temperature, total_potential_energy, sigma_of_bond_ene, sigma_of_angle_ene, _, \
+ _, nb14_cf_energy_sum, lj_energy_sum, ee_ene, _ = simulation(Tensor(steps), Tensor(print_step))
+
+ if steps == 0:
+ compiler_time = time.time()
+ cv, biasp = simulation.main_print(steps, temperature, total_potential_energy, sigma_of_bond_ene,
+ sigma_of_angle_ene, Tensor(0), Tensor(0), nb14_cf_energy_sum,
+ lj_energy_sum, ee_ene)
+ assert np.allclose(round(float(temperature.asnumpy()), 3), 0.000, rtol=0.1)
+ assert np.allclose(round(float(total_potential_energy.asnumpy()), 3), 464.834, rtol=0.1)
+ assert np.allclose(round(float(cv.asnumpy()), 3), 1.449, rtol=0.1)
+ assert np.allclose(round(float(biasp.asnumpy()), 3), 0.222, rtol=0.1)
+ end = time.time()
+ assert ((end - compiler_time) / 9) < 0.5
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindsponge/test_polypeptide/__init__.py b/reproduce/AlphaFold2-Chinese/tests/st/mindsponge/test_polypeptide/__init__.py
new file mode 100644
index 0000000..6228b71
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindsponge/test_polypeptide/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindsponge/test_polypeptide/test_case_polypeptide.py b/reproduce/AlphaFold2-Chinese/tests/st/mindsponge/test_polypeptide/test_case_polypeptide.py
new file mode 100644
index 0000000..f35d8c4
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindsponge/test_polypeptide/test_case_polypeptide.py
@@ -0,0 +1,63 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Test case polypeptide"""
+
+import time
+import numpy as np
+import pytest
+from mindspore import context, Tensor
+from mindsponge.md.simulation import Simulation
+
+class ArgsOpt():
+ """ArgsOpt"""
+ def __init__(self):
+ self.amber_parm = '/home/workspace/mindspore_dataset/mindsponge_data/ala/WATER_ALA.parm7'
+ self.box = ''
+ self.c = '/home/workspace/mindspore_dataset/mindsponge_data/ala/WATER_ALA_350_cool_290.rst7'
+ self.checkpoint = ''
+ self.device_id = 0
+ self.i = '/home/workspace/mindspore_dataset/mindsponge_data/ala/NVT_290_10ns.in'
+ self.o = ''
+ self.r = ''
+ self.u = False
+ self.x = ''
+
+@pytest.mark.level0
+@pytest.mark.platform_x86_gpu_training
+@pytest.mark.env_onecard
+def test_case_poly():
+ """test_case_poly"""
+ context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
+ args_opt = ArgsOpt()
+ simulation = Simulation(args_opt)
+ for steps in range(10):
+ print_step = steps % simulation.ntwx
+ temperature, total_potential_energy, sigma_of_bond_ene, sigma_of_angle_ene, sigma_of_dihedral_ene, \
+ nb14_lj_energy_sum, nb14_cf_energy_sum, lj_energy_sum, ee_ene, _, _, _, _ = \
+ simulation(Tensor(steps), Tensor(print_step))
+ if steps == 0:
+ start = time.time()
+ assert np.allclose(round(float(temperature.asnumpy()), 3), 0.788, rtol=0.1)
+ assert np.allclose(round(float(total_potential_energy.asnumpy()), 3), -5836.541, rtol=0.1)
+ assert np.allclose(round(float(sigma_of_bond_ene.asnumpy()), 3), 48.745, rtol=0.1)
+ assert np.allclose(round(float(sigma_of_angle_ene.asnumpy()), 3), 0.891, rtol=0.1)
+ assert np.allclose(round(float(sigma_of_dihedral_ene.asnumpy()), 3), 14.904, rtol=0.1)
+ assert np.allclose(round(float(nb14_lj_energy_sum.asnumpy()), 3), 9.041, rtol=0.1)
+ assert np.allclose(round(float(nb14_cf_energy_sum.asnumpy()), 3), 194.479, rtol=0.1)
+ assert np.allclose(round(float(lj_energy_sum.asnumpy()), 3), 763.169, rtol=0.1)
+ assert np.allclose(round(float(ee_ene.asnumpy()), 3), -6867.770, rtol=0.1)
+ end = time.time()
+
+ assert ((end - start) / 9) < 0.007
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindsponge/test_polypeptide_bond/simulation_poly_bond.py b/reproduce/AlphaFold2-Chinese/tests/st/mindsponge/test_polypeptide_bond/simulation_poly_bond.py
new file mode 100644
index 0000000..9978693
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindsponge/test_polypeptide_bond/simulation_poly_bond.py
@@ -0,0 +1,115 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Simulation"""
+import numpy as np
+
+import mindspore.common.dtype as mstype
+from mindspore import Tensor
+from mindspore import nn
+from mindspore.common.parameter import Parameter
+from mindspore.ops import operations as P
+from mindsponge import Bond
+from mindsponge import MdInformation
+
+
+class Controller:
+ """Controller"""
+
+ def __init__(self, args_opt):
+ self.input_file = args_opt['i']
+ self.initial_coordinates_file = args_opt['c']
+ self.amber_parm = args_opt['amber_parm']
+ self.restrt = args_opt['r']
+ self.mdcrd = args_opt['x']
+ self.mdout = args_opt['o']
+ self.mdbox = args_opt['box']
+
+ self.command_set = {}
+ self.md_task = None
+ self.commands_from_in_file()
+ self.punctuation = ","
+
+ def commands_from_in_file(self):
+ """command from in file"""
+ file = open(self.input_file, 'r')
+ context = file.readlines()
+ file.close()
+ self.md_task = context[0].strip()
+ for val in context:
+ val = val.strip()
+ if val and val[0] != '#' and ("=" in val):
+ val = val[:val.index(",")] if ',' in val else val
+ assert len(val.strip().split("=")) == 2
+ flag, value = val.strip().split("=")
+ value = value.replace(" ", "")
+ flag = flag.replace(" ", "")
+ if flag not in self.command_set:
+ self.command_set[flag] = value
+ else:
+ print("ERROR COMMAND FILE")
+
+
+class Simulation(nn.Cell):
+ """simulation"""
+
+ def __init__(self, args_opt):
+ super(Simulation, self).__init__()
+ self.control = Controller(args_opt)
+ self.md_info = MdInformation(self.control)
+ self.mode = self.md_info.mode
+ self.bond = Bond(self.control)
+ self.atom_numbers = self.md_info.atom_numbers
+ self.residue_numbers = self.md_info.residue_numbers
+ self.bond_numbers = self.bond.bond_numbers
+ self.init_tensor()
+ self.op_define()
+
+ def init_tensor(self):
+ """init tensor"""
+ self.crd = Parameter(
+ Tensor(np.array(self.md_info.coordinate).reshape([self.atom_numbers, 3]), mstype.float32),
+ requires_grad=False)
+ self.crd_to_uint_crd_cof = Tensor(np.asarray(self.md_info.pbc.crd_to_uint_crd_cof, np.float32), mstype.float32)
+ self.uint_dr_to_dr_cof = Parameter(Tensor(self.md_info.pbc.uint_dr_to_dr_cof, mstype.float32),
+ requires_grad=False)
+ self.bond_atom_a = Tensor(np.asarray(self.bond.h_atom_a, np.int32), mstype.int32)
+ self.bond_atom_b = Tensor(np.asarray(self.bond.h_atom_b, np.int32), mstype.int32)
+ self.bond_k = Tensor(np.asarray(self.bond.h_k, np.float32), mstype.float32)
+ self.bond_r0 = Tensor(np.asarray(self.bond.h_r0, np.float32), mstype.float32)
+
+ def op_define(self):
+ """op define"""
+ self.crd_to_uint_crd = P.CrdToUintCrd(self.atom_numbers)
+ self.bond_energy = P.BondEnergy(self.bond_numbers, self.atom_numbers)
+
+ def simulation_beforce_caculate_force(self):
+ """simulation before calculate force"""
+ crd_to_uint_crd_cof = 0.5 * self.crd_to_uint_crd_cof
+ uint_crd = self.crd_to_uint_crd(crd_to_uint_crd_cof, self.crd)
+ return uint_crd
+
+ def simulation_caculate_energy(self, uint_crd, uint_dr_to_dr_cof):
+ """simulation calculate energy"""
+ bond_energy = self.bond_energy(uint_crd, uint_dr_to_dr_cof, self.bond_atom_a, self.bond_atom_b, self.bond_k,
+ self.bond_r0)
+ bond_energy_sum = P.ReduceSum(True)(bond_energy)
+
+ return bond_energy_sum
+
+ def construct(self):
+ """construct"""
+ uint_crd = self.simulation_beforce_caculate_force()
+ bond_energy_sum = self.simulation_caculate_energy(uint_crd, self.uint_dr_to_dr_cof)
+ return bond_energy_sum
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/mindsponge/test_polypeptide_bond/test_case_bond.py b/reproduce/AlphaFold2-Chinese/tests/st/mindsponge/test_polypeptide_bond/test_case_bond.py
new file mode 100644
index 0000000..1c0315e
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/mindsponge/test_polypeptide_bond/test_case_bond.py
@@ -0,0 +1,37 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Test case polypeptide"""
+
+import numpy as np
+import pytest
+from mindspore import context
+from simulation_poly_bond import Simulation
+
+context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=0, save_graphs=False)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_x86_gpu_training
+@pytest.mark.env_onecard
+def test_case_poly():
+ """test_case_poly"""
+ args_opt = {'amber_parm': '/home/workspace/mindspore_dataset/polypeptide/ala.parm7', 'box': 'mdbox',
+ 'c': '/home/workspace/mindspore_dataset/polypeptide/ala.rst7', 'checkpoint': '',
+ 'device_id': 0, 'i': '/home/workspace/mindspore_dataset/polypeptide/nvt.in',
+ 'o': '', 'r': 'restrt', 'u': False, 'x': ''}
+ simulation = Simulation(args_opt)
+ sigma_of_bond_ene = simulation()
+
+ assert np.allclose(round(float(sigma_of_bond_ene.asnumpy()), 3), 0.037)
diff --git a/reproduce/AlphaFold2-Chinese/tests/st/runtest.sh b/reproduce/AlphaFold2-Chinese/tests/st/runtest.sh
new file mode 100644
index 0000000..6c2eb23
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/st/runtest.sh
@@ -0,0 +1,43 @@
+#!/bin/bash
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -e
+
+SCRIPT_BASEDIR=$(realpath "$(dirname "$0")")
+
+PROJECT_DIR=$(realpath "$SCRIPT_BASEDIR/../../")
+ST_PATH="$PROJECT_DIR/tests/st"
+
+if [ $# -gt 0 ]; then
+ if [ $1 == "mindelec" ]; then
+ echo "Run st mindelec."
+ cd "$PROJECT_DIR" || exit
+ ST_PATH="$PROJECT_DIR/tests/st/mindelec/"
+ pytest "$ST_PATH"
+ echo "Test all mindelec use cases success."
+ elif [ $1 == "mindsponge" ]; then
+ echo "Run st mindsponge."
+ cd "$PROJECT_DIR" || exit
+ ST_PATH="$PROJECT_DIR/tests/st/mindsponge/"
+ pytest "$ST_PATH"
+ echo "Test all mindsponge use cases success."
+ fi
+else
+ echo "Run all st."
+ cd "$PROJECT_DIR" || exit
+ ST_PATH="$PROJECT_DIR/tests/st/"
+ pytest "$ST_PATH"
+ echo "Test all use cases success."
+ fi
diff --git a/reproduce/AlphaFold2-Chinese/tests/ut/__init__.py b/reproduce/AlphaFold2-Chinese/tests/ut/__init__.py
new file mode 100644
index 0000000..6228b71
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/ut/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
diff --git a/reproduce/AlphaFold2-Chinese/tests/ut/mindelec/__init__.py b/reproduce/AlphaFold2-Chinese/tests/ut/mindelec/__init__.py
new file mode 100644
index 0000000..b3a552b
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/ut/mindelec/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""init"""
diff --git a/reproduce/AlphaFold2-Chinese/tests/ut/mindelec/test_mindelec.py b/reproduce/AlphaFold2-Chinese/tests/ut/mindelec/test_mindelec.py
new file mode 100644
index 0000000..8148033
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/ut/mindelec/test_mindelec.py
@@ -0,0 +1,19 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Test mindelec."""
+
+
+def test_empty():
+ assert 1 < 2
diff --git a/reproduce/AlphaFold2-Chinese/tests/ut/mindsponge/__init__.py b/reproduce/AlphaFold2-Chinese/tests/ut/mindsponge/__init__.py
new file mode 100644
index 0000000..6228b71
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/ut/mindsponge/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
diff --git a/reproduce/AlphaFold2-Chinese/tests/ut/mindsponge/test_mindsponge.py b/reproduce/AlphaFold2-Chinese/tests/ut/mindsponge/test_mindsponge.py
new file mode 100644
index 0000000..77080c1
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/ut/mindsponge/test_mindsponge.py
@@ -0,0 +1,20 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Test mindsponge."""
+
+#test
+
+def test_empty():
+ assert 1 < 2
diff --git a/reproduce/AlphaFold2-Chinese/tests/ut/runtest.sh b/reproduce/AlphaFold2-Chinese/tests/ut/runtest.sh
new file mode 100644
index 0000000..9894c38
--- /dev/null
+++ b/reproduce/AlphaFold2-Chinese/tests/ut/runtest.sh
@@ -0,0 +1,44 @@
+#!/bin/bash
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+set -e
+
+SCRIPT_BASEDIR=$(realpath "$(dirname "$0")")
+PROJECT_DIR=$(realpath "$SCRIPT_BASEDIR/../../")
+
+if [ $# -gt 0 ]; then
+ if [ $1 == "mindelec" ]; then
+ export PYTHONPATH=$PYTHONPATH:${PROJECT_DIR}/MindElec/
+ echo "export PYTHONPATH=$PYTHONPATH"
+ echo "Run ut mindelec."
+ cd "$PROJECT_DIR" || exit
+ UT_PATH="$PROJECT_DIR/tests/ut/mindelec/"
+ pytest "$UT_PATH"
+ echo "Test all mindelec use cases success."
+ elif [ $1 == "mindsponge" ]; then
+ echo "Run ut mindsponge."
+ cd "$PROJECT_DIR" || exit
+ UT_PATH="$PROJECT_DIR/tests/ut/mindsponge/"
+ pytest "$UT_PATH"
+ echo "Test all mindsponge use cases success."
+ fi
+else
+ export PYTHONPATH=$PYTHONPATH:${PROJECT_DIR}/MindElec/
+ echo "export PYTHONPATH=$PYTHONPATH"
+ echo "Run all ut."
+ cd "$PROJECT_DIR" || exit
+ UT_PATH="$PROJECT_DIR/tests/ut/"
+ pytest "$UT_PATH"
+ echo "Test all use cases success."
+ fi
\ No newline at end of file