@@ -0,0 +1,21 @@ | |||
/tmp/ | |||
*.coverage | |||
/.idea | |||
/.vscode | |||
.vscode | |||
__pycache__/ | |||
*.pyc | |||
*.so | |||
*.so.* | |||
*.o | |||
*.out | |||
*.gch | |||
build | |||
*.egg-info | |||
dist | |||
version.py | |||
local_script/ | |||
output | |||
.ipynb_checkpoints | |||
somas_meta | |||
analyze_fail.dat |
@@ -0,0 +1,479 @@ | |||
{ | |||
"nbformat": 4, | |||
"nbformat_minor": 0, | |||
"metadata": { | |||
"accelerator": "GPU", | |||
"colab": { | |||
"name": "AlphaFold2中文版蛋白质预测模型使用指南(基于Deepmind与Mindspore开源框架).ipynb", | |||
"provenance": [], | |||
"collapsed_sections": [], | |||
"include_colab_link": true | |||
}, | |||
"kernelspec": { | |||
"display_name": "Python 3 (ipykernel)", | |||
"language": "python", | |||
"name": "python3" | |||
}, | |||
"language_info": { | |||
"codemirror_mode": { | |||
"name": "ipython", | |||
"version": 3 | |||
}, | |||
"file_extension": ".py", | |||
"mimetype": "text/x-python", | |||
"name": "python", | |||
"nbconvert_exporter": "python", | |||
"pygments_lexer": "ipython3", | |||
"version": "3.8.10" | |||
} | |||
}, | |||
"cells": [ | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": { | |||
"id": "view-in-github", | |||
"colab_type": "text" | |||
}, | |||
"source": [ | |||
"<a href=\"https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/AlphaFold2.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": { | |||
"id": "G4yBrceuFbf3" | |||
}, | |||
"source": [ | |||
"<img src=\"https://raw.githubusercontent.com/sokrypton/ColabFold/main/.github/ColabFold_Marv_Logo_Small.png\" height=\"200\" align=\"right\" style=\"height:240px\">\n", | |||
"\n", | |||
"##AlphaFold2_CN: AlphaFold2 with MMseqs2\n", | |||
"\n", | |||
"简单的中文版蛋白质结构预测操作指南(中文版),基于[AlphaFold2](https://www.nature.com/articles/s41586-021-03819-2)和[Alphafold2-multimer](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1). 序列比对方式基于[MMseqs2](mmseqs.com)和[HHsearch](https://github.com/soedinglab/hh-suite)." | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"metadata": { | |||
"id": "kOblAo-xetgx", | |||
"cellView": "form" | |||
}, | |||
"source": [ | |||
"#@title 输入蛋白质序列(默认为视频中的测试序列)\n", | |||
"from google.colab import files\n", | |||
"import os.path\n", | |||
"import re\n", | |||
"import hashlib\n", | |||
"import random\n", | |||
"\n", | |||
"def add_hash(x,y):\n", | |||
" return x+\"_\"+hashlib.sha1(y.encode()).hexdigest()[:5]\n", | |||
"\n", | |||
"query_sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n", | |||
"#@markdown - Use `:` to specify inter-protein chainbreaks for **modeling complexes** (supports homo- and hetro-oligomers). For example **PI...SK:PI...SK** for a homodimer\n", | |||
"\n", | |||
"# remove whitespaces\n", | |||
"query_sequence = \"\".join(query_sequence.split())\n", | |||
"\n", | |||
"jobname = 'test' #@param {type:\"string\"}\n", | |||
"# remove whitespaces\n", | |||
"basejobname = \"\".join(jobname.split())\n", | |||
"basejobname = re.sub(r'\\W+', '', basejobname)\n", | |||
"jobname = add_hash(basejobname, query_sequence)\n", | |||
"while os.path.isfile(f\"{jobname}.csv\"):\n", | |||
" jobname = add_hash(basejobname, ''.join(random.sample(query_sequence,len(query_sequence))))\n", | |||
"\n", | |||
"with open(f\"{jobname}.csv\", \"w\") as text_file:\n", | |||
" text_file.write(f\"id,sequence\\n{jobname},{query_sequence}\")\n", | |||
"\n", | |||
"queries_path=f\"{jobname}.csv\"\n", | |||
"\n", | |||
"# number of models to use\n", | |||
"use_amber = False #@param {type:\"boolean\"}\n", | |||
"template_mode = \"none\" #@param [\"none\", \"pdb70\",\"custom\"]\n", | |||
"#@markdown - \"none\" = no template information is used, \"pdb70\" = detect templates in pdb70, \"custom\" - upload and search own templates (PDB or mmCIF format, see [notes below](#custom_templates))\n", | |||
"\n", | |||
"if template_mode == \"pdb70\":\n", | |||
" use_templates = True\n", | |||
" custom_template_path = None\n", | |||
"elif template_mode == \"custom\":\n", | |||
" custom_template_path = f\"{jobname}_template\"\n", | |||
" os.mkdir(custom_template_path)\n", | |||
" uploaded = files.upload()\n", | |||
" use_templates = True\n", | |||
" for fn in uploaded.keys():\n", | |||
" os.rename(fn, f\"{jobname}_template/{fn}\")\n", | |||
"else:\n", | |||
" custom_template_path = None\n", | |||
" use_templates = False\n" | |||
], | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"source": [ | |||
"#@markdown ### MSA选项(custom MSA upload, single sequence, pairing mode)\n", | |||
"msa_mode = \"MMseqs2 (UniRef+Environmental)\" #@param [\"MMseqs2 (UniRef+Environmental)\", \"MMseqs2 (UniRef only)\",\"single_sequence\",\"custom\"]\n", | |||
"pair_mode = \"unpaired+paired\" #@param [\"unpaired+paired\",\"paired\",\"unpaired\"] {type:\"string\"}\n", | |||
"#@markdown - \"unpaired+paired\" = pair sequences from same species + unpaired MSA, \"unpaired\" = seperate MSA for each chain, \"paired\" - only use paired sequences.\n", | |||
"\n", | |||
"# decide which a3m to use\n", | |||
"if msa_mode.startswith(\"MMseqs2\"):\n", | |||
" a3m_file = f\"{jobname}.a3m\"\n", | |||
"elif msa_mode == \"custom\":\n", | |||
" a3m_file = f\"{jobname}.custom.a3m\"\n", | |||
" if not os.path.isfile(a3m_file):\n", | |||
" custom_msa_dict = files.upload()\n", | |||
" custom_msa = list(custom_msa_dict.keys())[0]\n", | |||
" header = 0\n", | |||
" import fileinput\n", | |||
" for line in fileinput.FileInput(custom_msa,inplace=1):\n", | |||
" if line.startswith(\">\"):\n", | |||
" header = header + 1\n", | |||
" if not line.rstrip():\n", | |||
" continue\n", | |||
" if line.startswith(\">\") == False and header == 1:\n", | |||
" query_sequence = line.rstrip()\n", | |||
" print(line, end='')\n", | |||
"\n", | |||
" os.rename(custom_msa, a3m_file)\n", | |||
" queries_path=a3m_file\n", | |||
" print(f\"moving {custom_msa} to {a3m_file}\")\n", | |||
"else:\n", | |||
" a3m_file = f\"{jobname}.single_sequence.a3m\"\n", | |||
" with open(a3m_file, \"w\") as text_file:\n", | |||
" text_file.write(\">1\\n%s\" % query_sequence)" | |||
], | |||
"metadata": { | |||
"cellView": "form", | |||
"id": "C2_sh2uAonJH" | |||
}, | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"source": [ | |||
"#@markdown ### 参数设置\n", | |||
"model_type = \"auto\" #@param [\"auto\", \"AlphaFold2-ptm\", \"AlphaFold2-multimer-v1\", \"AlphaFold2-multimer-v2\"]\n", | |||
"#@markdown - \"auto\" = protein structure prediction using \"AlphaFold2-ptm\" and complex prediction \"AlphaFold-multimer-v2\". For complexes \"AlphaFold-multimer-v[1,2]\" and \"AlphaFold-ptm\" can be used.\n", | |||
"num_recycles = 3 #@param [1,3,6,12,24,48] {type:\"raw\"}\n", | |||
"save_to_google_drive = False #@param {type:\"boolean\"}\n", | |||
"\n", | |||
"#@markdown - if the save_to_google_drive option was selected, the result zip will be uploaded to your Google Drive\n", | |||
"dpi = 200 #@param {type:\"integer\"}\n", | |||
"#@markdown - set dpi for image resolution\n", | |||
"\n", | |||
"#@markdown Don't forget to hit `Runtime` -> `Run all` after updating the form.\n", | |||
"\n", | |||
"\n", | |||
"if save_to_google_drive:\n", | |||
" from pydrive.drive import GoogleDrive\n", | |||
" from pydrive.auth import GoogleAuth\n", | |||
" from google.colab import auth\n", | |||
" from oauth2client.client import GoogleCredentials\n", | |||
" auth.authenticate_user()\n", | |||
" gauth = GoogleAuth()\n", | |||
" gauth.credentials = GoogleCredentials.get_application_default()\n", | |||
" drive = GoogleDrive(gauth)\n", | |||
" print(\"You are logged into Google Drive and are good to go!\")" | |||
], | |||
"metadata": { | |||
"cellView": "form", | |||
"id": "ADDuaolKmjGW" | |||
}, | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"metadata": { | |||
"id": "iccGdbe_Pmt9", | |||
"pycharm": { | |||
"name": "#%%\n" | |||
}, | |||
"cellView": "form" | |||
}, | |||
"source": [ | |||
"#@title 环境安装\n", | |||
"%%bash -s $use_amber $use_templates\n", | |||
"\n", | |||
"set -e\n", | |||
"\n", | |||
"USE_AMBER=$1\n", | |||
"USE_TEMPLATES=$2\n", | |||
"\n", | |||
"if [ ! -f COLABFOLD_READY ]; then\n", | |||
" # install dependencies\n", | |||
" # We have to use \"--no-warn-conflicts\" because colab already has a lot preinstalled with requirements different to ours\n", | |||
" pip install -q --no-warn-conflicts \"colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold\"\n", | |||
" # high risk high gain\n", | |||
" pip install -q \"jax[cuda11_cudnn805]>=0.3.8,<0.4\" -f https://storage.googleapis.com/jax-releases/jax_releases.html\n", | |||
" touch COLABFOLD_READY\n", | |||
"fi\n", | |||
"\n", | |||
"# setup conda\n", | |||
"if [ ${USE_AMBER} == \"True\" ] || [ ${USE_TEMPLATES} == \"True\" ]; then\n", | |||
" if [ ! -f CONDA_READY ]; then\n", | |||
" wget -qnc https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh\n", | |||
" bash Miniconda3-latest-Linux-x86_64.sh -bfp /usr/local 2>&1 1>/dev/null\n", | |||
" rm Miniconda3-latest-Linux-x86_64.sh\n", | |||
" touch CONDA_READY\n", | |||
" fi\n", | |||
"fi\n", | |||
"# setup template search\n", | |||
"if [ ${USE_TEMPLATES} == \"True\" ] && [ ! -f HH_READY ]; then\n", | |||
" conda install -y -q -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 python=3.7 2>&1 1>/dev/null\n", | |||
" touch HH_READY\n", | |||
"fi\n", | |||
"# setup openmm for amber refinement\n", | |||
"if [ ${USE_AMBER} == \"True\" ] && [ ! -f AMBER_READY ]; then\n", | |||
" conda install -y -q -c conda-forge openmm=7.5.1 python=3.7 pdbfixer 2>&1 1>/dev/null\n", | |||
" touch AMBER_READY\n", | |||
"fi" | |||
], | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"metadata": { | |||
"id": "_sztQyz29DIC", | |||
"cellView": "form" | |||
}, | |||
"source": [ | |||
"#@title 模型预测\n", | |||
"\n", | |||
"import sys\n", | |||
"\n", | |||
"from colabfold.download import download_alphafold_params, default_data_dir\n", | |||
"from colabfold.utils import setup_logging\n", | |||
"from colabfold.batch import get_queries, run, set_model_type\n", | |||
"K80_chk = !nvidia-smi | grep \"Tesla K80\" | wc -l\n", | |||
"if \"1\" in K80_chk:\n", | |||
" print(\"WARNING: found GPU Tesla K80: limited to total length < 1000\")\n", | |||
" if \"TF_FORCE_UNIFIED_MEMORY\" in os.environ:\n", | |||
" del os.environ[\"TF_FORCE_UNIFIED_MEMORY\"]\n", | |||
" if \"XLA_PYTHON_CLIENT_MEM_FRACTION\" in os.environ:\n", | |||
" del os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]\n", | |||
"\n", | |||
"from colabfold.colabfold import plot_protein\n", | |||
"from pathlib import Path\n", | |||
"import matplotlib.pyplot as plt\n", | |||
"\n", | |||
"\n", | |||
"# For some reason we need that to get pdbfixer to import\n", | |||
"if use_amber and '/usr/local/lib/python3.7/site-packages/' not in sys.path:\n", | |||
" sys.path.insert(0, '/usr/local/lib/python3.7/site-packages/')\n", | |||
"\n", | |||
"def prediction_callback(unrelaxed_protein, length, prediction_result, input_features, type):\n", | |||
" fig = plot_protein(unrelaxed_protein, Ls=length, dpi=150)\n", | |||
" plt.show()\n", | |||
" plt.close()\n", | |||
"\n", | |||
"result_dir=\".\"\n", | |||
"setup_logging(Path(\".\").joinpath(\"log.txt\"))\n", | |||
"queries, is_complex = get_queries(queries_path)\n", | |||
"model_type = set_model_type(is_complex, model_type)\n", | |||
"download_alphafold_params(model_type, Path(\".\"))\n", | |||
"run(\n", | |||
" queries=queries,\n", | |||
" result_dir=result_dir,\n", | |||
" use_templates=use_templates,\n", | |||
" custom_template_path=custom_template_path,\n", | |||
" use_amber=use_amber,\n", | |||
" msa_mode=msa_mode, \n", | |||
" model_type=model_type,\n", | |||
" num_models=5,\n", | |||
" num_recycles=num_recycles,\n", | |||
" model_order=[1, 2, 3, 4, 5],\n", | |||
" is_complex=is_complex,\n", | |||
" data_dir=Path(\".\"),\n", | |||
" keep_existing_results=False,\n", | |||
" recompile_padding=1.0,\n", | |||
" rank_by=\"auto\",\n", | |||
" pair_mode=pair_mode,\n", | |||
" stop_at_score=float(100),\n", | |||
" prediction_callback=prediction_callback,\n", | |||
" dpi=dpi\n", | |||
")" | |||
], | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"metadata": { | |||
"id": "KK7X9T44pWb7", | |||
"cellView": "form" | |||
}, | |||
"source": [ | |||
"#@title 展示3维结构 {run: \"auto\"}\n", | |||
"import py3Dmol\n", | |||
"import glob\n", | |||
"import matplotlib.pyplot as plt\n", | |||
"from colabfold.colabfold import plot_plddt_legend\n", | |||
"rank_num = 1 #@param [\"1\", \"2\", \"3\", \"4\", \"5\"] {type:\"raw\"}\n", | |||
"color = \"lDDT\" #@param [\"chain\", \"lDDT\", \"rainbow\"]\n", | |||
"show_sidechains = False #@param {type:\"boolean\"}\n", | |||
"show_mainchains = False #@param {type:\"boolean\"}\n", | |||
"\n", | |||
"jobname_prefix = \".custom\" if msa_mode == \"custom\" else \"\"\n", | |||
"if use_amber:\n", | |||
" pdb_filename = f\"{jobname}{jobname_prefix}_relaxed_rank_{rank_num}_model_*.pdb\"\n", | |||
"else:\n", | |||
" pdb_filename = f\"{jobname}{jobname_prefix}_unrelaxed_rank_{rank_num}_model_*.pdb\"\n", | |||
"\n", | |||
"pdb_file = glob.glob(pdb_filename)\n", | |||
"\n", | |||
"def show_pdb(rank_num=1, show_sidechains=False, show_mainchains=False, color=\"lDDT\"):\n", | |||
" model_name = f\"rank_{rank_num}\"\n", | |||
" view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)\n", | |||
" view.addModel(open(pdb_file[0],'r').read(),'pdb')\n", | |||
"\n", | |||
" if color == \"lDDT\":\n", | |||
" view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}})\n", | |||
" elif color == \"rainbow\":\n", | |||
" view.setStyle({'cartoon': {'color':'spectrum'}})\n", | |||
" elif color == \"chain\":\n", | |||
" chains = len(queries[0][1]) + 1 if is_complex else 1\n", | |||
" for n,chain,color in zip(range(chains),list(\"ABCDEFGH\"),\n", | |||
" [\"lime\",\"cyan\",\"magenta\",\"yellow\",\"salmon\",\"white\",\"blue\",\"orange\"]):\n", | |||
" view.setStyle({'chain':chain},{'cartoon': {'color':color}})\n", | |||
" if show_sidechains:\n", | |||
" BB = ['C','O','N']\n", | |||
" view.addStyle({'and':[{'resn':[\"GLY\",\"PRO\"],'invert':True},{'atom':BB,'invert':True}]},\n", | |||
" {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n", | |||
" view.addStyle({'and':[{'resn':\"GLY\"},{'atom':'CA'}]},\n", | |||
" {'sphere':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n", | |||
" view.addStyle({'and':[{'resn':\"PRO\"},{'atom':['C','O'],'invert':True}]},\n", | |||
" {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}}) \n", | |||
" if show_mainchains:\n", | |||
" BB = ['C','O','N','CA']\n", | |||
" view.addStyle({'atom':BB},{'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n", | |||
"\n", | |||
" view.zoomTo()\n", | |||
" return view\n", | |||
"\n", | |||
"\n", | |||
"show_pdb(rank_num,show_sidechains, show_mainchains, color).show()\n", | |||
"if color == \"lDDT\":\n", | |||
" plot_plddt_legend().show() " | |||
], | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"metadata": { | |||
"id": "11l8k--10q0C", | |||
"cellView": "form" | |||
}, | |||
"source": [ | |||
"#@title 图表展示 {run: \"auto\"}\n", | |||
"from IPython.display import display, HTML\n", | |||
"import base64\n", | |||
"from html import escape\n", | |||
"\n", | |||
"# see: https://stackoverflow.com/a/53688522\n", | |||
"def image_to_data_url(filename):\n", | |||
" ext = filename.split('.')[-1]\n", | |||
" prefix = f'data:image/{ext};base64,'\n", | |||
" with open(filename, 'rb') as f:\n", | |||
" img = f.read()\n", | |||
" return prefix + base64.b64encode(img).decode('utf-8')\n", | |||
"\n", | |||
"pae = image_to_data_url(f\"{jobname}{jobname_prefix}_PAE.png\")\n", | |||
"cov = image_to_data_url(f\"{jobname}{jobname_prefix}_coverage.png\")\n", | |||
"plddt = image_to_data_url(f\"{jobname}{jobname_prefix}_plddt.png\")\n", | |||
"display(HTML(f\"\"\"\n", | |||
"<style>\n", | |||
" img {{\n", | |||
" float:left;\n", | |||
" }}\n", | |||
" .full {{\n", | |||
" max-width:100%;\n", | |||
" }}\n", | |||
" .half {{\n", | |||
" max-width:50%;\n", | |||
" }}\n", | |||
" @media (max-width:640px) {{\n", | |||
" .half {{\n", | |||
" max-width:100%;\n", | |||
" }}\n", | |||
" }}\n", | |||
"</style>\n", | |||
"<div style=\"max-width:90%; padding:2em;\">\n", | |||
" <h1>Plots for {escape(jobname)}</h1>\n", | |||
" <img src=\"{pae}\" class=\"full\" />\n", | |||
" <img src=\"{cov}\" class=\"half\" />\n", | |||
" <img src=\"{plddt}\" class=\"half\" />\n", | |||
"</div>\n", | |||
"\"\"\"))\n" | |||
], | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"metadata": { | |||
"id": "33g5IIegij5R", | |||
"cellView": "form" | |||
}, | |||
"source": [ | |||
"#@title结果下载\n", | |||
"#@markdown If you are having issues downloading the result archive, try disabling your adblocker and run this cell again. If that fails click on the little folder icon to the left, navigate to file: `jobname.result.zip`, right-click and select \\\"Download\\\" (see [screenshot](https://pbs.twimg.com/media/E6wRW2lWUAEOuoe?format=jpg&name=small)).\n", | |||
"\n", | |||
"if msa_mode == \"custom\":\n", | |||
" print(\"Don't forget to cite your custom MSA generation method.\")\n", | |||
"\n", | |||
"!zip -FSr $jobname\".result.zip\" config.json $jobname*\".json\" $jobname*\".a3m\" $jobname*\"relaxed_rank_\"*\".pdb\" \"cite.bibtex\" $jobname*\".png\"\n", | |||
"files.download(f\"{jobname}.result.zip\")\n", | |||
"\n", | |||
"if save_to_google_drive == True and drive:\n", | |||
" uploaded = drive.CreateFile({'title': f\"{jobname}.result.zip\"})\n", | |||
" uploaded.SetContentFile(f\"{jobname}.result.zip\")\n", | |||
" uploaded.Upload()\n", | |||
" print(f\"Uploaded {jobname}.result.zip to Google Drive with ID {uploaded.get('id')}\")" | |||
], | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": { | |||
"id": "UGUBLzB3C6WN", | |||
"pycharm": { | |||
"name": "#%% md\n" | |||
} | |||
}, | |||
"source": [ | |||
"# 操作指南 <a name=\"Instructions\"></a>\n", | |||
"**Quick start**\n", | |||
"1. 把你要预测的氨基酸序列复制进输入框.\n", | |||
"2. 点击\"Runtime\" -> \"Run all\".\n", | |||
"3. 目前的模型预测包括5个模块,最终会生成一个3维结构图.\n", | |||
"\n", | |||
"**生成文件**\n", | |||
"\n", | |||
"1. PDB格式的模型结构文件.\n", | |||
"2. 模型质量图\n", | |||
"3. 模型MSA覆盖率.\n", | |||
"4. 其他.\n", | |||
"\n", | |||
"**Acknowledgments**\n", | |||
"- We thank the AlphaFold team for developing an excellent model and open sourcing the software. \n", | |||
"\n", | |||
"- [Söding Lab](https://www.mpibpc.mpg.de/soeding) for providing the computational resources for the MMseqs2 server\n", | |||
"\n", | |||
"- Richard Evans for helping to benchmark the ColabFold's Alphafold-multimer support\n", | |||
"\n", | |||
"- [David Koes](https://github.com/dkoes) for his awesome [py3Dmol](https://3dmol.csb.pitt.edu/) plugin, without whom these notebooks would be quite boring!\n", | |||
"\n", | |||
"- Do-Yoon Kim for creating the ColabFold logo.\n", | |||
"\n", | |||
"- A colab by Sergey Ovchinnikov ([@sokrypton](https://twitter.com/sokrypton)), Milot Mirdita ([@milot_mirdita](https://twitter.com/milot_mirdita)) and Martin Steinegger ([@thesteinegger](https://twitter.com/thesteinegger)).\n" | |||
] | |||
} | |||
] | |||
} |
@@ -0,0 +1,479 @@ | |||
{ | |||
"nbformat": 4, | |||
"nbformat_minor": 0, | |||
"metadata": { | |||
"accelerator": "GPU", | |||
"colab": { | |||
"name": "AlphaFold2中文版蛋白质预测模型使用指南(基于Deepmind与Mindspore开源框架).ipynb", | |||
"provenance": [], | |||
"collapsed_sections": [], | |||
"include_colab_link": true | |||
}, | |||
"kernelspec": { | |||
"display_name": "Python 3 (ipykernel)", | |||
"language": "python", | |||
"name": "python3" | |||
}, | |||
"language_info": { | |||
"codemirror_mode": { | |||
"name": "ipython", | |||
"version": 3 | |||
}, | |||
"file_extension": ".py", | |||
"mimetype": "text/x-python", | |||
"name": "python", | |||
"nbconvert_exporter": "python", | |||
"pygments_lexer": "ipython3", | |||
"version": "3.8.10" | |||
} | |||
}, | |||
"cells": [ | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": { | |||
"id": "view-in-github", | |||
"colab_type": "text" | |||
}, | |||
"source": [ | |||
"<a href=\"https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/AlphaFold2.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": { | |||
"id": "G4yBrceuFbf3" | |||
}, | |||
"source": [ | |||
"<img src=\"https://raw.githubusercontent.com/sokrypton/ColabFold/main/.github/ColabFold_Marv_Logo_Small.png\" height=\"200\" align=\"right\" style=\"height:240px\">\n", | |||
"\n", | |||
"##AlphaFold2_CN: AlphaFold2 with MMseqs2\n", | |||
"\n", | |||
"简单的中文版蛋白质结构预测操作指南(中文版),基于[AlphaFold2](https://www.nature.com/articles/s41586-021-03819-2)和[Alphafold2-multimer](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1). 序列比对方式基于[MMseqs2](mmseqs.com)和[HHsearch](https://github.com/soedinglab/hh-suite)." | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"metadata": { | |||
"id": "kOblAo-xetgx", | |||
"cellView": "form" | |||
}, | |||
"source": [ | |||
"#@title 输入蛋白质序列(默认为视频中的测试序列)\n", | |||
"from google.colab import files\n", | |||
"import os.path\n", | |||
"import re\n", | |||
"import hashlib\n", | |||
"import random\n", | |||
"\n", | |||
"def add_hash(x,y):\n", | |||
" return x+\"_\"+hashlib.sha1(y.encode()).hexdigest()[:5]\n", | |||
"\n", | |||
"query_sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n", | |||
"#@markdown - Use `:` to specify inter-protein chainbreaks for **modeling complexes** (supports homo- and hetro-oligomers). For example **PI...SK:PI...SK** for a homodimer\n", | |||
"\n", | |||
"# remove whitespaces\n", | |||
"query_sequence = \"\".join(query_sequence.split())\n", | |||
"\n", | |||
"jobname = 'test' #@param {type:\"string\"}\n", | |||
"# remove whitespaces\n", | |||
"basejobname = \"\".join(jobname.split())\n", | |||
"basejobname = re.sub(r'\\W+', '', basejobname)\n", | |||
"jobname = add_hash(basejobname, query_sequence)\n", | |||
"while os.path.isfile(f\"{jobname}.csv\"):\n", | |||
" jobname = add_hash(basejobname, ''.join(random.sample(query_sequence,len(query_sequence))))\n", | |||
"\n", | |||
"with open(f\"{jobname}.csv\", \"w\") as text_file:\n", | |||
" text_file.write(f\"id,sequence\\n{jobname},{query_sequence}\")\n", | |||
"\n", | |||
"queries_path=f\"{jobname}.csv\"\n", | |||
"\n", | |||
"# number of models to use\n", | |||
"use_amber = False #@param {type:\"boolean\"}\n", | |||
"template_mode = \"none\" #@param [\"none\", \"pdb70\",\"custom\"]\n", | |||
"#@markdown - \"none\" = no template information is used, \"pdb70\" = detect templates in pdb70, \"custom\" - upload and search own templates (PDB or mmCIF format, see [notes below](#custom_templates))\n", | |||
"\n", | |||
"if template_mode == \"pdb70\":\n", | |||
" use_templates = True\n", | |||
" custom_template_path = None\n", | |||
"elif template_mode == \"custom\":\n", | |||
" custom_template_path = f\"{jobname}_template\"\n", | |||
" os.mkdir(custom_template_path)\n", | |||
" uploaded = files.upload()\n", | |||
" use_templates = True\n", | |||
" for fn in uploaded.keys():\n", | |||
" os.rename(fn, f\"{jobname}_template/{fn}\")\n", | |||
"else:\n", | |||
" custom_template_path = None\n", | |||
" use_templates = False\n" | |||
], | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"source": [ | |||
"#@markdown ### MSA选项(custom MSA upload, single sequence, pairing mode)\n", | |||
"msa_mode = \"MMseqs2 (UniRef+Environmental)\" #@param [\"MMseqs2 (UniRef+Environmental)\", \"MMseqs2 (UniRef only)\",\"single_sequence\",\"custom\"]\n", | |||
"pair_mode = \"unpaired+paired\" #@param [\"unpaired+paired\",\"paired\",\"unpaired\"] {type:\"string\"}\n", | |||
"#@markdown - \"unpaired+paired\" = pair sequences from same species + unpaired MSA, \"unpaired\" = seperate MSA for each chain, \"paired\" - only use paired sequences.\n", | |||
"\n", | |||
"# decide which a3m to use\n", | |||
"if msa_mode.startswith(\"MMseqs2\"):\n", | |||
" a3m_file = f\"{jobname}.a3m\"\n", | |||
"elif msa_mode == \"custom\":\n", | |||
" a3m_file = f\"{jobname}.custom.a3m\"\n", | |||
" if not os.path.isfile(a3m_file):\n", | |||
" custom_msa_dict = files.upload()\n", | |||
" custom_msa = list(custom_msa_dict.keys())[0]\n", | |||
" header = 0\n", | |||
" import fileinput\n", | |||
" for line in fileinput.FileInput(custom_msa,inplace=1):\n", | |||
" if line.startswith(\">\"):\n", | |||
" header = header + 1\n", | |||
" if not line.rstrip():\n", | |||
" continue\n", | |||
" if line.startswith(\">\") == False and header == 1:\n", | |||
" query_sequence = line.rstrip()\n", | |||
" print(line, end='')\n", | |||
"\n", | |||
" os.rename(custom_msa, a3m_file)\n", | |||
" queries_path=a3m_file\n", | |||
" print(f\"moving {custom_msa} to {a3m_file}\")\n", | |||
"else:\n", | |||
" a3m_file = f\"{jobname}.single_sequence.a3m\"\n", | |||
" with open(a3m_file, \"w\") as text_file:\n", | |||
" text_file.write(\">1\\n%s\" % query_sequence)" | |||
], | |||
"metadata": { | |||
"cellView": "form", | |||
"id": "C2_sh2uAonJH" | |||
}, | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"source": [ | |||
"#@markdown ### 参数设置\n", | |||
"model_type = \"auto\" #@param [\"auto\", \"AlphaFold2-ptm\", \"AlphaFold2-multimer-v1\", \"AlphaFold2-multimer-v2\"]\n", | |||
"#@markdown - \"auto\" = protein structure prediction using \"AlphaFold2-ptm\" and complex prediction \"AlphaFold-multimer-v2\". For complexes \"AlphaFold-multimer-v[1,2]\" and \"AlphaFold-ptm\" can be used.\n", | |||
"num_recycles = 3 #@param [1,3,6,12,24,48] {type:\"raw\"}\n", | |||
"save_to_google_drive = False #@param {type:\"boolean\"}\n", | |||
"\n", | |||
"#@markdown - if the save_to_google_drive option was selected, the result zip will be uploaded to your Google Drive\n", | |||
"dpi = 200 #@param {type:\"integer\"}\n", | |||
"#@markdown - set dpi for image resolution\n", | |||
"\n", | |||
"#@markdown Don't forget to hit `Runtime` -> `Run all` after updating the form.\n", | |||
"\n", | |||
"\n", | |||
"if save_to_google_drive:\n", | |||
" from pydrive.drive import GoogleDrive\n", | |||
" from pydrive.auth import GoogleAuth\n", | |||
" from google.colab import auth\n", | |||
" from oauth2client.client import GoogleCredentials\n", | |||
" auth.authenticate_user()\n", | |||
" gauth = GoogleAuth()\n", | |||
" gauth.credentials = GoogleCredentials.get_application_default()\n", | |||
" drive = GoogleDrive(gauth)\n", | |||
" print(\"You are logged into Google Drive and are good to go!\")" | |||
], | |||
"metadata": { | |||
"cellView": "form", | |||
"id": "ADDuaolKmjGW" | |||
}, | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"metadata": { | |||
"id": "iccGdbe_Pmt9", | |||
"pycharm": { | |||
"name": "#%%\n" | |||
}, | |||
"cellView": "form" | |||
}, | |||
"source": [ | |||
"#@title 环境安装\n", | |||
"%%bash -s $use_amber $use_templates\n", | |||
"\n", | |||
"set -e\n", | |||
"\n", | |||
"USE_AMBER=$1\n", | |||
"USE_TEMPLATES=$2\n", | |||
"\n", | |||
"if [ ! -f COLABFOLD_READY ]; then\n", | |||
" # install dependencies\n", | |||
" # We have to use \"--no-warn-conflicts\" because colab already has a lot preinstalled with requirements different to ours\n", | |||
" pip install -q --no-warn-conflicts \"colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold\"\n", | |||
" # high risk high gain\n", | |||
" pip install -q \"jax[cuda11_cudnn805]>=0.3.8,<0.4\" -f https://storage.googleapis.com/jax-releases/jax_releases.html\n", | |||
" touch COLABFOLD_READY\n", | |||
"fi\n", | |||
"\n", | |||
"# setup conda\n", | |||
"if [ ${USE_AMBER} == \"True\" ] || [ ${USE_TEMPLATES} == \"True\" ]; then\n", | |||
" if [ ! -f CONDA_READY ]; then\n", | |||
" wget -qnc https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh\n", | |||
" bash Miniconda3-latest-Linux-x86_64.sh -bfp /usr/local 2>&1 1>/dev/null\n", | |||
" rm Miniconda3-latest-Linux-x86_64.sh\n", | |||
" touch CONDA_READY\n", | |||
" fi\n", | |||
"fi\n", | |||
"# setup template search\n", | |||
"if [ ${USE_TEMPLATES} == \"True\" ] && [ ! -f HH_READY ]; then\n", | |||
" conda install -y -q -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 python=3.7 2>&1 1>/dev/null\n", | |||
" touch HH_READY\n", | |||
"fi\n", | |||
"# setup openmm for amber refinement\n", | |||
"if [ ${USE_AMBER} == \"True\" ] && [ ! -f AMBER_READY ]; then\n", | |||
" conda install -y -q -c conda-forge openmm=7.5.1 python=3.7 pdbfixer 2>&1 1>/dev/null\n", | |||
" touch AMBER_READY\n", | |||
"fi" | |||
], | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"metadata": { | |||
"id": "_sztQyz29DIC", | |||
"cellView": "form" | |||
}, | |||
"source": [ | |||
"#@title 模型预测\n", | |||
"\n", | |||
"import sys\n", | |||
"\n", | |||
"from colabfold.download import download_alphafold_params, default_data_dir\n", | |||
"from colabfold.utils import setup_logging\n", | |||
"from colabfold.batch import get_queries, run, set_model_type\n", | |||
"K80_chk = !nvidia-smi | grep \"Tesla K80\" | wc -l\n", | |||
"if \"1\" in K80_chk:\n", | |||
" print(\"WARNING: found GPU Tesla K80: limited to total length < 1000\")\n", | |||
" if \"TF_FORCE_UNIFIED_MEMORY\" in os.environ:\n", | |||
" del os.environ[\"TF_FORCE_UNIFIED_MEMORY\"]\n", | |||
" if \"XLA_PYTHON_CLIENT_MEM_FRACTION\" in os.environ:\n", | |||
" del os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]\n", | |||
"\n", | |||
"from colabfold.colabfold import plot_protein\n", | |||
"from pathlib import Path\n", | |||
"import matplotlib.pyplot as plt\n", | |||
"\n", | |||
"\n", | |||
"# For some reason we need that to get pdbfixer to import\n", | |||
"if use_amber and '/usr/local/lib/python3.7/site-packages/' not in sys.path:\n", | |||
" sys.path.insert(0, '/usr/local/lib/python3.7/site-packages/')\n", | |||
"\n", | |||
"def prediction_callback(unrelaxed_protein, length, prediction_result, input_features, type):\n", | |||
" fig = plot_protein(unrelaxed_protein, Ls=length, dpi=150)\n", | |||
" plt.show()\n", | |||
" plt.close()\n", | |||
"\n", | |||
"result_dir=\".\"\n", | |||
"setup_logging(Path(\".\").joinpath(\"log.txt\"))\n", | |||
"queries, is_complex = get_queries(queries_path)\n", | |||
"model_type = set_model_type(is_complex, model_type)\n", | |||
"download_alphafold_params(model_type, Path(\".\"))\n", | |||
"run(\n", | |||
" queries=queries,\n", | |||
" result_dir=result_dir,\n", | |||
" use_templates=use_templates,\n", | |||
" custom_template_path=custom_template_path,\n", | |||
" use_amber=use_amber,\n", | |||
" msa_mode=msa_mode, \n", | |||
" model_type=model_type,\n", | |||
" num_models=5,\n", | |||
" num_recycles=num_recycles,\n", | |||
" model_order=[1, 2, 3, 4, 5],\n", | |||
" is_complex=is_complex,\n", | |||
" data_dir=Path(\".\"),\n", | |||
" keep_existing_results=False,\n", | |||
" recompile_padding=1.0,\n", | |||
" rank_by=\"auto\",\n", | |||
" pair_mode=pair_mode,\n", | |||
" stop_at_score=float(100),\n", | |||
" prediction_callback=prediction_callback,\n", | |||
" dpi=dpi\n", | |||
")" | |||
], | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"metadata": { | |||
"id": "KK7X9T44pWb7", | |||
"cellView": "form" | |||
}, | |||
"source": [ | |||
"#@title 展示3维结构 {run: \"auto\"}\n", | |||
"import py3Dmol\n", | |||
"import glob\n", | |||
"import matplotlib.pyplot as plt\n", | |||
"from colabfold.colabfold import plot_plddt_legend\n", | |||
"rank_num = 1 #@param [\"1\", \"2\", \"3\", \"4\", \"5\"] {type:\"raw\"}\n", | |||
"color = \"lDDT\" #@param [\"chain\", \"lDDT\", \"rainbow\"]\n", | |||
"show_sidechains = False #@param {type:\"boolean\"}\n", | |||
"show_mainchains = False #@param {type:\"boolean\"}\n", | |||
"\n", | |||
"jobname_prefix = \".custom\" if msa_mode == \"custom\" else \"\"\n", | |||
"if use_amber:\n", | |||
" pdb_filename = f\"{jobname}{jobname_prefix}_relaxed_rank_{rank_num}_model_*.pdb\"\n", | |||
"else:\n", | |||
" pdb_filename = f\"{jobname}{jobname_prefix}_unrelaxed_rank_{rank_num}_model_*.pdb\"\n", | |||
"\n", | |||
"pdb_file = glob.glob(pdb_filename)\n", | |||
"\n", | |||
"def show_pdb(rank_num=1, show_sidechains=False, show_mainchains=False, color=\"lDDT\"):\n", | |||
" model_name = f\"rank_{rank_num}\"\n", | |||
" view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)\n", | |||
" view.addModel(open(pdb_file[0],'r').read(),'pdb')\n", | |||
"\n", | |||
" if color == \"lDDT\":\n", | |||
" view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}})\n", | |||
" elif color == \"rainbow\":\n", | |||
" view.setStyle({'cartoon': {'color':'spectrum'}})\n", | |||
" elif color == \"chain\":\n", | |||
" chains = len(queries[0][1]) + 1 if is_complex else 1\n", | |||
" for n,chain,color in zip(range(chains),list(\"ABCDEFGH\"),\n", | |||
" [\"lime\",\"cyan\",\"magenta\",\"yellow\",\"salmon\",\"white\",\"blue\",\"orange\"]):\n", | |||
" view.setStyle({'chain':chain},{'cartoon': {'color':color}})\n", | |||
" if show_sidechains:\n", | |||
" BB = ['C','O','N']\n", | |||
" view.addStyle({'and':[{'resn':[\"GLY\",\"PRO\"],'invert':True},{'atom':BB,'invert':True}]},\n", | |||
" {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n", | |||
" view.addStyle({'and':[{'resn':\"GLY\"},{'atom':'CA'}]},\n", | |||
" {'sphere':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n", | |||
" view.addStyle({'and':[{'resn':\"PRO\"},{'atom':['C','O'],'invert':True}]},\n", | |||
" {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}}) \n", | |||
" if show_mainchains:\n", | |||
" BB = ['C','O','N','CA']\n", | |||
" view.addStyle({'atom':BB},{'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n", | |||
"\n", | |||
" view.zoomTo()\n", | |||
" return view\n", | |||
"\n", | |||
"\n", | |||
"show_pdb(rank_num,show_sidechains, show_mainchains, color).show()\n", | |||
"if color == \"lDDT\":\n", | |||
" plot_plddt_legend().show() " | |||
], | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"metadata": { | |||
"id": "11l8k--10q0C", | |||
"cellView": "form" | |||
}, | |||
"source": [ | |||
"#@title 图表展示 {run: \"auto\"}\n", | |||
"from IPython.display import display, HTML\n", | |||
"import base64\n", | |||
"from html import escape\n", | |||
"\n", | |||
"# see: https://stackoverflow.com/a/53688522\n", | |||
"def image_to_data_url(filename):\n", | |||
" ext = filename.split('.')[-1]\n", | |||
" prefix = f'data:image/{ext};base64,'\n", | |||
" with open(filename, 'rb') as f:\n", | |||
" img = f.read()\n", | |||
" return prefix + base64.b64encode(img).decode('utf-8')\n", | |||
"\n", | |||
"pae = image_to_data_url(f\"{jobname}{jobname_prefix}_PAE.png\")\n", | |||
"cov = image_to_data_url(f\"{jobname}{jobname_prefix}_coverage.png\")\n", | |||
"plddt = image_to_data_url(f\"{jobname}{jobname_prefix}_plddt.png\")\n", | |||
"display(HTML(f\"\"\"\n", | |||
"<style>\n", | |||
" img {{\n", | |||
" float:left;\n", | |||
" }}\n", | |||
" .full {{\n", | |||
" max-width:100%;\n", | |||
" }}\n", | |||
" .half {{\n", | |||
" max-width:50%;\n", | |||
" }}\n", | |||
" @media (max-width:640px) {{\n", | |||
" .half {{\n", | |||
" max-width:100%;\n", | |||
" }}\n", | |||
" }}\n", | |||
"</style>\n", | |||
"<div style=\"max-width:90%; padding:2em;\">\n", | |||
" <h1>Plots for {escape(jobname)}</h1>\n", | |||
" <img src=\"{pae}\" class=\"full\" />\n", | |||
" <img src=\"{cov}\" class=\"half\" />\n", | |||
" <img src=\"{plddt}\" class=\"half\" />\n", | |||
"</div>\n", | |||
"\"\"\"))\n" | |||
], | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"metadata": { | |||
"id": "33g5IIegij5R", | |||
"cellView": "form" | |||
}, | |||
"source": [ | |||
"#@title结果下载\n", | |||
"#@markdown If you are having issues downloading the result archive, try disabling your adblocker and run this cell again. If that fails click on the little folder icon to the left, navigate to file: `jobname.result.zip`, right-click and select \\\"Download\\\" (see [screenshot](https://pbs.twimg.com/media/E6wRW2lWUAEOuoe?format=jpg&name=small)).\n", | |||
"\n", | |||
"if msa_mode == \"custom\":\n", | |||
" print(\"Don't forget to cite your custom MSA generation method.\")\n", | |||
"\n", | |||
"!zip -FSr $jobname\".result.zip\" config.json $jobname*\".json\" $jobname*\".a3m\" $jobname*\"relaxed_rank_\"*\".pdb\" \"cite.bibtex\" $jobname*\".png\"\n", | |||
"files.download(f\"{jobname}.result.zip\")\n", | |||
"\n", | |||
"if save_to_google_drive == True and drive:\n", | |||
" uploaded = drive.CreateFile({'title': f\"{jobname}.result.zip\"})\n", | |||
" uploaded.SetContentFile(f\"{jobname}.result.zip\")\n", | |||
" uploaded.Upload()\n", | |||
" print(f\"Uploaded {jobname}.result.zip to Google Drive with ID {uploaded.get('id')}\")" | |||
], | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": { | |||
"id": "UGUBLzB3C6WN", | |||
"pycharm": { | |||
"name": "#%% md\n" | |||
} | |||
}, | |||
"source": [ | |||
"# 操作指南 <a name=\"Instructions\"></a>\n", | |||
"**Quick start**\n", | |||
"1. 把你要预测的氨基酸序列复制进输入框.\n", | |||
"2. 点击\"Runtime\" -> \"Run all\".\n", | |||
"3. 目前的模型预测包括5个模块,最终会生成一个3维结构图.\n", | |||
"\n", | |||
"**生成文件**\n", | |||
"\n", | |||
"1. PDB格式的模型结构文件.\n", | |||
"2. 模型质量图\n", | |||
"3. 模型MSA覆盖率.\n", | |||
"4. 其他.\n", | |||
"\n", | |||
"**Acknowledgments**\n", | |||
"- We thank the AlphaFold team for developing an excellent model and open sourcing the software. \n", | |||
"\n", | |||
"- [Söding Lab](https://www.mpibpc.mpg.de/soeding) for providing the computational resources for the MMseqs2 server\n", | |||
"\n", | |||
"- Richard Evans for helping to benchmark the ColabFold's Alphafold-multimer support\n", | |||
"\n", | |||
"- [David Koes](https://github.com/dkoes) for his awesome [py3Dmol](https://3dmol.csb.pitt.edu/) plugin, without whom these notebooks would be quite boring!\n", | |||
"\n", | |||
"- Do-Yoon Kim for creating the ColabFold logo.\n", | |||
"\n", | |||
"- A colab by Sergey Ovchinnikov ([@sokrypton](https://twitter.com/sokrypton)), Milot Mirdita ([@milot_mirdita](https://twitter.com/milot_mirdita)) and Martin Steinegger ([@thesteinegger](https://twitter.com/thesteinegger)).\n" | |||
] | |||
} | |||
] | |||
} |
@@ -0,0 +1,478 @@ | |||
{ | |||
"nbformat": 4, | |||
"nbformat_minor": 0, | |||
"metadata": { | |||
"accelerator": "GPU", | |||
"colab": { | |||
"name": "AlphaFold2_CN.ipynb", | |||
"provenance": [], | |||
"collapsed_sections": [], | |||
"include_colab_link": true | |||
}, | |||
"kernelspec": { | |||
"display_name": "Python 3 (ipykernel)", | |||
"language": "python", | |||
"name": "python3" | |||
}, | |||
"language_info": { | |||
"codemirror_mode": { | |||
"name": "ipython", | |||
"version": 3 | |||
}, | |||
"file_extension": ".py", | |||
"mimetype": "text/x-python", | |||
"name": "python", | |||
"nbconvert_exporter": "python", | |||
"pygments_lexer": "ipython3", | |||
"version": "3.8.10" | |||
} | |||
}, | |||
"cells": [ | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": { | |||
"id": "view-in-github", | |||
"colab_type": "text" | |||
}, | |||
"source": [ | |||
"<a href=\"https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/AlphaFold2.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": { | |||
"id": "G4yBrceuFbf3" | |||
}, | |||
"source": [ | |||
"<img src=\"https://raw.githubusercontent.com/sokrypton/ColabFold/main/.github/ColabFold_Marv_Logo_Small.png\" height=\"200\" align=\"right\" style=\"height:240px\">\n", | |||
"\n", | |||
"##ColabFold: AlphaFold2 using MMseqs2\n", | |||
"\n", | |||
"简单的中文版蛋白质结构预测操作指南(中文版),基于[AlphaFold2](https://www.nature.com/articles/s41586-021-03819-2)和[Alphafold2-multimer](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1). 序列比对方式基于[MMseqs2](mmseqs.com)和[HHsearch](https://github.com/soedinglab/hh-suite)." | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"metadata": { | |||
"id": "kOblAo-xetgx", | |||
"cellView": "form" | |||
}, | |||
"source": [ | |||
"#输入蛋白质序列(默认为视频中的测试序列)然后点击上边框的 `Runtime` -> `Run all`\n", | |||
"from google.colab import files\n", | |||
"import os.path\n", | |||
"import re\n", | |||
"import hashlib\n", | |||
"import random\n", | |||
"\n", | |||
"def add_hash(x,y):\n", | |||
" return x+\"_\"+hashlib.sha1(y.encode()).hexdigest()[:5]\n", | |||
"\n", | |||
"query_sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n", | |||
"#@markdown - Use `:` to specify inter-protein chainbreaks for **modeling complexes** (supports homo- and hetro-oligomers). For example **PI...SK:PI...SK** for a homodimer\n", | |||
"\n", | |||
"# remove whitespaces\n", | |||
"query_sequence = \"\".join(query_sequence.split())\n", | |||
"\n", | |||
"jobname = 'test' #@param {type:\"string\"}\n", | |||
"# remove whitespaces\n", | |||
"basejobname = \"\".join(jobname.split())\n", | |||
"basejobname = re.sub(r'\\W+', '', basejobname)\n", | |||
"jobname = add_hash(basejobname, query_sequence)\n", | |||
"while os.path.isfile(f\"{jobname}.csv\"):\n", | |||
" jobname = add_hash(basejobname, ''.join(random.sample(query_sequence,len(query_sequence))))\n", | |||
"\n", | |||
"with open(f\"{jobname}.csv\", \"w\") as text_file:\n", | |||
" text_file.write(f\"id,sequence\\n{jobname},{query_sequence}\")\n", | |||
"\n", | |||
"queries_path=f\"{jobname}.csv\"\n", | |||
"\n", | |||
"# number of models to use\n", | |||
"use_amber = False #@param {type:\"boolean\"}\n", | |||
"template_mode = \"none\" #@param [\"none\", \"pdb70\",\"custom\"]\n", | |||
"#@markdown - \"none\" = no template information is used, \"pdb70\" = detect templates in pdb70, \"custom\" - upload and search own templates (PDB or mmCIF format, see [notes below](#custom_templates))\n", | |||
"\n", | |||
"if template_mode == \"pdb70\":\n", | |||
" use_templates = True\n", | |||
" custom_template_path = None\n", | |||
"elif template_mode == \"custom\":\n", | |||
" custom_template_path = f\"{jobname}_template\"\n", | |||
" os.mkdir(custom_template_path)\n", | |||
" uploaded = files.upload()\n", | |||
" use_templates = True\n", | |||
" for fn in uploaded.keys():\n", | |||
" os.rename(fn, f\"{jobname}_template/{fn}\")\n", | |||
"else:\n", | |||
" custom_template_path = None\n", | |||
" use_templates = False\n" | |||
], | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"source": [ | |||
"#@markdown ### MSA options (custom MSA upload, single sequence, pairing mode)\n", | |||
"msa_mode = \"MMseqs2 (UniRef+Environmental)\" #@param [\"MMseqs2 (UniRef+Environmental)\", \"MMseqs2 (UniRef only)\",\"single_sequence\",\"custom\"]\n", | |||
"pair_mode = \"unpaired+paired\" #@param [\"unpaired+paired\",\"paired\",\"unpaired\"] {type:\"string\"}\n", | |||
"#@markdown - \"unpaired+paired\" = pair sequences from same species + unpaired MSA, \"unpaired\" = seperate MSA for each chain, \"paired\" - only use paired sequences.\n", | |||
"\n", | |||
"# decide which a3m to use\n", | |||
"if msa_mode.startswith(\"MMseqs2\"):\n", | |||
" a3m_file = f\"{jobname}.a3m\"\n", | |||
"elif msa_mode == \"custom\":\n", | |||
" a3m_file = f\"{jobname}.custom.a3m\"\n", | |||
" if not os.path.isfile(a3m_file):\n", | |||
" custom_msa_dict = files.upload()\n", | |||
" custom_msa = list(custom_msa_dict.keys())[0]\n", | |||
" header = 0\n", | |||
" import fileinput\n", | |||
" for line in fileinput.FileInput(custom_msa,inplace=1):\n", | |||
" if line.startswith(\">\"):\n", | |||
" header = header + 1\n", | |||
" if not line.rstrip():\n", | |||
" continue\n", | |||
" if line.startswith(\">\") == False and header == 1:\n", | |||
" query_sequence = line.rstrip()\n", | |||
" print(line, end='')\n", | |||
"\n", | |||
" os.rename(custom_msa, a3m_file)\n", | |||
" queries_path=a3m_file\n", | |||
" print(f\"moving {custom_msa} to {a3m_file}\")\n", | |||
"else:\n", | |||
" a3m_file = f\"{jobname}.single_sequence.a3m\"\n", | |||
" with open(a3m_file, \"w\") as text_file:\n", | |||
" text_file.write(\">1\\n%s\" % query_sequence)" | |||
], | |||
"metadata": { | |||
"cellView": "form", | |||
"id": "C2_sh2uAonJH" | |||
}, | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"source": [ | |||
"#@markdown ### Advanced settings\n", | |||
"model_type = \"auto\" #@param [\"auto\", \"AlphaFold2-ptm\", \"AlphaFold2-multimer-v1\", \"AlphaFold2-multimer-v2\"]\n", | |||
"#@markdown - \"auto\" = protein structure prediction using \"AlphaFold2-ptm\" and complex prediction \"AlphaFold-multimer-v2\". For complexes \"AlphaFold-multimer-v[1,2]\" and \"AlphaFold-ptm\" can be used.\n", | |||
"num_recycles = 3 #@param [1,3,6,12,24,48] {type:\"raw\"}\n", | |||
"save_to_google_drive = False #@param {type:\"boolean\"}\n", | |||
"\n", | |||
"#@markdown - if the save_to_google_drive option was selected, the result zip will be uploaded to your Google Drive\n", | |||
"dpi = 200 #@param {type:\"integer\"}\n", | |||
"#@markdown - set dpi for image resolution\n", | |||
"\n", | |||
"#@markdown Don't forget to hit `Runtime` -> `Run all` after updating the form.\n", | |||
"\n", | |||
"\n", | |||
"if save_to_google_drive:\n", | |||
" from pydrive.drive import GoogleDrive\n", | |||
" from pydrive.auth import GoogleAuth\n", | |||
" from google.colab import auth\n", | |||
" from oauth2client.client import GoogleCredentials\n", | |||
" auth.authenticate_user()\n", | |||
" gauth = GoogleAuth()\n", | |||
" gauth.credentials = GoogleCredentials.get_application_default()\n", | |||
" drive = GoogleDrive(gauth)\n", | |||
" print(\"You are logged into Google Drive and are good to go!\")" | |||
], | |||
"metadata": { | |||
"cellView": "form", | |||
"id": "ADDuaolKmjGW" | |||
}, | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"metadata": { | |||
"id": "iccGdbe_Pmt9", | |||
"pycharm": { | |||
"name": "#%%\n" | |||
}, | |||
"cellView": "form" | |||
}, | |||
"source": [ | |||
"#@title Install dependencies\n", | |||
"%%bash -s $use_amber $use_templates\n", | |||
"\n", | |||
"set -e\n", | |||
"\n", | |||
"USE_AMBER=$1\n", | |||
"USE_TEMPLATES=$2\n", | |||
"\n", | |||
"if [ ! -f COLABFOLD_READY ]; then\n", | |||
" # install dependencies\n", | |||
" # We have to use \"--no-warn-conflicts\" because colab already has a lot preinstalled with requirements different to ours\n", | |||
" pip install -q --no-warn-conflicts \"colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold\"\n", | |||
" # high risk high gain\n", | |||
" pip install -q \"jax[cuda11_cudnn805]>=0.3.8,<0.4\" -f https://storage.googleapis.com/jax-releases/jax_releases.html\n", | |||
" touch COLABFOLD_READY\n", | |||
"fi\n", | |||
"\n", | |||
"# setup conda\n", | |||
"if [ ${USE_AMBER} == \"True\" ] || [ ${USE_TEMPLATES} == \"True\" ]; then\n", | |||
" if [ ! -f CONDA_READY ]; then\n", | |||
" wget -qnc https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh\n", | |||
" bash Miniconda3-latest-Linux-x86_64.sh -bfp /usr/local 2>&1 1>/dev/null\n", | |||
" rm Miniconda3-latest-Linux-x86_64.sh\n", | |||
" touch CONDA_READY\n", | |||
" fi\n", | |||
"fi\n", | |||
"# setup template search\n", | |||
"if [ ${USE_TEMPLATES} == \"True\" ] && [ ! -f HH_READY ]; then\n", | |||
" conda install -y -q -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 python=3.7 2>&1 1>/dev/null\n", | |||
" touch HH_READY\n", | |||
"fi\n", | |||
"# setup openmm for amber refinement\n", | |||
"if [ ${USE_AMBER} == \"True\" ] && [ ! -f AMBER_READY ]; then\n", | |||
" conda install -y -q -c conda-forge openmm=7.5.1 python=3.7 pdbfixer 2>&1 1>/dev/null\n", | |||
" touch AMBER_READY\n", | |||
"fi" | |||
], | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"metadata": { | |||
"id": "_sztQyz29DIC", | |||
"cellView": "form" | |||
}, | |||
"source": [ | |||
"#@title Run Prediction\n", | |||
"\n", | |||
"import sys\n", | |||
"\n", | |||
"from colabfold.download import download_alphafold_params, default_data_dir\n", | |||
"from colabfold.utils import setup_logging\n", | |||
"from colabfold.batch import get_queries, run, set_model_type\n", | |||
"K80_chk = !nvidia-smi | grep \"Tesla K80\" | wc -l\n", | |||
"if \"1\" in K80_chk:\n", | |||
" print(\"WARNING: found GPU Tesla K80: limited to total length < 1000\")\n", | |||
" if \"TF_FORCE_UNIFIED_MEMORY\" in os.environ:\n", | |||
" del os.environ[\"TF_FORCE_UNIFIED_MEMORY\"]\n", | |||
" if \"XLA_PYTHON_CLIENT_MEM_FRACTION\" in os.environ:\n", | |||
" del os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]\n", | |||
"\n", | |||
"from colabfold.colabfold import plot_protein\n", | |||
"from pathlib import Path\n", | |||
"import matplotlib.pyplot as plt\n", | |||
"\n", | |||
"\n", | |||
"# For some reason we need that to get pdbfixer to import\n", | |||
"if use_amber and '/usr/local/lib/python3.7/site-packages/' not in sys.path:\n", | |||
" sys.path.insert(0, '/usr/local/lib/python3.7/site-packages/')\n", | |||
"\n", | |||
"def prediction_callback(unrelaxed_protein, length, prediction_result, input_features, type):\n", | |||
" fig = plot_protein(unrelaxed_protein, Ls=length, dpi=150)\n", | |||
" plt.show()\n", | |||
" plt.close()\n", | |||
"\n", | |||
"result_dir=\".\"\n", | |||
"setup_logging(Path(\".\").joinpath(\"log.txt\"))\n", | |||
"queries, is_complex = get_queries(queries_path)\n", | |||
"model_type = set_model_type(is_complex, model_type)\n", | |||
"download_alphafold_params(model_type, Path(\".\"))\n", | |||
"run(\n", | |||
" queries=queries,\n", | |||
" result_dir=result_dir,\n", | |||
" use_templates=use_templates,\n", | |||
" custom_template_path=custom_template_path,\n", | |||
" use_amber=use_amber,\n", | |||
" msa_mode=msa_mode, \n", | |||
" model_type=model_type,\n", | |||
" num_models=5,\n", | |||
" num_recycles=num_recycles,\n", | |||
" model_order=[1, 2, 3, 4, 5],\n", | |||
" is_complex=is_complex,\n", | |||
" data_dir=Path(\".\"),\n", | |||
" keep_existing_results=False,\n", | |||
" recompile_padding=1.0,\n", | |||
" rank_by=\"auto\",\n", | |||
" pair_mode=pair_mode,\n", | |||
" stop_at_score=float(100),\n", | |||
" prediction_callback=prediction_callback,\n", | |||
" dpi=dpi\n", | |||
")" | |||
], | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"metadata": { | |||
"id": "KK7X9T44pWb7", | |||
"cellView": "form" | |||
}, | |||
"source": [ | |||
"#@title Display 3D structure {run: \"auto\"}\n", | |||
"import py3Dmol\n", | |||
"import glob\n", | |||
"import matplotlib.pyplot as plt\n", | |||
"from colabfold.colabfold import plot_plddt_legend\n", | |||
"rank_num = 1 #@param [\"1\", \"2\", \"3\", \"4\", \"5\"] {type:\"raw\"}\n", | |||
"color = \"lDDT\" #@param [\"chain\", \"lDDT\", \"rainbow\"]\n", | |||
"show_sidechains = False #@param {type:\"boolean\"}\n", | |||
"show_mainchains = False #@param {type:\"boolean\"}\n", | |||
"\n", | |||
"jobname_prefix = \".custom\" if msa_mode == \"custom\" else \"\"\n", | |||
"if use_amber:\n", | |||
" pdb_filename = f\"{jobname}{jobname_prefix}_relaxed_rank_{rank_num}_model_*.pdb\"\n", | |||
"else:\n", | |||
" pdb_filename = f\"{jobname}{jobname_prefix}_unrelaxed_rank_{rank_num}_model_*.pdb\"\n", | |||
"\n", | |||
"pdb_file = glob.glob(pdb_filename)\n", | |||
"\n", | |||
"def show_pdb(rank_num=1, show_sidechains=False, show_mainchains=False, color=\"lDDT\"):\n", | |||
" model_name = f\"rank_{rank_num}\"\n", | |||
" view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)\n", | |||
" view.addModel(open(pdb_file[0],'r').read(),'pdb')\n", | |||
"\n", | |||
" if color == \"lDDT\":\n", | |||
" view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}})\n", | |||
" elif color == \"rainbow\":\n", | |||
" view.setStyle({'cartoon': {'color':'spectrum'}})\n", | |||
" elif color == \"chain\":\n", | |||
" chains = len(queries[0][1]) + 1 if is_complex else 1\n", | |||
" for n,chain,color in zip(range(chains),list(\"ABCDEFGH\"),\n", | |||
" [\"lime\",\"cyan\",\"magenta\",\"yellow\",\"salmon\",\"white\",\"blue\",\"orange\"]):\n", | |||
" view.setStyle({'chain':chain},{'cartoon': {'color':color}})\n", | |||
" if show_sidechains:\n", | |||
" BB = ['C','O','N']\n", | |||
" view.addStyle({'and':[{'resn':[\"GLY\",\"PRO\"],'invert':True},{'atom':BB,'invert':True}]},\n", | |||
" {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n", | |||
" view.addStyle({'and':[{'resn':\"GLY\"},{'atom':'CA'}]},\n", | |||
" {'sphere':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n", | |||
" view.addStyle({'and':[{'resn':\"PRO\"},{'atom':['C','O'],'invert':True}]},\n", | |||
" {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}}) \n", | |||
" if show_mainchains:\n", | |||
" BB = ['C','O','N','CA']\n", | |||
" view.addStyle({'atom':BB},{'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n", | |||
"\n", | |||
" view.zoomTo()\n", | |||
" return view\n", | |||
"\n", | |||
"\n", | |||
"show_pdb(rank_num,show_sidechains, show_mainchains, color).show()\n", | |||
"if color == \"lDDT\":\n", | |||
" plot_plddt_legend().show() " | |||
], | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"metadata": { | |||
"id": "11l8k--10q0C", | |||
"cellView": "form" | |||
}, | |||
"source": [ | |||
"#@title Plots {run: \"auto\"}\n", | |||
"from IPython.display import display, HTML\n", | |||
"import base64\n", | |||
"from html import escape\n", | |||
"\n", | |||
"# see: https://stackoverflow.com/a/53688522\n", | |||
"def image_to_data_url(filename):\n", | |||
" ext = filename.split('.')[-1]\n", | |||
" prefix = f'data:image/{ext};base64,'\n", | |||
" with open(filename, 'rb') as f:\n", | |||
" img = f.read()\n", | |||
" return prefix + base64.b64encode(img).decode('utf-8')\n", | |||
"\n", | |||
"pae = image_to_data_url(f\"{jobname}{jobname_prefix}_PAE.png\")\n", | |||
"cov = image_to_data_url(f\"{jobname}{jobname_prefix}_coverage.png\")\n", | |||
"plddt = image_to_data_url(f\"{jobname}{jobname_prefix}_plddt.png\")\n", | |||
"display(HTML(f\"\"\"\n", | |||
"<style>\n", | |||
" img {{\n", | |||
" float:left;\n", | |||
" }}\n", | |||
" .full {{\n", | |||
" max-width:100%;\n", | |||
" }}\n", | |||
" .half {{\n", | |||
" max-width:50%;\n", | |||
" }}\n", | |||
" @media (max-width:640px) {{\n", | |||
" .half {{\n", | |||
" max-width:100%;\n", | |||
" }}\n", | |||
" }}\n", | |||
"</style>\n", | |||
"<div style=\"max-width:90%; padding:2em;\">\n", | |||
" <h1>Plots for {escape(jobname)}</h1>\n", | |||
" <img src=\"{pae}\" class=\"full\" />\n", | |||
" <img src=\"{cov}\" class=\"half\" />\n", | |||
" <img src=\"{plddt}\" class=\"half\" />\n", | |||
"</div>\n", | |||
"\"\"\"))\n" | |||
], | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"metadata": { | |||
"id": "33g5IIegij5R", | |||
"cellView": "form" | |||
}, | |||
"source": [ | |||
"#@title Package and download results\n", | |||
"#@markdown If you are having issues downloading the result archive, try disabling your adblocker and run this cell again. If that fails click on the little folder icon to the left, navigate to file: `jobname.result.zip`, right-click and select \\\"Download\\\" (see [screenshot](https://pbs.twimg.com/media/E6wRW2lWUAEOuoe?format=jpg&name=small)).\n", | |||
"\n", | |||
"if msa_mode == \"custom\":\n", | |||
" print(\"Don't forget to cite your custom MSA generation method.\")\n", | |||
"\n", | |||
"!zip -FSr $jobname\".result.zip\" config.json $jobname*\".json\" $jobname*\".a3m\" $jobname*\"relaxed_rank_\"*\".pdb\" \"cite.bibtex\" $jobname*\".png\"\n", | |||
"files.download(f\"{jobname}.result.zip\")\n", | |||
"\n", | |||
"if save_to_google_drive == True and drive:\n", | |||
" uploaded = drive.CreateFile({'title': f\"{jobname}.result.zip\"})\n", | |||
" uploaded.SetContentFile(f\"{jobname}.result.zip\")\n", | |||
" uploaded.Upload()\n", | |||
" print(f\"Uploaded {jobname}.result.zip to Google Drive with ID {uploaded.get('id')}\")" | |||
], | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": { | |||
"id": "UGUBLzB3C6WN", | |||
"pycharm": { | |||
"name": "#%% md\n" | |||
} | |||
}, | |||
"source": [ | |||
"# 操作指南 <a name=\"Instructions\"></a>\n", | |||
"**Quick start**\n", | |||
"1. 把你要预测的氨基酸序列复制进输入框.\n", | |||
"2. 点击\"Runtime\" -> \"Run all\".\n", | |||
"3. 目前的模型预测包括5个模块,最终会生成一个3维结构图.\n", | |||
"\n", | |||
"**生成文件**\n", | |||
"\n", | |||
"1. PDB格式的模型结构文件.\n", | |||
"2. 模型质量图\n", | |||
"3. 模型MSA覆盖率.\n", | |||
"4. 其他.\n", | |||
"**Acknowledgments**\n", | |||
"- We thank the AlphaFold team for developing an excellent model and open sourcing the software. \n", | |||
"\n", | |||
"- [Söding Lab](https://www.mpibpc.mpg.de/soeding) for providing the computational resources for the MMseqs2 server\n", | |||
"\n", | |||
"- Richard Evans for helping to benchmark the ColabFold's Alphafold-multimer support\n", | |||
"\n", | |||
"- [David Koes](https://github.com/dkoes) for his awesome [py3Dmol](https://3dmol.csb.pitt.edu/) plugin, without whom these notebooks would be quite boring!\n", | |||
"\n", | |||
"- Do-Yoon Kim for creating the ColabFold logo.\n", | |||
"\n", | |||
"- A colab by Sergey Ovchinnikov ([@sokrypton](https://twitter.com/sokrypton)), Milot Mirdita ([@milot_mirdita](https://twitter.com/milot_mirdita)) and Martin Steinegger ([@thesteinegger](https://twitter.com/thesteinegger)).\n" | |||
] | |||
} | |||
] | |||
} |
@@ -0,0 +1,201 @@ | |||
Apache License | |||
Version 2.0, January 2004 | |||
http://www.apache.org/licenses/ | |||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION | |||
1. Definitions. | |||
"License" shall mean the terms and conditions for use, reproduction, | |||
and distribution as defined by Sections 1 through 9 of this document. | |||
"Licensor" shall mean the copyright owner or entity authorized by | |||
the copyright owner that is granting the License. | |||
"Legal Entity" shall mean the union of the acting entity and all | |||
other entities that control, are controlled by, or are under common | |||
control with that entity. For the purposes of this definition, | |||
"control" means (i) the power, direct or indirect, to cause the | |||
direction or management of such entity, whether by contract or | |||
otherwise, or (ii) ownership of fifty percent (50%) or more of the | |||
outstanding shares, or (iii) beneficial ownership of such entity. | |||
"You" (or "Your") shall mean an individual or Legal Entity | |||
exercising permissions granted by this License. | |||
"Source" form shall mean the preferred form for making modifications, | |||
including but not limited to software source code, documentation | |||
source, and configuration files. | |||
"Object" form shall mean any form resulting from mechanical | |||
transformation or translation of a Source form, including but | |||
not limited to compiled object code, generated documentation, | |||
and conversions to other media types. | |||
"Work" shall mean the work of authorship, whether in Source or | |||
Object form, made available under the License, as indicated by a | |||
copyright notice that is included in or attached to the work | |||
(an example is provided in the Appendix below). | |||
"Derivative Works" shall mean any work, whether in Source or Object | |||
form, that is based on (or derived from) the Work and for which the | |||
editorial revisions, annotations, elaborations, or other modifications | |||
represent, as a whole, an original work of authorship. For the purposes | |||
of this License, Derivative Works shall not include works that remain | |||
separable from, or merely link (or bind by name) to the interfaces of, | |||
the Work and Derivative Works thereof. | |||
"Contribution" shall mean any work of authorship, including | |||
the original version of the Work and any modifications or additions | |||
to that Work or Derivative Works thereof, that is intentionally | |||
submitted to Licensor for inclusion in the Work by the copyright owner | |||
or by an individual or Legal Entity authorized to submit on behalf of | |||
the copyright owner. For the purposes of this definition, "submitted" | |||
means any form of electronic, verbal, or written communication sent | |||
to the Licensor or its representatives, including but not limited to | |||
communication on electronic mailing lists, source code control systems, | |||
and issue tracking systems that are managed by, or on behalf of, the | |||
Licensor for the purpose of discussing and improving the Work, but | |||
excluding communication that is conspicuously marked or otherwise | |||
designated in writing by the copyright owner as "Not a Contribution." | |||
"Contributor" shall mean Licensor and any individual or Legal Entity | |||
on behalf of whom a Contribution has been received by Licensor and | |||
subsequently incorporated within the Work. | |||
2. Grant of Copyright License. Subject to the terms and conditions of | |||
this License, each Contributor hereby grants to You a perpetual, | |||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable | |||
copyright license to reproduce, prepare Derivative Works of, | |||
publicly display, publicly perform, sublicense, and distribute the | |||
Work and such Derivative Works in Source or Object form. | |||
3. Grant of Patent License. Subject to the terms and conditions of | |||
this License, each Contributor hereby grants to You a perpetual, | |||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable | |||
(except as stated in this section) patent license to make, have made, | |||
use, offer to sell, sell, import, and otherwise transfer the Work, | |||
where such license applies only to those patent claims licensable | |||
by such Contributor that are necessarily infringed by their | |||
Contribution(s) alone or by combination of their Contribution(s) | |||
with the Work to which such Contribution(s) was submitted. If You | |||
institute patent litigation against any entity (including a | |||
cross-claim or counterclaim in a lawsuit) alleging that the Work | |||
or a Contribution incorporated within the Work constitutes direct | |||
or contributory patent infringement, then any patent licenses | |||
granted to You under this License for that Work shall terminate | |||
as of the date such litigation is filed. | |||
4. Redistribution. You may reproduce and distribute copies of the | |||
Work or Derivative Works thereof in any medium, with or without | |||
modifications, and in Source or Object form, provided that You | |||
meet the following conditions: | |||
(a) You must give any other recipients of the Work or | |||
Derivative Works a copy of this License; and | |||
(b) You must cause any modified files to carry prominent notices | |||
stating that You changed the files; and | |||
(c) You must retain, in the Source form of any Derivative Works | |||
that You distribute, all copyright, patent, trademark, and | |||
attribution notices from the Source form of the Work, | |||
excluding those notices that do not pertain to any part of | |||
the Derivative Works; and | |||
(d) If the Work includes a "NOTICE" text file as part of its | |||
distribution, then any Derivative Works that You distribute must | |||
include a readable copy of the attribution notices contained | |||
within such NOTICE file, excluding those notices that do not | |||
pertain to any part of the Derivative Works, in at least one | |||
of the following places: within a NOTICE text file distributed | |||
as part of the Derivative Works; within the Source form or | |||
documentation, if provided along with the Derivative Works; or, | |||
within a display generated by the Derivative Works, if and | |||
wherever such third-party notices normally appear. The contents | |||
of the NOTICE file are for informational purposes only and | |||
do not modify the License. You may add Your own attribution | |||
notices within Derivative Works that You distribute, alongside | |||
or as an addendum to the NOTICE text from the Work, provided | |||
that such additional attribution notices cannot be construed | |||
as modifying the License. | |||
You may add Your own copyright statement to Your modifications and | |||
may provide additional or different license terms and conditions | |||
for use, reproduction, or distribution of Your modifications, or | |||
for any such Derivative Works as a whole, provided Your use, | |||
reproduction, and distribution of the Work otherwise complies with | |||
the conditions stated in this License. | |||
5. Submission of Contributions. Unless You explicitly state otherwise, | |||
any Contribution intentionally submitted for inclusion in the Work | |||
by You to the Licensor shall be under the terms and conditions of | |||
this License, without any additional terms or conditions. | |||
Notwithstanding the above, nothing herein shall supersede or modify | |||
the terms of any separate license agreement you may have executed | |||
with Licensor regarding such Contributions. | |||
6. Trademarks. This License does not grant permission to use the trade | |||
names, trademarks, service marks, or product names of the Licensor, | |||
except as required for reasonable and customary use in describing the | |||
origin of the Work and reproducing the content of the NOTICE file. | |||
7. Disclaimer of Warranty. Unless required by applicable law or | |||
agreed to in writing, Licensor provides the Work (and each | |||
Contributor provides its Contributions) on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
implied, including, without limitation, any warranties or conditions | |||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A | |||
PARTICULAR PURPOSE. You are solely responsible for determining the | |||
appropriateness of using or redistributing the Work and assume any | |||
risks associated with Your exercise of permissions under this License. | |||
8. Limitation of Liability. In no event and under no legal theory, | |||
whether in tort (including negligence), contract, or otherwise, | |||
unless required by applicable law (such as deliberate and grossly | |||
negligent acts) or agreed to in writing, shall any Contributor be | |||
liable to You for damages, including any direct, indirect, special, | |||
incidental, or consequential damages of any character arising as a | |||
result of this License or out of the use or inability to use the | |||
Work (including but not limited to damages for loss of goodwill, | |||
work stoppage, computer failure or malfunction, or any and all | |||
other commercial damages or losses), even if such Contributor | |||
has been advised of the possibility of such damages. | |||
9. Accepting Warranty or Additional Liability. While redistributing | |||
the Work or Derivative Works thereof, You may choose to offer, | |||
and charge a fee for, acceptance of support, warranty, indemnity, | |||
or other liability obligations and/or rights consistent with this | |||
License. However, in accepting such obligations, You may act only | |||
on Your own behalf and on Your sole responsibility, not on behalf | |||
of any other Contributor, and only if You agree to indemnify, | |||
defend, and hold each Contributor harmless for any liability | |||
incurred by, or claims asserted against, such Contributor by reason | |||
of your accepting any such warranty or additional liability. | |||
END OF TERMS AND CONDITIONS | |||
APPENDIX: How to apply the Apache License to your work. | |||
To apply the Apache License to your work, attach the following | |||
boilerplate notice, with the fields enclosed by brackets "[]" | |||
replaced with your own identifying information. (Don't include | |||
the brackets!) The text should be enclosed in the appropriate | |||
comment syntax for the file format. We also recommend that a | |||
file or class name and description of purpose be included on the | |||
same "printed page" as the copyright notice for easier | |||
identification within third-party archives. | |||
Copyright [yyyy] [name of copyright owner] | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. |
@@ -0,0 +1,221 @@ | |||
# AlphaFold2-Chinese | |||
中文版AlphaFold2开源模型复现-基于DeepMind&ColabFold&Mindspore: | |||
* https://github.com/sokrypton/ColabFold | |||
* https://github.com/deepmind/alphafold | |||
* https://gitee.com/mindspore/mindscience/tree/master/MindSPONGE/mindsponge/fold | |||
<div align=center> | |||
<img src="https://github.com/Tssck/AlphaFold2-Chinese/blob/main/docs/seq_64.gif" alt="T1079" width="400"/> | |||
</div> | |||
# 目录 | |||
<!-- TOC --> | |||
- [目录](#目录) | |||
- [模型描述](#模型描述) | |||
- [环境要求](#环境要求) | |||
- [硬件环境与框架](#硬件环境与框架) | |||
- [MMseqs2安装](#mmseqs2安装) | |||
- [MindSpore Serving安装](#mindspore_serving安装) | |||
- [数据准备](#数据准备) | |||
- [MSA所需数据库](#msa所需数据库) | |||
- [Template所需工具和数据](#template所需工具和数据) | |||
- [数据](#数据) | |||
- [工具](#工具) | |||
- [脚本说明](#脚本说明) | |||
- [脚本及样例代码](#脚本及样例代码) | |||
- [推理示例](#推理示例) | |||
- [推理过程](#推理过程) | |||
- [推理结果](#推理结果) | |||
- [推理性能](#推理性能) | |||
- [TMscore对比图](#tmscore对比图) | |||
- [预测结果对比图](#预测结果对比图) | |||
- [引用](#引用) | |||
<!-- /TOC --> | |||
## 模型描述 | |||
蛋白质结构预测工具是利用计算机高效计算获取蛋白质空间结构的软件。该计算方法一直存在精度不足的缺陷,直至2020年谷歌DeepMind团队的[AlphaFold2](https://www.nature.com/articles/s41586-021-03819-2)【1】【2】取得CASP14比赛中蛋白质3D结构预测的榜首,才让这一缺陷得以弥补。本次开源的蛋白质结构预测推理工具模型部分与其相同,在多序列比对阶段,采用了[MMseqs2](https://www.biorxiv.org/content/10.1101/2021.08.15.456425v1.full.pdf)【3】进行序列检索,相比于原版算法端到端运算速度有2-3倍提升。 | |||
## 环境要求 | |||
### 硬件环境与框架 | |||
本代码运行基于Ascend处理器硬件环境与[MindSpore](https://www.mindspore.cn/) AI框架,当前版本需基于最新库上master代码(2021-11-08之后的代码)[编译](https://www.mindspore.cn/install/detail?path=install/r1.5/mindspore_ascend_install_source.md&highlight=%E6%BA%90%E7%A0%81%E7%BC%96%E8%AF%91), | |||
MindSpore环境参见[MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html),环境安装后需要运行以下命令配置环境变量: | |||
``` shell | |||
export MS_DEV_ENABLE_FALLBACK=0 | |||
``` | |||
其余python依赖请参见[requirements.txt](https://gitee.com/mindspore/mindscience/tree/master/MindSPONGE/mindsponge/fold/requirements.txt)。 | |||
### MMseqs2安装 | |||
MMseqs2用于生成多序列比对(multiple sequence alignments,MSA),MMseqs2安装和使用可以参考[MMseqs2 User Guide](https://mmseqs.com/latest/userguide.pdf),安装完成后需要运行以下命令配置环境变量: | |||
``` shell | |||
export PATH=$(pwd)/mmseqs/bin/:$PATH | |||
``` | |||
### MindSpore Serving安装 | |||
我们提供以服务模式运行推理,该模式使用MindSpore Serving提供高效推理服务,多条序列推理时避免重复编译,大幅提高推理效率,MindSpore Serving安装和配置可以参考[MindSpore Serving安装页面](https://www.mindspore.cn/serving/docs/zh-CN/r1.5/serving_install.html)。 | |||
## 数据准备 | |||
### MSA所需数据库 | |||
- [uniref30_2103](http://wwwuser.gwdg.de/~compbiol/colabfold/uniref30_2103.tar.gz):375G(下载68G) | |||
- [colabfold_envdb_202108](http://wwwuser.gwdg.de/~compbiol/colabfold/colabfold_envdb_202108.tar.gz):949G(下载110G) | |||
数据处理参考[colabfold](http://colabfold.mmseqs.com)。 | |||
### Template所需工具和数据 | |||
#### 数据 | |||
- [pdb70](http://wwwuser.gwdg.de/~compbiol/data/hhsuite/databases/hhsuite_dbs/old-releases/pdb70_from_mmcif_200401.tar.gz):56G(下载19G) | |||
- [mmcif database](https://ftp.rcsb.org/pub/pdb/data/structures/divided/mmCIF/): 206G(下载48G) | |||
- [obsolete_pdbs](http://ftp.wwpdb.org/pub/pdb/data/status/obsolete.dat):140K | |||
#### 工具 | |||
- [HHsearch](https://github.com/soedinglab/hh-suite) | |||
- [kalign](https://msa.sbc.su.se/downloads/kalign/current.tar.gz) | |||
## 脚本说明 | |||
### 脚本及样例代码 | |||
```bash | |||
├── mindscience | |||
├── MindSPONGE | |||
├── mindsponge | |||
├── fold | |||
├── README_CN.md // fold 相关中文说明 | |||
├── run.py // 推理脚本 | |||
├── model.py // 主模型 | |||
├── requirements.txt // 依赖包 | |||
├── serving_server.py // 服务模式服务端脚本 | |||
├── serving_cline.py // 服务模式客户端脚本 | |||
├── fold_service | |||
├── servable_config.py // 服务模式配置脚本 | |||
├── module | |||
├── basic_module.py // 基础模块 | |||
├── evoformer_module.py // evoformer模块 | |||
├── structure_module.py // 结构模块 | |||
├── data | |||
├── feature | |||
├── data_transforms.py //msa和template数据处理 | |||
├── feature_extraction.py //msa和template特征提取 | |||
├── tools | |||
├── data_process.py // 搜索msa和template | |||
├── data_tools.py // 数据处理脚本 | |||
├── mmcif_parsing.py // mmcif解析脚本 | |||
├── msa_search.sh // mmseqs2搜索msa的shell脚本 | |||
├── parsers.py // 解析文件脚本 | |||
├── templates.py // 模板搜索脚本 | |||
├── config | |||
├── config.py //参数配置脚本 | |||
├── global_config.py //全局参数配置脚本 | |||
├── common | |||
├── generate_pdb.py // 生成pdb | |||
├── r3.py // 3D坐标转换 | |||
├── residue_constants.py // 氨基酸残基常量 | |||
├── utils.py // 功能函数 | |||
``` | |||
### 推理示例 | |||
```bash | |||
用法:run.py [--seq_length PADDING_SEQENCE_LENGTH] | |||
[--input_fasta_path INPUT_PATH][--msa_result_path MSA_RESULT_PATH] | |||
[--database_dir DATABASE_PATH][--database_envdb_dir DATABASE_ENVDB_PATH] | |||
[--hhsearch_binary_path HHSEARCH_PATH][--pdb70_database_path PDB70_PATH] | |||
[--template_mmcif_dir TEMPLATE_PATH][--max_template_date TRMPLATE_DATE] | |||
[--kalign_binary_path KALIGN_PATH][--obsolete_pdbs_path OBSOLETE_PATH] | |||
选项: | |||
--seq_length 补零后序列长度,目前支持256/512/1024/2048 | |||
--input_fasta_path FASTA文件,用于预测蛋白质结构的蛋白质序列 | |||
--msa_result_path 保存mmseqs2检索得到的msa结果路径 | |||
--database_dir 搜索msa时的数据库 | |||
--database_envdb_dir 搜索msa时的扩展数据库 | |||
--hhsearch_binary_path hhsearch可执行文件路径 | |||
--pdb70_database_path 供hhsearch使用的pdb70数据库路径 | |||
--template_mmcif_dir 具有mmcif结构模板的路径 | |||
--max_template_date 模板最新发布的时间 | |||
--kalign_binary_path kalign可执行文件路径 | |||
--obsolete_pdbs_path PDB IDs的映射文件路径 | |||
``` | |||
### 推理过程 | |||
加载alphafold checkpoint,下载地址[点击这里](https://download.mindspore.cn/model_zoo/research/hpc/molecular_dynamics/protein_fold_1.ckpt),根据自身需求选择合适蛋白质序列配置,当前提供256/512/1024/2048四个标准配置,推理过程如下: | |||
1. 输入参数需要通过`fold_service/config.py`配置,参数含义参见[推理示例](#推理示例) | |||
2. 参数配置好后,先使用`serving_server.py`启动服务端进程,进程成功启动时log显示如下: | |||
``` log | |||
Serving: Serving gRPC server start success, listening on 127.0.0.1:5500 | |||
Serving: Serving RESTful server start success, listening on 127.0.0.1:1500 | |||
``` | |||
3. 服务端进程成功启动后运行`serving_client.py`即可进行推理,第一次推理需要编译 | |||
#### 推理结果 | |||
推理结果保存在 `./result` 中,共有两个文件,其中的pdb文件即为蛋白质结构预测结果,timings文件保存了运行过程中的时间信息和confidence信息。 | |||
```bash | |||
{"pre_process_time": 418.57, "model_time": 122.86, "pos_process_time": 0.14, "all_time ": 541.56, "confidence ": 94.61789646019058} | |||
``` | |||
## 推理性能 | |||
| 参数 | Fold(Ascend) | | |||
| ------------------- | --------------------------- | | |||
| 模型版本 | AlphaFold | | |||
| 资源 | Ascend 910 | | |||
| 上传日期 | 2021-11-05 | | |||
| MindSpore版本 | master | | |||
| 数据集 | CASP14 T1079 | | |||
| seq_length | 505 | | |||
| confidence | 94.62 | | |||
| TM-score | 98.01% | | |||
|运行时间|541.56s| | |||
### TMscore对比图 | |||
- 34条CASP14结果与alphafold2对比: | |||
<div align=center> | |||
<img src="https://github.com/Tssck/AlphaFold2-Chinese/blob/main/docs/all_experiment_data.jpg" alt="all_data" width="600"/> | |||
</div> | |||
### 预测结果对比图 | |||
- T1079(长度505): | |||
<div align=center> | |||
<img src="https://github.com/Tssck/AlphaFold2-Chinese/blob/main/docs/seq_64.gif" alt="T1079" width="400"/> | |||
</div> | |||
- T1044(长度2180): | |||
<div align=center> | |||
<img src="https://github.com/Tssck/AlphaFold2-Chinese/blob/main/docs/seq_21.jpg" alt="T1044" width="400"/> | |||
</div> | |||
## 引用 | |||
[1] Jumper J, Evans R, Pritzel A, et al. Applying and improving AlphaFold at CASP14[J]. Proteins: Structure, Function, and Bioinformatics, 2021. | |||
[2] Jumper J, Evans R, Pritzel A, et al. Highly accurate protein structure prediction with AlphaFold[J]. Nature, 2021, 596(7873): 583-589. | |||
[3] Mirdita M, Ovchinnikov S, Steinegger M. ColabFold-Making protein folding accessible to all[J]. BioRxiv, 2021. |
@@ -0,0 +1,118 @@ | |||
"""generate pdb file""" | |||
import dataclasses | |||
import numpy as np | |||
from commons import residue_constants | |||
@dataclasses.dataclass(frozen=True) | |||
class Protein: | |||
"""Protein structure representation.""" | |||
# Cartesian coordinates of atoms in angstroms. The atom types correspond to | |||
# residue_constants.atom_types, i.e. the first three are N, CA, CB. | |||
atom_positions: np.ndarray # [num_res, num_atom_type, 3] | |||
# Amino-acid type for each residue represented as an integer between 0 and | |||
# 20, where 20 is 'X'. | |||
aatype: np.ndarray # [num_res] | |||
# Binary float mask to indicate presence of a particular atom. 1.0 if an atom | |||
# is present and 0.0 if not. This should be used for loss masking. | |||
atom_mask: np.ndarray # [num_res, num_atom_type] | |||
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed. | |||
residue_index: np.ndarray # [num_res] | |||
# B-factors, or temperature factors, of each residue (in sq. angstroms units), | |||
# representing the displacement of the residue from its ground truth mean | |||
# value. | |||
b_factors: np.ndarray # [num_res, num_atom_type] | |||
def from_prediction(final_atom_mask, aatype, final_atom_positions, residue_index): | |||
"""Assembles a protein from a prediction. | |||
Args: | |||
final_atom_mask: atom mask info from structure module. | |||
aatype: amino acid type info. | |||
final_atom_positions: final atom positions from structure module | |||
residue_index: from processed_features | |||
Returns: | |||
A protein instance. | |||
""" | |||
dist_per_residue = np.zeros_like(final_atom_mask) | |||
return Protein( | |||
aatype=aatype, | |||
atom_positions=final_atom_positions, | |||
atom_mask=final_atom_mask, | |||
residue_index=residue_index + 1, | |||
b_factors=dist_per_residue) | |||
def to_pdb(prot: Protein): | |||
"""Converts a `Protein` instance to a PDB string. | |||
Args: | |||
prot: The protein to convert to PDB. | |||
Returns: | |||
PDB string. | |||
""" | |||
restypes = residue_constants.restypes + ['X'] | |||
res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK') | |||
atom_types = residue_constants.atom_types | |||
pdb_lines = [] | |||
atom_mask = prot.atom_mask | |||
aatype = prot.aatype | |||
atom_positions = prot.atom_positions | |||
residue_index = prot.residue_index.astype(np.int32) | |||
b_factors = prot.b_factors | |||
if (aatype > residue_constants.restype_num).any(): | |||
raise ValueError('Invalid aatypes.') | |||
pdb_lines.append('MODEL 1') | |||
atom_index = 1 | |||
chain_id = 'A' | |||
# Add all atom sites. | |||
for i in range(aatype.shape[0]): | |||
res_name_3 = res_1to3(aatype[i]) | |||
for atom_name, pos, mask, b_factor in zip( | |||
atom_types, atom_positions[i], atom_mask[i], b_factors[i]): | |||
if mask < 0.5: | |||
continue | |||
record_type = 'ATOM' | |||
name = atom_name if len(atom_name) == 4 else f' {atom_name}' | |||
alt_loc = '' | |||
insertion_code = '' | |||
occupancy = 1.00 | |||
element = atom_name[0] # Protein supports only C, N, O, S, this works. | |||
charge = '' | |||
# PDB is a columnar format, every space matters here! | |||
atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}' | |||
f'{res_name_3:>3} {chain_id:>1}' | |||
f'{residue_index[i]:>4}{insertion_code:>1} ' | |||
f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}' | |||
f'{occupancy:>6.2f}{b_factor:>6.2f} ' | |||
f'{element:>2}{charge:>2}') | |||
pdb_lines.append(atom_line) | |||
atom_index += 1 | |||
# Close the chain. | |||
chain_end = 'TER' | |||
chain_termination_line = ( | |||
f'{chain_end:<6}{atom_index:>5} {res_1to3(aatype[-1]):>3} ' | |||
f'{chain_id:>1}{residue_index[-1]:>4}') | |||
pdb_lines.append(chain_termination_line) | |||
pdb_lines.append('ENDMDL') | |||
pdb_lines.append('END') | |||
pdb_lines.append('') | |||
return '\n'.join(pdb_lines) |
@@ -0,0 +1,104 @@ | |||
"""Transformations for 3D coordinates.""" | |||
import mindspore.numpy as mnp | |||
def vecs_sub(v1, v2): | |||
"""Computes v1 - v2.""" | |||
return v1 - v2 | |||
def vecs_robust_norm(v, epsilon=1e-8): | |||
"""Computes norm of vectors 'v'.""" | |||
return mnp.sqrt(mnp.sum(mnp.square(v), axis=-1) + epsilon) | |||
def vecs_robust_normalize(v, epsilon=1e-8): | |||
"""Normalizes vectors 'v'.""" | |||
norms = vecs_robust_norm(v, epsilon) | |||
return v / norms[..., None] | |||
def vecs_dot_vecs(v1, v2): | |||
"""Dot product of vectors 'v1' and 'v2'.""" | |||
return mnp.sum(v1 * v2, axis=-1) | |||
def vecs_cross_vecs(v1, v2): | |||
"""Cross product of vectors 'v1' and 'v2'.""" | |||
return mnp.concatenate(((v1[..., 1] * v2[..., 2] - v1[..., 2] * v2[..., 1])[..., None], | |||
(v1[..., 2] * v2[..., 0] - v1[..., 0] * v2[..., 2])[..., None], | |||
(v1[..., 0] * v2[..., 1] - v1[..., 1] * v2[..., 0])[..., None]), axis=-1) | |||
def rots_from_two_vecs(e0_unnormalized, e1_unnormalized): | |||
"""Create rotation matrices from unnormalized vectors for the x and y-axes.""" | |||
# Normalize the unit vector for the x-axis, e0. | |||
e0 = vecs_robust_normalize(e0_unnormalized) | |||
# make e1 perpendicular to e0. | |||
c = vecs_dot_vecs(e1_unnormalized, e0) | |||
e1 = e1_unnormalized - c[..., None] * e0 | |||
e1 = vecs_robust_normalize(e1) | |||
# Compute e2 as cross product of e0 and e1. | |||
e2 = vecs_cross_vecs(e0, e1) | |||
rots = mnp.concatenate( | |||
(mnp.concatenate([e0[..., 0][None, ...], e1[..., 0][None, ...], e2[..., 0][None, ...]], axis=0)[None, ...], | |||
mnp.concatenate([e0[..., 1][None, ...], e1[..., 1][None, ...], e2[..., 1][None, ...]], axis=0)[None, ...], | |||
mnp.concatenate([e0[..., 2][None, ...], e1[..., 2][None, ...], e2[..., 2][None, ...]], axis=0)[None, ...]), | |||
axis=0) | |||
return rots | |||
def rigids_from_3_points( | |||
point_on_neg_x_axis, # shape (...) | |||
origin, # shape (...) | |||
point_on_xy_plane, # shape (...) | |||
): # shape (...) | |||
"""Create Rigids from 3 points. """ | |||
m = rots_from_two_vecs( | |||
e0_unnormalized=vecs_sub(origin, point_on_neg_x_axis), | |||
e1_unnormalized=vecs_sub(point_on_xy_plane, origin)) | |||
return m, origin | |||
def invert_rots(m): | |||
"""Computes inverse of rotations 'm'.""" | |||
return mnp.transpose(m, (1, 0, 2, 3, 4)) | |||
def rots_mul_vecs(m, v): | |||
"""Apply rotations 'm' to vectors 'v'.""" | |||
return mnp.concatenate(((m[0][0] * v[..., 0] + m[0][1] * v[..., 1] + m[0][2] * v[..., 2])[..., None], | |||
(m[1][0] * v[..., 0] + m[1][1] * v[..., 1] + m[1][2] * v[..., 2])[..., None], | |||
(m[2][0] * v[..., 0] + m[2][1] * v[..., 1] + m[2][2] * v[..., 2])[..., None]), axis=-1) | |||
def invert_rigids(rot, trans): | |||
"""Computes group inverse of rigid transformations 'r'.""" | |||
inv_rots = invert_rots(rot) | |||
t = rots_mul_vecs(inv_rots, trans) | |||
inv_trans = -t | |||
return inv_rots, inv_trans | |||
def vecs_add(v1, v2): | |||
"""Add two vectors 'v1' and 'v2'.""" | |||
return v1 + v2 | |||
def rigids_mul_vecs(rot, trans, v): | |||
"""Apply rigid transforms 'r' to points 'v'.""" | |||
return vecs_add(rots_mul_vecs(rot, v), trans) |
@@ -0,0 +1,842 @@ | |||
"""residue_constants.""" | |||
import collections | |||
import functools | |||
from typing import Mapping, List, Tuple | |||
import numpy as np | |||
from mindspore.common.tensor import Tensor | |||
QUAT_MULTIPLY = np.zeros((4, 4, 4), dtype=np.float32) | |||
QUAT_MULTIPLY[:, :, 0] = [[1, 0, 0, 0], | |||
[0, -1, 0, 0], | |||
[0, 0, -1, 0], | |||
[0, 0, 0, -1]] | |||
QUAT_MULTIPLY[:, :, 1] = [[0, 1, 0, 0], | |||
[1, 0, 0, 0], | |||
[0, 0, 0, 1], | |||
[0, 0, -1, 0]] | |||
QUAT_MULTIPLY[:, :, 2] = [[0, 0, 1, 0], | |||
[0, 0, 0, -1], | |||
[1, 0, 0, 0], | |||
[0, 1, 0, 0]] | |||
QUAT_MULTIPLY[:, :, 3] = [[0, 0, 0, 1], | |||
[0, 0, 1, 0], | |||
[0, -1, 0, 0], | |||
[1, 0, 0, 0]] | |||
QUAT_MULTIPLY_BY_VEC = Tensor(QUAT_MULTIPLY[:, 1:, :]) | |||
# Distance from one CA to next CA [trans configuration: omega = 180]. | |||
ca_ca = 3.80209737096 | |||
# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in | |||
# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have | |||
# chi angles so their chi angle lists are empty. | |||
chi_angles_atoms = { | |||
'ALA': [], | |||
# Chi5 in arginine is always 0 +- 5 degrees, so ignore it. | |||
'ARG': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], | |||
['CB', 'CG', 'CD', 'NE'], ['CG', 'CD', 'NE', 'CZ']], | |||
'ASN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']], | |||
'ASP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']], | |||
'CYS': [['N', 'CA', 'CB', 'SG']], | |||
'GLN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], | |||
['CB', 'CG', 'CD', 'OE1']], | |||
'GLU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], | |||
['CB', 'CG', 'CD', 'OE1']], | |||
'GLY': [], | |||
'HIS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'ND1']], | |||
'ILE': [['N', 'CA', 'CB', 'CG1'], ['CA', 'CB', 'CG1', 'CD1']], | |||
'LEU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], | |||
'LYS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], | |||
['CB', 'CG', 'CD', 'CE'], ['CG', 'CD', 'CE', 'NZ']], | |||
'MET': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'SD'], | |||
['CB', 'CG', 'SD', 'CE']], | |||
'PHE': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], | |||
'PRO': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD']], | |||
'SER': [['N', 'CA', 'CB', 'OG']], | |||
'THR': [['N', 'CA', 'CB', 'OG1']], | |||
'TRP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], | |||
'TYR': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], | |||
'VAL': [['N', 'CA', 'CB', 'CG1']], | |||
} | |||
# If chi angles given in fixed-length array, this matrix determines how to mask | |||
# them for each AA type. The order is as per restype_order (see below). | |||
chi_angles_mask = [ | |||
[0.0, 0.0, 0.0, 0.0], # ALA | |||
[1.0, 1.0, 1.0, 1.0], # ARG | |||
[1.0, 1.0, 0.0, 0.0], # ASN | |||
[1.0, 1.0, 0.0, 0.0], # ASP | |||
[1.0, 0.0, 0.0, 0.0], # CYS | |||
[1.0, 1.0, 1.0, 0.0], # GLN | |||
[1.0, 1.0, 1.0, 0.0], # GLU | |||
[0.0, 0.0, 0.0, 0.0], # GLY | |||
[1.0, 1.0, 0.0, 0.0], # HIS | |||
[1.0, 1.0, 0.0, 0.0], # ILE | |||
[1.0, 1.0, 0.0, 0.0], # LEU | |||
[1.0, 1.0, 1.0, 1.0], # LYS | |||
[1.0, 1.0, 1.0, 0.0], # MET | |||
[1.0, 1.0, 0.0, 0.0], # PHE | |||
[1.0, 1.0, 0.0, 0.0], # PRO | |||
[1.0, 0.0, 0.0, 0.0], # SER | |||
[1.0, 0.0, 0.0, 0.0], # THR | |||
[1.0, 1.0, 0.0, 0.0], # TRP | |||
[1.0, 1.0, 0.0, 0.0], # TYR | |||
[1.0, 0.0, 0.0, 0.0], # VAL | |||
] | |||
# The following chi angles are pi periodic: they can be rotated by a multiple | |||
# of pi without affecting the structure. | |||
chi_pi_periodic = [ | |||
[0.0, 0.0, 0.0, 0.0], # ALA | |||
[0.0, 0.0, 0.0, 0.0], # ARG | |||
[0.0, 0.0, 0.0, 0.0], # ASN | |||
[0.0, 1.0, 0.0, 0.0], # ASP | |||
[0.0, 0.0, 0.0, 0.0], # CYS | |||
[0.0, 0.0, 0.0, 0.0], # GLN | |||
[0.0, 0.0, 1.0, 0.0], # GLU | |||
[0.0, 0.0, 0.0, 0.0], # GLY | |||
[0.0, 0.0, 0.0, 0.0], # HIS | |||
[0.0, 0.0, 0.0, 0.0], # ILE | |||
[0.0, 0.0, 0.0, 0.0], # LEU | |||
[0.0, 0.0, 0.0, 0.0], # LYS | |||
[0.0, 0.0, 0.0, 0.0], # MET | |||
[0.0, 1.0, 0.0, 0.0], # PHE | |||
[0.0, 0.0, 0.0, 0.0], # PRO | |||
[0.0, 0.0, 0.0, 0.0], # SER | |||
[0.0, 0.0, 0.0, 0.0], # THR | |||
[0.0, 0.0, 0.0, 0.0], # TRP | |||
[0.0, 1.0, 0.0, 0.0], # TYR | |||
[0.0, 0.0, 0.0, 0.0], # VAL | |||
[0.0, 0.0, 0.0, 0.0], # UNK | |||
] | |||
# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi, | |||
# psi and chi angles: | |||
# 0: 'backbone group', | |||
# 1: 'pre-omega-group', (empty) | |||
# 2: 'phi-group', (currently empty, because it defines only hydrogens) | |||
# 3: 'psi-group', | |||
# 4,5,6,7: 'chi1,2,3,4-group' | |||
# The atom positions are relative to the axis-end-atom of the corresponding | |||
# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis | |||
# is defined such that the dihedral-angle-definiting atom (the last entry in | |||
# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate). | |||
# format: [atomname, group_idx, rel_position] | |||
rigid_group_atom_positions = { | |||
'ALA': [ | |||
['N', 0, (-0.525, 1.363, 0.000)], | |||
['CA', 0, (0.000, 0.000, 0.000)], | |||
['C', 0, (1.526, -0.000, -0.000)], | |||
['CB', 0, (-0.529, -0.774, -1.205)], | |||
['O', 3, (0.627, 1.062, 0.000)], | |||
], | |||
'ARG': [ | |||
['N', 0, (-0.524, 1.362, -0.000)], | |||
['CA', 0, (0.000, 0.000, 0.000)], | |||
['C', 0, (1.525, -0.000, -0.000)], | |||
['CB', 0, (-0.524, -0.778, -1.209)], | |||
['O', 3, (0.626, 1.062, 0.000)], | |||
['CG', 4, (0.616, 1.390, -0.000)], | |||
['CD', 5, (0.564, 1.414, 0.000)], | |||
['NE', 6, (0.539, 1.357, -0.000)], | |||
['NH1', 7, (0.206, 2.301, 0.000)], | |||
['NH2', 7, (2.078, 0.978, -0.000)], | |||
['CZ', 7, (0.758, 1.093, -0.000)], | |||
], | |||
'ASN': [ | |||
['N', 0, (-0.536, 1.357, 0.000)], | |||
['CA', 0, (0.000, 0.000, 0.000)], | |||
['C', 0, (1.526, -0.000, -0.000)], | |||
['CB', 0, (-0.531, -0.787, -1.200)], | |||
['O', 3, (0.625, 1.062, 0.000)], | |||
['CG', 4, (0.584, 1.399, 0.000)], | |||
['ND2', 5, (0.593, -1.188, 0.001)], | |||
['OD1', 5, (0.633, 1.059, 0.000)], | |||
], | |||
'ASP': [ | |||
['N', 0, (-0.525, 1.362, -0.000)], | |||
['CA', 0, (0.000, 0.000, 0.000)], | |||
['C', 0, (1.527, 0.000, -0.000)], | |||
['CB', 0, (-0.526, -0.778, -1.208)], | |||
['O', 3, (0.626, 1.062, -0.000)], | |||
['CG', 4, (0.593, 1.398, -0.000)], | |||
['OD1', 5, (0.610, 1.091, 0.000)], | |||
['OD2', 5, (0.592, -1.101, -0.003)], | |||
], | |||
'CYS': [ | |||
['N', 0, (-0.522, 1.362, -0.000)], | |||
['CA', 0, (0.000, 0.000, 0.000)], | |||
['C', 0, (1.524, 0.000, 0.000)], | |||
['CB', 0, (-0.519, -0.773, -1.212)], | |||
['O', 3, (0.625, 1.062, -0.000)], | |||
['SG', 4, (0.728, 1.653, 0.000)], | |||
], | |||
'GLN': [ | |||
['N', 0, (-0.526, 1.361, -0.000)], | |||
['CA', 0, (0.000, 0.000, 0.000)], | |||
['C', 0, (1.526, 0.000, 0.000)], | |||
['CB', 0, (-0.525, -0.779, -1.207)], | |||
['O', 3, (0.626, 1.062, -0.000)], | |||
['CG', 4, (0.615, 1.393, 0.000)], | |||
['CD', 5, (0.587, 1.399, -0.000)], | |||
['NE2', 6, (0.593, -1.189, -0.001)], | |||
['OE1', 6, (0.634, 1.060, 0.000)], | |||
], | |||
'GLU': [ | |||
['N', 0, (-0.528, 1.361, 0.000)], | |||
['CA', 0, (0.000, 0.000, 0.000)], | |||
['C', 0, (1.526, -0.000, -0.000)], | |||
['CB', 0, (-0.526, -0.781, -1.207)], | |||
['O', 3, (0.626, 1.062, 0.000)], | |||
['CG', 4, (0.615, 1.392, 0.000)], | |||
['CD', 5, (0.600, 1.397, 0.000)], | |||
['OE1', 6, (0.607, 1.095, -0.000)], | |||
['OE2', 6, (0.589, -1.104, -0.001)], | |||
], | |||
'GLY': [ | |||
['N', 0, (-0.572, 1.337, 0.000)], | |||
['CA', 0, (0.000, 0.000, 0.000)], | |||
['C', 0, (1.517, -0.000, -0.000)], | |||
['O', 3, (0.626, 1.062, -0.000)], | |||
], | |||
'HIS': [ | |||
['N', 0, (-0.527, 1.360, 0.000)], | |||
['CA', 0, (0.000, 0.000, 0.000)], | |||
['C', 0, (1.525, 0.000, 0.000)], | |||
['CB', 0, (-0.525, -0.778, -1.208)], | |||
['O', 3, (0.625, 1.063, 0.000)], | |||
['CG', 4, (0.600, 1.370, -0.000)], | |||
['CD2', 5, (0.889, -1.021, 0.003)], | |||
['ND1', 5, (0.744, 1.160, -0.000)], | |||
['CE1', 5, (2.030, 0.851, 0.002)], | |||
['NE2', 5, (2.145, -0.466, 0.004)], | |||
], | |||
'ILE': [ | |||
['N', 0, (-0.493, 1.373, -0.000)], | |||
['CA', 0, (0.000, 0.000, 0.000)], | |||
['C', 0, (1.527, -0.000, -0.000)], | |||
['CB', 0, (-0.536, -0.793, -1.213)], | |||
['O', 3, (0.627, 1.062, -0.000)], | |||
['CG1', 4, (0.534, 1.437, -0.000)], | |||
['CG2', 4, (0.540, -0.785, -1.199)], | |||
['CD1', 5, (0.619, 1.391, 0.000)], | |||
], | |||
'LEU': [ | |||
['N', 0, (-0.520, 1.363, 0.000)], | |||
['CA', 0, (0.000, 0.000, 0.000)], | |||
['C', 0, (1.525, -0.000, -0.000)], | |||
['CB', 0, (-0.522, -0.773, -1.214)], | |||
['O', 3, (0.625, 1.063, -0.000)], | |||
['CG', 4, (0.678, 1.371, 0.000)], | |||
['CD1', 5, (0.530, 1.430, -0.000)], | |||
['CD2', 5, (0.535, -0.774, 1.200)], | |||
], | |||
'LYS': [ | |||
['N', 0, (-0.526, 1.362, -0.000)], | |||
['CA', 0, (0.000, 0.000, 0.000)], | |||
['C', 0, (1.526, 0.000, 0.000)], | |||
['CB', 0, (-0.524, -0.778, -1.208)], | |||
['O', 3, (0.626, 1.062, -0.000)], | |||
['CG', 4, (0.619, 1.390, 0.000)], | |||
['CD', 5, (0.559, 1.417, 0.000)], | |||
['CE', 6, (0.560, 1.416, 0.000)], | |||
['NZ', 7, (0.554, 1.387, 0.000)], | |||
], | |||
'MET': [ | |||
['N', 0, (-0.521, 1.364, -0.000)], | |||
['CA', 0, (0.000, 0.000, 0.000)], | |||
['C', 0, (1.525, 0.000, 0.000)], | |||
['CB', 0, (-0.523, -0.776, -1.210)], | |||
['O', 3, (0.625, 1.062, -0.000)], | |||
['CG', 4, (0.613, 1.391, -0.000)], | |||
['SD', 5, (0.703, 1.695, 0.000)], | |||
['CE', 6, (0.320, 1.786, -0.000)], | |||
], | |||
'PHE': [ | |||
['N', 0, (-0.518, 1.363, 0.000)], | |||
['CA', 0, (0.000, 0.000, 0.000)], | |||
['C', 0, (1.524, 0.000, -0.000)], | |||
['CB', 0, (-0.525, -0.776, -1.212)], | |||
['O', 3, (0.626, 1.062, -0.000)], | |||
['CG', 4, (0.607, 1.377, 0.000)], | |||
['CD1', 5, (0.709, 1.195, -0.000)], | |||
['CD2', 5, (0.706, -1.196, 0.000)], | |||
['CE1', 5, (2.102, 1.198, -0.000)], | |||
['CE2', 5, (2.098, -1.201, -0.000)], | |||
['CZ', 5, (2.794, -0.003, -0.001)], | |||
], | |||
'PRO': [ | |||
['N', 0, (-0.566, 1.351, -0.000)], | |||
['CA', 0, (0.000, 0.000, 0.000)], | |||
['C', 0, (1.527, -0.000, 0.000)], | |||
['CB', 0, (-0.546, -0.611, -1.293)], | |||
['O', 3, (0.621, 1.066, 0.000)], | |||
['CG', 4, (0.382, 1.445, 0.0)], | |||
# ['CD', 5, (0.427, 1.440, 0.0)], | |||
['CD', 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger | |||
], | |||
'SER': [ | |||
['N', 0, (-0.529, 1.360, -0.000)], | |||
['CA', 0, (0.000, 0.000, 0.000)], | |||
['C', 0, (1.525, -0.000, -0.000)], | |||
['CB', 0, (-0.518, -0.777, -1.211)], | |||
['O', 3, (0.626, 1.062, -0.000)], | |||
['OG', 4, (0.503, 1.325, 0.000)], | |||
], | |||
'THR': [ | |||
['N', 0, (-0.517, 1.364, 0.000)], | |||
['CA', 0, (0.000, 0.000, 0.000)], | |||
['C', 0, (1.526, 0.000, -0.000)], | |||
['CB', 0, (-0.516, -0.793, -1.215)], | |||
['O', 3, (0.626, 1.062, 0.000)], | |||
['CG2', 4, (0.550, -0.718, -1.228)], | |||
['OG1', 4, (0.472, 1.353, 0.000)], | |||
], | |||
'TRP': [ | |||
['N', 0, (-0.521, 1.363, 0.000)], | |||
['CA', 0, (0.000, 0.000, 0.000)], | |||
['C', 0, (1.525, -0.000, 0.000)], | |||
['CB', 0, (-0.523, -0.776, -1.212)], | |||
['O', 3, (0.627, 1.062, 0.000)], | |||
['CG', 4, (0.609, 1.370, -0.000)], | |||
['CD1', 5, (0.824, 1.091, 0.000)], | |||
['CD2', 5, (0.854, -1.148, -0.005)], | |||
['CE2', 5, (2.186, -0.678, -0.007)], | |||
['CE3', 5, (0.622, -2.530, -0.007)], | |||
['NE1', 5, (2.140, 0.690, -0.004)], | |||
['CH2', 5, (3.028, -2.890, -0.013)], | |||
['CZ2', 5, (3.283, -1.543, -0.011)], | |||
['CZ3', 5, (1.715, -3.389, -0.011)], | |||
], | |||
'TYR': [ | |||
['N', 0, (-0.522, 1.362, 0.000)], | |||
['CA', 0, (0.000, 0.000, 0.000)], | |||
['C', 0, (1.524, -0.000, -0.000)], | |||
['CB', 0, (-0.522, -0.776, -1.213)], | |||
['O', 3, (0.627, 1.062, -0.000)], | |||
['CG', 4, (0.607, 1.382, -0.000)], | |||
['CD1', 5, (0.716, 1.195, -0.000)], | |||
['CD2', 5, (0.713, -1.194, -0.001)], | |||
['CE1', 5, (2.107, 1.200, -0.002)], | |||
['CE2', 5, (2.104, -1.201, -0.003)], | |||
['OH', 5, (4.168, -0.002, -0.005)], | |||
['CZ', 5, (2.791, -0.001, -0.003)], | |||
], | |||
'VAL': [ | |||
['N', 0, (-0.494, 1.373, -0.000)], | |||
['CA', 0, (0.000, 0.000, 0.000)], | |||
['C', 0, (1.527, -0.000, -0.000)], | |||
['CB', 0, (-0.533, -0.795, -1.213)], | |||
['O', 3, (0.627, 1.062, -0.000)], | |||
['CG1', 4, (0.540, 1.429, -0.000)], | |||
['CG2', 4, (0.533, -0.776, 1.203)], | |||
], | |||
} | |||
# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention. | |||
residue_atoms = { | |||
'ALA': ['C', 'CA', 'CB', 'N', 'O'], | |||
'ARG': ['C', 'CA', 'CB', 'CG', 'CD', 'CZ', 'N', 'NE', 'O', 'NH1', 'NH2'], | |||
'ASP': ['C', 'CA', 'CB', 'CG', 'N', 'O', 'OD1', 'OD2'], | |||
'ASN': ['C', 'CA', 'CB', 'CG', 'N', 'ND2', 'O', 'OD1'], | |||
'CYS': ['C', 'CA', 'CB', 'N', 'O', 'SG'], | |||
'GLU': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O', 'OE1', 'OE2'], | |||
'GLN': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'NE2', 'O', 'OE1'], | |||
'GLY': ['C', 'CA', 'N', 'O'], | |||
'HIS': ['C', 'CA', 'CB', 'CG', 'CD2', 'CE1', 'N', 'ND1', 'NE2', 'O'], | |||
'ILE': ['C', 'CA', 'CB', 'CG1', 'CG2', 'CD1', 'N', 'O'], | |||
'LEU': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'N', 'O'], | |||
'LYS': ['C', 'CA', 'CB', 'CG', 'CD', 'CE', 'N', 'NZ', 'O'], | |||
'MET': ['C', 'CA', 'CB', 'CG', 'CE', 'N', 'O', 'SD'], | |||
'PHE': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O'], | |||
'PRO': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O'], | |||
'SER': ['C', 'CA', 'CB', 'N', 'O', 'OG'], | |||
'THR': ['C', 'CA', 'CB', 'CG2', 'N', 'O', 'OG1'], | |||
'TRP': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE2', 'CE3', 'CZ2', 'CZ3', | |||
'CH2', 'N', 'NE1', 'O'], | |||
'TYR': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O', | |||
'OH'], | |||
'VAL': ['C', 'CA', 'CB', 'CG1', 'CG2', 'N', 'O'] | |||
} | |||
# Van der Waals radii [Angstroem] of the atoms (from Wikipedia) | |||
van_der_waals_radius = { | |||
'C': 1.7, | |||
'N': 1.55, | |||
'O': 1.52, | |||
'S': 1.8, | |||
} | |||
Bond = collections.namedtuple( | |||
'Bond', ['atom1_name', 'atom2_name', 'length', 'stddev']) | |||
BondAngle = collections.namedtuple( | |||
'BondAngle', | |||
['atom1_name', 'atom2_name', 'atom3name', 'angle_rad', 'stddev']) | |||
@functools.lru_cache(maxsize=None) | |||
def load_stereo_chemical_props() -> Tuple[Mapping[str, List[Bond]], | |||
Mapping[str, List[Bond]], | |||
Mapping[str, List[BondAngle]]]: | |||
"""Load stereo_chemical_props.txt into a nice structure. | |||
Load literature values for bond lengths and bond angles and translate | |||
bond angles into the length of the opposite edge of the triangle | |||
("residue_virtual_bonds"). | |||
Returns: | |||
residue_bonds: dict that maps resname --> list of Bond tuples | |||
residue_virtual_bonds: dict that maps resname --> list of Bond tuples | |||
residue_bond_angles: dict that maps resname --> list of BondAngle tuples | |||
""" | |||
stereo_chemical_props_path = ( | |||
'alphafold/common/stereo_chemical_props.txt') | |||
with open(stereo_chemical_props_path, 'rt') as f: | |||
stereo_chemical_props = f.read() | |||
lines_iter = iter(stereo_chemical_props.splitlines()) | |||
# Load bond lengths. | |||
residue_bonds = {} | |||
next(lines_iter) # Skip header line. | |||
for line in lines_iter: | |||
if line.strip() == '-': | |||
break | |||
bond, resname, length, stddev = line.split() | |||
atom1, atom2 = bond.split('-') | |||
if resname not in residue_bonds: | |||
residue_bonds[resname] = [] | |||
residue_bonds[resname].append( | |||
Bond(atom1, atom2, float(length), float(stddev))) | |||
residue_bonds['UNK'] = [] | |||
# Load bond angles. | |||
residue_bond_angles = {} | |||
next(lines_iter) # Skip empty line. | |||
next(lines_iter) # Skip header line. | |||
for line in lines_iter: | |||
if line.strip() == '-': | |||
break | |||
bond, resname, angle_degree, stddev_degree = line.split() | |||
atom1, atom2, atom3 = bond.split('-') | |||
if resname not in residue_bond_angles: | |||
residue_bond_angles[resname] = [] | |||
residue_bond_angles[resname].append( | |||
BondAngle(atom1, atom2, atom3, | |||
float(angle_degree) / 180. * np.pi, | |||
float(stddev_degree) / 180. * np.pi)) | |||
residue_bond_angles['UNK'] = [] | |||
def make_bond_key(atom1_name, atom2_name): | |||
"""Unique key to lookup bonds.""" | |||
return '-'.join(sorted([atom1_name, atom2_name])) | |||
# Translate bond angles into distances ("virtual bonds"). | |||
residue_virtual_bonds = {} | |||
for resname, bond_angles in residue_bond_angles.items(): | |||
# Create a fast lookup dict for bond lengths. | |||
bond_cache = {} | |||
for b in residue_bonds[resname]: | |||
bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b | |||
residue_virtual_bonds[resname] = [] | |||
for ba in bond_angles: | |||
bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)] | |||
bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)] | |||
# Compute distance between atom1 and atom3 using the law of cosines | |||
# c^2 = a^2 + b^2 - 2ab*cos(gamma). | |||
gamma = ba.angle_rad | |||
length = np.sqrt(bond1.length**2 + bond2.length**2 | |||
- 2 * bond1.length * bond2.length * np.cos(gamma)) | |||
# Propagation of uncertainty assuming uncorrelated errors. | |||
dl_outer = 0.5 / length | |||
dl_dgamma = (2 * bond1.length * bond2.length * | |||
np.sin(gamma)) * dl_outer | |||
dl_db1 = (2 * bond1.length - 2 * bond2.length * | |||
np.cos(gamma)) * dl_outer | |||
dl_db2 = (2 * bond2.length - 2 * bond1.length * | |||
np.cos(gamma)) * dl_outer | |||
stddev = np.sqrt((dl_dgamma * ba.stddev)**2 + | |||
(dl_db1 * bond1.stddev)**2 + | |||
(dl_db2 * bond2.stddev)**2) | |||
residue_virtual_bonds[resname].append( | |||
Bond(ba.atom1_name, ba.atom3name, length, stddev)) | |||
return (residue_bonds, | |||
residue_virtual_bonds, | |||
residue_bond_angles) | |||
# This mapping is used when we need to store atom data in a format that requires | |||
# fixed atom data size for every residue (e.g. a numpy array). | |||
atom_types = [ | |||
'N', 'CA', 'C', 'CB', 'O', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD', | |||
'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3', | |||
'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2', | |||
'CZ3', 'NZ', 'OXT' | |||
] | |||
atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)} | |||
atom_type_num = len(atom_types) # := 37. | |||
# A compact atom encoding with 14 columns | |||
# pylint: disable=line-too-long | |||
# pylint: disable=bad-whitespace | |||
restype_name_to_atom14_names = { | |||
'ALA': ['N', 'CA', 'C', 'O', 'CB', '', '', '', '', '', '', '', '', ''], | |||
'ARG': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2', '', '', ''], | |||
'ASN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2', '', '', '', '', '', ''], | |||
'ASP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2', '', '', '', '', '', ''], | |||
'CYS': ['N', 'CA', 'C', 'O', 'CB', 'SG', '', '', '', '', '', '', '', ''], | |||
'GLN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2', '', '', '', '', ''], | |||
'GLU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2', '', '', '', '', ''], | |||
'GLY': ['N', 'CA', 'C', 'O', '', '', '', '', '', '', '', '', '', ''], | |||
'HIS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2', '', '', '', ''], | |||
'ILE': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '', '', '', '', '', ''], | |||
'LEU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', '', '', '', '', '', ''], | |||
'LYS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ', '', '', '', '', ''], | |||
'MET': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE', '', '', '', '', '', ''], | |||
'PHE': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', '', '', ''], | |||
'PRO': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', '', '', '', '', '', '', ''], | |||
'SER': ['N', 'CA', 'C', 'O', 'CB', 'OG', '', '', '', '', '', '', '', ''], | |||
'THR': ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '', '', '', '', '', '', ''], | |||
'TRP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2'], | |||
'TYR': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH', '', ''], | |||
'VAL': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '', '', '', '', '', '', ''], | |||
'UNK': ['', '', '', '', '', '', '', '', '', '', '', '', '', ''], | |||
} | |||
# This is the standard residue order when coding AA type as a number. | |||
# Reproduce it by taking 3-letter AA codes and sorting them alphabetically. | |||
restypes = [ | |||
'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', | |||
'S', 'T', 'W', 'Y', 'V' | |||
] | |||
restype_order = {restype: i for i, restype in enumerate(restypes)} | |||
restype_num = len(restypes) # := 20. | |||
restypes_with_x = restypes + ['X'] | |||
restype_order_with_x = { | |||
restype: i for i, | |||
restype in enumerate(restypes_with_x)} | |||
def sequence_to_onehot( | |||
sequence: str, | |||
mapping: Mapping[str, int], | |||
map_unknown_to_x: bool = False) -> np.ndarray: | |||
"""Maps the given sequence into a one-hot encoded matrix. | |||
Args: | |||
sequence: An amino acid sequence. | |||
mapping: A dictionary mapping amino acids to integers. | |||
map_unknown_to_x: If True, any amino acid that is not in the mapping will be | |||
mapped to the unknown amino acid 'X'. If the mapping doesn't contain | |||
amino acid 'X', an error will be thrown. If False, any amino acid not in | |||
the mapping will throw an error. | |||
Returns: | |||
A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of | |||
the sequence. | |||
Raises: | |||
ValueError: If the mapping doesn't contain values from 0 to | |||
num_unique_aas - 1 without any gaps. | |||
""" | |||
num_entries = max(mapping.values()) + 1 | |||
if sorted(set(mapping.values())) != list(range(num_entries)): | |||
raise ValueError( | |||
'The mapping must have values from 0 to num_unique_aas-1 ' | |||
'without any gaps. Got: %s' % | |||
sorted( | |||
mapping.values())) | |||
one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32) | |||
for aa_index, aa_type in enumerate(sequence): | |||
if map_unknown_to_x: | |||
if aa_type.isalpha() and aa_type.isupper(): | |||
aa_id = mapping.get(aa_type, mapping['X']) | |||
else: | |||
raise ValueError( | |||
f'Invalid character in the sequence: {aa_type}') | |||
else: | |||
aa_id = mapping[aa_type] | |||
one_hot_arr[aa_index, aa_id] = 1 | |||
return one_hot_arr | |||
restype_1to3 = { | |||
'A': 'ALA', | |||
'R': 'ARG', | |||
'N': 'ASN', | |||
'D': 'ASP', | |||
'C': 'CYS', | |||
'Q': 'GLN', | |||
'E': 'GLU', | |||
'G': 'GLY', | |||
'H': 'HIS', | |||
'I': 'ILE', | |||
'L': 'LEU', | |||
'K': 'LYS', | |||
'M': 'MET', | |||
'F': 'PHE', | |||
'P': 'PRO', | |||
'S': 'SER', | |||
'T': 'THR', | |||
'W': 'TRP', | |||
'Y': 'TYR', | |||
'V': 'VAL', | |||
} | |||
# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple | |||
# 1-to-1 mapping of 3 letter names to one letter names. The latter contains | |||
# many more, and less common, three letter names as keys and maps many of these | |||
# to the same one letter name (including 'X' and 'U' which we don't use here). | |||
restype_3to1 = {v: k for k, v in restype_1to3.items()} | |||
# Define a restype name for all unknown residues. | |||
unk_restype = 'UNK' | |||
resnames = [restype_1to3[r] for r in restypes] + [unk_restype] | |||
resname_to_idx = {resname: i for i, resname in enumerate(resnames)} | |||
# The mapping here uses hhblits convention, so that B is mapped to D, J and O | |||
# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the | |||
# remaining 20 amino acids are kept in alphabetical order. | |||
# There are 2 non-amino acid codes, X (representing any amino acid) and | |||
# "-" representing a missing amino acid in an alignment. The id for these | |||
# codes is put at the end (20 and 21) so that they can easily be ignored if | |||
# desired. | |||
HHBLITS_AA_TO_ID = { | |||
'A': 0, | |||
'B': 2, | |||
'C': 1, | |||
'D': 2, | |||
'E': 3, | |||
'F': 4, | |||
'G': 5, | |||
'H': 6, | |||
'I': 7, | |||
'J': 20, | |||
'K': 8, | |||
'L': 9, | |||
'M': 10, | |||
'N': 11, | |||
'O': 20, | |||
'P': 12, | |||
'Q': 13, | |||
'R': 14, | |||
'S': 15, | |||
'T': 16, | |||
'U': 1, | |||
'V': 17, | |||
'W': 18, | |||
'X': 20, | |||
'Y': 19, | |||
'Z': 3, | |||
'-': 21, | |||
} | |||
# Partial inversion of HHBLITS_AA_TO_ID. | |||
ID_TO_HHBLITS_AA = { | |||
0: 'A', | |||
1: 'C', # Also U. | |||
2: 'D', # Also B. | |||
3: 'E', # Also Z. | |||
4: 'F', | |||
5: 'G', | |||
6: 'H', | |||
7: 'I', | |||
8: 'K', | |||
9: 'L', | |||
10: 'M', | |||
11: 'N', | |||
12: 'P', | |||
13: 'Q', | |||
14: 'R', | |||
15: 'S', | |||
16: 'T', | |||
17: 'V', | |||
18: 'W', | |||
19: 'Y', | |||
20: 'X', # Includes J and O. | |||
21: '-', | |||
} | |||
restypes_with_x_and_gap = restypes + ['X', '-'] | |||
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple(restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i]) | |||
for i in range(len(restypes_with_x_and_gap))) | |||
def _make_standard_atom_mask() -> np.ndarray: | |||
"""Returns [num_res_types, num_atom_types] mask array.""" | |||
# +1 to account for unknown (all 0s). | |||
mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32) | |||
for restype, restype_letter in enumerate(restypes): | |||
restype_name = restype_1to3[restype_letter] | |||
atom_names = residue_atoms[restype_name] | |||
for atom_name in atom_names: | |||
atom_type = atom_order[atom_name] | |||
mask[restype, atom_type] = 1 | |||
return mask | |||
STANDARD_ATOM_MASK = _make_standard_atom_mask() | |||
# A one hot representation for the first and second atoms defining the axis | |||
# of rotation for each chi-angle in each residue. | |||
def chi_angle_atom(atom_index: int) -> np.ndarray: | |||
"""Define chi-angle rigid groups via one-hot representations.""" | |||
chi_angles_index = {} | |||
one_hots = [] | |||
for k, v in chi_angles_atoms.items(): | |||
indices = [atom_types.index(s[atom_index]) for s in v] | |||
indices.extend([-1] * (4 - len(indices))) | |||
chi_angles_index[k] = indices | |||
for r in restypes: | |||
res3 = restype_1to3[r] | |||
one_hot = np.eye(atom_type_num)[chi_angles_index[res3]] | |||
one_hots.append(one_hot) | |||
one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`. | |||
one_hot = np.stack(one_hots, axis=0) | |||
one_hot = np.transpose(one_hot, [0, 2, 1]) | |||
return one_hot | |||
chi_atom_1_one_hot = chi_angle_atom(1) | |||
chi_atom_2_one_hot = chi_angle_atom(2) | |||
# An array like chi_angles_atoms but using indices rather than names. | |||
chi_angles_atom_indices = [[], [[0, 1, 3, 5], [1, 3, 5, 11], [3, 5, 11, 23], [5, 11, 23, 32]], [[0, 1, 3, 5], [1, 3, 5, 16]], [[0, 1, 3, 5], [1, 3, 5, 16]], [[0, 1, 3, 10]], [[0, 1, 3, 5], [1, 3, 5, 11], [3, 5, 11, 26]], [[0, 1, 3, 5], [1, 3, 5, 11], [3, 5, 11, 26]], [], [[0, 1, 3, 5], [1, 3, 5, 14]], [[0, 1, 3, 6], [1, 3, 6, 12]], [[0, 1, 3, 5], [1, 3, 5, 12]], [[0, 1, 3, 5], [1, 3, 5, 11], [3, 5, 11, 19], [5, 11, 19, 35]], [[0, 1, 3, 5], [1, 3, 5, 18], [3, 5, 18, 19]], [[0, 1, 3, 5], [1, 3, 5, 12]], [[0, 1, 3, 5], [1, 3, 5, 11]], [[0, 1, 3, 8]], [[0, 1, 3, 9]], [[0, 1, 3, 5], [1, 3, 5, 12]], [[0, 1, 3, 5], [1, 3, 5, 12]], [[0, 1, 3, 6]]] | |||
chi_angles_atom_indices = np.array([ | |||
chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms))) | |||
for chi_atoms in chi_angles_atom_indices]) | |||
# Mapping from (res_name, atom_name) pairs to the atom's chi group index | |||
# and atom index within that group. | |||
chi_groups_for_atom = collections.defaultdict(list) | |||
for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items(): | |||
for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res): | |||
for atom_i, atom in enumerate(chi_group): | |||
chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i)) | |||
chi_groups_for_atom = dict(chi_groups_for_atom) | |||
def _make_rigid_transformation_4x4(ex, ey, translation): | |||
"""Create a rigid 4x4 transformation matrix from two axes and transl.""" | |||
# Normalize ex. | |||
ex_normalized = ex / np.linalg.norm(ex) | |||
# make ey perpendicular to ex | |||
ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized | |||
ey_normalized /= np.linalg.norm(ey_normalized) | |||
# compute ez as cross product | |||
eznorm = np.cross(ex_normalized, ey_normalized) | |||
m = np.stack([ex_normalized, ey_normalized, | |||
eznorm, translation]).transpose() | |||
m = np.concatenate([m, [[0., 0., 0., 1.]]], axis=0) | |||
return m | |||
# create an array with (restype, atomtype) --> rigid_group_idx | |||
# and an array with (restype, atomtype, coord) for the atom positions | |||
# and compute affine transformation matrices (4,4) from one rigid group to the | |||
# previous group | |||
restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int) | |||
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) | |||
restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32) | |||
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int) | |||
restype_atom14_mask = np.zeros([21, 14], dtype=np.float32) | |||
restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32) | |||
restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32) | |||
def _make_rigid_group_constants(): | |||
"""Fill the arrays above.""" | |||
for restype, restype_letter in enumerate(restypes): | |||
resname = restype_1to3[restype_letter] | |||
for atomname, group_idx, atom_position in rigid_group_atom_positions[ | |||
resname]: | |||
atomtype = atom_order[atomname] | |||
restype_atom37_to_rigid_group[restype, atomtype] = group_idx | |||
restype_atom37_mask[restype, atomtype] = 1 | |||
restype_atom37_rigid_group_positions[restype, | |||
atomtype, :] = atom_position | |||
atom14idx = restype_name_to_atom14_names[resname].index(atomname) | |||
restype_atom14_to_rigid_group[restype, atom14idx] = group_idx | |||
restype_atom14_mask[restype, atom14idx] = 1 | |||
restype_atom14_rigid_group_positions[restype, | |||
atom14idx, :] = atom_position | |||
for restype, restype_letter in enumerate(restypes): | |||
resname = restype_1to3[restype_letter] | |||
atom_positions = {name: np.array(pos) for name, _, pos | |||
in rigid_group_atom_positions[resname]} | |||
# backbone to backbone is the identity transform | |||
restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4) | |||
# pre-omega-frame to backbone (currently dummy identity matrix) | |||
restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4) | |||
# phi-frame to backbone | |||
mat = _make_rigid_transformation_4x4( | |||
ex=atom_positions['N'] - atom_positions['CA'], | |||
ey=np.array([1., 0., 0.]), | |||
translation=atom_positions['N']) | |||
restype_rigid_group_default_frame[restype, 2, :, :] = mat | |||
# psi-frame to backbone | |||
mat = _make_rigid_transformation_4x4( | |||
ex=atom_positions['C'] - atom_positions['CA'], | |||
ey=atom_positions['CA'] - atom_positions['N'], | |||
translation=atom_positions['C']) | |||
restype_rigid_group_default_frame[restype, 3, :, :] = mat | |||
# chi1-frame to backbone | |||
if chi_angles_mask[restype][0]: | |||
base_atom_names = chi_angles_atoms[resname][0] | |||
base_atom_positions = [atom_positions[name] | |||
for name in base_atom_names] | |||
mat = _make_rigid_transformation_4x4( | |||
ex=base_atom_positions[2] - base_atom_positions[1], | |||
ey=base_atom_positions[0] - base_atom_positions[1], | |||
translation=base_atom_positions[2]) | |||
restype_rigid_group_default_frame[restype, 4, :, :] = mat | |||
# chi2-frame to chi1-frame | |||
# chi3-frame to chi2-frame | |||
# chi4-frame to chi3-frame | |||
# luckily all rotation axes for the next frame start at (0,0,0) of the | |||
# previous frame | |||
for chi_idx in range(1, 4): | |||
if chi_angles_mask[restype][chi_idx]: | |||
axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2] | |||
axis_end_atom_position = atom_positions[axis_end_atom_name] | |||
mat = _make_rigid_transformation_4x4( | |||
ex=axis_end_atom_position, | |||
ey=np.array([-1., 0., 0.]), | |||
translation=axis_end_atom_position) | |||
restype_rigid_group_default_frame[restype, | |||
4 + chi_idx, :, :] = mat | |||
_make_rigid_group_constants() |
@@ -0,0 +1,382 @@ | |||
"""Model config.""" | |||
import copy | |||
import ml_collections | |||
NUM_RES = 'num residues placeholder' | |||
NUM_MSA_SEQ = 'msa placeholder' | |||
NUM_EXTRA_SEQ = 'extra msa placeholder' | |||
NUM_TEMPLATES = 'num templates placeholder' | |||
def model_config(name: str) -> ml_collections.ConfigDict: | |||
"""Get the ConfigDict of a CASP14 model.""" | |||
if name not in CONFIG_DIFFS: | |||
raise ValueError(f'Invalid model name {name}.') | |||
cfg = copy.deepcopy(CONFIG) | |||
cfg.update_from_flattened_dict(CONFIG_DIFFS[name]) | |||
return cfg | |||
CONFIG_DIFFS = { | |||
'model_1': { | |||
# Jumper et al. (2021) Suppl. Table 5, Model 1.1.1 | |||
'data.common.max_extra_msa': 5120, | |||
'data.common.reduce_msa_clusters_by_max_templates': True, | |||
'data.common.use_templates': True, | |||
'model.embeddings_and_evoformer.template.embed_torsion_angles': True, | |||
'model.embeddings_and_evoformer.template.enabled': True | |||
}, | |||
'model_2': { | |||
# Jumper et al. (2021) Suppl. Table 5, Model 1.1.2 | |||
'data.common.reduce_msa_clusters_by_max_templates': True, | |||
'data.common.use_templates': True, | |||
'model.embeddings_and_evoformer.template.embed_torsion_angles': True, | |||
'model.embeddings_and_evoformer.template.enabled': True | |||
}, | |||
'model_3': { | |||
# Jumper et al. (2021) Suppl. Table 5, Model 1.2.1 | |||
'data.common.max_extra_msa': 5120, | |||
}, | |||
'model_4': { | |||
# Jumper et al. (2021) Suppl. Table 5, Model 1.2.2 | |||
'data.common.max_extra_msa': 5120, | |||
}, | |||
'model_5': { | |||
# Jumper et al. (2021) Suppl. Table 5, Model 1.2.3 | |||
}, | |||
# The following models are fine-tuned from the corresponding models above | |||
# with an additional predicted_aligned_error head that can produce | |||
# predicted TM-score (pTM) and predicted aligned errors. | |||
'model_1_ptm': { | |||
'data.common.max_extra_msa': 5120, | |||
'data.common.reduce_msa_clusters_by_max_templates': True, | |||
'data.common.use_templates': True, | |||
'model.embeddings_and_evoformer.template.embed_torsion_angles': True, | |||
'model.embeddings_and_evoformer.template.enabled': True, | |||
'model.heads.predicted_aligned_error.weight': 0.1 | |||
}, | |||
'model_2_ptm': { | |||
'data.common.reduce_msa_clusters_by_max_templates': True, | |||
'data.common.use_templates': True, | |||
'model.embeddings_and_evoformer.template.embed_torsion_angles': True, | |||
'model.embeddings_and_evoformer.template.enabled': True, | |||
'model.heads.predicted_aligned_error.weight': 0.1 | |||
}, | |||
'model_3_ptm': { | |||
'data.common.max_extra_msa': 5120, | |||
'model.heads.predicted_aligned_error.weight': 0.1 | |||
}, | |||
'model_4_ptm': { | |||
'data.common.max_extra_msa': 5120, | |||
'model.heads.predicted_aligned_error.weight': 0.1 | |||
}, | |||
'model_5_ptm': { | |||
'model.heads.predicted_aligned_error.weight': 0.1 | |||
} | |||
} | |||
CONFIG = ml_collections.ConfigDict({ | |||
'data': { | |||
'common': { | |||
'masked_msa': { | |||
'profile_prob': 0.1, | |||
'same_prob': 0.1, | |||
'uniform_prob': 0.1 | |||
}, | |||
'max_extra_msa': 1024, | |||
'msa_cluster_features': True, | |||
'num_recycle': 3, | |||
'reduce_msa_clusters_by_max_templates': False, | |||
'resample_msa_in_recycling': True, | |||
'template_features': [ | |||
'template_all_atom_positions', 'template_sum_probs', | |||
'template_aatype', 'template_all_atom_masks', | |||
'template_domain_names' | |||
], | |||
'unsupervised_features': [ | |||
'aatype', 'residue_index', 'sequence', 'msa', 'domain_name', | |||
'num_alignments', 'seq_length', 'between_segment_residues', | |||
'deletion_matrix' | |||
], | |||
'use_templates': False, | |||
}, | |||
'eval': { | |||
'feat': { | |||
'aatype': [NUM_RES], | |||
'all_atom_mask': [NUM_RES, None], | |||
'all_atom_positions': [NUM_RES, None, None], | |||
'alt_chi_angles': [NUM_RES, None], | |||
'atom14_alt_gt_exists': [NUM_RES, None], | |||
'atom14_alt_gt_positions': [NUM_RES, None, None], | |||
'atom14_atom_exists': [NUM_RES, None], | |||
'atom14_atom_is_ambiguous': [NUM_RES, None], | |||
'atom14_gt_exists': [NUM_RES, None], | |||
'atom14_gt_positions': [NUM_RES, None, None], | |||
'atom37_atom_exists': [NUM_RES, None], | |||
'backbone_affine_mask': [NUM_RES], | |||
'backbone_affine_tensor': [NUM_RES, None], | |||
'bert_mask': [NUM_MSA_SEQ, NUM_RES], | |||
'chi_angles': [NUM_RES, None], | |||
'chi_mask': [NUM_RES, None], | |||
'extra_deletion_value': [NUM_EXTRA_SEQ, NUM_RES], | |||
'extra_has_deletion': [NUM_EXTRA_SEQ, NUM_RES], | |||
'extra_msa': [NUM_EXTRA_SEQ, NUM_RES], | |||
'extra_msa_mask': [NUM_EXTRA_SEQ, NUM_RES], | |||
'extra_msa_row_mask': [NUM_EXTRA_SEQ], | |||
'is_distillation': [], | |||
'msa_feat': [NUM_MSA_SEQ, NUM_RES, None], | |||
'msa_mask': [NUM_MSA_SEQ, NUM_RES], | |||
'msa_row_mask': [NUM_MSA_SEQ], | |||
'pseudo_beta': [NUM_RES, None], | |||
'pseudo_beta_mask': [NUM_RES], | |||
'random_crop_to_size_seed': [None], | |||
'residue_index': [NUM_RES], | |||
'residx_atom14_to_atom37': [NUM_RES, None], | |||
'residx_atom37_to_atom14': [NUM_RES, None], | |||
'resolution': [], | |||
'rigidgroups_alt_gt_frames': [NUM_RES, None, None], | |||
'rigidgroups_group_exists': [NUM_RES, None], | |||
'rigidgroups_group_is_ambiguous': [NUM_RES, None], | |||
'rigidgroups_gt_exists': [NUM_RES, None], | |||
'rigidgroups_gt_frames': [NUM_RES, None, None], | |||
'seq_length': [], | |||
'seq_mask': [NUM_RES], | |||
'target_feat': [NUM_RES, None], | |||
'template_aatype': [NUM_TEMPLATES, NUM_RES], | |||
'template_all_atom_masks': [NUM_TEMPLATES, NUM_RES, None], | |||
'template_all_atom_positions': [ | |||
NUM_TEMPLATES, NUM_RES, None, None], | |||
'template_backbone_affine_mask': [NUM_TEMPLATES, NUM_RES], | |||
'template_backbone_affine_tensor': [ | |||
NUM_TEMPLATES, NUM_RES, None], | |||
'template_mask': [NUM_TEMPLATES], | |||
'template_pseudo_beta': [NUM_TEMPLATES, NUM_RES, None], | |||
'template_pseudo_beta_mask': [NUM_TEMPLATES, NUM_RES], | |||
'template_sum_probs': [NUM_TEMPLATES, None], | |||
'true_msa': [NUM_MSA_SEQ, NUM_RES] | |||
}, | |||
'fixed_size': True, | |||
'subsample_templates': False, # We want top templates. | |||
'masked_msa_replace_fraction': 0.15, | |||
'max_msa_clusters': 512, | |||
'max_templates': 4, | |||
'num_ensemble': 1, | |||
}, | |||
}, | |||
'model': { | |||
'embeddings_and_evoformer': { | |||
'evoformer_num_block': 48, | |||
'evoformer': { | |||
'msa_row_attention_with_pair_bias': { | |||
'dropout_rate': 0.15, | |||
'gating': True, | |||
'num_head': 8, | |||
'orientation': 'per_row', | |||
'shared_dropout': True | |||
}, | |||
'msa_column_attention': { | |||
'dropout_rate': 0.0, | |||
'gating': True, | |||
'num_head': 8, | |||
'orientation': 'per_column', | |||
'shared_dropout': True | |||
}, | |||
'msa_transition': { | |||
'dropout_rate': 0.0, | |||
'num_intermediate_factor': 4, | |||
'orientation': 'per_row', | |||
'shared_dropout': True | |||
}, | |||
'outer_product_mean': { | |||
'chunk_size': 128, | |||
'dropout_rate': 0.0, | |||
'num_outer_channel': 32, | |||
'orientation': 'per_row', | |||
'shared_dropout': True | |||
}, | |||
'triangle_attention_starting_node': { | |||
'dropout_rate': 0.25, | |||
'gating': True, | |||
'num_head': 4, | |||
'orientation': 'per_row', | |||
'shared_dropout': True | |||
}, | |||
'triangle_attention_ending_node': { | |||
'dropout_rate': 0.25, | |||
'gating': True, | |||
'num_head': 4, | |||
'orientation': 'per_column', | |||
'shared_dropout': True | |||
}, | |||
'triangle_multiplication_outgoing': { | |||
'dropout_rate': 0.25, | |||
'equation': 'ikc,jkc->ijc', | |||
'num_intermediate_channel': 128, | |||
'orientation': 'per_row', | |||
'shared_dropout': True | |||
}, | |||
'triangle_multiplication_incoming': { | |||
'dropout_rate': 0.25, | |||
'equation': 'kjc,kic->ijc', | |||
'num_intermediate_channel': 128, | |||
'orientation': 'per_row', | |||
'shared_dropout': True | |||
}, | |||
'pair_transition': { | |||
'dropout_rate': 0.0, | |||
'num_intermediate_factor': 4, | |||
'orientation': 'per_row', | |||
'shared_dropout': True | |||
} | |||
}, | |||
'extra_msa_channel': 64, | |||
'extra_msa_stack_num_block': 4, | |||
'max_relative_feature': 32, | |||
'msa_channel': 256, | |||
'pair_channel': 128, | |||
'prev_pos': { | |||
'min_bin': 3.25, | |||
'max_bin': 20.75, | |||
'num_bins': 15 | |||
}, | |||
'recycle_features': True, | |||
'recycle_pos': True, | |||
'seq_channel': 384, | |||
'template': { | |||
'attention': { | |||
'gating': False, | |||
'key_dim': 64, | |||
'num_head': 4, | |||
'value_dim': 64 | |||
}, | |||
'dgram_features': { | |||
'min_bin': 3.25, | |||
'max_bin': 50.75, | |||
'num_bins': 39 | |||
}, | |||
'embed_torsion_angles': False, | |||
'enabled': False, | |||
'template_pair_stack': { | |||
'num_block': 2, | |||
'triangle_attention_starting_node': { | |||
'dropout_rate': 0.25, | |||
'gating': True, | |||
'key_dim': 64, | |||
'num_head': 4, | |||
'orientation': 'per_row', | |||
'shared_dropout': True, | |||
'value_dim': 64 | |||
}, | |||
'triangle_attention_ending_node': { | |||
'dropout_rate': 0.25, | |||
'gating': True, | |||
'key_dim': 64, | |||
'num_head': 4, | |||
'orientation': 'per_column', | |||
'shared_dropout': True, | |||
'value_dim': 64 | |||
}, | |||
'triangle_multiplication_outgoing': { | |||
'dropout_rate': 0.25, | |||
'equation': 'ikc,jkc->ijc', | |||
'num_intermediate_channel': 64, | |||
'orientation': 'per_row', | |||
'shared_dropout': True | |||
}, | |||
'triangle_multiplication_incoming': { | |||
'dropout_rate': 0.25, | |||
'equation': 'kjc,kic->ijc', | |||
'num_intermediate_channel': 64, | |||
'orientation': 'per_row', | |||
'shared_dropout': True | |||
}, | |||
'pair_transition': { | |||
'dropout_rate': 0.0, | |||
'num_intermediate_factor': 2, | |||
'orientation': 'per_row', | |||
'shared_dropout': True | |||
} | |||
}, | |||
'max_templates': 4, | |||
'subbatch_size': 128, | |||
'use_template_unit_vector': False, | |||
} | |||
}, | |||
'heads': { | |||
'distogram': { | |||
'first_break': 2.3125, | |||
'last_break': 21.6875, | |||
'num_bins': 64, | |||
'weight': 0.3 | |||
}, | |||
'predicted_aligned_error': { | |||
# `num_bins - 1` bins uniformly space the | |||
# [0, max_error_bin A] range. | |||
# The final bin covers [max_error_bin A, +infty] | |||
# 31A gives bins with 0.5A width. | |||
'max_error_bin': 31., | |||
'num_bins': 64, | |||
'num_channels': 128, | |||
'filter_by_resolution': True, | |||
'min_resolution': 0.1, | |||
'max_resolution': 3.0, | |||
'weight': 0.0, | |||
}, | |||
'experimentally_resolved': { | |||
'filter_by_resolution': True, | |||
'max_resolution': 3.0, | |||
'min_resolution': 0.1, | |||
'weight': 0.01 | |||
}, | |||
'structure_module': { | |||
'num_layer': 8, | |||
'fape': { | |||
'clamp_distance': 10.0, | |||
'clamp_type': 'relu', | |||
'loss_unit_distance': 10.0 | |||
}, | |||
'angle_norm_weight': 0.01, | |||
'chi_weight': 0.5, | |||
'clash_overlap_tolerance': 1.5, | |||
'compute_in_graph_metrics': True, | |||
'dropout': 0.1, | |||
'num_channel': 384, | |||
'num_head': 12, | |||
'num_layer_in_transition': 3, | |||
'num_point_qk': 4, | |||
'num_point_v': 8, | |||
'num_scalar_qk': 16, | |||
'num_scalar_v': 16, | |||
'position_scale': 10.0, | |||
'sidechain': { | |||
'atom_clamp_distance': 10.0, | |||
'num_channel': 128, | |||
'num_residual_block': 2, | |||
'weight_frac': 0.5, | |||
'length_scale': 10., | |||
}, | |||
'structural_violation_loss_weight': 1.0, | |||
'violation_tolerance_factor': 12.0, | |||
'weight': 1.0 | |||
}, | |||
'predicted_lddt': { | |||
'filter_by_resolution': True, | |||
'max_resolution': 3.0, | |||
'min_resolution': 0.1, | |||
'num_bins': 50, | |||
'num_channels': 128, | |||
'weight': 0.01 | |||
}, | |||
'masked_msa': { | |||
'num_output': 23, | |||
'weight': 2.0 | |||
}, | |||
}, | |||
'num_recycle': 3, | |||
'resample_msa_in_recycling': True | |||
}, | |||
}) |
@@ -0,0 +1,341 @@ | |||
"""Model config.""" | |||
import copy | |||
import ml_collections | |||
def global_config(length: int) -> ml_collections.ConfigDict: | |||
"""Get the global config.""" | |||
if str(length) not in GLOBAL_CONFIG: | |||
raise ValueError(f'Invalid padding sequence length {length}.') | |||
cfg = copy.deepcopy(GLOBAL_CONFIG[str(length)]) | |||
return cfg | |||
GLOBAL_CONFIG = ml_collections.ConfigDict({ | |||
"256": { | |||
'zero_init': True, | |||
'seq_length': 256, | |||
'extra_msa_length': 5120, | |||
'template_embedding': { | |||
'slice_num': 0, | |||
}, | |||
'template_pair_stack': { | |||
'triangle_attention_starting_node': { | |||
'slice_num': 0, | |||
}, | |||
'triangle_attention_ending_node': { | |||
'slice_num': 0, | |||
}, | |||
'pair_transition': { | |||
'slice_num': 0, | |||
}, | |||
}, | |||
'extra_msa_stack': { | |||
'msa_transition': { | |||
'slice_num': 0, | |||
}, | |||
'msa_row_attention_with_pair_bias': { | |||
'slice_num': 0, | |||
}, | |||
'msa_column_global_attention': { | |||
'slice_num': 0, | |||
}, | |||
'outer_product_mean': { | |||
'slice_num': 0, | |||
}, | |||
'triangle_attention_starting_node': { | |||
'slice_num': 0, | |||
}, | |||
'triangle_attention_ending_node': { | |||
'slice_num': 0, | |||
}, | |||
'pair_transition': { | |||
'slice_num': 0, | |||
}, | |||
}, | |||
'evoformer_iteration': { | |||
'msa_transition': { | |||
'slice_num': 0, | |||
}, | |||
'msa_row_attention_with_pair_bias': { | |||
'slice_num': 0, | |||
}, | |||
'msa_column_attention': { | |||
'slice_num': 0, | |||
}, | |||
'outer_product_mean': { | |||
'slice_num': 0, | |||
}, | |||
'triangle_attention_starting_node': { | |||
'slice_num': 0, | |||
}, | |||
'triangle_attention_ending_node': { | |||
'slice_num': 0, | |||
}, | |||
'pair_transition': { | |||
'slice_num': 0, | |||
}, | |||
}, | |||
}, | |||
"512": { | |||
'zero_init': True, | |||
'seq_length': 512, | |||
'extra_msa_length': 5120, | |||
'template_embedding': { | |||
'slice_num': 0, | |||
}, | |||
'template_pair_stack': { | |||
'triangle_attention_starting_node': { | |||
'slice_num': 0, | |||
}, | |||
'triangle_attention_ending_node': { | |||
'slice_num': 0, | |||
}, | |||
'pair_transition': { | |||
'slice_num': 0, | |||
}, | |||
}, | |||
'extra_msa_stack': { | |||
'msa_transition': { | |||
'slice_num': 0, | |||
}, | |||
'msa_row_attention_with_pair_bias': { | |||
'slice_num': 4, | |||
}, | |||
'msa_column_global_attention': { | |||
'slice_num': 0, | |||
}, | |||
'outer_product_mean': { | |||
'slice_num': 0, | |||
}, | |||
'triangle_attention_starting_node': { | |||
'slice_num': 0, | |||
}, | |||
'triangle_attention_ending_node': { | |||
'slice_num': 0, | |||
}, | |||
'pair_transition': { | |||
'slice_num': 0, | |||
}, | |||
}, | |||
'evoformer_iteration': { | |||
'msa_transition': { | |||
'slice_num': 0, | |||
}, | |||
'msa_row_attention_with_pair_bias': { | |||
'slice_num': 0, | |||
}, | |||
'msa_column_attention': { | |||
'slice_num': 0, | |||
}, | |||
'outer_product_mean': { | |||
'slice_num': 0, | |||
}, | |||
'triangle_attention_starting_node': { | |||
'slice_num': 0, | |||
}, | |||
'triangle_attention_ending_node': { | |||
'slice_num': 0, | |||
}, | |||
'pair_transition': { | |||
'slice_num': 0, | |||
}, | |||
}, | |||
}, | |||
"1024": { | |||
'zero_init': True, | |||
'seq_length': 1024, | |||
'extra_msa_length': 5120, | |||
'template_embedding': { | |||
'slice_num': 4, | |||
}, | |||
'template_pair_stack': { | |||
'triangle_attention_starting_node': { | |||
'slice_num': 4, | |||
}, | |||
'triangle_attention_ending_node': { | |||
'slice_num': 4, | |||
}, | |||
'pair_transition': { | |||
'slice_num': 0, | |||
}, | |||
}, | |||
'extra_msa_stack': { | |||
'msa_transition': { | |||
'slice_num': 0, | |||
}, | |||
'msa_row_attention_with_pair_bias': { | |||
'slice_num': 16, | |||
}, | |||
'msa_column_global_attention': { | |||
'slice_num': 4, | |||
}, | |||
'outer_product_mean': { | |||
'slice_num': 0, | |||
}, | |||
'triangle_attention_starting_node': { | |||
'slice_num': 4, | |||
}, | |||
'triangle_attention_ending_node': { | |||
'slice_num': 4, | |||
}, | |||
'pair_transition': { | |||
'slice_num': 0, | |||
}, | |||
}, | |||
'evoformer_iteration': { | |||
'msa_transition': { | |||
'slice_num': 0, | |||
}, | |||
'msa_row_attention_with_pair_bias': { | |||
'slice_num': 4, | |||
}, | |||
'msa_column_attention': { | |||
'slice_num': 4, | |||
}, | |||
'outer_product_mean': { | |||
'slice_num': 0, | |||
}, | |||
'triangle_attention_starting_node': { | |||
'slice_num': 4, | |||
}, | |||
'triangle_attention_ending_node': { | |||
'slice_num': 4, | |||
}, | |||
'pair_transition': { | |||
'slice_num': 0, | |||
}, | |||
}, | |||
}, | |||
"2048": { | |||
'zero_init': True, | |||
'seq_length': 2048, | |||
'extra_msa_length': 5120, | |||
'template_embedding': { | |||
'slice_num': 32, | |||
}, | |||
'template_pair_stack': { | |||
'triangle_attention_starting_node': { | |||
'slice_num': 32, | |||
}, | |||
'triangle_attention_ending_node': { | |||
'slice_num': 32, | |||
}, | |||
'pair_transition': { | |||
'slice_num': 16, | |||
}, | |||
}, | |||
'extra_msa_stack': { | |||
'msa_transition': { | |||
'slice_num': 16, | |||
}, | |||
'msa_row_attention_with_pair_bias': { | |||
'slice_num': 128, | |||
}, | |||
'msa_column_global_attention': { | |||
'slice_num': 32, | |||
}, | |||
'outer_product_mean': { | |||
'slice_num': 16, | |||
}, | |||
'triangle_attention_starting_node': { | |||
'slice_num': 32, | |||
}, | |||
'triangle_attention_ending_node': { | |||
'slice_num': 32, | |||
}, | |||
'pair_transition': { | |||
'slice_num': 16, | |||
}, | |||
}, | |||
'evoformer_iteration': { | |||
'msa_transition': { | |||
'slice_num': 16, | |||
}, | |||
'msa_row_attention_with_pair_bias': { | |||
'slice_num': 32, | |||
}, | |||
'msa_column_attention': { | |||
'slice_num': 32, | |||
}, | |||
'outer_product_mean': { | |||
'slice_num': 16, | |||
}, | |||
'triangle_attention_starting_node': { | |||
'slice_num': 32, | |||
}, | |||
'triangle_attention_ending_node': { | |||
'slice_num': 32, | |||
}, | |||
'pair_transition': { | |||
'slice_num': 16, | |||
}, | |||
}, | |||
}, | |||
"2304": { | |||
'zero_init': True, | |||
'seq_length': 2304, | |||
'extra_msa_length': 5120, | |||
'template_embedding': { | |||
'slice_num': 64, | |||
}, | |||
'template_pair_stack': { | |||
'triangle_attention_starting_node': { | |||
'slice_num': 64, | |||
}, | |||
'triangle_attention_ending_node': { | |||
'slice_num': 64, | |||
}, | |||
'pair_transition': { | |||
'slice_num': 2, | |||
}, | |||
}, | |||
'extra_msa_stack': { | |||
'msa_transition': { | |||
'slice_num': 2, | |||
}, | |||
'msa_row_attention_with_pair_bias': { | |||
'slice_num': 64, | |||
}, | |||
'msa_column_global_attention': { | |||
'slice_num': 64, | |||
}, | |||
'outer_product_mean': { | |||
'slice_num': 8, | |||
}, | |||
'triangle_attention_starting_node': { | |||
'slice_num': 64, | |||
}, | |||
'triangle_attention_ending_node': { | |||
'slice_num': 64, | |||
}, | |||
'pair_transition': { | |||
'slice_num': 4, | |||
}, | |||
}, | |||
'evoformer_iteration': { | |||
'msa_transition': { | |||
'slice_num': 2, | |||
}, | |||
'msa_row_attention_with_pair_bias': { | |||
'slice_num': 64, | |||
}, | |||
'msa_column_attention': { | |||
'slice_num': 64, | |||
}, | |||
'outer_product_mean': { | |||
'slice_num': 8, | |||
}, | |||
'triangle_attention_starting_node': { | |||
'slice_num': 64, | |||
}, | |||
'triangle_attention_ending_node': { | |||
'slice_num': 64, | |||
}, | |||
'pair_transition': { | |||
'slice_num': 4, | |||
}, | |||
}, | |||
}, | |||
}) |
@@ -0,0 +1,517 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""data transforms""" | |||
import numpy as np | |||
from commons import residue_constants | |||
NUM_RES = 'num residues placeholder' | |||
NUM_MSA_SEQ = 'msa placeholder' | |||
NUM_EXTRA_SEQ = 'extra msa placeholder' | |||
NUM_TEMPLATES = 'num templates placeholder' | |||
MS_MIN32 = -2147483648 | |||
MS_MAX32 = 2147483647 | |||
_MSA_FEATURE_NAMES = ['msa', 'deletion_matrix', 'msa_mask', 'msa_row_mask', 'bert_mask', 'true_msa'] | |||
class SeedMaker: | |||
"""Return unique seeds.""" | |||
def __init__(self, initial_seed=0): | |||
self.next_seed = initial_seed | |||
def __call__(self): | |||
i = self.next_seed | |||
self.next_seed += 1 | |||
return i | |||
seed_maker = SeedMaker() | |||
def one_hot(depth, indices): | |||
res = np.eye(depth)[indices.reshape(-1)] | |||
return res.reshape(list(indices.shape) + [depth]) | |||
def make_random_seed(size, seed_maker_t, low=MS_MIN32, high=MS_MAX32): | |||
np.random.seed(seed_maker_t) | |||
return np.random.uniform(size=size, low=low, high=high) | |||
def curry1(f): | |||
"""Supply all arguments except the first.""" | |||
def fc(*args, **kwargs): | |||
return lambda x: f(x, *args, **kwargs) | |||
return fc | |||
@curry1 | |||
def compose(x, fs): | |||
for f in fs: | |||
x = f(x) | |||
return x | |||
@curry1 | |||
def randomly_replace_msa_with_unknown(protein, replace_proportion): | |||
"""Replace a proportion of the MSA with 'X'.""" | |||
msa_mask = np.random.uniform(size=shape_list(protein['msa']), low=0, high=1) < replace_proportion | |||
x_idx = 20 | |||
gap_idx = 21 | |||
msa_mask = np.logical_and(msa_mask, protein['msa'] != gap_idx) | |||
protein['msa'] = np.where(msa_mask, np.ones_like(protein['msa']) * x_idx, protein['msa']) | |||
aatype_mask = np.random.uniform(size=shape_list(protein['aatype']), low=0, high=1) < replace_proportion | |||
protein['aatype'] = np.where(aatype_mask, np.ones_like(protein['aatype']) * x_idx, protein['aatype']) | |||
return protein | |||
@curry1 | |||
def sample_msa(protein, max_seq, keep_extra): | |||
"""Sample MSA randomly, remaining sequences are stored as `extra_*`.""" | |||
num_seq = protein['msa'].shape[0] | |||
shuffled = list(range(1, num_seq)) | |||
np.random.shuffle(shuffled) | |||
shuffled.insert(0, 0) | |||
index_order = np.array(shuffled, np.int32) | |||
num_sel = min(max_seq, num_seq) | |||
sel_seq = index_order[:num_sel] | |||
not_sel_seq = index_order[num_sel:] | |||
is_sel = num_seq - num_sel | |||
for k in _MSA_FEATURE_NAMES: | |||
if k in protein: | |||
if keep_extra and not is_sel: | |||
new_shape = list(protein[k].shape) | |||
new_shape[0] = 1 | |||
protein['extra_' + k] = np.zeros(new_shape) | |||
elif keep_extra and is_sel: | |||
protein['extra_' + k] = protein[k][not_sel_seq] | |||
if k == 'msa': | |||
protein['extra_msa'] = protein['extra_msa'].astype(np.int32) | |||
protein[k] = protein[k][sel_seq] | |||
return protein | |||
def shaped_categorical(probs): | |||
ds = shape_list(probs) | |||
num_classes = ds[-1] | |||
probs = np.reshape(probs, (-1, num_classes)) | |||
nums = list(range(num_classes)) | |||
counts = [] | |||
for prob in probs: | |||
counts.append(np.random.choice(nums, p=prob)) | |||
return np.reshape(np.array(counts, np.int32), ds[:-1]) | |||
@curry1 | |||
def make_masked_msa(protein, config, replace_fraction): | |||
"""Create data for BERT on raw MSA.""" | |||
random_aa = np.array([0.05] * 20 + [0., 0.], dtype=np.float32) | |||
categorical_probs = config.uniform_prob * random_aa + config.profile_prob * protein['hhblits_profile'] + \ | |||
config.same_prob * one_hot(22, protein['msa']) | |||
pad_shapes = [[0, 0] for _ in range(len(categorical_probs.shape))] | |||
pad_shapes[-1][1] = 1 | |||
mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob | |||
assert mask_prob >= 0. | |||
categorical_probs = np.pad(categorical_probs, pad_shapes, constant_values=(mask_prob,)) | |||
mask_position = np.random.uniform(size=shape_list(protein['msa']), low=0, high=1) < replace_fraction | |||
bert_msa = shaped_categorical(categorical_probs) | |||
bert_msa = np.where(mask_position, bert_msa, protein['msa']) | |||
protein['bert_mask'] = mask_position.astype(np.int32) | |||
protein['true_msa'] = protein['msa'] | |||
protein['msa'] = bert_msa | |||
return protein | |||
@curry1 | |||
def nearest_neighbor_clusters(protein, gap_agreement_weight=0.): | |||
"""Assign each extra MSA sequence to its nearest neighbor in sampled MSA.""" | |||
weights = np.concatenate([np.ones(21), gap_agreement_weight * np.ones(1), np.zeros(1)], 0) | |||
sample_one_hot = protein['msa_mask'][:, :, None] * one_hot(23, protein['msa']) | |||
num_seq, num_res, _ = shape_list(sample_one_hot) | |||
array_extra_msa_mask = protein['extra_msa_mask'] | |||
if array_extra_msa_mask.any(): | |||
extra_one_hot = protein['extra_msa_mask'][:, :, None] * one_hot(23, protein['extra_msa']) | |||
extra_num_seq, _, _ = shape_list(extra_one_hot) | |||
agreement = np.matmul( | |||
np.reshape(extra_one_hot, [extra_num_seq, num_res * 23]), | |||
np.reshape(sample_one_hot * weights, [num_seq, num_res * 23]).T) | |||
protein['extra_cluster_assignment'] = np.argmax(agreement, axis=1) | |||
else: | |||
protein['extra_cluster_assignment'] = np.array([]) | |||
return protein | |||
@curry1 | |||
def summarize_clusters(protein): | |||
"""Produce profile and deletion_matrix_mean within each cluster.""" | |||
num_seq = shape_list(protein['msa'])[0] | |||
def csum(x): | |||
result = [] | |||
for i in range(num_seq): | |||
result.append(np.sum(x[np.where(protein['extra_cluster_assignment'] == i)], axis=0)) | |||
return np.array(result) | |||
mask = protein['extra_msa_mask'] | |||
mask_counts = 1e-6 + protein['msa_mask'] + csum(mask) # Include center | |||
msa_sum = csum(mask[:, :, None] * np.zeros(mask.shape + (23,), np.float32)) | |||
msa_sum += one_hot(23, protein['msa']) # Original sequence | |||
protein['cluster_profile'] = msa_sum / mask_counts[:, :, None] | |||
del msa_sum | |||
del_sum = csum(mask * protein['extra_deletion_matrix']) | |||
del_sum += protein['deletion_matrix'] # Original sequence | |||
protein['cluster_deletion_mean'] = del_sum / mask_counts | |||
del del_sum | |||
return protein | |||
@curry1 | |||
def crop_extra_msa(protein, max_extra_msa): | |||
"""MSA features are cropped so only `max_extra_msa` sequences are kept.""" | |||
if protein['extra_msa'].any(): | |||
num_seq = protein['extra_msa'].shape[0] | |||
num_sel = np.minimum(max_extra_msa, num_seq) | |||
shuffled = list(range(num_seq)) | |||
np.random.shuffle(shuffled) | |||
select_indices = shuffled[:num_sel] | |||
for k in _MSA_FEATURE_NAMES: | |||
if 'extra_' + k in protein: | |||
protein['extra_' + k] = protein['extra_' + k][select_indices] | |||
return protein | |||
def delete_extra_msa(protein): | |||
for k in _MSA_FEATURE_NAMES: | |||
if 'extra_' + k in protein: | |||
del protein['extra_' + k] | |||
return protein | |||
@curry1 | |||
def make_msa_feat(protein): | |||
"""Create and concatenate MSA features.""" | |||
has_break = np.clip(protein['between_segment_residues'].astype(np.float32), np.array(0), np.array(1)) | |||
aatype_1hot = one_hot(21, protein['aatype']) | |||
target_feat = [np.expand_dims(has_break, axis=-1), aatype_1hot] | |||
msa_1hot = one_hot(23, protein['msa']) | |||
has_deletion = np.clip(protein['deletion_matrix'], np.array(0), np.array(1)) | |||
deletion_value = np.arctan(protein['deletion_matrix'] / 3.) * (2. / np.pi) | |||
msa_feat = [msa_1hot, np.expand_dims(has_deletion, axis=-1), np.expand_dims(deletion_value, axis=-1)] | |||
if 'cluster_profile' in protein: | |||
deletion_mean_value = (np.arctan(protein['cluster_deletion_mean'] / 3.) * (2. / np.pi)) | |||
msa_feat.extend([protein['cluster_profile'], np.expand_dims(deletion_mean_value, axis=-1)]) | |||
if 'extra_deletion_matrix' in protein: | |||
protein['extra_has_deletion'] = np.clip(protein['extra_deletion_matrix'], np.array(0), np.array(1)) | |||
protein['extra_deletion_value'] = np.arctan(protein['extra_deletion_matrix'] / 3.) * (2. / np.pi) | |||
protein['msa_feat'] = np.concatenate(msa_feat, axis=-1) | |||
protein['target_feat'] = np.concatenate(target_feat, axis=-1) | |||
return protein | |||
@curry1 | |||
def select_feat(protein, feature_list): | |||
return {k: v for k, v in protein.items() if k in feature_list} | |||
@curry1 | |||
def random_crop_to_size(protein, crop_size, max_templates, shape_schema, | |||
subsample_templates=False): | |||
"""Crop randomly to `crop_size`, or keep as is if shorter than that.""" | |||
seq_length = protein['seq_length'] | |||
seq_length_int = int(seq_length) | |||
if 'template_mask' in protein: | |||
num_templates = np.array(shape_list(protein['template_mask'])[0], np.int32) | |||
else: | |||
num_templates = np.array(0, np.int32) | |||
num_res_crop_size = np.minimum(seq_length, crop_size) | |||
num_res_crop_size_int = int(num_res_crop_size) | |||
if subsample_templates: | |||
templates_crop_start = make_random_seed(size=(), seed_maker_t=seed_maker(), low=0, high=num_templates + 1) | |||
else: | |||
templates_crop_start = 0 | |||
num_templates_crop_size = np.minimum(num_templates - templates_crop_start, max_templates) | |||
num_templates_crop_size_int = int(num_templates_crop_size) | |||
num_res_crop_start = int(make_random_seed(size=(), seed_maker_t=seed_maker(), low=0, | |||
high=seq_length_int - num_res_crop_size_int + 1)) | |||
templates_select_indices = np.argsort(make_random_seed(size=[num_templates], seed_maker_t=seed_maker())) | |||
for k, v in protein.items(): | |||
if k not in shape_schema or ('template' not in k and NUM_RES not in shape_schema[k]): | |||
continue | |||
if k.startswith('template') and subsample_templates: | |||
v = v[templates_select_indices] | |||
crop_sizes = [] | |||
crop_starts = [] | |||
for i, (dim_size, dim) in enumerate(zip(shape_schema[k], shape_list(v))): | |||
is_num_res = (dim_size == NUM_RES) | |||
if i == 0 and k.startswith('template'): | |||
crop_size = num_templates_crop_size_int | |||
crop_start = templates_crop_start | |||
else: | |||
crop_start = num_res_crop_start if is_num_res else 0 | |||
crop_size = (num_res_crop_size_int if is_num_res else (-1 if dim is None else dim)) | |||
crop_sizes.append(crop_size) | |||
crop_starts.append(crop_start) | |||
if len(v.shape) == 1: | |||
protein[k] = v[crop_starts[0]:crop_starts[0] + crop_sizes[0]] | |||
elif len(v.shape) == 2: | |||
protein[k] = v[crop_starts[0]:crop_starts[0] + crop_sizes[0], crop_starts[1]:crop_starts[1] + crop_sizes[1]] | |||
elif len(v.shape) == 3: | |||
protein[k] = v[crop_starts[0]:crop_starts[0] + crop_sizes[0], crop_starts[1]:crop_starts[1] + crop_sizes[1], | |||
crop_starts[2]:crop_starts[2] + crop_sizes[2]] | |||
else: | |||
protein[k] = v[crop_starts[0]:crop_starts[0] + crop_sizes[0], crop_starts[1]:crop_starts[1] + crop_sizes[1], | |||
crop_starts[2]:crop_starts[2] + crop_sizes[2], crop_starts[3]:crop_starts[3] + crop_sizes[3]] | |||
protein['seq_length'] = num_res_crop_size | |||
return protein | |||
@curry1 | |||
def make_fixed_size(protein, shape_schema, msa_cluster_size, extra_msa_size, | |||
num_res, num_templates=0): | |||
"""Guess at the MSA and sequence dimensions to make fixed size.""" | |||
pad_size_map = { | |||
NUM_RES: num_res, | |||
NUM_MSA_SEQ: msa_cluster_size, | |||
NUM_EXTRA_SEQ: extra_msa_size, | |||
NUM_TEMPLATES: num_templates, | |||
} | |||
for k, v in protein.items(): | |||
if k == 'extra_cluster_assignment': | |||
continue | |||
shape = list(v.shape) | |||
schema = shape_schema[k] | |||
assert len(shape) == len(schema), f'Rank mismatch between shape and shape schema for {k}: {shape} vs {schema}' | |||
pad_size = [pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)] | |||
padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)] | |||
if padding: | |||
protein[k] = np.pad(v, padding) | |||
protein[k].reshape(pad_size) | |||
return protein | |||
@curry1 | |||
def crop_templates(protein, max_templates): | |||
for k, v in protein.items(): | |||
if k.startswith('template_'): | |||
protein[k] = v[:max_templates] | |||
return protein | |||
def correct_msa_restypes(protein): | |||
"""Correct MSA restype to have the same order as residue_constants.""" | |||
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE | |||
new_order = np.array(new_order_list, dtype=protein['msa'].dtype) | |||
protein['msa'] = new_order[protein['msa']] | |||
perm_matrix = np.zeros((22, 22), dtype=np.float32) | |||
perm_matrix[range(len(new_order_list)), new_order_list] = 1. | |||
return protein | |||
@curry1 | |||
def add_distillation_flag(protein, distillation): | |||
protein['is_distillation'] = np.array(float(distillation), dtype=np.float32) | |||
return protein | |||
def squeeze_features(protein): | |||
"""Remove singleton and repeated dimensions in protein features.""" | |||
protein['aatype'] = np.argmax(protein['aatype'], axis=-1) | |||
for k in ['msa', 'num_alignments', 'seq_length', 'sequence', 'superfamily', 'deletion_matrix', | |||
'resolution', 'between_segment_residues', 'residue_index', 'template_all_atom_masks']: | |||
if k in protein: | |||
final_dim = shape_list(protein[k])[-1] | |||
if isinstance(final_dim, int) and final_dim == 1: | |||
protein[k] = np.squeeze(protein[k], axis=-1) | |||
for k in ['seq_length', 'num_alignments']: | |||
if k in protein: | |||
protein[k] = protein[k][0] # Remove fake sequence dimension | |||
return protein | |||
def cast_64bit_ints(protein): | |||
for k, v in protein.items(): | |||
if v.dtype == np.int64: | |||
protein[k] = v.astype(np.int32) | |||
return protein | |||
def make_seq_mask(protein): | |||
protein['seq_mask'] = np.ones(shape_list(protein['aatype']), dtype=np.float32) | |||
return protein | |||
def make_msa_mask(protein): | |||
"""Mask features are all ones, but will later be zero-padded.""" | |||
protein['msa_mask'] = np.ones(shape_list(protein['msa']), dtype=np.float32) | |||
protein['msa_row_mask'] = np.ones(shape_list(protein['msa'])[0], dtype=np.float32) | |||
return protein | |||
def make_hhblits_profile(protein): | |||
"""Compute the HHblits MSA profile if not already present.""" | |||
if 'hhblits_profile' in protein: | |||
return protein | |||
protein['hhblits_profile'] = np.mean(one_hot(22, protein['msa']), axis=0) | |||
return protein | |||
def make_random_crop_to_size_seed(protein): | |||
"""Random seed for cropping residues and templates.""" | |||
protein['random_crop_to_size_seed'] = np.array(make_random_seed([2], seed_maker_t=seed_maker()), np.int32) | |||
return protein | |||
def fix_templates_aatype(protein): | |||
"""Fixes aatype encoding of templates.""" | |||
protein['template_aatype'] = np.argmax(protein['template_aatype'], axis=-1).astype(np.int32) | |||
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE | |||
new_order = np.array(new_order_list, np.int32) | |||
protein['template_aatype'] = new_order[protein['template_aatype']] | |||
return protein | |||
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): | |||
"""Create pseudo beta features.""" | |||
is_gly = np.equal(aatype, residue_constants.restype_order['G']) | |||
ca_idx = residue_constants.atom_order['CA'] | |||
cb_idx = residue_constants.atom_order['CB'] | |||
pseudo_beta = np.where( | |||
np.tile(is_gly[..., None].astype("int32"), [1,] * len(is_gly.shape) + [3,]).astype("bool"), | |||
all_atom_positions[..., ca_idx, :], | |||
all_atom_positions[..., cb_idx, :]) | |||
if all_atom_masks is not None: | |||
pseudo_beta_mask = np.where(is_gly, all_atom_masks[..., ca_idx], all_atom_masks[..., cb_idx]) | |||
pseudo_beta_mask = pseudo_beta_mask.astype(np.float32) | |||
return pseudo_beta, pseudo_beta_mask | |||
return pseudo_beta | |||
@curry1 | |||
def make_pseudo_beta(protein, prefix=''): | |||
"""Create pseudo-beta (alpha for glycine) position and mask.""" | |||
assert prefix in ['', 'template_'] | |||
protein[prefix + 'pseudo_beta'], protein[prefix + 'pseudo_beta_mask'] = ( | |||
pseudo_beta_fn( | |||
protein['template_aatype' if prefix else 'all_atom_aatype'], | |||
protein[prefix + 'all_atom_positions'], | |||
protein['template_all_atom_masks' if prefix else 'all_atom_mask'])) | |||
return protein | |||
def make_atom14_masks(protein): | |||
"""Construct denser atom positions (14 dimensions instead of 37).""" | |||
restype_atom14_to_atom37 = [] | |||
restype_atom37_to_atom14 = [] | |||
restype_atom14_mask = [] | |||
for rt in residue_constants.restypes: | |||
atom_names = residue_constants.restype_name_to_atom14_names[residue_constants.restype_1to3[rt]] | |||
restype_atom14_to_atom37.append([(residue_constants.atom_order[name] if name else 0) for name in atom_names]) | |||
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} | |||
restype_atom37_to_atom14.append([(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) | |||
for name in residue_constants.atom_types]) | |||
restype_atom14_mask.append([(1. if name else 0.) for name in atom_names]) | |||
restype_atom14_to_atom37.append([0] * 14) | |||
restype_atom37_to_atom14.append([0] * 37) | |||
restype_atom14_mask.append([0.] * 14) | |||
restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, np.int32) | |||
restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, np.int32) | |||
restype_atom14_mask = np.array(restype_atom14_mask, np.float32) | |||
residx_atom14_to_atom37 = restype_atom14_to_atom37[protein['aatype']] | |||
residx_atom14_mask = restype_atom14_mask[protein['aatype']] | |||
protein['atom14_atom_exists'] = residx_atom14_mask | |||
protein['residx_atom14_to_atom37'] = residx_atom14_to_atom37 | |||
residx_atom37_to_atom14 = restype_atom37_to_atom14[protein['aatype']] | |||
protein['residx_atom37_to_atom14'] = residx_atom37_to_atom14 | |||
restype_atom37_mask = np.zeros([21, 37], np.float32) | |||
for restype, restype_letter in enumerate(residue_constants.restypes): | |||
restype_name = residue_constants.restype_1to3[restype_letter] | |||
atom_names = residue_constants.residue_atoms[restype_name] | |||
for atom_name in atom_names: | |||
atom_type = residue_constants.atom_order[atom_name] | |||
restype_atom37_mask[restype, atom_type] = 1 | |||
residx_atom37_mask = restype_atom37_mask[protein['aatype']] | |||
protein['atom37_atom_exists'] = residx_atom37_mask | |||
return protein | |||
def shape_list(x): | |||
"""Return list of dimensions of an array.""" | |||
x = np.array(x) | |||
if x.ndim is None: | |||
return x.shape | |||
static = x.shape | |||
ret = [] | |||
for _, dim in enumerate(static): | |||
ret.append(dim) | |||
return ret |
@@ -0,0 +1,294 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""feature extraction""" | |||
import copy | |||
import numpy as np | |||
from commons import residue_constants | |||
from data.feature import data_transforms | |||
NUM_RES = "num residues placeholder" | |||
NUM_SEQ = "length msa placeholder" | |||
NUM_TEMPLATES = "num templates placeholder" | |||
FEATURES = { | |||
"aatype": (np.float32, [NUM_RES, 21]), | |||
"between_segment_residues": (np.int64, [NUM_RES, 1]), | |||
"deletion_matrix": (np.float32, [NUM_SEQ, NUM_RES, 1]), | |||
"msa": (np.int64, [NUM_SEQ, NUM_RES, 1]), | |||
"num_alignments": (np.int64, [NUM_RES, 1]), | |||
"residue_index": (np.int64, [NUM_RES, 1]), | |||
"seq_length": (np.int64, [NUM_RES, 1]), | |||
"all_atom_positions": (np.float32, [NUM_RES, residue_constants.atom_type_num, 3]), | |||
"all_atom_mask": (np.int64, [NUM_RES, residue_constants.atom_type_num]), | |||
"resolution": (np.float32, [1]), | |||
"template_domain_names": (str, [NUM_TEMPLATES]), | |||
"template_sum_probs": (np.float32, [NUM_TEMPLATES, 1]), | |||
"template_aatype": (np.float32, [NUM_TEMPLATES, NUM_RES, 22]), | |||
"template_all_atom_positions": (np.float32, [NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 3]), | |||
"template_all_atom_masks": (np.float32, [NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 1]), | |||
} | |||
def nonensembled_map_fns(data_config): | |||
"""Input pipeline functions which are not ensembled.""" | |||
common_cfg = data_config.common | |||
map_fns = [ | |||
data_transforms.correct_msa_restypes, | |||
data_transforms.add_distillation_flag(False), | |||
data_transforms.cast_64bit_ints, | |||
data_transforms.squeeze_features, | |||
data_transforms.randomly_replace_msa_with_unknown(0.0), | |||
data_transforms.make_seq_mask, | |||
data_transforms.make_msa_mask, | |||
data_transforms.make_hhblits_profile, | |||
data_transforms.make_random_crop_to_size_seed, | |||
] | |||
if common_cfg.use_templates: | |||
map_fns.extend([data_transforms.fix_templates_aatype, data_transforms.make_pseudo_beta('template_')]) | |||
map_fns.extend([data_transforms.make_atom14_masks,]) | |||
return map_fns | |||
def ensembled_map_fns(data_config): | |||
"""Input pipeline functions that can be ensembled and averaged.""" | |||
common_cfg = data_config.common | |||
eval_cfg = data_config.eval | |||
map_fns = [] | |||
if common_cfg.reduce_msa_clusters_by_max_templates: | |||
pad_msa_clusters = eval_cfg.max_msa_clusters - eval_cfg.max_templates | |||
else: | |||
pad_msa_clusters = eval_cfg.max_msa_clusters | |||
max_msa_clusters = pad_msa_clusters | |||
max_extra_msa = common_cfg.max_extra_msa | |||
map_fns.append(data_transforms.sample_msa(max_msa_clusters, keep_extra=True)) | |||
if 'masked_msa' in common_cfg: | |||
map_fns.append(data_transforms.make_masked_msa(common_cfg.masked_msa, eval_cfg.masked_msa_replace_fraction)) | |||
if common_cfg.msa_cluster_features: | |||
map_fns.append(data_transforms.nearest_neighbor_clusters()) | |||
map_fns.append(data_transforms.summarize_clusters()) | |||
if max_extra_msa: | |||
map_fns.append(data_transforms.crop_extra_msa(max_extra_msa)) | |||
else: | |||
map_fns.append(data_transforms.delete_extra_msa) | |||
map_fns.append(data_transforms.make_msa_feat()) | |||
crop_feats = dict(eval_cfg.feat) | |||
if eval_cfg.fixed_size: | |||
map_fns.append(data_transforms.select_feat(list(crop_feats))) | |||
map_fns.append(data_transforms.random_crop_to_size( | |||
eval_cfg.crop_size, | |||
eval_cfg.max_templates, | |||
crop_feats, | |||
eval_cfg.subsample_templates)) | |||
map_fns.append(data_transforms.make_fixed_size( | |||
crop_feats, | |||
pad_msa_clusters, | |||
common_cfg.max_extra_msa, | |||
eval_cfg.crop_size, | |||
eval_cfg.max_templates)) | |||
else: | |||
map_fns.append(data_transforms.crop_templates(eval_cfg.max_templates)) | |||
return map_fns | |||
def process_arrays_from_config(arrays, data_config): | |||
"""Apply filters and maps to an existing dataset, based on the config.""" | |||
def wrap_ensemble_fn(data, i): | |||
"""Function to be mapped over the ensemble dimension.""" | |||
d = data.copy() | |||
fns = ensembled_map_fns(data_config) | |||
fn = data_transforms.compose(fns) | |||
d['ensemble_index'] = i | |||
return fn(d) | |||
eval_cfg = data_config.eval | |||
arrays = data_transforms.compose(nonensembled_map_fns(data_config))(arrays) | |||
arrays_0 = wrap_ensemble_fn(arrays, np.array(0, np.int32)) | |||
num_ensemble = eval_cfg.num_ensemble | |||
if data_config.common.resample_msa_in_recycling: | |||
num_ensemble *= data_config.common.num_recycle + 1 | |||
result_array = {x: () for x in arrays_0.keys()} | |||
if num_ensemble > 1: | |||
for i in range(num_ensemble): | |||
arrays_t = wrap_ensemble_fn(arrays, np.array(i, np.int32)) | |||
for key in arrays_0.keys(): | |||
result_array[key] += (arrays_t[key][None],) | |||
for key in arrays_0.keys(): | |||
result_array[key] = np.concatenate(result_array[key], axis=0) | |||
else: | |||
result_array = {key: arrays_0[key][None] for key in arrays_0.keys()} | |||
return result_array | |||
def feature_shape(feature_name, | |||
num_residues, | |||
msa_length, | |||
num_templates, | |||
features=None): | |||
"""Get the shape for the given feature name.""" | |||
features = features or FEATURES | |||
if feature_name.endswith("_unnormalized"): | |||
feature_name = feature_name[:-13] | |||
unused_dtype, raw_sizes = features[feature_name] | |||
replacements = {NUM_RES: num_residues, | |||
NUM_SEQ: msa_length} | |||
if num_templates is not None: | |||
replacements[NUM_TEMPLATES] = num_templates | |||
sizes = [replacements.get(dimension, dimension) for dimension in raw_sizes] | |||
for dimension in sizes: | |||
if isinstance(dimension, str): | |||
raise ValueError("Could not parse %s (shape: %s) with values: %s" % ( | |||
feature_name, raw_sizes, replacements)) | |||
size_r = [int(x) for x in sizes] | |||
return size_r | |||
def parse_reshape_logic(parsed_features, features, num_template, key=None): | |||
"""Transforms parsed serial features to the correct shape.""" | |||
num_residues = np.reshape(parsed_features['seq_length'].astype(np.int32), (-1,))[0] | |||
if "num_alignments" in parsed_features: | |||
num_msa = np.reshape(parsed_features["num_alignments"].astype(np.int32), (-1,))[0] | |||
else: | |||
num_msa = 0 | |||
if key is not None and "key" in features: | |||
parsed_features["key"] = [key] # Expand dims from () to (1,). | |||
for k, v in parsed_features.items(): | |||
new_shape = feature_shape( | |||
feature_name=k, | |||
num_residues=num_residues, | |||
msa_length=num_msa, | |||
num_templates=num_template, | |||
features=features) | |||
new_shape_size = 1 | |||
for dim in new_shape: | |||
new_shape_size *= dim | |||
if np.size(v) != new_shape_size: | |||
raise ValueError("the size of feature {} ({}) could not be reshaped into {}" | |||
"".format(k, np.size(v), new_shape)) | |||
if "template" not in k: | |||
if np.size(v) <= 0: | |||
raise ValueError("The feature {} is not empty.".format(k)) | |||
parsed_features[k] = np.reshape(v, new_shape) | |||
return parsed_features | |||
def _make_features_metadata(feature_names): | |||
"""Makes a feature name to type and shape mapping from a list of names.""" | |||
required_features = ["sequence", "domain_name", "template_domain_names"] | |||
feature_names = list(set(feature_names) - set(required_features)) | |||
features_metadata = {name: FEATURES[name] for name in feature_names} | |||
return features_metadata | |||
def np_to_array_dict(np_example, features): | |||
"""Creates dict of arrays.""" | |||
features_metadata = _make_features_metadata(features) | |||
array_dict = {k: v for k, v in np_example.items() if k in features_metadata} | |||
if "template_domain_names" in np_example: | |||
num_template = len(np_example["template_domain_names"]) | |||
else: | |||
num_template = 0 | |||
array_dict = parse_reshape_logic(array_dict, features_metadata, num_template) | |||
array_dict['template_mask'] = np.ones([num_template], np.float32) | |||
return array_dict | |||
def make_data_config(config, num_res): | |||
"""Makes a data config for the input pipeline.""" | |||
cfg = copy.deepcopy(config.data) | |||
feature_names = cfg.common.unsupervised_features | |||
if cfg.common.use_templates: | |||
feature_names += cfg.common.template_features | |||
with cfg.unlocked(): | |||
cfg.eval.crop_size = num_res | |||
return cfg, feature_names | |||
def custom_padding(config, arrays, dims): | |||
"""Pad array to fixed size.""" | |||
step_size = config.seq_length | |||
res_length = arrays[0].shape[dims[0]] | |||
padding_size = step_size - res_length | |||
for i, arr in enumerate(arrays): | |||
if dims[i] == -1: | |||
continue | |||
extra_array_shape = list(arr.shape) | |||
extra_array_shape[dims[i]] = padding_size | |||
extra_array = np.zeros(extra_array_shape, dtype=arr.dtype) | |||
arrays[i] = np.concatenate((arr, extra_array), axis=dims[i]) | |||
return arrays | |||
def process_features(raw_features, config, global_config): | |||
"""Preprocesses NumPy feature dict using pipeline.""" | |||
num_res = int(raw_features['seq_length'][0]) | |||
cfg, feature_names = make_data_config(config, num_res=num_res) | |||
if 'deletion_matrix_int' in raw_features: | |||
raw_features['deletion_matrix'] = (raw_features.pop('deletion_matrix_int').astype(np.float32)) | |||
array_dict = np_to_array_dict(np_example=raw_features, features=feature_names) | |||
features = process_arrays_from_config(array_dict, cfg) | |||
features = {k: v for k, v in features.items() if v.dtype != 'O'} | |||
extra_msa_length = global_config.extra_msa_length | |||
ori_res_length = features["target_feat"].shape[1] | |||
aatype = features["aatype"] | |||
residue_index = features["residue_index"] | |||
for key in ["extra_msa", "extra_has_deletion", "extra_deletion_value", "extra_msa_mask"]: | |||
features[key] = features[key][:, :extra_msa_length] | |||
input_keys = ['target_feat', 'msa_feat', 'msa_mask', 'seq_mask', 'aatype', 'template_aatype', | |||
'template_all_atom_masks', 'template_all_atom_positions', 'template_mask', | |||
'template_pseudo_beta_mask', 'template_pseudo_beta', 'template_sum_probs', | |||
'extra_msa', 'extra_has_deletion', 'extra_deletion_value', 'extra_msa_mask', | |||
'residx_atom37_to_atom14', 'atom37_atom_exists', 'residue_index'] | |||
arrays = [features[key] for key in input_keys] | |||
dims = [1, 2, 2, 1, 1, 2, 2, 2, -1, 2, 2, -1, 2, 2, 2, 2, 1, 1, 1] | |||
arrays = custom_padding(global_config, arrays, dims) | |||
arrays = [array.astype(np.float16) if array.dtype == "float64" else array for array in arrays] | |||
arrays = [array.astype(np.float16) if array.dtype == "float32" else array for array in arrays] | |||
return arrays, aatype, residue_index, ori_res_length |
@@ -0,0 +1,205 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
'''data process''' | |||
import os | |||
import hashlib | |||
import re | |||
import numpy as np | |||
from commons import residue_constants | |||
from data.tools.parsers import parse_fasta, parse_hhr, parse_a3m | |||
from data.tools.templates import TemplateHitFeaturizer | |||
from data.tools.data_tools import HHSearch | |||
def get_hash(x): | |||
return hashlib.sha1(x.encode()).hexdigest() | |||
def run_mmseqs2(x, path, use_env=False): | |||
'''run mmseqs2''' | |||
a3m_files = [f"{path}/uniref.a3m"] | |||
if use_env: a3m_files.append(f"{path}/bfd.mgnify30.metaeuk30.smag30.a3m") | |||
# gather a3m lines | |||
a3m_lines = {} | |||
for a3m_file in a3m_files: | |||
update_m, m = True, None | |||
for line in open(a3m_file, "r"): | |||
if line: | |||
if "\x00" in line: | |||
line = line.replace("\x00", "") | |||
update_m = True | |||
if line.startswith(">") and update_m: | |||
m = int(line[2:6].rstrip()) | |||
update_m = False | |||
if m not in a3m_lines: a3m_lines[m] = [] | |||
a3m_lines[m].append(line) | |||
# return results | |||
a3m_lines = ["".join(a3m_lines[key]) for key in a3m_lines] | |||
if isinstance(x, str): | |||
return a3m_lines[0] | |||
return a3m_lines | |||
def make_sequence_features( | |||
sequence: str, description: str, num_res: int): | |||
"""Constructs a feature dict of sequence features.""" | |||
features = {'aatype': residue_constants.sequence_to_onehot(sequence=sequence, | |||
mapping=residue_constants.restype_order_with_x, | |||
map_unknown_to_x=True), | |||
'between_segment_residues': np.zeros((num_res,), dtype=np.int32), | |||
'domain_name': np.array([description.encode('utf-8')], dtype=np.object_), | |||
'residue_index': np.array(range(num_res), dtype=np.int32), | |||
'seq_length': np.array([num_res] * num_res, dtype=np.int32), | |||
'sequence': np.array([sequence.encode('utf-8')], dtype=np.object_)} | |||
return features | |||
def make_msa_features( | |||
msas, | |||
deletion_matrices): | |||
"""Constructs a feature dict of MSA features.""" | |||
if not msas: | |||
raise ValueError('At least one MSA must be provided.') | |||
int_msa = [] | |||
deletion_matrix = [] | |||
seen_sequences = set() | |||
for msa_index, msa in enumerate(msas): | |||
if not msa: | |||
raise ValueError(f'MSA {msa_index} must contain at least one sequence.') | |||
for sequence_index, sequence in enumerate(msa): | |||
if sequence in seen_sequences: | |||
continue | |||
seen_sequences.add(sequence) | |||
int_msa.append( | |||
[residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]) | |||
deletion_matrix.append(deletion_matrices[msa_index][sequence_index]) | |||
num_res = len(msas[0][0]) | |||
num_alignments = len(int_msa) | |||
features = {'deletion_matrix_int': np.array(deletion_matrix, dtype=np.int32), | |||
'msa': np.array(int_msa, dtype=np.int32), | |||
'num_alignments': np.array([num_alignments] * num_res, dtype=np.int32)} | |||
return features | |||
class DataPipeline: | |||
"""Runs the alignment tools and assembles the input features.""" | |||
def __init__(self, | |||
hhsearch_binary_path: str, | |||
pdb70_database_path: str, | |||
template_featurizer: TemplateHitFeaturizer, | |||
result_path, | |||
use_env=False): | |||
"""Constructs a feature dict for a given FASTA file.""" | |||
self.hhsearch_pdb70_runner = HHSearch( | |||
binary_path=hhsearch_binary_path, | |||
databases=[pdb70_database_path]) | |||
self.template_featurizer = template_featurizer | |||
self.result_path = result_path | |||
self.use_env = use_env | |||
def process(self, input_fasta_path): | |||
"""Runs alignment tools on the input sequence and creates features.""" | |||
with open(input_fasta_path) as f: | |||
input_fasta_str = f.read() | |||
input_seqs, input_descs = parse_fasta(input_fasta_str) | |||
if len(input_seqs) != 1: | |||
raise ValueError(f'More than one input sequence found in {input_fasta_path}.') | |||
input_sequence = input_seqs[0] | |||
input_description = input_descs[0] | |||
num_res = len(input_sequence) | |||
# mmseq2 | |||
sequence = input_sequence | |||
sequence = re.sub("[^A-Z:/]", "", sequence.upper()) | |||
sequence = re.sub(":+", ":", sequence) | |||
sequence = re.sub("/+", "/", sequence) | |||
sequence = re.sub("^[:/]+", "", sequence) | |||
sequence = re.sub("[:/]+$", "", sequence) | |||
ori_sequence = sequence | |||
seqs = ori_sequence.replace("/", "").split(":") | |||
a3m_lines = run_mmseqs2(seqs, path=self.result_path, use_env=self.use_env) | |||
hhsearch_result = self.hhsearch_pdb70_runner.query(a3m_lines[0]) | |||
hhsearch_hits = parse_hhr(hhsearch_result) | |||
msas, deletion_matrices = parse_a3m(a3m_lines[0]) | |||
templates_result = self.template_featurizer.get_templates( | |||
query_sequence=input_sequence, | |||
query_pdb_code=None, | |||
query_release_date=None, | |||
hhr_hits=hhsearch_hits) | |||
sequence_features = make_sequence_features( | |||
sequence=input_sequence, | |||
description=input_description, | |||
num_res=num_res) | |||
msa_features = make_msa_features( | |||
msas=(msas,), | |||
deletion_matrices=(deletion_matrices, | |||
)) | |||
return {**sequence_features, **msa_features, **templates_result.features} | |||
def data_process(seq_name, args): | |||
"""data_process""" | |||
fasta_path = os.path.join(args.input_fasta_path, seq_name + '.fasta') | |||
result_path = os.path.join(args.msa_result_path, "/result_" + str(seq_name)) | |||
if args.database_envdb_dir: | |||
use_env = True | |||
command = "sh ./data/tools/msa_search.sh mmseqs " + fasta_path + " " + result_path + " " + \ | |||
args.database_dir + " " + "\"\"" + " " + args.database_envdb_dir + " \"1\" \"0\" \"1\"" | |||
else: | |||
use_env = False | |||
command = "sh ./data/tools/msa_search.sh mmseqs " + fasta_path + " " + result_path + " " + \ | |||
args.database_dir + " " + "\"\"" + " \"\"" + " \"0\" \"0\" \"1\"" | |||
print('start mmseqs2 MSA') | |||
print('command: ', command) | |||
os.system(command) | |||
print('mmseqs2 MSA successful') | |||
print('use_env: ', use_env) | |||
hhsearch_binary_path = args.hhsearch_binary_path | |||
pdb70_database_path = args.pdb70_database_path | |||
template_mmcif_dir = args.template_mmcif_dir | |||
max_template_date = args.max_template_date | |||
kalign_binary_path = args.kalign_binary_path | |||
obsolete_pdbs_path = args.obsolete_pdbs_path | |||
template_featurizer = TemplateHitFeaturizer( | |||
mmcif_dir=template_mmcif_dir, | |||
max_template_date=max_template_date, | |||
max_hits=20, | |||
kalign_binary_path=kalign_binary_path, | |||
release_dates_path=None, | |||
obsolete_pdbs_path=obsolete_pdbs_path) | |||
data_pipeline = DataPipeline( | |||
hhsearch_binary_path=hhsearch_binary_path, | |||
pdb70_database_path=pdb70_database_path, | |||
template_featurizer=template_featurizer, | |||
result_path=result_path, | |||
use_env=use_env) | |||
feature_dict = data_pipeline.process(fasta_path) | |||
return feature_dict |
@@ -0,0 +1,428 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
'''data tools''' | |||
import glob | |||
import os | |||
import subprocess | |||
import contextlib | |||
import shutil | |||
import tempfile | |||
import time | |||
from typing import Any, Mapping, Optional, Sequence | |||
from absl import logging | |||
_HHBLITS_DEFAULT_P = 20 | |||
_HHBLITS_DEFAULT_Z = 500 | |||
def _to_a3m(sequences: Sequence[str]) -> str: | |||
"""Converts sequences to an a3m file.""" | |||
names = ['sequence %d' % i for i in range(1, len(sequences) + 1)] | |||
a3m = [] | |||
for sequence, name in zip(sequences, names): | |||
a3m.append(u'>' + name + u'\n') | |||
a3m.append(sequence + u'\n') | |||
return ''.join(a3m) | |||
class Kalign: | |||
"""Python wrapper of the Kalign binary.""" | |||
def __init__(self, *, binary_path: str): | |||
"""Initializes the Python Kalign wrapper. | |||
Args: | |||
binary_path: The path to the Kalign binary. | |||
""" | |||
self.binary_path = binary_path | |||
def align(self, sequences: Sequence[str]) -> str: | |||
"""Aligns the sequences and returns the alignment in A3M string. | |||
Args: | |||
sequences: A list of query sequence strings. The sequences have to be at | |||
least 6 residues long (Kalign requires this). Note that the order in | |||
which you give the sequences might alter the output slightly as | |||
different alignment tree might get constructed. | |||
Returns: | |||
A string with the alignment in a3m format. | |||
Raises: | |||
RuntimeError: If Kalign fails. | |||
ValueError: If any of the sequences is less than 6 residues long. | |||
""" | |||
logging.info('Aligning %d sequences', len(sequences)) | |||
for s in sequences: | |||
if len(s) < 6: | |||
raise ValueError('Kalign requires all sequences to be at least 6 ' | |||
'residues long. Got %s (%d residues).' % (s, len(s))) | |||
with tmpdir_manager(base_dir='/tmp') as query_tmp_dir: | |||
input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta') | |||
output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m') | |||
with open(input_fasta_path, 'w') as f: | |||
f.write(_to_a3m(sequences)) | |||
cmd = [self.binary_path, '-i', input_fasta_path, '-o', output_a3m_path, '-format', 'fasta',] | |||
logging.info('Launching subprocess "%s"', ' '.join(cmd)) | |||
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |||
with timing('Kalign query'): | |||
stdout, stderr = process.communicate() | |||
retcode = process.wait() | |||
logging.info('Kalign stdout:\n%s\n\nstderr:\n%s\n', stdout.decode('utf-8'), stderr.decode('utf-8')) | |||
if retcode: | |||
raise RuntimeError( | |||
'Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n' % (stdout.decode('utf-8'), stderr.decode('utf-8'))) | |||
with open(output_a3m_path) as f: | |||
a3m = f.read() | |||
return a3m | |||
@contextlib.contextmanager | |||
def tmpdir_manager(base_dir: Optional[str] = None): | |||
"""Context manager that deletes a temporary directory on exit.""" | |||
tmpdir = tempfile.mkdtemp(dir=base_dir) | |||
try: | |||
yield tmpdir | |||
finally: | |||
shutil.rmtree(tmpdir, ignore_errors=True) | |||
@contextlib.contextmanager | |||
def timing(msg: str): | |||
logging.info('Started %s', msg) | |||
tic = time.time() | |||
yield | |||
toc = time.time() | |||
logging.info('Finished %s in %.3f seconds', msg, toc - tic) | |||
class HHBlits: | |||
"""Python wrapper of the HHblits binary.""" | |||
def __init__(self, | |||
*, | |||
binary_path: str, | |||
databases: Sequence[str], | |||
n_cpu: int = 4, | |||
n_iter: int = 3, | |||
e_value: float = 0.001, | |||
maxseq: int = 1_000_000, | |||
realign_max: int = 100_000, | |||
maxfilt: int = 100_000, | |||
min_prefilter_hits: int = 1000, | |||
all_seqs: bool = False, | |||
alt: Optional[int] = None, | |||
p: int = _HHBLITS_DEFAULT_P, | |||
z: int = _HHBLITS_DEFAULT_Z): | |||
"""Initializes the Python HHblits wrapper. | |||
Args: | |||
binary_path: The path to the HHblits executable. | |||
databases: A sequence of HHblits database paths. This should be the | |||
common prefix for the database files (i.e. up to but not including | |||
_hhm.ffindex etc.) | |||
n_cpu: The number of CPUs to give HHblits. | |||
n_iter: The number of HHblits iterations. | |||
e_value: The E-value, see HHblits docs for more details. | |||
maxseq: The maximum number of rows in an input alignment. Note that this | |||
parameter is only supported in HHBlits version 3.1 and higher. | |||
realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500. | |||
maxfilt: Max number of hits allowed to pass the 2nd prefilter. | |||
HHblits default: 20000. | |||
min_prefilter_hits: Min number of hits to pass prefilter. | |||
HHblits default: 100. | |||
all_seqs: Return all sequences in the MSA / Do not filter the result MSA. | |||
HHblits default: False. | |||
alt: Show up to this many alternative alignments. | |||
p: Minimum Prob for a hit to be included in the output hhr file. | |||
HHblits default: 20. | |||
z: Hard cap on number of hits reported in the hhr file. | |||
HHblits default: 500. NB: The relevant HHblits flag is -Z not -z. | |||
Raises: | |||
RuntimeError: If HHblits binary not found within the path. | |||
""" | |||
self.binary_path = binary_path | |||
self.databases = databases | |||
for database_path in self.databases: | |||
if not glob.glob(database_path + '_*'): | |||
logging.error('Could not find HHBlits database %s', database_path) | |||
raise ValueError(f'Could not find HHBlits database {database_path}') | |||
self.n_cpu = n_cpu | |||
self.n_iter = n_iter | |||
self.e_value = e_value | |||
self.maxseq = maxseq | |||
self.realign_max = realign_max | |||
self.maxfilt = maxfilt | |||
self.min_prefilter_hits = min_prefilter_hits | |||
self.all_seqs = all_seqs | |||
self.alt = alt | |||
self.p = p | |||
self.z = z | |||
def query(self, input_fasta_path: str) -> Mapping[str, Any]: | |||
"""Queries the database using HHblits.""" | |||
with tmpdir_manager(base_dir='/tmp') as query_tmp_dir: | |||
a3m_path = os.path.join(query_tmp_dir, 'output.a3m') | |||
db_cmd = [] | |||
for db_path in self.databases: | |||
db_cmd.append('-d') | |||
db_cmd.append(db_path) | |||
cmd = [ | |||
self.binary_path, | |||
'-i', input_fasta_path, | |||
'-cpu', str(self.n_cpu), | |||
'-oa3m', a3m_path, | |||
'-o', '/dev/null', | |||
'-n', str(self.n_iter), | |||
'-e', str(self.e_value), | |||
'-maxseq', str(self.maxseq), | |||
'-realign_max', str(self.realign_max), | |||
'-maxfilt', str(self.maxfilt), | |||
'-min_prefilter_hits', str(self.min_prefilter_hits)] | |||
if self.all_seqs: | |||
cmd += ['-all'] | |||
if self.alt: | |||
cmd += ['-alt', str(self.alt)] | |||
if self.p != _HHBLITS_DEFAULT_P: | |||
cmd += ['-p', str(self.p)] | |||
if self.z != _HHBLITS_DEFAULT_Z: | |||
cmd += ['-Z', str(self.z)] | |||
cmd += db_cmd | |||
logging.info('Launching subprocess "%s"', ' '.join(cmd)) | |||
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |||
with timing('HHblits query'): | |||
stdout, stderr = process.communicate() | |||
retcode = process.wait() | |||
if retcode: | |||
# Logs have a 15k character limit, so log HHblits error line by | |||
# line. | |||
logging.error('HHblits failed. HHblits stderr begin:') | |||
for error_line in stderr.decode('utf-8').splitlines(): | |||
if error_line.strip(): | |||
logging.error(error_line.strip()) | |||
logging.error('HHblits stderr end') | |||
raise RuntimeError('HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n' % ( | |||
stdout.decode('utf-8'), stderr[:500_000].decode('utf-8'))) | |||
with open(a3m_path) as f: | |||
a3m = f.read() | |||
raw_output = dict( | |||
a3m=a3m, | |||
output=stdout, | |||
stderr=stderr, | |||
n_iter=self.n_iter, | |||
e_value=self.e_value) | |||
return raw_output | |||
class HHSearch: | |||
"""Python wrapper of the HHsearch binary.""" | |||
def __init__(self, | |||
*, | |||
binary_path: str, | |||
databases: Sequence[str], | |||
maxseq: int = 1_000_000): | |||
"""Initializes the Python HHsearch wrapper. | |||
Args: | |||
binary_path: The path to the HHsearch executable. | |||
databases: A sequence of HHsearch database paths. This should be the | |||
common prefix for the database files (i.e. up to but not including | |||
_hhm.ffindex etc.) | |||
maxseq: The maximum number of rows in an input alignment. Note that this | |||
parameter is only supported in HHBlits version 3.1 and higher. | |||
Raises: | |||
RuntimeError: If HHsearch binary not found within the path. | |||
""" | |||
self.binary_path = binary_path | |||
self.databases = databases | |||
self.maxseq = maxseq | |||
for database_path in self.databases: | |||
if not glob.glob(database_path + '_*'): | |||
logging.error( | |||
'Could not find HHsearch database %s', | |||
database_path) | |||
raise ValueError( | |||
f'Could not find HHsearch database {database_path}') | |||
def query(self, a3m: str) -> str: | |||
"""Queries the database using HHsearch using a given a3m.""" | |||
with tmpdir_manager(base_dir='/tmp') as query_tmp_dir: | |||
input_path = os.path.join(query_tmp_dir, 'query.a3m') | |||
hhr_path = os.path.join(query_tmp_dir, 'output.hhr') | |||
with open(input_path, 'w') as f: | |||
f.write(a3m) | |||
db_cmd = [] | |||
for db_path in self.databases: | |||
db_cmd.append('-d') | |||
db_cmd.append(db_path) | |||
cmd = [self.binary_path, | |||
'-i', input_path, | |||
'-o', hhr_path, | |||
'-maxseq', str(self.maxseq), | |||
'-cpu', '8',] + db_cmd | |||
logging.info('Launching subprocess "%s"', ' '.join(cmd)) | |||
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |||
with timing('HHsearch query'): | |||
stdout, stderr = process.communicate() | |||
retcode = process.wait() | |||
if retcode: | |||
# Stderr is truncated to prevent proto size errors in Beam. | |||
raise RuntimeError('HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % ( | |||
stdout.decode('utf-8'), stderr[:100_000].decode('utf-8'))) | |||
with open(hhr_path) as f: | |||
hhr = f.read() | |||
return hhr | |||
class Jackhmmer: | |||
"""Python wrapper of the Jackhmmer binary.""" | |||
def __init__(self, | |||
*, | |||
binary_path: str, | |||
database_path: str, | |||
n_cpu: int = 8, | |||
n_iter: int = 1, | |||
e_value: float = 0.0001, | |||
z_value: Optional[int] = None, | |||
get_tblout: bool = False, | |||
filter_f1: float = 0.0005, | |||
filter_f2: float = 0.00005, | |||
filter_f3: float = 0.0000005, | |||
incdom_e: Optional[float] = None, | |||
dom_e: Optional[float] = None): | |||
"""Initializes the Python Jackhmmer wrapper. | |||
Args: | |||
binary_path: The path to the jackhmmer executable. | |||
database_path: The path to the jackhmmer database (FASTA format). | |||
n_cpu: The number of CPUs to give Jackhmmer. | |||
n_iter: The number of Jackhmmer iterations. | |||
e_value: The E-value, see Jackhmmer docs for more details. | |||
z_value: The Z-value, see Jackhmmer docs for more details. | |||
get_tblout: Whether to save tblout string. | |||
filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off. | |||
filter_f2: Viterbi pre-filter, set to >1.0 to turn off. | |||
filter_f3: Forward pre-filter, set to >1.0 to turn off. | |||
incdom_e: Domain e-value criteria for inclusion of domains in MSA/next | |||
round. | |||
dom_e: Domain e-value criteria for inclusion in tblout. | |||
""" | |||
self.binary_path = binary_path | |||
self.database_path = database_path | |||
if not os.path.exists(self.database_path): | |||
logging.error( | |||
'Could not find Jackhmmer database %s', | |||
database_path) | |||
raise ValueError( | |||
f'Could not find Jackhmmer database {database_path}') | |||
self.n_cpu = n_cpu | |||
self.n_iter = n_iter | |||
self.e_value = e_value | |||
self.z_value = z_value | |||
self.filter_f1 = filter_f1 | |||
self.filter_f2 = filter_f2 | |||
self.filter_f3 = filter_f3 | |||
self.incdom_e = incdom_e | |||
self.dom_e = dom_e | |||
self.get_tblout = get_tblout | |||
def query(self, input_fasta_path: str) -> Mapping[str, Any]: | |||
"""Queries the database using Jackhmmer.""" | |||
with tmpdir_manager(base_dir='/tmp') as query_tmp_dir: | |||
sto_path = os.path.join(query_tmp_dir, 'output.sto') | |||
# The F1/F2/F3 are the expected proportion to pass each of the filtering | |||
# stages (which get progressively more expensive), reducing these | |||
# speeds up the pipeline at the expensive of sensitivity. They are | |||
# currently set very low to make querying Mgnify run in a reasonable | |||
# amount of time. | |||
cmd_flags = [ | |||
# Don't pollute stdout with Jackhmmer output. | |||
'-o', '/dev/null', | |||
'-A', sto_path, | |||
'--noali', | |||
'--F1', str(self.filter_f1), | |||
'--F2', str(self.filter_f2), | |||
'--F3', str(self.filter_f3), | |||
'--incE', str(self.e_value), | |||
# Report only sequences with E-values <= x in per-sequence | |||
# output. | |||
'-E', str(self.e_value), | |||
'--cpu', str(self.n_cpu), | |||
'-N', str(self.n_iter) | |||
] | |||
if self.get_tblout: | |||
tblout_path = os.path.join(query_tmp_dir, 'tblout.txt') | |||
cmd_flags.extend(['--tblout', tblout_path]) | |||
if self.z_value: | |||
cmd_flags.extend(['-Z', str(self.z_value)]) | |||
if self.dom_e is not None: | |||
cmd_flags.extend(['--domE', str(self.dom_e)]) | |||
if self.incdom_e is not None: | |||
cmd_flags.extend(['--incdomE', str(self.incdom_e)]) | |||
cmd = [self.binary_path] + cmd_flags + [input_fasta_path, self.database_path] | |||
logging.info('Launching subprocess "%s"', ' '.join(cmd)) | |||
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |||
with timing(f'Jackhmmer ({os.path.basename(self.database_path)}) query'): | |||
_, stderr = process.communicate() | |||
retcode = process.wait() | |||
if retcode: | |||
raise RuntimeError('Jackhmmer failed\nstderr:\n%s\n' % stderr.decode('utf-8')) | |||
# Get e-values for each target name | |||
tbl = '' | |||
if self.get_tblout: | |||
with open(tblout_path) as f: | |||
tbl = f.read() | |||
with open(sto_path) as f: | |||
sto = f.read() | |||
raw_output = dict(sto=sto, tbl=tbl, stderr=stderr, n_iter=self.n_iter, e_value=self.e_value) | |||
return raw_output |
@@ -0,0 +1,393 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""Parses the mmCIF file format.""" | |||
import collections | |||
import io | |||
import dataclasses | |||
from typing import Any, Mapping, Optional, Sequence, Tuple | |||
from absl import logging | |||
from Bio import PDB | |||
from Bio.Data import SCOPData | |||
# Type aliases: | |||
ChainId = str | |||
PdbHeader = Mapping[str, Any] | |||
PDBSTRUCTURE = PDB.Structure.Structure | |||
SeqRes = str | |||
MmCIFDict = Mapping[str, Sequence[str]] | |||
@dataclasses.dataclass(frozen=True) | |||
class Monomer: | |||
id: str | |||
num: int | |||
# Note - mmCIF format provides no guarantees on the type of author-assigned | |||
# sequence numbers. They need not be integers. | |||
@dataclasses.dataclass(frozen=True) | |||
class AtomSite: | |||
residue_name: str | |||
author_chain_id: str | |||
mmcif_chain_id: str | |||
author_seq_num: str | |||
mmcif_seq_num: int | |||
insertion_code: str | |||
hetatm_atom: str | |||
model_num: int | |||
# Used to map SEQRES index to a residue in the structure. | |||
@dataclasses.dataclass(frozen=True) | |||
class ResiduePosition: | |||
chain_id: str | |||
residue_number: int | |||
insertion_code: str | |||
@dataclasses.dataclass(frozen=True) | |||
class ResidueAtPosition: | |||
position: Optional[ResiduePosition] | |||
name: str | |||
is_missing: bool | |||
hetflag: str | |||
@dataclasses.dataclass(frozen=True) | |||
class MmcifObject: | |||
"""Representation of a parsed mmCIF file. | |||
Contains: | |||
file_id: A meaningful name, e.g. a pdb_id. Should be unique amongst all | |||
files being processed. | |||
header: Biopython header. | |||
structure: Biopython structure. | |||
chain_to_seqres: Dict mapping chain_id to 1 letter amino acid sequence. E.g. | |||
{'A': 'ABCDEFG'} | |||
seqres_to_structure: Dict; for each chain_id contains a mapping between | |||
SEQRES index and a ResidueAtPosition. e.g. {'A': {0: ResidueAtPosition, | |||
1: ResidueAtPosition, | |||
...}} | |||
raw_string: The raw string used to construct the MmcifObject. | |||
""" | |||
file_id: str | |||
header: PdbHeader | |||
structure: PDBSTRUCTURE | |||
chain_to_seqres: Mapping[ChainId, SeqRes] | |||
seqres_to_structure: Mapping[ChainId, Mapping[int, ResidueAtPosition]] | |||
raw_string: Any | |||
@dataclasses.dataclass(frozen=True) | |||
class ParsingResult: | |||
"""Returned by the parse function. | |||
Contains: | |||
mmcif_object: A MmcifObject, may be None if no chain could be successfully | |||
parsed. | |||
errors: A dict mapping (file_id, chain_id) to any exception generated. | |||
""" | |||
mmcif_object: Optional[MmcifObject] | |||
errors: Mapping[Tuple[str, str], Any] | |||
class ParseError(Exception): | |||
"""An error indicating that an mmCIF file could not be parsed.""" | |||
def mmcif_loop_to_list(prefix: str, | |||
parsed_info: MmCIFDict) -> Sequence[Mapping[str, str]]: | |||
"""Extracts loop associated with a prefix from mmCIF data as a list. | |||
Reference for loop_ in mmCIF: | |||
http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html | |||
Args: | |||
prefix: Prefix shared by each of the data items in the loop. | |||
e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num, | |||
_entity_poly_seq.mon_id. Should include the trailing period. | |||
parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython | |||
parser. | |||
Returns: | |||
Returns a list of dicts; each dict represents 1 entry from an mmCIF loop. | |||
""" | |||
cols = [] | |||
data = [] | |||
for key, value in parsed_info.items(): | |||
if key.startswith(prefix): | |||
cols.append(key) | |||
data.append(value) | |||
assert all([len(xs) == len(data[0]) for xs in data]), ( | |||
'mmCIF error: Not all loops are the same length: %s' % cols) | |||
return [dict(zip(cols, xs)) for xs in zip(*data)] | |||
def mmcif_loop_to_dict(prefix: str, | |||
index: str, | |||
parsed_info: MmCIFDict, | |||
) -> Mapping[str, Mapping[str, str]]: | |||
"""Extracts loop associated with a prefix from mmCIF data as a dictionary. | |||
Args: | |||
prefix: Prefix shared by each of the data items in the loop. | |||
e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num, | |||
_entity_poly_seq.mon_id. Should include the trailing period. | |||
index: Which item of loop data should serve as the key. | |||
parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython | |||
parser. | |||
Returns: | |||
Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop, | |||
indexed by the index column. | |||
""" | |||
entries = mmcif_loop_to_list(prefix, parsed_info) | |||
return {entry[index]: entry for entry in entries} | |||
def parse(*, | |||
file_id: str, | |||
mmcif_string: str, | |||
catch_all_errors: bool = True) -> ParsingResult: | |||
"""Entry point, parses an mmcif_string. | |||
Args: | |||
file_id: A string identifier for this file. Should be unique within the | |||
collection of files being processed. | |||
mmcif_string: Contents of an mmCIF file. | |||
catch_all_errors: If True, all exceptions are caught and error messages are | |||
returned as part of the ParsingResult. If False exceptions will be allowed | |||
to propagate. | |||
Returns: | |||
A ParsingResult. | |||
""" | |||
errors = {} | |||
try: | |||
parser = PDB.MMCIFParser(QUIET=True) | |||
handle = io.StringIO(mmcif_string) | |||
full_structure = parser.get_structure('', handle) | |||
first_model_structure = _get_first_model(full_structure) | |||
# Extract the _mmcif_dict from the parser, which contains useful fields not | |||
# reflected in the Biopython structure. | |||
parsed_info = parser._mmcif_dict # pylint:disable=protected-access | |||
# Ensure all values are lists, even if singletons. | |||
for key, value in parsed_info.items(): | |||
if not isinstance(value, list): | |||
parsed_info[key] = [value] | |||
header = _get_header(parsed_info) | |||
# Determine the protein chains, and their start numbers according to the | |||
# internal mmCIF numbering scheme (likely but not guaranteed to be 1). | |||
valid_chains = _get_protein_chains(parsed_info=parsed_info) | |||
if not valid_chains: | |||
return ParsingResult( | |||
None, {(file_id, ''): 'No protein chains found in this file.'}) | |||
seq_start_num = {chain_id: min([monomer.num for monomer in seq]) | |||
for chain_id, seq in valid_chains.items()} | |||
# Loop over the atoms for which we have coordinates. Populate two mappings: | |||
# -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used | |||
# the authors / Biopython). | |||
# -seq_to_structure_mappings (maps idx into sequence to ResidueAtPosition). | |||
mmcif_to_author_chain_id = {} | |||
seq_to_structure_mappings = {} | |||
for atom in _get_atom_site_list(parsed_info): | |||
if atom.model_num != '1': | |||
# We only process the first model at the moment. | |||
continue | |||
mmcif_to_author_chain_id[atom.mmcif_chain_id] = atom.author_chain_id | |||
if atom.mmcif_chain_id in valid_chains: | |||
hetflag = ' ' | |||
if atom.hetatm_atom == 'HETATM': | |||
# Water atoms are assigned a special hetflag of W in Biopython. We | |||
# need to do the same, so that this hetflag can be used to fetch | |||
# a residue from the Biopython structure by id. | |||
if atom.residue_name in ('HOH', 'WAT'): | |||
hetflag = 'W' | |||
else: | |||
hetflag = 'H_' + atom.residue_name | |||
insertion_code = atom.insertion_code | |||
if not _is_set(atom.insertion_code): | |||
insertion_code = ' ' | |||
position = ResiduePosition( | |||
chain_id=atom.author_chain_id, residue_number=int( | |||
atom.author_seq_num), insertion_code=insertion_code) | |||
seq_idx = int(atom.mmcif_seq_num) - \ | |||
seq_start_num[atom.mmcif_chain_id] | |||
current = seq_to_structure_mappings.get( | |||
atom.author_chain_id, {}) | |||
current[seq_idx] = ResidueAtPosition(position=position, | |||
name=atom.residue_name, | |||
is_missing=False, | |||
hetflag=hetflag) | |||
seq_to_structure_mappings[atom.author_chain_id] = current | |||
# Add missing residue information to seq_to_structure_mappings. | |||
for chain_id, seq_info in valid_chains.items(): | |||
author_chain = mmcif_to_author_chain_id[chain_id] | |||
current_mapping = seq_to_structure_mappings[author_chain] | |||
for idx, monomer in enumerate(seq_info): | |||
if idx not in current_mapping: | |||
current_mapping[idx] = ResidueAtPosition(position=None, | |||
name=monomer.id, | |||
is_missing=True, | |||
hetflag=' ') | |||
author_chain_to_sequence = {} | |||
for chain_id, seq_info in valid_chains.items(): | |||
author_chain = mmcif_to_author_chain_id[chain_id] | |||
seq = [] | |||
for monomer in seq_info: | |||
code = SCOPData.protein_letters_3to1.get(monomer.id, 'X') | |||
seq.append(code if len(code) == 1 else 'X') | |||
seq = ''.join(seq) | |||
author_chain_to_sequence[author_chain] = seq | |||
mmcif_object = MmcifObject( | |||
file_id=file_id, | |||
header=header, | |||
structure=first_model_structure, | |||
chain_to_seqres=author_chain_to_sequence, | |||
seqres_to_structure=seq_to_structure_mappings, | |||
raw_string=parsed_info) | |||
return ParsingResult(mmcif_object=mmcif_object, errors=errors) | |||
except Exception as e: # pylint:disable=broad-except | |||
errors[(file_id, '')] = e | |||
if not catch_all_errors: | |||
raise | |||
return ParsingResult(mmcif_object=None, errors=errors) | |||
def _get_first_model(structure: PDBSTRUCTURE) -> PDBSTRUCTURE: | |||
"""Returns the first model in a Biopython structure.""" | |||
return next(structure.get_models()) | |||
_MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE = 21 | |||
def get_release_date(parsed_info: MmCIFDict) -> str: | |||
"""Returns the oldest revision date.""" | |||
revision_dates = parsed_info['_pdbx_audit_revision_history.revision_date'] | |||
return min(revision_dates) | |||
def _get_header(parsed_info: MmCIFDict) -> PdbHeader: | |||
"""Returns a basic header containing method, release date and resolution.""" | |||
header = {} | |||
experiments = mmcif_loop_to_list('_exptl.', parsed_info) | |||
header['structure_method'] = ','.join([ | |||
experiment['_exptl.method'].lower() for experiment in experiments]) | |||
# Note: The release_date here corresponds to the oldest revision. We prefer to | |||
# use this for dataset filtering over the deposition_date. | |||
if '_pdbx_audit_revision_history.revision_date' in parsed_info: | |||
header['release_date'] = get_release_date(parsed_info) | |||
else: | |||
logging.warning('Could not determine release_date: %s', | |||
parsed_info['_entry.id']) | |||
header['resolution'] = 0.00 | |||
for res_key in ('_refine.ls_d_res_high', '_em_3d_reconstruction.resolution', '_reflns.d_resolution_high'): | |||
if res_key in parsed_info: | |||
try: | |||
raw_resolution = parsed_info[res_key][0] | |||
header['resolution'] = float(raw_resolution) | |||
except ValueError: | |||
logging.warning( | |||
'Invalid resolution format: %s', | |||
parsed_info[res_key]) | |||
return header | |||
def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]: | |||
"""Returns list of atom sites; contains data not present in the structure.""" | |||
return [AtomSite(*site) for site in zip( # pylint:disable=g-complex-comprehension | |||
parsed_info['_atom_site.label_comp_id'], | |||
parsed_info['_atom_site.auth_asym_id'], | |||
parsed_info['_atom_site.label_asym_id'], | |||
parsed_info['_atom_site.auth_seq_id'], | |||
parsed_info['_atom_site.label_seq_id'], | |||
parsed_info['_atom_site.pdbx_PDB_ins_code'], | |||
parsed_info['_atom_site.group_PDB'], | |||
parsed_info['_atom_site.pdbx_PDB_model_num'], | |||
)] | |||
def _get_protein_chains(*, | |||
parsed_info: Mapping[str, | |||
Any]) -> Mapping[ChainId, | |||
Sequence[Monomer]]: | |||
"""Extracts polymer information for protein chains only. | |||
Args: | |||
parsed_info: _mmcif_dict produced by the Biopython parser. | |||
Returns: | |||
A dict mapping mmcif chain id to a list of Monomers. | |||
""" | |||
# Get polymer information for each entity in the structure. | |||
entity_poly_seqs = mmcif_loop_to_list('_entity_poly_seq.', parsed_info) | |||
polymers = collections.defaultdict(list) | |||
for entity_poly_seq in entity_poly_seqs: | |||
polymers[entity_poly_seq['_entity_poly_seq.entity_id']].append( | |||
Monomer(id=entity_poly_seq['_entity_poly_seq.mon_id'], | |||
num=int(entity_poly_seq['_entity_poly_seq.num']))) | |||
# Get chemical compositions. Will allow us to identify which of these polymers | |||
# are proteins. | |||
chem_comps = mmcif_loop_to_dict( | |||
'_chem_comp.', '_chem_comp.id', parsed_info) | |||
# Get chains information for each entity. Necessary so that we can return a | |||
# dict keyed on chain id rather than entity. | |||
struct_asyms = mmcif_loop_to_list('_struct_asym.', parsed_info) | |||
entity_to_mmcif_chains = collections.defaultdict(list) | |||
for struct_asym in struct_asyms: | |||
chain_id = struct_asym['_struct_asym.id'] | |||
entity_id = struct_asym['_struct_asym.entity_id'] | |||
entity_to_mmcif_chains[entity_id].append(chain_id) | |||
# Identify and return the valid protein chains. | |||
valid_chains = {} | |||
for entity_id, seq_info in polymers.items(): | |||
chain_ids = entity_to_mmcif_chains[entity_id] | |||
# Reject polymers without any peptide-like components, such as DNA/RNA. | |||
if any(['peptide' in chem_comps[monomer.id]['_chem_comp.type'] | |||
for monomer in seq_info]): | |||
for chain_id in chain_ids: | |||
valid_chains[chain_id] = seq_info | |||
return valid_chains | |||
def _is_set(data: str) -> bool: | |||
"""Returns False if data is a special mmCIF character indicating 'unset'.""" | |||
return data not in ('.', '?') |
@@ -0,0 +1,61 @@ | |||
#!/bin/bash -e | |||
MMSEQS="$1" #mmseqs | |||
QUERY="$2" #"/path/QUERY.fasta" | |||
BASE="$3" #"./result/" | |||
DB1="$4" #"uniref30_2103_db" | |||
DB2="$5" #"" | |||
DB3="$6" #"colabfold_envdb_202108_db" | |||
USE_ENV="$7" #1 | |||
USE_TEMPLATES="$8" #0 | |||
FILTER="${9}" #1 | |||
EXPAND_EVAL=inf | |||
ALIGN_EVAL=10 | |||
DIFF=3000 | |||
QSC=-20.0 | |||
MAX_ACCEPT=1000000 | |||
time=$(date ) | |||
echo "${time}" | |||
if [ "${FILTER}" = "1" ]; then | |||
# 0.1 was not used in benchmarks due to POSIX shell bug in line above | |||
# EXPAND_EVAL=0.1 | |||
ALIGN_EVAL=10 | |||
QSC=0.8 | |||
MAX_ACCEPT=100000 | |||
fi | |||
export MMSEQS_CALL_DEPTH=1 | |||
SEARCH_PARAM="--num-iterations 3 --db-load-mode 2 -a -s 8 -e 0.1 --max-seqs 10000" | |||
FILTER_PARAM="--filter-msa ${FILTER} --filter-min-enable 1000 --diff ${DIFF} --qid 0.0,0.2,0.4,0.6,0.8,1.0 --qsc 0 --max-seq-id 0.95" | |||
EXPAND_PARAM="--expansion-mode 0 -e ${EXPAND_EVAL} --expand-filter-clusters ${FILTER} --max-seq-id 0.95" | |||
mkdir -p "${BASE}" | |||
"${MMSEQS}" createdb "${QUERY}" "${BASE}/qdb" | |||
"${MMSEQS}" search "${BASE}/qdb" "${DB1}" "${BASE}/res" "${BASE}/tmp" $SEARCH_PARAM | |||
"${MMSEQS}" expandaln "${BASE}/qdb" "${DB1}.idx" "${BASE}/res" "${DB1}.idx" "${BASE}/res_exp" --db-load-mode 2 ${EXPAND_PARAM} | |||
"${MMSEQS}" mvdb "${BASE}/tmp/latest/profile_1" "${BASE}/prof_res" | |||
"${MMSEQS}" lndb "${BASE}/qdb_h" "${BASE}/prof_res_h" | |||
"${MMSEQS}" align "${BASE}/prof_res" "${DB1}.idx" "${BASE}/res_exp" "${BASE}/res_exp_realign" --db-load-mode 2 -e ${ALIGN_EVAL} --max-accept ${MAX_ACCEPT} --alt-ali 10 -a | |||
"${MMSEQS}" filterresult "${BASE}/qdb" "${DB1}.idx" "${BASE}/res_exp_realign" "${BASE}/res_exp_realign_filter" --db-load-mode 2 --qid 0 --qsc $QSC --diff 0 --max-seq-id 1.0 --filter-min-enable 100 | |||
"${MMSEQS}" result2msa "${BASE}/qdb" "${DB1}.idx" "${BASE}/res_exp_realign_filter" "${BASE}/uniref.a3m" --msa-format-mode 6 --db-load-mode 2 ${FILTER_PARAM} | |||
"${MMSEQS}" rmdb "${BASE}/res_exp_realign" | |||
"${MMSEQS}" rmdb "${BASE}/res_exp" | |||
"${MMSEQS}" rmdb "${BASE}/res" | |||
"${MMSEQS}" rmdb "${BASE}/res_exp_realign_filter" | |||
if [ "${USE_TEMPLATES}" = "1" ]; then | |||
"${MMSEQS}" search "${BASE}/prof_res" "${DB2}" "${BASE}/res_pdb" "${BASE}/tmp" --db-load-mode 2 -s 7.5 -a -e 0.1 | |||
echo "-----------------------in here" | |||
echo "${BASE}/${DB2}.m8" | |||
"${MMSEQS}" convertalis "${BASE}/prof_res" "${DB2}.idx" "${BASE}/res_pdb" "${BASE}/${DB2}.m8" --format-output query,target,fident,alnlen,mismatch,gapopen,qstart,qend,tstart,tend,evalue,bits,cigar --db-load-mode 2 | |||
"${MMSEQS}" rmdb "${BASE}/res_pdb" | |||
fi | |||
if [ "${USE_ENV}" = "1" ]; then | |||
"${MMSEQS}" search "${BASE}/prof_res" "${DB3}" "${BASE}/res_env" "${BASE}/tmp" $SEARCH_PARAM | |||
"${MMSEQS}" expandaln "${BASE}/prof_res" "${DB3}.idx" "${BASE}/res_env" "${DB3}.idx" "${BASE}/res_env_exp" -e ${EXPAND_EVAL} --expansion-mode 0 --db-load-mode 2 | |||
"${MMSEQS}" align "${BASE}/tmp/latest/profile_1" "${DB3}.idx" "${BASE}/res_env_exp" "${BASE}/res_env_exp_realign" --db-load-mode 2 -e ${ALIGN_EVAL} --max-accept ${MAX_ACCEPT} --alt-ali 10 -a | |||
"${MMSEQS}" filterresult "${BASE}/qdb" "${DB3}.idx" "${BASE}/res_env_exp_realign" "${BASE}/res_env_exp_realign_filter" --db-load-mode 2 --qid 0 --qsc $QSC --diff 0 --max-seq-id 1.0 --filter-min-enable 100 | |||
"${MMSEQS}" result2msa "${BASE}/qdb" "${DB3}.idx" "${BASE}/res_env_exp_realign_filter" "${BASE}/bfd.mgnify30.metaeuk30.smag30.a3m" --msa-format-mode 6 --db-load-mode 2 ${FILTER_PARAM} | |||
"${MMSEQS}" rmdb "${BASE}/res_env_exp_realign_filter" | |||
"${MMSEQS}" rmdb "${BASE}/res_env_exp_realign" | |||
fi | |||
time=$(date ) | |||
echo "${time}" |
@@ -0,0 +1,389 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
'''parsers''' | |||
import collections | |||
import re | |||
import string | |||
from typing import Iterable, List, Optional, Sequence, Tuple | |||
import dataclasses | |||
DeletionMatrix = Sequence[Sequence[int]] | |||
@dataclasses.dataclass(frozen=True) | |||
class HhrHit: | |||
"""Class representing a hit in an hhr file.""" | |||
index: int | |||
name: str | |||
prob_true: float | |||
e_value: float | |||
score: float | |||
aligned_cols: int | |||
identity: float | |||
similarity: float | |||
sum_probs: float | |||
neff: float | |||
query: str | |||
hit_sequence: str | |||
hit_dssp: str | |||
column_score_code: str | |||
confidence_scores: str | |||
indices_query: List[int] | |||
indices_hit: List[int] | |||
def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]: | |||
"""Parses FASTA string and returns list of strings with amino-acid sequences. | |||
Arguments: | |||
fasta_string: The string contents of a FASTA file. | |||
Returns: | |||
A tuple of two lists: | |||
* A list of sequences. | |||
* A list of sequence descriptions taken from the comment lines. In the | |||
same order as the sequences. | |||
""" | |||
sequences = [] | |||
descriptions = [] | |||
index = -1 | |||
for line in fasta_string.splitlines(): | |||
line = line.strip() | |||
if line.startswith('>'): | |||
index += 1 | |||
descriptions.append(line[1:]) # Remove the '>' at the beginning. | |||
sequences.append('') | |||
continue | |||
elif not line: | |||
continue # Skip blank lines. | |||
sequences[index] += line | |||
return sequences, descriptions | |||
def parse_stockholm( | |||
stockholm_string: str) -> Tuple[Sequence[str], DeletionMatrix]: | |||
"""Parses sequences and deletion matrix from stockholm format alignment. | |||
Args: | |||
stockholm_string: The string contents of a stockholm file. The first | |||
sequence in the file should be the query sequence. | |||
Returns: | |||
A tuple of: | |||
* A list of sequences that have been aligned to the query. These | |||
might contain duplicates. | |||
* The deletion matrix for the alignment as a list of lists. The element | |||
at `deletion_matrix[i][j]` is the number of residues deleted from | |||
the aligned sequence i at residue position j. | |||
""" | |||
name_to_sequence = collections.OrderedDict() | |||
for line in stockholm_string.splitlines(): | |||
line = line.strip() | |||
if not line or line.startswith(('#', '//')): | |||
continue | |||
name, sequence = line.split() | |||
if name not in name_to_sequence: | |||
name_to_sequence[name] = '' | |||
name_to_sequence[name] += sequence | |||
msa = [] | |||
deletion_matrix = [] | |||
query = '' | |||
keep_columns = [] | |||
for seq_index, sequence in enumerate(name_to_sequence.values()): | |||
if seq_index == 0: | |||
# Gather the columns with gaps from the query | |||
query = sequence | |||
keep_columns = [i for i, res in enumerate(query) if res != '-'] | |||
# Remove the columns with gaps in the query from all sequences. | |||
aligned_sequence = ''.join([sequence[c] for c in keep_columns]) | |||
msa.append(aligned_sequence) | |||
# Count the number of deletions w.r.t. query. | |||
deletion_vec = [] | |||
deletion_count = 0 | |||
for seq_res, query_res in zip(sequence, query): | |||
if seq_res != '-' or query_res != '-': | |||
if query_res == '-': | |||
deletion_count += 1 | |||
else: | |||
deletion_vec.append(deletion_count) | |||
deletion_count = 0 | |||
deletion_matrix.append(deletion_vec) | |||
return msa, deletion_matrix | |||
def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]: | |||
"""Parses sequences and deletion matrix from a3m format alignment. | |||
Args: | |||
a3m_string: The string contents of a a3m file. The first sequence in the | |||
file should be the query sequence. | |||
Returns: | |||
A tuple of: | |||
* A list of sequences that have been aligned to the query. These | |||
might contain duplicates. | |||
* The deletion matrix for the alignment as a list of lists. The element | |||
at `deletion_matrix[i][j]` is the number of residues deleted from | |||
the aligned sequence i at residue position j. | |||
""" | |||
sequences, _ = parse_fasta(a3m_string) | |||
deletion_matrix = [] | |||
for msa_sequence in sequences: | |||
deletion_vec = [] | |||
deletion_count = 0 | |||
for j in msa_sequence: | |||
if j.islower(): | |||
deletion_count += 1 | |||
else: | |||
deletion_vec.append(deletion_count) | |||
deletion_count = 0 | |||
deletion_matrix.append(deletion_vec) | |||
# Make the MSA matrix out of aligned (deletion-free) sequences. | |||
deletion_table = str.maketrans('', '', string.ascii_lowercase) | |||
aligned_sequences = [s.translate(deletion_table) for s in sequences] | |||
return aligned_sequences, deletion_matrix | |||
def _convert_sto_seq_to_a3m( | |||
query_non_gaps: Sequence[bool], sto_seq: str) -> Iterable[str]: | |||
for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq): | |||
if is_query_res_non_gap: | |||
yield sequence_res | |||
elif sequence_res != '-': | |||
yield sequence_res.lower() | |||
def convert_stockholm_to_a3m(stockholm_format: str, | |||
max_sequences: Optional[int] = None) -> str: | |||
"""Converts MSA in Stockholm format to the A3M format.""" | |||
descriptions = {} | |||
sequences = {} | |||
reached_max_sequences = False | |||
for line in stockholm_format.splitlines(): | |||
reached_max_sequences = max_sequences and len(sequences) >= max_sequences | |||
if line.strip() and not line.startswith(('#', '//')): | |||
# Ignore blank lines, markup and end symbols - remainder are alignment | |||
# sequence parts. | |||
seqname, aligned_seq = line.split(maxsplit=1) | |||
if seqname not in sequences: | |||
if reached_max_sequences: | |||
continue | |||
sequences[seqname] = '' | |||
sequences[seqname] += aligned_seq | |||
for line in stockholm_format.splitlines(): | |||
if line[:4] == '#=GS': | |||
# Description row - example format is: | |||
# #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ... | |||
columns = line.split(maxsplit=3) | |||
seqname, feature = columns[1:3] | |||
value = columns[3] if len(columns) == 4 else '' | |||
if feature != 'DE': | |||
continue | |||
if reached_max_sequences and seqname not in sequences: | |||
continue | |||
descriptions[seqname] = value | |||
if len(descriptions) == len(sequences): | |||
break | |||
# Convert sto format to a3m line by line | |||
a3m_sequences = {} | |||
# query_sequence is assumed to be the first sequence | |||
query_sequence = next(iter(sequences.values())) | |||
query_non_gaps = [res != '-' for res in query_sequence] | |||
for seqname, sto_sequence in sequences.items(): | |||
a3m_sequences[seqname] = ''.join( | |||
_convert_sto_seq_to_a3m(query_non_gaps, sto_sequence)) | |||
fasta_chunks = (f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}" | |||
for k in a3m_sequences) | |||
return '\n'.join(fasta_chunks) + '\n' # Include terminating newline. | |||
def _get_hhr_line_regex_groups( | |||
regex_pattern: str, line: str) -> Sequence[Optional[str]]: | |||
match = re.match(regex_pattern, line) | |||
if match is None: | |||
raise RuntimeError(f'Could not parse query line {line}') | |||
return match.groups() | |||
def _update_hhr_residue_indices_list( | |||
sequence: str, start_index: int, indices_list: List[int]): | |||
"""Computes the relative indices for each residue with respect to the original sequence.""" | |||
counter = start_index | |||
for symbol in sequence: | |||
if symbol == '-': | |||
indices_list.append(-1) | |||
else: | |||
indices_list.append(counter) | |||
counter += 1 | |||
def _parse_hhr_hit(detailed_lines: Sequence[str]) -> HhrHit: | |||
"""Parses the detailed HMM HMM comparison section for a single Hit. | |||
This works on .hhr files generated from both HHBlits and HHSearch. | |||
Args: | |||
detailed_lines: A list of lines from a single comparison section between 2 | |||
sequences (which each have their own HMM's) | |||
Returns: | |||
A dictionary with the information from that detailed comparison section | |||
Raises: | |||
RuntimeError: If a certain line cannot be processed | |||
""" | |||
# Parse first 2 lines. | |||
number_of_hit = int(detailed_lines[0].split()[-1]) | |||
name_hit = detailed_lines[1][1:] | |||
# Parse the summary line. | |||
pattern = ( | |||
'Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t' | |||
' ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t ' | |||
']*Template_Neff=(.*)') | |||
match = re.match(pattern, detailed_lines[2]) | |||
if match is None: | |||
raise RuntimeError( | |||
'Could not parse section: %s. Expected this: \n%s to contain summary.' % | |||
(detailed_lines, detailed_lines[2])) | |||
(prob_true, e_value, score, aligned_cols, identity, similarity, sum_probs, | |||
neff) = [float(x) for x in match.groups()] | |||
# The next section reads the detailed comparisons. These are in a 'human | |||
# readable' format which has a fixed length. The strategy employed is to | |||
# assume that each block starts with the query sequence line, and to parse | |||
# that with a regexp in order to deduce the fixed length used for that | |||
# block. | |||
query = '' | |||
hit_sequence = '' | |||
hit_dssp = '' | |||
column_score_code = '' | |||
confidence_scores = '' | |||
indices_query = [] | |||
indices_hit = [] | |||
length_block = None | |||
for line in detailed_lines[3:]: | |||
# Parse the query sequence line | |||
if (line.startswith('Q ') and not line.startswith('Q ss_dssp') and not line.startswith('Q ss_pred') \ | |||
and not line.startswith('Q Consensus')): | |||
# Thus the first 17 characters must be 'Q <query_name> ', and we can parse | |||
# everything after that. | |||
# start sequence end total_sequence_length | |||
patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)' | |||
groups = _get_hhr_line_regex_groups(patt, line[17:]) | |||
# Get the length of the parsed block using the start and finish indices, | |||
# and ensure it is the same as the actual block length. | |||
start = int(groups[0]) - 1 # Make index zero based. | |||
delta_query = groups[1] | |||
end = int(groups[2]) | |||
num_insertions = len([x for x in delta_query if x == '-']) | |||
length_block = end - start + num_insertions | |||
assert length_block == len(delta_query) | |||
# Update the query sequence and indices list. | |||
query += delta_query | |||
_update_hhr_residue_indices_list(delta_query, start, indices_query) | |||
elif line.startswith('T '): | |||
# Parse the hit dssp line. | |||
if line.startswith('T ss_dssp'): | |||
# T ss_dssp hit_dssp | |||
patt = r'T ss_dssp[\t ]*([A-Z-]*)' | |||
groups = _get_hhr_line_regex_groups(patt, line) | |||
assert len(groups[0]) == length_block | |||
hit_dssp += groups[0] | |||
# Parse the hit sequence. | |||
elif (not line.startswith('T ss_pred') and | |||
not line.startswith('T Consensus')): | |||
# Thus the first 17 characters must be 'T <hit_name> ', and we can | |||
# parse everything after that. | |||
# start sequence end total_sequence_length | |||
patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)' | |||
groups = _get_hhr_line_regex_groups(patt, line[17:]) | |||
start = int(groups[0]) - 1 # Make index zero based. | |||
delta_hit_sequence = groups[1] | |||
assert length_block == len(delta_hit_sequence) | |||
# Update the hit sequence and indices list. | |||
hit_sequence += delta_hit_sequence | |||
_update_hhr_residue_indices_list( | |||
delta_hit_sequence, start, indices_hit) | |||
# Parse the column score line. | |||
elif line.startswith(' ' * 22): | |||
assert length_block | |||
column_score_code += line[22:length_block + 22] | |||
# Update confidence score. | |||
elif line.startswith('Confidence'): | |||
assert length_block | |||
confidence_scores += line[22:length_block + 22] | |||
return HhrHit( | |||
index=number_of_hit, | |||
name=name_hit, | |||
prob_true=prob_true, | |||
e_value=e_value, | |||
score=score, | |||
aligned_cols=int(aligned_cols), | |||
identity=identity, | |||
similarity=similarity, | |||
sum_probs=sum_probs, | |||
neff=neff, | |||
query=query, | |||
hit_sequence=hit_sequence, | |||
hit_dssp=hit_dssp, | |||
column_score_code=column_score_code, | |||
confidence_scores=confidence_scores, | |||
indices_query=indices_query, | |||
indices_hit=indices_hit, | |||
) | |||
def parse_hhr(hhr_string: str) -> Sequence[HhrHit]: | |||
"""Parses the content of an entire HHR file.""" | |||
lines = hhr_string.splitlines() | |||
# Each .hhr file starts with a results table, then has a sequence of hit | |||
# "paragraphs", each paragraph starting with a line 'No <hit number>'. We | |||
# iterate through each paragraph to parse each hit. | |||
block_starts = [i for i, line in enumerate( | |||
lines) if line.startswith('No ')] | |||
hits = [] | |||
if block_starts: | |||
block_starts.append(len(lines)) # Add the end of the final block. | |||
for i in range(len(block_starts) - 1): | |||
hits.append(_parse_hhr_hit( | |||
lines[block_starts[i]:block_starts[i + 1]])) | |||
return hits |
@@ -0,0 +1,112 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""run script""" | |||
import time | |||
import os | |||
import json | |||
import argparse | |||
import numpy as np | |||
import mindspore.context as context | |||
from mindspore.common.tensor import Tensor | |||
from mindspore import load_checkpoint | |||
from data.feature.feature_extraction import process_features | |||
from data.tools.data_process import data_process | |||
from commons.generate_pdb import to_pdb, from_prediction | |||
from commons.utils import compute_confidence | |||
from model import AlphaFold | |||
from config import config, global_config | |||
parser = argparse.ArgumentParser(description='Inputs for run.py') | |||
parser.add_argument('--seq_length', help='padding sequence length') | |||
parser.add_argument('--input_fasta_path', help='Path of FASTA files folder directory to be predicted.') | |||
parser.add_argument('--msa_result_path', help='Path to save msa result.') | |||
parser.add_argument('--database_dir', help='Path of data to generate msa.') | |||
parser.add_argument('--database_envdb_dir', help='Path of expandable data to generate msa.') | |||
parser.add_argument('--hhsearch_binary_path', help='Path of hhsearch executable.') | |||
parser.add_argument('--pdb70_database_path', help='Path to pdb70.') | |||
parser.add_argument('--template_mmcif_dir', help='Path of template mmcif.') | |||
parser.add_argument('--max_template_date', help='Maximum template release date.') | |||
parser.add_argument('--kalign_binary_path', help='Path to kalign executable.') | |||
parser.add_argument('--obsolete_pdbs_path', help='Path to obsolete pdbs path.') | |||
parser.add_argument('--checkpoint_path', help='Path of the checkpoint.') | |||
parser.add_argument('--device_id', default=0, type=int, help='Device id to be used.') | |||
args = parser.parse_args() | |||
if __name__ == "__main__": | |||
context.set_context(mode=context.GRAPH_MODE, | |||
device_target="Ascend", | |||
variable_memory_max_size="31GB", | |||
device_id=args.device_id, | |||
save_graphs=False) | |||
model_name = "model_1" | |||
model_config = config.model_config(model_name) | |||
num_recycle = model_config.model.num_recycle | |||
global_config = global_config.global_config(args.seq_length) | |||
extra_msa_length = global_config.extra_msa_length | |||
fold_net = AlphaFold(model_config, global_config) | |||
load_checkpoint(args.checkpoint_path, fold_net) | |||
seq_files = os.listdir(args.input_fasta_path) | |||
for seq_file in seq_files: | |||
t1 = time.time() | |||
seq_name = seq_file.split('.')[0] | |||
input_features = data_process(seq_name, args) | |||
tensors, aatype, residue_index, ori_res_length = process_features( | |||
raw_features=input_features, config=model_config, global_config=global_config) | |||
prev_pos = Tensor(np.zeros([global_config.seq_length, 37, 3]).astype(np.float16)) | |||
prev_msa_first_row = Tensor(np.zeros([global_config.seq_length, 256]).astype(np.float16)) | |||
prev_pair = Tensor(np.zeros([global_config.seq_length, global_config.seq_length, 128]).astype(np.float16)) | |||
""" | |||
:param::@sequence_length | |||
""" | |||
t2 = time.time() | |||
for i in range(num_recycle+1): | |||
tensors_i = [tensor[i] for tensor in tensors] | |||
input_feats = [Tensor(tensor) for tensor in tensors_i] | |||
final_atom_positions, final_atom_mask, predicted_lddt_logits,\ | |||
prev_pos, prev_msa_first_row, prev_pair = fold_net(*input_feats, | |||
prev_pos, | |||
prev_msa_first_row, | |||
prev_pair) | |||
t3 = time.time() | |||
final_atom_positions = final_atom_positions.asnumpy()[:ori_res_length] | |||
final_atom_mask = final_atom_mask.asnumpy()[:ori_res_length] | |||
predicted_lddt_logits = predicted_lddt_logits.asnumpy()[:ori_res_length] | |||
confidence = compute_confidence(predicted_lddt_logits) | |||
unrelaxed_protein = from_prediction(final_atom_mask, aatype[0], final_atom_positions, residue_index[0]) | |||
pdb_file = to_pdb(unrelaxed_protein) | |||
seq_length = aatype.shape[-1] | |||
os.makedirs(f'./result/seq_{seq_name}_{seq_length}', exist_ok=True) | |||
with open(os.path.join(f'./result/seq_{seq_name}_{seq_length}/', f'unrelaxed_model_{seq_name}.pdb'), 'w') as f: | |||
f.write(pdb_file) | |||
t4 = time.time() | |||
timings = {"pre_process_time": round(t2 - t1, 2), | |||
"model_time": round(t3 - t2, 2), | |||
"pos_process_time": round(t4 - t3, 2), | |||
"all_time": round(t4 - t1, 2), | |||
"confidence": confidence} | |||
print(timings) | |||
with open(f'./result/seq_{seq_name}_{seq_length}/timings', 'w') as f: | |||
f.write(json.dumps(timings)) |
@@ -0,0 +1,936 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""basic module""" | |||
import mindspore.nn as nn | |||
import mindspore.common.dtype as mstype | |||
import mindspore.numpy as mnp | |||
from mindspore.ops import operations as P | |||
from mindspore.ops import functional as F | |||
from mindspore.common.tensor import Tensor | |||
from mindspore import Parameter | |||
import numpy as np | |||
from commons.utils import mask_mean | |||
class Attention(nn.Cell): | |||
'''attention module''' | |||
def __init__(self, config, q_data_dim, m_data_dim, output_dim, batch_size=None): | |||
super(Attention, self).__init__() | |||
self.config = config | |||
self.q_data_dim = q_data_dim | |||
self.m_data_dim = m_data_dim | |||
self.output_dim = output_dim | |||
self.num_head = self.config.num_head | |||
self.gating = self.config.gating | |||
self.key_dim = self.config.get('key_dim', int(q_data_dim)) | |||
self.value_dim = self.config.get('value_dim', int(m_data_dim)) | |||
self.key_dim = self.key_dim // self.num_head | |||
self.value_dim = self.value_dim // self.num_head | |||
self.batch_size = batch_size | |||
self.matmul = P.MatMul(transpose_b=True) | |||
self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True) | |||
self.softmax = nn.Softmax() | |||
self.sigmoid = nn.Sigmoid() | |||
self.batch_size = batch_size | |||
self._init_parameter() | |||
def _init_parameter(self): | |||
'''init parameter''' | |||
if self.batch_size: | |||
self.linear_q_weights = Parameter(Tensor(np.zeros([self.batch_size, self.num_head * self.key_dim, | |||
self.q_data_dim]), mstype.float32)) | |||
self.linear_k_weights = Parameter(Tensor(np.zeros([self.batch_size, self.num_head * self.key_dim, | |||
self.m_data_dim]), mstype.float32)) | |||
self.linear_v_weights = Parameter(Tensor(np.zeros([self.batch_size, self.num_head * self.value_dim, | |||
self.m_data_dim]), mstype.float32)) | |||
self.linear_output_weights = Parameter(Tensor(np.zeros([self.batch_size, self.output_dim, | |||
self.num_head * self.value_dim]), mstype.float32)) | |||
self.o_biases = Parameter(Tensor(np.zeros([self.batch_size, self.output_dim]), mstype.float32)) | |||
if self.gating: | |||
self.linear_gating_weights = Parameter(Tensor(np.zeros([self.batch_size, self.num_head * self.value_dim, | |||
self.q_data_dim]), mstype.float32)) | |||
self.gating_biases = Parameter(Tensor(np.zeros((self.batch_size, self.num_head, self.value_dim)), | |||
mstype.float32), name="gating_b") | |||
else: | |||
self.linear_q_weights = Parameter(Tensor(np.zeros([self.num_head * self.key_dim, self.q_data_dim]), | |||
mstype.float32)) | |||
self.linear_k_weights = Parameter(Tensor(np.zeros([self.num_head * self.key_dim, self.m_data_dim]), | |||
mstype.float32)) | |||
self.linear_v_weights = Parameter(Tensor(np.zeros([self.num_head * self.value_dim, self.m_data_dim]), | |||
mstype.float32)) | |||
self.linear_output_weights = Parameter(Tensor(np.zeros([self.output_dim, self.num_head * self.value_dim]), | |||
mstype.float32)) | |||
self.o_biases = Parameter(Tensor(np.zeros([self.output_dim]), mstype.float32)) | |||
if self.gating: | |||
self.linear_gating_weights = Parameter(Tensor(np.zeros([self.num_head * self.value_dim, | |||
self.q_data_dim]), mstype.float32)) | |||
self.gating_biases = Parameter(Tensor(np.zeros((self.num_head, self.value_dim)), mstype.float32), | |||
name="gating_b") | |||
def construct(self, q_data, m_data, bias, index=None, nonbatched_bias=None): | |||
'''construct''' | |||
if self.batch_size: | |||
linear_q_weight = P.Gather()(self.linear_q_weights, index, 0) | |||
linear_k_weight = P.Gather()(self.linear_k_weights, index, 0) | |||
linear_v_weight = P.Gather()(self.linear_v_weights, index, 0) | |||
linear_output_weight = P.Gather()(self.linear_output_weights, index, 0) | |||
o_bias = P.Gather()(self.o_biases, index, 0) | |||
linear_gating_weight = 0 | |||
gating_bias = 0 | |||
if self.gating: | |||
linear_gating_weight = P.Gather()(self.linear_gating_weights, index, 0) | |||
gating_bias = P.Gather()(self.gating_biases, index, 0) | |||
else: | |||
linear_q_weight = self.linear_q_weights | |||
linear_k_weight = self.linear_k_weights | |||
linear_v_weight = self.linear_v_weights | |||
linear_output_weight = self.linear_output_weights | |||
o_bias = self.o_biases | |||
linear_gating_weight = 0 | |||
gating_bias = 0 | |||
if self.gating: | |||
linear_gating_weight = self.linear_gating_weights | |||
gating_bias = self.gating_biases | |||
q_data, m_data, bias = q_data.astype(mstype.float32), m_data.astype(mstype.float32), bias.astype(mstype.float32) | |||
dim_b, dim_q, dim_a = q_data.shape | |||
_, dim_k, dim_c = m_data.shape | |||
dim_h = self.num_head | |||
q_data = P.Reshape()(q_data, (-1, dim_a)) | |||
m_data = P.Reshape()(m_data, (-1, dim_c)) | |||
q = self.matmul(q_data.astype(mstype.float16), linear_q_weight.astype(mstype.float16)). \ | |||
astype(mstype.float32) * self.key_dim ** (-0.5) | |||
k = self.matmul(m_data.astype(mstype.float16), linear_k_weight.astype(mstype.float16)) | |||
v = self.matmul(m_data.astype(mstype.float16), linear_v_weight.astype(mstype.float16)) | |||
q = P.Reshape()(q, (dim_b, dim_q, dim_h, -1)) | |||
k = P.Reshape()(k, (dim_b, dim_k, dim_h, -1)) | |||
v = P.Reshape()(v, (dim_b, dim_k, dim_h, -1)) | |||
tmp_q = P.Reshape()(P.Transpose()(q.astype(mstype.float16), (0, 2, 1, 3)), (dim_b * dim_h, dim_q, -1)) | |||
tmp_k = P.Reshape()(P.Transpose()(k.astype(mstype.float16), (0, 2, 1, 3)), (dim_b * dim_h, dim_k, -1)) | |||
logits = P.Reshape()(self.batch_matmul_trans_b(tmp_q.astype(mstype.float16), | |||
tmp_k.astype(mstype.float16)), | |||
(dim_b, dim_h, dim_q, dim_k)) + bias.astype(mstype.float16) | |||
if nonbatched_bias is not None: | |||
logits += mnp.expand_dims(nonbatched_bias, axis=0) | |||
weights = self.softmax(logits.astype(mstype.float32)) | |||
tmp_v = P.Reshape()(P.Transpose()(v, (0, 2, 3, 1)), (dim_b * dim_h, -1, dim_k)) | |||
tmp_weights = P.Reshape()(weights, (dim_b * dim_h, dim_q, -1)) | |||
weighted_avg = P.Transpose()(P.Reshape()(self.batch_matmul_trans_b(tmp_weights.astype(mstype.float16), | |||
tmp_v.astype(mstype.float16)), | |||
(dim_b, dim_h, dim_q, -1)), | |||
(0, 2, 1, 3)).astype(mstype.float32) | |||
if self.gating: | |||
gate_values = P.Reshape()( | |||
self.matmul(q_data.astype(mstype.float16), linear_gating_weight.astype(mstype.float16)), | |||
(dim_b, dim_q, dim_h, -1)) + gating_bias[None, None, ...].astype(mstype.float16) | |||
gate_values = self.sigmoid(gate_values.astype(mstype.float32)) | |||
weighted_avg = P.Reshape()(weighted_avg * gate_values, (dim_b * dim_q, -1)) | |||
weighted_avg = P.Reshape()(weighted_avg, (dim_b * dim_q, -1)) | |||
output = P.Reshape()( | |||
self.matmul(weighted_avg.astype(mstype.float16), linear_output_weight.astype(mstype.float16)), | |||
(dim_b, dim_q, -1)) + o_bias[None, ...].astype(mstype.float16) | |||
return output | |||
class MSARowAttentionWithPairBias(nn.Cell): | |||
'''MSA row attention''' | |||
def __init__(self, config, msa_act_dim, pair_act_dim, batch_size=None, slice_num=0): | |||
super(MSARowAttentionWithPairBias, self).__init__() | |||
self.config = config | |||
self.num_head = self.config.num_head | |||
self.batch_size = batch_size | |||
self.norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5) | |||
self.matmul = P.MatMul(transpose_b=True) | |||
self.attn_mod = Attention(self.config, msa_act_dim, msa_act_dim, msa_act_dim, batch_size) | |||
self.msa_act_dim = msa_act_dim | |||
self.pair_act_dim = pair_act_dim | |||
self.batch_size = batch_size | |||
self.slice_num = slice_num | |||
self.idx = Tensor(0, mstype.int32) | |||
self._init_parameter() | |||
def _init_parameter(self): | |||
'''init parameter''' | |||
self.query_norm_gammas = Parameter(Tensor(np.zeros([self.batch_size, self.msa_act_dim,]), mstype.float32)) | |||
self.query_norm_betas = Parameter(Tensor(np.zeros([self.batch_size, self.msa_act_dim,]), mstype.float32)) | |||
self.feat_2d_norm_gammas = Parameter(Tensor(np.zeros([self.batch_size, self.pair_act_dim,]), mstype.float32)) | |||
self.feat_2d_norm_betas = Parameter(Tensor(np.zeros([self.batch_size, self.pair_act_dim,]), mstype.float32)) | |||
self.feat_2d_weights = Parameter( | |||
Tensor(np.zeros([self.batch_size, self.num_head, self.pair_act_dim]), mstype.float32)) | |||
def construct(self, msa_act, msa_mask, pair_act, index): | |||
'''construct''' | |||
query_norm_gamma = P.Gather()(self.query_norm_gammas, index, 0) | |||
query_norm_beta = P.Gather()(self.query_norm_betas, index, 0) | |||
feat_2d_norm_gamma = P.Gather()(self.feat_2d_norm_gammas, index, 0) | |||
feat_2d_norm_beta = P.Gather()(self.feat_2d_norm_betas, index, 0) | |||
feat_2d_weight = P.Gather()(self.feat_2d_weights, index, 0) | |||
q, k, _ = pair_act.shape | |||
bias = (1e9 * (msa_mask - 1.))[:, None, None, :] | |||
msa_act, _, _ = self.norm(msa_act.astype(mstype.float32), query_norm_gamma.astype(mstype.float32), | |||
query_norm_beta.astype(mstype.float32)) | |||
pair_act, _, _ = self.norm(pair_act.astype(mstype.float32), feat_2d_norm_gamma.astype(mstype.float32), | |||
feat_2d_norm_beta.astype(mstype.float32)) | |||
pair_act = P.Reshape()(pair_act, (-1, pair_act.shape[-1])) | |||
nonbatched_bias = P.Transpose()( | |||
P.Reshape()(self.matmul(pair_act.astype(mstype.float16), feat_2d_weight.astype(mstype.float16)), | |||
(q, k, self.num_head)), (2, 0, 1)) | |||
if self.slice_num: | |||
msa_act_ori_shape = P.Shape()(msa_act) | |||
slice_shape = (self.slice_num, -1) + msa_act_ori_shape[1:] | |||
msa_act = P.Reshape()(msa_act, slice_shape).astype(mstype.float16) | |||
bias_shape = P.Shape()(bias) | |||
bias = P.Reshape()(bias, slice_shape[:2] + bias_shape[1:]) | |||
slice_idx = 0 | |||
slice_idx_tensor = self.idx | |||
msa_act_tuple = () | |||
msa_act_slice = P.Gather()(msa_act, slice_idx_tensor, 0) | |||
bias_slice = P.Gather()(bias, slice_idx_tensor, 0) | |||
msa_act_slice = self.attn_mod(msa_act_slice, msa_act_slice, bias_slice, index, nonbatched_bias) | |||
msa_act_slice = P.Reshape()(msa_act_slice, ((1,) + P.Shape()(msa_act_slice))) | |||
msa_act_tuple = msa_act_tuple + (msa_act_slice,) | |||
slice_idx += 1 | |||
slice_idx_tensor += 1 | |||
while slice_idx < self.slice_num: | |||
msa_act_slice = P.Gather()(msa_act, slice_idx_tensor, 0) | |||
msa_act_slice = F.depend(msa_act_slice, msa_act_tuple[-1]) | |||
bias_slice = P.Gather()(bias, slice_idx_tensor, 0) | |||
msa_act_slice = self.attn_mod(msa_act_slice, msa_act_slice, bias_slice, index, nonbatched_bias) | |||
msa_act_slice = P.Reshape()(msa_act_slice, ((1,) + P.Shape()(msa_act_slice))) | |||
msa_act_tuple = msa_act_tuple + (msa_act_slice,) | |||
slice_idx += 1 | |||
slice_idx_tensor += 1 | |||
msa_act = P.Concat()(msa_act_tuple) | |||
msa_act = P.Reshape()(msa_act, msa_act_ori_shape) | |||
return msa_act | |||
msa_act = self.attn_mod(msa_act, msa_act, bias, index, nonbatched_bias) | |||
return msa_act | |||
class MSAColumnAttention(nn.Cell): | |||
'''MSA column attention''' | |||
def __init__(self, config, msa_act_dim, batch_size=None, slice_num=0): | |||
super(MSAColumnAttention, self).__init__() | |||
self.config = config | |||
self.query_norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5) | |||
self.attn_mod = Attention(self.config, msa_act_dim, msa_act_dim, msa_act_dim, batch_size) | |||
self.batch_size = batch_size | |||
self.slice_num = slice_num | |||
self.msa_act_dim = msa_act_dim | |||
self.idx = Tensor(0, mstype.int32) | |||
self._init_parameter() | |||
def _init_parameter(self): | |||
self.query_norm_gammas = Parameter(Tensor(np.zeros([self.batch_size, self.msa_act_dim]), mstype.float32)) | |||
self.query_norm_betas = Parameter(Tensor(np.zeros([self.batch_size, self.msa_act_dim]), mstype.float32)) | |||
def construct(self, msa_act, msa_mask, index): | |||
'''construct''' | |||
query_norm_gamma = P.Gather()(self.query_norm_gammas, index, 0) | |||
query_norm_beta = P.Gather()(self.query_norm_betas, index, 0) | |||
msa_act = mnp.swapaxes(msa_act, -2, -3) | |||
msa_mask = mnp.swapaxes(msa_mask, -1, -2) | |||
bias = (1e9 * (msa_mask - 1.))[:, None, None, :] | |||
msa_act, _, _ = self.query_norm(msa_act.astype(mstype.float32), query_norm_gamma.astype(mstype.float32), | |||
query_norm_beta.astype(mstype.float32)) | |||
if self.slice_num: | |||
msa_act_ori_shape = P.Shape()(msa_act) | |||
slice_shape = (self.slice_num, -1) + msa_act_ori_shape[1:] | |||
msa_act = P.Reshape()(msa_act, slice_shape).astype(mstype.float16) | |||
bias_shape = P.Shape()(bias) | |||
bias = P.Reshape()(bias, slice_shape[:2] + bias_shape[1:]) | |||
slice_idx = 0 | |||
slice_idx_tensor = self.idx | |||
msa_act_tuple = () | |||
msa_act_slice = P.Gather()(msa_act, slice_idx_tensor, 0) | |||
bias_slice = P.Gather()(bias, slice_idx_tensor, 0) | |||
msa_act_slice = self.attn_mod(msa_act_slice, msa_act_slice, bias_slice, index) | |||
msa_act_slice = P.Reshape()(msa_act_slice, ((1,) + P.Shape()(msa_act_slice))) | |||
msa_act_tuple = msa_act_tuple + (msa_act_slice,) | |||
slice_idx += 1 | |||
slice_idx_tensor += 1 | |||
while slice_idx < self.slice_num: | |||
msa_act_slice = P.Gather()(msa_act, slice_idx_tensor, 0) | |||
msa_act_slice = F.depend(msa_act_slice, msa_act_tuple[-1]) | |||
bias_slice = P.Gather()(bias, slice_idx_tensor, 0) | |||
msa_act_slice = self.attn_mod(msa_act_slice, msa_act_slice, bias_slice, index) | |||
msa_act_slice = P.Reshape()(msa_act_slice, ((1,) + P.Shape()(msa_act_slice))) | |||
msa_act_tuple = msa_act_tuple + (msa_act_slice,) | |||
slice_idx += 1 | |||
slice_idx_tensor += 1 | |||
msa_act = P.Concat()(msa_act_tuple) | |||
msa_act = P.Reshape()(msa_act, msa_act_ori_shape) | |||
msa_act = mnp.swapaxes(msa_act, -2, -3) | |||
return msa_act | |||
msa_act = self.attn_mod(msa_act, msa_act, bias, index) | |||
msa_act = mnp.swapaxes(msa_act, -2, -3) | |||
return msa_act | |||
class GlobalAttention(nn.Cell): | |||
'''global attention''' | |||
def __init__(self, config, key_dim, value_dim, output_dim, batch_size=None): | |||
super(GlobalAttention, self).__init__() | |||
self.config = config | |||
self.key_dim = key_dim | |||
self.ori_key_dim = key_dim | |||
self.value_dim = value_dim | |||
self.ori_value_dim = value_dim | |||
self.num_head = self.config.num_head | |||
self.key_dim = self.key_dim // self.num_head | |||
self.value_dim = self.value_dim // self.num_head | |||
self.output_dim = output_dim | |||
self.matmul_trans_b = P.MatMul(transpose_b=True) | |||
self.batch_matmul = P.BatchMatMul() | |||
self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True) | |||
self.matmul = P.MatMul() | |||
self.softmax = nn.Softmax() | |||
self.sigmoid = nn.Sigmoid() | |||
self.gating = self.config.gating | |||
self.batch_size = batch_size | |||
self._init_parameter() | |||
def _init_parameter(self): | |||
'''init parameter''' | |||
self.linear_q_weights = Parameter( | |||
Tensor(np.zeros((self.batch_size, self.ori_key_dim, self.num_head, self.key_dim)), mstype.float32)) | |||
self.linear_k_weights = Parameter( | |||
Tensor(np.zeros((self.batch_size, self.ori_value_dim, self.key_dim)), mstype.float32)) | |||
self.linear_v_weights = Parameter( | |||
Tensor(np.zeros((self.batch_size, self.ori_value_dim, self.value_dim)), mstype.float32)) | |||
self.linear_output_weights = Parameter( | |||
Tensor(np.zeros((self.batch_size, self.output_dim, self.num_head * self.value_dim)), mstype.float32)) | |||
self.o_biases = Parameter(Tensor(np.zeros((self.batch_size, self.output_dim)), mstype.float32)) | |||
if self.gating: | |||
self.linear_gating_weights = Parameter( | |||
Tensor(np.zeros((self.batch_size, self.num_head * self.value_dim, self.ori_key_dim)), mstype.float32)) | |||
self.gating_biases = Parameter(Tensor(np.zeros((self.batch_size, self.ori_key_dim)), mstype.float32)) | |||
def construct(self, q_data, m_data, q_mask, bias, index): | |||
'''construct''' | |||
q_weights = P.Gather()(self.linear_q_weights, index, 0) | |||
k_weights = P.Gather()(self.linear_k_weights, index, 0) | |||
v_weights = P.Gather()(self.linear_v_weights, index, 0) | |||
output_weights = P.Gather()(self.linear_output_weights, index, 0) | |||
output_bias = P.Gather()(self.o_biases, index, 0) | |||
gating_weights = 0 | |||
gating_bias = 0 | |||
if self.gating: | |||
gating_weights = P.Gather()(self.linear_gating_weights, index, 0) | |||
gating_bias = P.Gather()(self.gating_biases, index, 0) | |||
b, _, _ = m_data.shape | |||
v_weights = v_weights[None, ...] | |||
v_weights = mnp.broadcast_to(v_weights, (b, self.value_dim * self.num_head, self.value_dim)) | |||
v = self.batch_matmul(m_data.astype(mstype.float16), v_weights.astype(mstype.float16)) | |||
q_avg = mask_mean(q_mask, q_data, axis=1) | |||
q_weights = P.Reshape()(q_weights, (-1, self.num_head * self.key_dim)) | |||
q = P.Reshape()(self.matmul(q_avg.astype(mstype.float16), q_weights.astype(mstype.float16)), | |||
(-1, self.num_head, self.key_dim)) * (self.key_dim ** (-0.5)) | |||
k_weights = k_weights[None, ...] | |||
k_weights = mnp.broadcast_to(k_weights, (b, self.value_dim * self.num_head, self.key_dim)) | |||
k = self.batch_matmul(m_data.astype(mstype.float16), k_weights.astype(mstype.float16)) | |||
bias = (1e9 * (q_mask[:, None, :, 0] - 1.)) | |||
logits = self.batch_matmul_trans_b(q.astype(mstype.float16), k.astype(mstype.float16)) + bias.astype( | |||
mstype.float16) | |||
weights = self.softmax(logits.astype(mstype.float32)) | |||
weighted_avg = self.batch_matmul(weights.astype(mstype.float16), v.astype(mstype.float16)) | |||
if self.gating: | |||
# gate_values = self.linear_gating(q_data).astype(mstype.float32) | |||
q_data_shape = P.Shape()(q_data) | |||
if len(q_data_shape) != 2: | |||
q_data = P.Reshape()(q_data, (-1, q_data_shape[-1])) | |||
out_shape = q_data_shape[:-1] + (-1,) | |||
gate_values = P.Reshape()(self.matmul_trans_b(q_data.astype(mstype.float16), | |||
gating_weights.astype(mstype.float16)) + | |||
gating_bias.astype(mstype.float16), out_shape) | |||
gate_values = P.Reshape()(self.sigmoid(gate_values.astype(mstype.float32)), | |||
(b, -1, self.num_head, self.value_dim)) | |||
weighted_avg = P.Reshape()(weighted_avg[:, None] * gate_values, (-1, self.num_head * self.value_dim)) | |||
weighted_avg_shape = P.Shape()(weighted_avg) | |||
if len(weighted_avg_shape) != 2: | |||
weighted_avg = P.Reshape()(weighted_avg, (-1, weighted_avg_shape[-1])) | |||
output = P.Reshape()(self.matmul_trans_b(weighted_avg.astype(mstype.float16), | |||
output_weights.astype(mstype.float16)) | |||
+ output_bias.astype(mstype.float16), (b, -1, self.output_dim)) | |||
else: | |||
weighted_avg = P.Reshape()(weighted_avg, (-1, self.num_head * self.value_dim)) | |||
# output = self.linear_gating(weighted_avg) | |||
weighted_avg_shape = P.Shape()(weighted_avg) | |||
if len(weighted_avg_shape) != 2: | |||
weighted_avg = P.Reshape()(weighted_avg, (-1, weighted_avg_shape[-1])) | |||
out_shape = weighted_avg_shape[:-1] + (-1,) | |||
output = P.Reshape()(self.matmul_trans_b(weighted_avg.astype(mstype.float16), | |||
output_weights.astype(mstype.float16)) + | |||
output_bias.astype(mstype.float16), out_shape) | |||
output = output[:, None] | |||
return output | |||
class MSAColumnGlobalAttention(nn.Cell): | |||
'''MSA column global attention''' | |||
def __init__(self, config, msa_act_dim, batch_size=None, slice_num=0): | |||
super(MSAColumnGlobalAttention, self).__init__() | |||
self.config = config | |||
self.attn_mod = GlobalAttention(self.config, msa_act_dim, msa_act_dim, msa_act_dim, batch_size) | |||
self.query_norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5) | |||
self.batch_size = batch_size | |||
self.slice_num = slice_num | |||
self.msa_act_dim = msa_act_dim | |||
self.idx = Tensor(0, mstype.int32) | |||
self._init_parameter() | |||
def _init_parameter(self): | |||
'''init parameter''' | |||
self.query_norm_gammas = Parameter(Tensor(np.zeros((self.batch_size, self.msa_act_dim)), mstype.float32)) | |||
self.query_norm_betas = Parameter(Tensor(np.zeros((self.batch_size, self.msa_act_dim)), mstype.float32)) | |||
def construct(self, msa_act, msa_mask, index): | |||
'''construct''' | |||
query_norm_gamma = P.Gather()(self.query_norm_gammas, index, 0) | |||
query_norm_beta = P.Gather()(self.query_norm_betas, index, 0) | |||
msa_act = mnp.swapaxes(msa_act, -2, -3) | |||
msa_mask = mnp.swapaxes(msa_mask, -1, -2) | |||
bias = (1e9 * (msa_mask - 1.))[:, None, None, :] | |||
msa_act, _, _ = self.query_norm(msa_act.astype(mstype.float32), | |||
query_norm_gamma.astype(mstype.float32), | |||
query_norm_beta.astype(mstype.float32)) | |||
msa_mask = mnp.expand_dims(msa_mask, axis=-1) | |||
if self.slice_num: | |||
msa_act_ori_shape = P.Shape()(msa_act) | |||
slice_shape = (self.slice_num, -1) + msa_act_ori_shape[1:] | |||
msa_act = P.Reshape()(msa_act, slice_shape).astype(mstype.float16) | |||
bias_shape = P.Shape()(bias) | |||
bias = P.Reshape()(bias, slice_shape[:2] + bias_shape[1:]) | |||
msa_mask_shape = P.Shape()(msa_mask) | |||
msa_mask = P.Reshape()(msa_mask, slice_shape[:2] + msa_mask_shape[1:]) | |||
slice_idx = 0 | |||
slice_idx_tensor = self.idx | |||
msa_act_tuple = () | |||
msa_act_slice = P.Gather()(msa_act, slice_idx_tensor, 0) | |||
msa_mask_slice = P.Gather()(msa_mask, slice_idx_tensor, 0) | |||
bias_slice = P.Gather()(bias, slice_idx_tensor, 0) | |||
msa_act_slice = self.attn_mod(msa_act_slice, msa_act_slice, msa_mask_slice, bias_slice, index) | |||
msa_act_slice = P.Reshape()(msa_act_slice, ((1,) + P.Shape()(msa_act_slice))) | |||
msa_act_tuple = msa_act_tuple + (msa_act_slice,) | |||
slice_idx += 1 | |||
slice_idx_tensor += 1 | |||
while slice_idx < self.slice_num: | |||
msa_act_slice = P.Gather()(msa_act, slice_idx_tensor, 0) | |||
msa_act_slice = F.depend(msa_act_slice, msa_act_tuple[-1]) | |||
msa_mask_slice = P.Gather()(msa_mask, slice_idx_tensor, 0) | |||
bias_slice = P.Gather()(bias, slice_idx_tensor, 0) | |||
msa_act_slice = self.attn_mod(msa_act_slice, msa_act_slice, msa_mask_slice, bias_slice, index) | |||
msa_act_slice = P.Reshape()(msa_act_slice, ((1,) + P.Shape()(msa_act_slice))) | |||
msa_act_tuple = msa_act_tuple + (msa_act_slice,) | |||
slice_idx += 1 | |||
slice_idx_tensor += 1 | |||
msa_act = P.Concat()(msa_act_tuple) | |||
msa_act = P.Reshape()(msa_act, msa_act_ori_shape) | |||
msa_act = mnp.swapaxes(msa_act, -2, -3) | |||
return msa_act | |||
msa_act = self.attn_mod(msa_act, msa_act, msa_mask, bias, index) | |||
msa_act = mnp.swapaxes(msa_act, -2, -3) | |||
return msa_act | |||
class Transition(nn.Cell): | |||
'''transition''' | |||
def __init__(self, config, layer_norm_dim, batch_size=None, slice_num=0): | |||
super(Transition, self).__init__() | |||
self.config = config | |||
self.input_layer_norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5) | |||
self.matmul = P.MatMul(transpose_b=True) | |||
self.layer_norm_dim = layer_norm_dim | |||
self.num_intermediate = int(layer_norm_dim * self.config.num_intermediate_factor) | |||
self.batch_size = batch_size | |||
self.slice_num = slice_num | |||
self.relu = nn.ReLU() | |||
self.idx = Tensor(0, mstype.int32) | |||
self._init_parameter() | |||
def _init_parameter(self): | |||
'''init parameter''' | |||
self.input_layer_norm_gammas = Parameter( | |||
Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32)) | |||
self.input_layer_norm_betas = Parameter( | |||
Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32)) | |||
self.transition1_weights = Parameter( | |||
Tensor(np.zeros((self.batch_size, self.num_intermediate, self.layer_norm_dim)), mstype.float32)) | |||
self.transition1_biases = Parameter(Tensor(np.zeros((self.batch_size, self.num_intermediate)), mstype.float32)) | |||
self.transition2_weights = Parameter( | |||
Tensor(np.zeros((self.batch_size, self.layer_norm_dim, self.num_intermediate)), mstype.float32)) | |||
self.transition2_biases = Parameter(Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32)) | |||
def construct(self, act, index): | |||
'''construct''' | |||
input_layer_norm_gamma = P.Gather()(self.input_layer_norm_gammas, index, 0) | |||
input_layer_norm_beta = P.Gather()(self.input_layer_norm_betas, index, 0) | |||
transition1_weight = P.Gather()(self.transition1_weights, index, 0) | |||
transition1_bias = P.Gather()(self.transition1_biases, index, 0) | |||
transition2_weight = P.Gather()(self.transition2_weights, index, 0) | |||
transition2_bias = P.Gather()(self.transition2_biases, index, 0) | |||
act, _, _ = self.input_layer_norm(act.astype(mstype.float32), input_layer_norm_gamma.astype(mstype.float32), | |||
input_layer_norm_beta.astype(mstype.float32)) | |||
if self.slice_num: | |||
act_ori_shape = P.Shape()(act) | |||
slice_shape = (self.slice_num, -1) + act_ori_shape[1:] | |||
act = P.Reshape()(act, slice_shape).astype(mstype.float16) | |||
slice_idx = 0 | |||
slice_idx_tensor = self.idx | |||
act_tuple = () | |||
act_slice = P.Gather()(act, slice_idx_tensor, 0) | |||
act_shape = P.Shape()(act_slice) | |||
if len(act_shape) != 2: | |||
act_slice = P.Reshape()(act_slice, (-1, act_shape[-1])) | |||
act_slice = self.relu( | |||
P.BiasAdd()(self.matmul(act_slice.astype(mstype.float16), transition1_weight.astype(mstype.float16)), | |||
transition1_bias.astype(mstype.float16)).astype(mstype.float32)) | |||
act_slice = P.BiasAdd()( | |||
self.matmul(act_slice.astype(mstype.float16), transition2_weight.astype(mstype.float16)), | |||
transition2_bias.astype(mstype.float16)) | |||
act_slice = P.Reshape()(act_slice, act_shape) | |||
act_slice = P.Reshape()(act_slice, ((1,) + P.Shape()(act_slice))) | |||
act_tuple = act_tuple + (act_slice,) | |||
slice_idx += 1 | |||
slice_idx_tensor += 1 | |||
while slice_idx < self.slice_num: | |||
act_slice = P.Gather()(act, slice_idx_tensor, 0) | |||
act_slice = F.depend(act_slice, act_tuple[-1]) | |||
act_shape = P.Shape()(act_slice) | |||
if len(act_shape) != 2: | |||
act_slice = P.Reshape()(act_slice, (-1, act_shape[-1])) | |||
act_slice = self.relu(P.BiasAdd()( | |||
self.matmul(act_slice.astype(mstype.float16), transition1_weight.astype(mstype.float16)), | |||
transition1_bias.astype(mstype.float16)).astype(mstype.float32)) | |||
act_slice = P.BiasAdd()( | |||
self.matmul(act_slice.astype(mstype.float16), transition2_weight.astype(mstype.float16)), | |||
transition2_bias.astype(mstype.float16)) | |||
act_slice = P.Reshape()(act_slice, act_shape) | |||
act_slice = P.Reshape()(act_slice, ((1,) + P.Shape()(act_slice))) | |||
act_tuple = act_tuple + (act_slice,) | |||
slice_idx += 1 | |||
slice_idx_tensor += 1 | |||
act = P.Concat()(act_tuple) | |||
act = P.Reshape()(act, act_ori_shape) | |||
return act | |||
act_shape = P.Shape()(act) | |||
if len(act_shape) != 2: | |||
act = P.Reshape()(act, (-1, act_shape[-1])) | |||
act = self.relu( | |||
P.BiasAdd()(self.matmul(act.astype(mstype.float16), transition1_weight.astype(mstype.float16)), | |||
transition1_bias.astype(mstype.float16)).astype(mstype.float32)) | |||
act = P.BiasAdd()(self.matmul(act.astype(mstype.float16), transition2_weight.astype(mstype.float16)), | |||
transition2_bias.astype(mstype.float16)) | |||
act = P.Reshape()(act, act_shape) | |||
return act | |||
class OuterProductMean(nn.Cell): | |||
'''outerproduct mean''' | |||
def __init__(self, config, act_dim, num_output_channel, batch_size=None, slice_num=0): | |||
super(OuterProductMean, self).__init__() | |||
self.num_output_channel = num_output_channel | |||
self.config = config | |||
self.layer_norm_input = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5) | |||
self.matmul_trans_b = P.MatMul(transpose_b=True) | |||
self.matmul = P.MatMul() | |||
self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True) | |||
self.act_dim = act_dim | |||
self.batch_size = batch_size | |||
self.slice_num = slice_num | |||
self.idx = Tensor(0, mstype.int32) | |||
self._init_parameter() | |||
def _init_parameter(self): | |||
'''init parameter''' | |||
self.layer_norm_input_gammas = Parameter(Tensor(np.zeros((self.batch_size, self.act_dim)), mstype.float32)) | |||
self.layer_norm_input_betas = Parameter(Tensor(np.zeros((self.batch_size, self.act_dim)), mstype.float32)) | |||
self.left_projection_weights = Parameter( | |||
Tensor(np.zeros((self.batch_size, self.config.num_outer_channel, self.act_dim)), mstype.float32)) | |||
self.left_projection_biases = Parameter( | |||
Tensor(np.zeros((self.batch_size, self.config.num_outer_channel)), mstype.float32)) | |||
self.right_projection_weights = Parameter( | |||
Tensor(np.zeros((self.batch_size, self.config.num_outer_channel, self.act_dim)), mstype.float32)) | |||
self.right_projection_biases = Parameter( | |||
Tensor(np.zeros((self.batch_size, self.config.num_outer_channel)), mstype.float32)) | |||
self.linear_output_weights = Parameter(Tensor(np.zeros( | |||
(self.batch_size, self.num_output_channel, self.config.num_outer_channel * | |||
self.config.num_outer_channel)), mstype.float32)) | |||
self.o_biases = Parameter(Tensor(np.zeros((self.batch_size, self.num_output_channel)), mstype.float32)) | |||
def construct(self, act, mask, index): | |||
'''construct''' | |||
layer_norm_input_gamma = P.Gather()(self.layer_norm_input_gammas, index, 0) | |||
layer_norm_input_beta = P.Gather()(self.layer_norm_input_betas, index, 0) | |||
left_projection_weight = P.Gather()(self.left_projection_weights, index, 0) | |||
left_projection_bias = P.Gather()(self.left_projection_biases, index, 0) | |||
right_projection_weight = P.Gather()(self.right_projection_weights, index, 0) | |||
right_projection_bias = P.Gather()(self.right_projection_biases, index, 0) | |||
linear_output_weight = P.Gather()(self.linear_output_weights, index, 0) | |||
linear_output_bias = P.Gather()(self.o_biases, index, 0) | |||
mask = mask[..., None] | |||
act, _, _ = self.layer_norm_input(act.astype(mstype.float32), | |||
layer_norm_input_gamma.astype(mstype.float32), | |||
layer_norm_input_beta.astype(mstype.float32)) | |||
act_shape = P.Shape()(act) | |||
if len(act_shape) != 2: | |||
act = P.Reshape()(act, (-1, act_shape[-1])) | |||
out_shape = act_shape[:-1] + (-1,) | |||
left_act = mask * P.Reshape()(P.BiasAdd()(self.matmul_trans_b(act.astype(mstype.float16), | |||
left_projection_weight.astype(mstype.float16)), | |||
left_projection_bias.astype(mstype.float16)), out_shape) | |||
right_act = mask * P.Reshape()(P.BiasAdd()(self.matmul_trans_b(act.astype(mstype.float16), | |||
right_projection_weight.astype(mstype.float16)), | |||
right_projection_bias.astype(mstype.float16)), out_shape) | |||
_, d, e = right_act.shape | |||
if self.slice_num: | |||
left_act_shape = P.Shape()(left_act) | |||
slice_shape = (left_act_shape[0],) + (self.slice_num, -1) + (left_act_shape[-1],) | |||
left_act = P.Reshape()(left_act, slice_shape) | |||
slice_idx = 0 | |||
slice_idx_tensor = self.idx | |||
act_tuple = () | |||
left_act_slice = P.Gather()(left_act, slice_idx_tensor, 1) | |||
a, b, c = left_act_slice.shape | |||
left_act_slice = P.Reshape()(mnp.transpose(left_act_slice.astype(mstype.float16), [2, 1, 0]), (-1, a)) | |||
right_act = P.Reshape()(right_act, (a, -1)) | |||
act_slice = P.Reshape()(P.Transpose()(P.Reshape()(self.matmul(left_act_slice.astype(mstype.float16), | |||
right_act.astype(mstype.float16)), | |||
(c, b, d, e)), (2, 1, 0, 3)), (d, b, c * e)) | |||
act_slice_shape = P.Shape()(act_slice) | |||
if len(act_shape) != 2: | |||
act_slice = P.Reshape()(act_slice, (-1, act_slice_shape[-1])) | |||
act_slice = P.Reshape()(P.BiasAdd()(self.matmul_trans_b(act_slice.astype(mstype.float16), | |||
linear_output_weight.astype(mstype.float16)), | |||
linear_output_bias.astype(mstype.float16)), (d, b, -1)) | |||
act_slice = mnp.transpose(act_slice.astype(mstype.float16), [1, 0, 2]) | |||
act_slice = P.Reshape()(act_slice, ((1,) + P.Shape()(act_slice))) | |||
act_tuple = act_tuple + (act_slice,) | |||
slice_idx += 1 | |||
slice_idx_tensor += 1 | |||
while slice_idx < self.slice_num: | |||
left_act_slice = P.Gather()(left_act, slice_idx_tensor, 1) | |||
left_act_slice = F.depend(left_act_slice, act_tuple[-1]) | |||
a, b, c = left_act_slice.shape | |||
left_act_slice = P.Reshape()(mnp.transpose(left_act_slice.astype(mstype.float16), [2, 1, 0]), (-1, a)) | |||
right_act = P.Reshape()(right_act, (a, -1)) | |||
act_slice = P.Reshape()(P.Transpose()(P.Reshape()(self.matmul(left_act_slice.astype(mstype.float16), | |||
right_act.astype(mstype.float16)), | |||
(c, b, d, e)), (2, 1, 0, 3)), (d, b, c * e)) | |||
act_slice_shape = P.Shape()(act_slice) | |||
if len(act_shape) != 2: | |||
act_slice = P.Reshape()(act_slice, (-1, act_slice_shape[-1])) | |||
act_slice = P.Reshape()(P.BiasAdd()(self.matmul_trans_b(act_slice.astype(mstype.float16), | |||
linear_output_weight.astype(mstype.float16)), | |||
linear_output_bias.astype(mstype.float16)), (d, b, -1)) | |||
act_slice = mnp.transpose(act_slice.astype(mstype.float16), [1, 0, 2]) | |||
act_slice = P.Reshape()(act_slice, ((1,) + P.Shape()(act_slice))) | |||
act_tuple = act_tuple + (act_slice,) | |||
slice_idx += 1 | |||
slice_idx_tensor += 1 | |||
act = P.Concat()(act_tuple) | |||
act_shape = P.Shape()(act) | |||
act = P.Reshape()(act, (-1, act_shape[-2], act_shape[-1])) | |||
epsilon = 1e-3 | |||
tmp_mask = P.Transpose()(mask.astype(mstype.float16), (2, 1, 0)) | |||
norm = P.Transpose()(self.batch_matmul_trans_b(tmp_mask.astype(mstype.float16), | |||
tmp_mask.astype(mstype.float16)), | |||
(1, 2, 0)).astype(mstype.float32) | |||
act /= epsilon + norm | |||
return act | |||
a, b, c = left_act.shape | |||
left_act = P.Reshape()(mnp.transpose(left_act.astype(mstype.float16), [2, 1, 0]), (-1, a)) | |||
right_act = P.Reshape()(right_act, (a, -1)) | |||
act = P.Reshape()(P.Transpose()(P.Reshape()(self.matmul(left_act.astype(mstype.float16), | |||
right_act.astype(mstype.float16)), | |||
(c, b, d, e)), (2, 1, 0, 3)), (d, b, c * e)) | |||
act_shape = P.Shape()(act) | |||
if len(act_shape) != 2: | |||
act = P.Reshape()(act, (-1, act_shape[-1])) | |||
act = P.Reshape()(P.BiasAdd()(self.matmul_trans_b(act.astype(mstype.float16), | |||
linear_output_weight.astype(mstype.float16)), | |||
linear_output_bias.astype(mstype.float16)), (d, b, -1)) | |||
act = mnp.transpose(act.astype(mstype.float16), [1, 0, 2]) | |||
epsilon = 1e-3 | |||
tmp_mask = P.Transpose()(mask.astype(mstype.float16), (2, 1, 0)) | |||
norm = P.Transpose()(self.batch_matmul_trans_b(tmp_mask.astype(mstype.float16), | |||
tmp_mask.astype(mstype.float16)), | |||
(1, 2, 0)).astype(mstype.float32) | |||
act /= epsilon + norm | |||
return act | |||
class TriangleMultiplication(nn.Cell): | |||
'''triangle multiplication''' | |||
def __init__(self, config, layer_norm_dim, batch_size): | |||
super(TriangleMultiplication, self).__init__() | |||
self.config = config | |||
self.layer_norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5) | |||
self.matmul = P.MatMul(transpose_b=True) | |||
self.sigmoid = nn.Sigmoid() | |||
self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True) | |||
equation = ["ikc,jkc->ijc", "kjc,kic->ijc"] | |||
if self.config.equation not in equation: | |||
print("TriangleMultiplication Not Suppl") | |||
if self.config.equation == "ikc,jkc->ijc": | |||
self.equation = True | |||
elif self.config.equation == "kjc,kic->ijc": | |||
self.equation = False | |||
else: | |||
self.equation = None | |||
self.batch_size = batch_size | |||
self.layer_norm_dim = layer_norm_dim | |||
self._init_parameter() | |||
def _init_parameter(self): | |||
'''init parameter''' | |||
self.layer_norm_input_gammas = Parameter( | |||
Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32)) | |||
self.layer_norm_input_betas = Parameter( | |||
Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32)) | |||
self.left_projection_weights = Parameter( | |||
Tensor(np.zeros((self.batch_size, self.config.num_intermediate_channel, self.layer_norm_dim)), | |||
mstype.float32)) | |||
self.left_projection_biases = Parameter( | |||
Tensor(np.zeros((self.batch_size, self.config.num_intermediate_channel)), mstype.float32)) | |||
self.right_projection_weights = Parameter( | |||
Tensor(np.zeros((self.batch_size, self.config.num_intermediate_channel, self.layer_norm_dim)), | |||
mstype.float32)) | |||
self.right_projection_biases = Parameter( | |||
Tensor(np.zeros((self.batch_size, self.config.num_intermediate_channel)), mstype.float32)) | |||
self.left_gate_weights = Parameter( | |||
Tensor(np.zeros((self.batch_size, self.config.num_intermediate_channel, self.layer_norm_dim)), | |||
mstype.float32)) | |||
self.left_gate_biases = Parameter( | |||
Tensor(np.zeros((self.batch_size, self.config.num_intermediate_channel)), mstype.float32)) | |||
self.right_gate_weights = Parameter( | |||
Tensor(np.zeros((self.batch_size, self.config.num_intermediate_channel, self.layer_norm_dim)), | |||
mstype.float32)) | |||
self.right_gate_biases = Parameter( | |||
Tensor(np.zeros((self.batch_size, self.config.num_intermediate_channel)), mstype.float32)) | |||
self.center_layer_norm_gammas = Parameter( | |||
Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32)) | |||
self.center_layer_norm_betas = Parameter( | |||
Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32)) | |||
self.output_projection_weights = Parameter( | |||
Tensor(np.zeros((self.batch_size, self.layer_norm_dim, self.layer_norm_dim)), mstype.float32)) | |||
self.output_projection_biases = Parameter( | |||
Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32)) | |||
self.gating_linear_weights = Parameter( | |||
Tensor(np.zeros((self.batch_size, self.layer_norm_dim, self.layer_norm_dim)), mstype.float32)) | |||
self.gating_linear_biases = Parameter(Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32)) | |||
def construct(self, act, mask, index): | |||
'''construct''' | |||
layer_norm_input_gamma = P.Gather()(self.layer_norm_input_gammas, index, 0) | |||
layer_norm_input_beta = P.Gather()(self.layer_norm_input_betas, index, 0) | |||
left_projection_weight = P.Gather()(self.left_projection_weights, index, 0) | |||
left_projection_bias = P.Gather()(self.left_projection_biases, index, 0) | |||
right_projection_weight = P.Gather()(self.right_projection_weights, index, 0) | |||
right_projection_bias = P.Gather()(self.right_projection_biases, index, 0) | |||
left_gate_weight = P.Gather()(self.left_gate_weights, index, 0) | |||
left_gate_bias = P.Gather()(self.left_gate_biases, index, 0) | |||
right_gate_weight = P.Gather()(self.right_gate_weights, index, 0) | |||
right_gate_bias = P.Gather()(self.right_gate_biases, index, 0) | |||
center_layer_norm_gamma = P.Gather()(self.center_layer_norm_gammas, index, 0) | |||
center_layer_norm_beta = P.Gather()(self.center_layer_norm_betas, index, 0) | |||
output_projection_weight = P.Gather()(self.output_projection_weights, index, 0) | |||
output_projection_bias = P.Gather()(self.output_projection_biases, index, 0) | |||
gating_linear_weight = P.Gather()(self.gating_linear_weights, index, 0) | |||
gating_linear_bias = P.Gather()(self.gating_linear_biases, index, 0) | |||
mask = mask[..., None] | |||
act, _, _ = self.layer_norm(act.astype(mstype.float32), | |||
layer_norm_input_gamma.astype(mstype.float32), | |||
layer_norm_input_beta.astype(mstype.float32)) | |||
input_act = act | |||
act_shape = P.Shape()(act) | |||
if len(act_shape) != 2: | |||
act = P.Reshape()(act, (-1, act_shape[-1])) | |||
out_shape = act_shape[:-1] + (-1,) | |||
left_projection = mask * P.Reshape()( | |||
P.BiasAdd()(self.matmul(act.astype(mstype.float16), left_projection_weight.astype(mstype.float16)), | |||
left_projection_bias.astype(mstype.float16)), out_shape) | |||
left_projection = left_projection.astype(mstype.float16) | |||
act = F.depend(act, left_projection) | |||
left_gate_values = self.sigmoid(P.Reshape()( | |||
P.BiasAdd()(self.matmul(act.astype(mstype.float16), left_gate_weight.astype(mstype.float16)), | |||
left_gate_bias.astype(mstype.float16)), out_shape).astype(mstype.float32)) | |||
left_proj_act = left_projection * left_gate_values | |||
act = F.depend(act, left_proj_act) | |||
right_projection = mask * P.Reshape()( | |||
P.BiasAdd()(self.matmul(act.astype(mstype.float16), right_projection_weight.astype(mstype.float16)), | |||
right_projection_bias.astype(mstype.float16)), out_shape) | |||
right_projection = right_projection.astype(mstype.float16) | |||
act = F.depend(act, right_projection) | |||
right_gate_values = self.sigmoid(P.Reshape()( | |||
P.BiasAdd()(self.matmul(act.astype(mstype.float16), right_gate_weight.astype(mstype.float16)), | |||
right_gate_bias.astype(mstype.float16)), out_shape).astype(mstype.float32)) | |||
right_proj_act = right_projection * right_gate_values | |||
left_proj_act = F.depend(left_proj_act, right_proj_act) | |||
if self.equation is not None: | |||
if self.equation: | |||
left_proj_act_tmp = P.Transpose()(left_proj_act.astype(mstype.float16), (2, 0, 1)) | |||
right_proj_act_tmp = P.Transpose()(right_proj_act.astype(mstype.float16), (2, 0, 1)) | |||
act = self.batch_matmul_trans_b(left_proj_act_tmp, right_proj_act_tmp) | |||
act = P.Transpose()(act, (1, 2, 0)).astype(mstype.float32) | |||
else: | |||
left_proj_act_tmp = P.Transpose()(left_proj_act.astype(mstype.float16), (2, 1, 0)) | |||
right_proj_act_tmp = P.Transpose()(right_proj_act.astype(mstype.float16), (2, 1, 0)) | |||
act = self.batch_matmul_trans_b(left_proj_act_tmp, right_proj_act_tmp) | |||
act = P.Transpose()(act, (2, 1, 0)).astype(mstype.float32) | |||
act, _, _ = self.layer_norm(act.astype(mstype.float32), | |||
center_layer_norm_gamma.astype(mstype.float32), | |||
center_layer_norm_beta.astype(mstype.float32)) | |||
act_shape = P.Shape()(act) | |||
if len(act_shape) != 2: | |||
act = P.Reshape()(act, (-1, act_shape[-1])) | |||
out_shape = act_shape[:-1] + (-1,) | |||
act = P.Reshape()( | |||
P.BiasAdd()(self.matmul(act.astype(mstype.float16), output_projection_weight.astype(mstype.float16)), | |||
output_projection_bias.astype(mstype.float16)), out_shape) | |||
input_act_shape = P.Shape()(input_act) | |||
if len(input_act_shape) != 2: | |||
input_act = P.Reshape()(input_act, (-1, input_act_shape[-1])) | |||
out_shape = input_act_shape[:-1] + (-1,) | |||
gate_values = self.sigmoid(P.Reshape()( | |||
P.BiasAdd()(self.matmul(input_act.astype(mstype.float16), gating_linear_weight.astype(mstype.float16)), | |||
gating_linear_bias.astype(mstype.float16)), out_shape).astype(mstype.float32)) | |||
act = act * gate_values | |||
return act | |||
class TriangleAttention(nn.Cell): | |||
'''triangle attention''' | |||
def __init__(self, config, layer_norm_dim, batch_size=None, slice_num=0): | |||
super(TriangleAttention, self).__init__() | |||
self.config = config | |||
self.orientation_is_per_column = (self.config.orientation == 'per_column') | |||
self.init_factor = Tensor(1. / np.sqrt(layer_norm_dim), mstype.float32) | |||
self.query_norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=1e-5) | |||
self.matmul = P.MatMul(transpose_b=True) | |||
self.attn_mod = Attention(self.config, layer_norm_dim, layer_norm_dim, layer_norm_dim, batch_size) | |||
self.batch_size = batch_size | |||
self.slice_num = slice_num | |||
self.layer_norm_dim = layer_norm_dim | |||
self.idx = Tensor(0, mstype.int32) | |||
self._init_parameter() | |||
def _init_parameter(self): | |||
'''init parameter''' | |||
self.query_norm_gammas = Parameter(Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32)) | |||
self.query_norm_betas = Parameter(Tensor(np.zeros((self.batch_size, self.layer_norm_dim)), mstype.float32)) | |||
self.feat_2d_weights = Parameter( | |||
Tensor(np.zeros((self.batch_size, self.config.num_head, self.layer_norm_dim)), mstype.float32)) | |||
def construct(self, pair_act, pair_mask, index): | |||
'''construct''' | |||
query_norm_gamma = P.Gather()(self.query_norm_gammas, index, 0) | |||
query_norm_beta = P.Gather()(self.query_norm_betas, index, 0) | |||
feat_2d_weight = P.Gather()(self.feat_2d_weights, index, 0) | |||
if self.orientation_is_per_column: | |||
pair_act = mnp.swapaxes(pair_act, -2, -3) | |||
pair_mask = mnp.swapaxes(pair_mask, -1, -2) | |||
bias = (1e9 * (pair_mask - 1.))[:, None, None, :] | |||
pair_act, _, _ = self.query_norm(pair_act.astype(mstype.float32), | |||
query_norm_gamma.astype(mstype.float32), | |||
query_norm_beta.astype(mstype.float32)) | |||
q, k, _ = pair_act.shape | |||
nonbatched_bias = self.matmul(P.Reshape()(pair_act.astype(mstype.float16), (-1, pair_act.shape[-1])), | |||
feat_2d_weight.astype(mstype.float16)) | |||
nonbatched_bias = P.Transpose()(P.Reshape()(nonbatched_bias, (q, k, -1)), (2, 0, 1)) | |||
if self.slice_num: | |||
pair_act_ori_shape = P.Shape()(pair_act) | |||
slice_shape = (self.slice_num, -1) + pair_act_ori_shape[1:] | |||
pair_act = P.Reshape()(pair_act, slice_shape).astype(mstype.float16) | |||
bias_shape = P.Shape()(bias) | |||
bias = P.Reshape()(bias, slice_shape[:2] + bias_shape[1:]) | |||
slice_idx = 0 | |||
slice_idx_tensor = self.idx | |||
pair_act_tuple = () | |||
pair_act_slice = P.Gather()(pair_act, slice_idx_tensor, 0) | |||
bias_slice = P.Gather()(bias, slice_idx_tensor, 0) | |||
pair_act_slice = self.attn_mod(pair_act_slice, pair_act_slice, bias_slice, index, nonbatched_bias) | |||
pair_act_slice = P.Reshape()(pair_act_slice, ((1,) + P.Shape()(pair_act_slice))) | |||
pair_act_tuple = pair_act_tuple + (pair_act_slice,) | |||
slice_idx += 1 | |||
slice_idx_tensor += 1 | |||
while slice_idx < self.slice_num: | |||
pair_act_slice = P.Gather()(pair_act, slice_idx_tensor, 0) | |||
pair_act_slice = F.depend(pair_act_slice, pair_act_tuple[-1]) | |||
bias_slice = P.Gather()(bias, slice_idx_tensor, 0) | |||
pair_act_slice = self.attn_mod(pair_act_slice, pair_act_slice, bias_slice, index, nonbatched_bias) | |||
pair_act_slice = P.Reshape()(pair_act_slice, ((1,) + P.Shape()(pair_act_slice))) | |||
pair_act_tuple = pair_act_tuple + (pair_act_slice,) | |||
slice_idx += 1 | |||
slice_idx_tensor += 1 | |||
pair_act = P.Concat()(pair_act_tuple) | |||
pair_act = P.Reshape()(pair_act, pair_act_ori_shape) | |||
if self.orientation_is_per_column: | |||
pair_act = mnp.swapaxes(pair_act, -2, -3) | |||
return pair_act | |||
pair_act = self.attn_mod(pair_act, pair_act, bias, index, nonbatched_bias) | |||
if self.orientation_is_per_column: | |||
pair_act = mnp.swapaxes(pair_act, -2, -3) | |||
return pair_act |
@@ -0,0 +1,304 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""Evoformer module""" | |||
import numpy as np | |||
import mindspore.nn as nn | |||
import mindspore.common.dtype as mstype | |||
import mindspore.numpy as mnp | |||
from mindspore.ops import operations as P | |||
from mindspore.ops import functional as F | |||
from mindspore.common.tensor import Tensor | |||
from mindspore import Parameter | |||
from commons import residue_constants | |||
from commons.utils import dgram_from_positions, batch_make_transform_from_reference, batch_quat_affine, \ | |||
batch_invert_point, batch_rot_to_quat | |||
from module.basic_module import Attention, MSARowAttentionWithPairBias, MSAColumnAttention, MSAColumnGlobalAttention, \ | |||
Transition, OuterProductMean, TriangleMultiplication, TriangleAttention | |||
class EvoformerIteration(nn.Cell): | |||
'''evoformer iteration''' | |||
def __init__(self, config, msa_act_dim, pair_act_dim, is_extra_msa, batch_size, global_config): | |||
super(EvoformerIteration, self).__init__() | |||
self.config = config | |||
self.is_extra_msa = is_extra_msa | |||
if not self.is_extra_msa: | |||
self.msa_row_attention_with_pair_bias = MSARowAttentionWithPairBias( | |||
self.config.msa_row_attention_with_pair_bias, msa_act_dim, pair_act_dim, batch_size, | |||
global_config.evoformer_iteration.msa_row_attention_with_pair_bias.slice_num) | |||
self.attn_mod = MSAColumnAttention(self.config.msa_column_attention, msa_act_dim, batch_size, | |||
global_config.evoformer_iteration.msa_column_attention.slice_num) | |||
self.msa_transition = Transition(self.config.msa_transition, msa_act_dim, batch_size, | |||
global_config.evoformer_iteration.msa_transition.slice_num) | |||
self.outer_product_mean = OuterProductMean(self.config.outer_product_mean, msa_act_dim, pair_act_dim, | |||
batch_size, | |||
global_config.evoformer_iteration.outer_product_mean.slice_num) | |||
self.triangle_attention_starting_node = \ | |||
TriangleAttention(self.config.triangle_attention_starting_node, | |||
pair_act_dim, batch_size, | |||
global_config.evoformer_iteration.triangle_attention_starting_node.slice_num) | |||
self.triangle_attention_ending_node = \ | |||
TriangleAttention(self.config.triangle_attention_ending_node, | |||
pair_act_dim, batch_size, | |||
global_config.evoformer_iteration.triangle_attention_ending_node.slice_num) | |||
self.pair_transition = Transition(self.config.pair_transition, pair_act_dim, batch_size, | |||
global_config.evoformer_iteration.pair_transition.slice_num) | |||
else: | |||
self.msa_row_attention_with_pair_bias = MSARowAttentionWithPairBias( | |||
self.config.msa_row_attention_with_pair_bias, msa_act_dim, pair_act_dim, batch_size, | |||
global_config.extra_msa_stack.msa_row_attention_with_pair_bias.slice_num) | |||
self.attn_mod = \ | |||
MSAColumnGlobalAttention(self.config.msa_column_attention, msa_act_dim, batch_size, | |||
global_config.extra_msa_stack.msa_column_global_attention.slice_num) | |||
self.msa_transition = Transition(self.config.msa_transition, msa_act_dim, batch_size, | |||
global_config.extra_msa_stack.msa_transition.slice_num) | |||
self.outer_product_mean = OuterProductMean(self.config.outer_product_mean, msa_act_dim, pair_act_dim, | |||
batch_size, | |||
global_config.extra_msa_stack.outer_product_mean.slice_num) | |||
self.triangle_attention_starting_node = \ | |||
TriangleAttention(self.config.triangle_attention_starting_node, | |||
pair_act_dim, batch_size, | |||
global_config.extra_msa_stack.triangle_attention_starting_node.slice_num) | |||
self.triangle_attention_ending_node = \ | |||
TriangleAttention(self.config.triangle_attention_ending_node, | |||
pair_act_dim, batch_size, | |||
global_config.extra_msa_stack.triangle_attention_ending_node.slice_num) | |||
self.pair_transition = Transition(self.config.pair_transition, pair_act_dim, batch_size, | |||
global_config.extra_msa_stack.pair_transition.slice_num) | |||
self.triangle_multiplication_outgoing = TriangleMultiplication(self.config.triangle_multiplication_outgoing, | |||
pair_act_dim, batch_size) | |||
self.triangle_multiplication_incoming = TriangleMultiplication(self.config.triangle_multiplication_incoming, | |||
pair_act_dim, batch_size) | |||
def construct(self, msa_act, pair_act, msa_mask, pair_mask, index): | |||
'''construct''' | |||
msa_act = msa_act + self.msa_row_attention_with_pair_bias(msa_act, msa_mask, pair_act, index) | |||
msa_act = msa_act + self.attn_mod(msa_act, msa_mask, index) | |||
msa_act = msa_act + self.msa_transition(msa_act, index) | |||
pair_act = pair_act + self.outer_product_mean(msa_act, msa_mask, index) | |||
pair_act = pair_act + self.triangle_multiplication_outgoing(pair_act, pair_mask, index) | |||
pair_act = pair_act + self.triangle_multiplication_incoming(pair_act, pair_mask, index) | |||
pair_act = pair_act + self.triangle_attention_starting_node(pair_act, pair_mask, index) | |||
pair_act = pair_act + self.triangle_attention_ending_node(pair_act, pair_mask, index) | |||
pair_act = pair_act + self.pair_transition(pair_act, index) | |||
return msa_act.astype(mstype.float16), pair_act.astype(mstype.float16) | |||
class TemplatePairStack(nn.Cell): | |||
'''template pair stack''' | |||
def __init__(self, config, global_config=None): | |||
super(TemplatePairStack, self).__init__() | |||
self.config = config | |||
self.global_config = global_config | |||
self.num_block = self.config.num_block | |||
# self.seq_length = global_config.step_size | |||
self.triangle_attention_starting_node = \ | |||
TriangleAttention(self.config.triangle_attention_starting_node, | |||
64, self.num_block, | |||
global_config.template_pair_stack.triangle_attention_starting_node.slice_num) | |||
self.triangle_attention_ending_node = \ | |||
TriangleAttention(self.config.triangle_attention_ending_node, | |||
64, self.num_block, | |||
global_config.template_pair_stack.triangle_attention_ending_node.slice_num) | |||
# Hard Code | |||
self.pair_transition = Transition(self.config.pair_transition, 64, self.num_block, | |||
global_config.template_pair_stack.pair_transition.slice_num) | |||
self.triangle_multiplication_outgoing = TriangleMultiplication(self.config.triangle_multiplication_outgoing, | |||
layer_norm_dim=64, batch_size=self.num_block) | |||
self.triangle_multiplication_incoming = TriangleMultiplication(self.config.triangle_multiplication_incoming, | |||
layer_norm_dim=64, batch_size=self.num_block) | |||
def construct(self, pair_act, pair_mask, index): | |||
if not self.num_block: | |||
return pair_act | |||
pair_act = pair_act + self.triangle_attention_starting_node(pair_act, pair_mask, index) | |||
pair_act = pair_act + self.triangle_attention_ending_node(pair_act, pair_mask, index) | |||
pair_act = pair_act + self.triangle_multiplication_outgoing(pair_act, pair_mask, index) | |||
pair_act = pair_act + self.triangle_multiplication_incoming(pair_act, pair_mask, index) | |||
pair_act = pair_act + self.pair_transition(pair_act, index) | |||
return pair_act.astype(mstype.float16) | |||
class SingleTemplateEmbedding(nn.Cell): | |||
'''single template embedding''' | |||
def __init__(self, config, global_config=None): | |||
super(SingleTemplateEmbedding, self).__init__() | |||
self.config = config | |||
# self.seq_length = global_config.step_size | |||
self.num_channels = (self.config.template_pair_stack.triangle_attention_ending_node.value_dim) | |||
self.embedding2d = nn.Dense(88, self.num_channels).to_float(mstype.float16) | |||
self.template_pair_stack = TemplatePairStack(self.config.template_pair_stack, global_config) | |||
self.num_bins = self.config.dgram_features.num_bins | |||
self.min_bin = self.config.dgram_features.min_bin | |||
self.max_bin = self.config.dgram_features.max_bin | |||
self.one_hot = nn.OneHot(depth=22, axis=-1) | |||
self.n, self.ca, self.c = [residue_constants.atom_order[a] for a in ('N', 'CA', 'C')] | |||
self.use_template_unit_vector = self.config.use_template_unit_vector | |||
layer_norm_dim = 64 | |||
self.output_layer_norm = nn.LayerNorm([layer_norm_dim,], epsilon=1e-5) | |||
self.idx_num_block = Parameter(Tensor(0, mstype.int32), requires_grad=False) | |||
self.idx_batch_loop = Parameter(Tensor(0, mstype.int32), requires_grad=False) | |||
# self.num_block = Tensor(self.template_pair_stack.num_block, mstype.int32) | |||
# self.batch_block = Tensor(4, mstype.int32) | |||
self.num_block = self.template_pair_stack.num_block | |||
self.batch_block = 4 | |||
self._act = Parameter( | |||
Tensor(np.zeros([global_config.seq_length, global_config.seq_length, 64]).astype(np.float16)), | |||
requires_grad=False) | |||
def construct(self, query_embedding, mask_2d, template_aatype, template_all_atom_masks, template_all_atom_positions, | |||
template_pseudo_beta_mask, template_pseudo_beta): | |||
'''construct''' | |||
num_res = template_aatype[0, ...].shape[0] | |||
template_mask_2d_temp = P.Cast()(template_pseudo_beta_mask[:, :, None] * template_pseudo_beta_mask[:, None, :], | |||
query_embedding.dtype) | |||
template_dgram_temp = dgram_from_positions(template_pseudo_beta, self.num_bins, self.min_bin, self.max_bin) | |||
template_dgram_temp = P.Cast()(template_dgram_temp, query_embedding.dtype) | |||
to_concat_temp = (template_dgram_temp, template_mask_2d_temp[:, :, :, None]) | |||
aatype_temp = self.one_hot(template_aatype) | |||
to_concat_temp = to_concat_temp + (mnp.tile(aatype_temp[:, None, :, :], (1, num_res, 1, 1)), | |||
mnp.tile(aatype_temp[:, :, None, :], (1, 1, num_res, 1))) | |||
rot_temp, trans_temp = batch_make_transform_from_reference(template_all_atom_positions[:, :, self.n], | |||
template_all_atom_positions[:, :, self.ca], | |||
template_all_atom_positions[:, :, self.c]) | |||
_, rotation_tmp, translation_tmp = batch_quat_affine( | |||
batch_rot_to_quat(rot_temp, unstack_inputs=True), translation=trans_temp, rotation=rot_temp, | |||
unstack_inputs=True) | |||
points_tmp = mnp.expand_dims(translation_tmp, axis=-2) | |||
affine_vec_tmp = batch_invert_point(points_tmp, rotation_tmp, translation_tmp, extra_dims=1) | |||
inv_distance_scalar_tmp = P.Rsqrt()(1e-6 + mnp.sum(mnp.square(affine_vec_tmp), axis=1)) | |||
template_mask_tmp = (template_all_atom_masks[:, :, self.n] * | |||
template_all_atom_masks[:, :, self.ca] * | |||
template_all_atom_masks[:, :, self.c]) | |||
template_mask_2d_tmp = template_mask_tmp[:, :, None] * template_mask_tmp[:, None, :] | |||
inv_distance_scalar_tmp = inv_distance_scalar_tmp * template_mask_2d_tmp.astype(inv_distance_scalar_tmp.dtype) | |||
unit_vector_tmp = P.Transpose()((affine_vec_tmp * inv_distance_scalar_tmp[:, None, ...]), (0, 2, 3, 1)) | |||
template_mask_2d_tmp = P.Cast()(template_mask_2d_tmp, query_embedding.dtype) | |||
if not self.use_template_unit_vector: | |||
unit_vector_tmp = mnp.zeros_like(unit_vector_tmp) | |||
to_concat_temp = to_concat_temp + (unit_vector_tmp, template_mask_2d_tmp[..., None],) | |||
act_tmp = mnp.concatenate(to_concat_temp, axis=-1) | |||
act_tmp = act_tmp * template_mask_2d_tmp[..., None] | |||
act_tmp = self.embedding2d(act_tmp) | |||
idx_batch_loop = self.idx_batch_loop | |||
output = [] | |||
idx_batch_loop_int = 0 | |||
self.idx_num_block = 0 | |||
while idx_batch_loop_int < self.batch_block: | |||
self.idx_num_block = 0 | |||
idx_num_block_int = 0 | |||
self._act = P.Gather()(act_tmp, idx_batch_loop, 0) | |||
while idx_num_block_int < self.num_block: | |||
self._act = self.template_pair_stack(self._act, mask_2d, self.idx_num_block) | |||
self.idx_num_block += 1 | |||
idx_num_block_int += 1 | |||
temp_act = P.Reshape()(self._act, ((1,) + P.Shape()(self._act))) | |||
output.append(temp_act) | |||
idx_batch_loop += 1 | |||
idx_batch_loop_int += 1 | |||
act_tmp_loop = P.Concat()(output) | |||
act_tmp = self.output_layer_norm(act_tmp_loop.astype(mstype.float32)) | |||
return act_tmp | |||
class TemplateEmbedding(nn.Cell): | |||
'''template embedding''' | |||
def __init__(self, config, slice_num, global_config=None): | |||
super(TemplateEmbedding, self).__init__() | |||
self.config = config | |||
self.global_config = global_config | |||
self.num_channels = (self.config.template_pair_stack.triangle_attention_ending_node.value_dim) | |||
self.template_embedder = SingleTemplateEmbedding(self.config, self.global_config) | |||
self.template_pointwise_attention = Attention(self.config.attention, q_data_dim=128, m_data_dim=64, | |||
output_dim=128) | |||
self.slice_num = slice_num | |||
if slice_num == 0: | |||
slice_num = 1 | |||
self._flat_query_slice = Parameter( | |||
Tensor(np.zeros((int(global_config.seq_length * global_config.seq_length / slice_num), 1, 128)), | |||
dtype=mstype.float16), requires_grad=False) | |||
self._flat_templates_slice = Parameter( | |||
Tensor(np.zeros((int(global_config.seq_length * global_config.seq_length / slice_num), 4, 64)), | |||
dtype=mstype.float16), requires_grad=False) | |||
def construct(self, query_embedding, template_aatype, template_all_atom_masks, template_all_atom_positions, | |||
template_mask, template_pseudo_beta_mask, template_pseudo_beta, mask_2d): | |||
'''construct''' | |||
num_templates = template_mask.shape[0] | |||
num_channels = self.num_channels | |||
num_res = query_embedding.shape[0] | |||
query_num_channels = query_embedding.shape[-1] | |||
template_mask = P.Cast()(template_mask, query_embedding.dtype) | |||
mask_2d = F.depend(mask_2d, query_embedding) | |||
template_pair_representation = self.template_embedder(query_embedding, mask_2d, template_aatype, | |||
template_all_atom_masks, template_all_atom_positions, | |||
template_pseudo_beta_mask, | |||
template_pseudo_beta) | |||
template_pair_representation = template_pair_representation.astype(mstype.float32) | |||
flat_query = mnp.reshape(query_embedding, [num_res * num_res, 1, query_num_channels]) | |||
flat_templates = mnp.reshape( | |||
mnp.transpose(template_pair_representation.astype(mstype.float16), [1, 2, 0, 3]), | |||
[num_res * num_res, num_templates, num_channels]).astype(mstype.float32) | |||
bias = (1e9 * (template_mask[None, None, None, :] - 1.)) | |||
flat_query, flat_templates, bias = flat_query.astype(mstype.float32), flat_templates.astype( | |||
mstype.float32), bias.astype(mstype.float32) | |||
if self.slice_num: | |||
slice_shape = (self.slice_num, -1) | |||
flat_query_shape = P.Shape()(flat_query) | |||
flat_query = P.Reshape()(flat_query, slice_shape + flat_query_shape[1:]).astype(mstype.float16) | |||
flat_templates_shape = P.Shape()(flat_templates) | |||
flat_templates = P.Reshape()(flat_templates, slice_shape + flat_templates_shape[1:]).astype(mstype.float16) | |||
slice_idx = 0 | |||
embedding_tuple = () | |||
while slice_idx < self.slice_num: | |||
self._flat_query_slice = flat_query[slice_idx] | |||
self._flat_templates_slice = flat_templates[slice_idx] | |||
embedding_slice = self.template_pointwise_attention(self._flat_query_slice, self._flat_templates_slice, | |||
bias, index=None, nonbatched_bias=None) | |||
embedding_slice = P.Reshape()(embedding_slice, ((1,) + P.Shape()(embedding_slice))) | |||
embedding_tuple = embedding_tuple + (embedding_slice,) | |||
slice_idx += 1 | |||
embedding = P.Concat()(embedding_tuple) | |||
embedding = embedding.astype(mstype.float32) | |||
embedding = mnp.reshape(embedding, [num_res, num_res, query_num_channels]) | |||
# No gradients if no templates. | |||
template_mask = template_mask.astype(embedding.dtype) | |||
embedding = embedding * (mnp.sum(template_mask) > 0.).astype(embedding.dtype) | |||
return embedding | |||
embedding = self.template_pointwise_attention(flat_query, flat_templates, bias, index=None, | |||
nonbatched_bias=None) | |||
embedding = embedding.astype(mstype.float32) | |||
embedding = mnp.reshape(embedding, [num_res, num_res, query_num_channels]) | |||
# No gradients if no templates. | |||
template_mask = template_mask.astype(embedding.dtype) | |||
embedding = embedding * (mnp.sum(template_mask) > 0.).astype(embedding.dtype) | |||
return embedding |
@@ -0,0 +1,235 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""AlphaFold Model""" | |||
import numpy as np | |||
import mindspore.nn as nn | |||
import mindspore.common.dtype as mstype | |||
import mindspore.numpy as mnp | |||
from mindspore.common.tensor import Tensor | |||
from mindspore import Parameter | |||
from mindspore.ops import functional as F | |||
from commons import residue_constants | |||
from commons.utils import get_chi_atom_indices, pseudo_beta_fn, dgram_from_positions, atom37_to_torsion_angles | |||
from module.evoformer_module import TemplateEmbedding, EvoformerIteration | |||
from module.structure_module import StructureModule, PredictedLDDTHead | |||
class AlphaFold(nn.Cell): | |||
"""AlphaFold Model""" | |||
def __init__(self, config, global_config): | |||
super(AlphaFold, self).__init__() | |||
self.config = config.model.embeddings_and_evoformer | |||
self.preprocess_1d = nn.Dense(22, self.config.msa_channel).to_float(mstype.float16) | |||
self.preprocess_msa = nn.Dense(49, self.config.msa_channel).to_float(mstype.float16) | |||
self.left_single = nn.Dense(22, self.config.pair_channel).to_float(mstype.float16) | |||
self.right_single = nn.Dense(22, self.config.pair_channel).to_float(mstype.float16) | |||
self.prev_pos_linear = nn.Dense(15, self.config.pair_channel).to_float(mstype.float16) | |||
self.pair_activations = nn.Dense(65, self.config.pair_channel).to_float(mstype.float16) | |||
self.prev_msa_first_row_norm = nn.LayerNorm([256,], epsilon=1e-5) | |||
self.prev_pair_norm = nn.LayerNorm([128,], epsilon=1e-5) | |||
self.one_hot = nn.OneHot(depth=self.config.max_relative_feature * 2 + 1, axis=-1) | |||
self.extra_msa_activations = nn.Dense(25, self.config.extra_msa_channel).to_float(mstype.float16) | |||
self.template_single_embedding = nn.Dense(57, self.config.msa_channel).to_float(mstype.float16) | |||
self.template_projection = nn.Dense(self.config.msa_channel, self.config.msa_channel).to_float(mstype.float16) | |||
self.single_activations = nn.Dense(self.config.msa_channel, self.config.seq_channel).to_float(mstype.float16) | |||
self.relu = nn.ReLU() | |||
self.recycle_pos = self.config.recycle_pos | |||
self.recycle_features = self.config.recycle_features | |||
self.template_enable = self.config.template.enabled | |||
self.max_relative_feature = self.config.max_relative_feature | |||
self.template_enabled = self.config.template.enabled | |||
self.template_embed_torsion_angles = self.config.template.embed_torsion_angles | |||
self.num_bins = self.config.prev_pos.num_bins | |||
self.min_bin = self.config.prev_pos.min_bin | |||
self.max_bin = self.config.prev_pos.max_bin | |||
self.extra_msa_one_hot = nn.OneHot(depth=23, axis=-1) | |||
self.template_aatype_one_hot = nn.OneHot(depth=22, axis=-1) | |||
self.template_embedding = TemplateEmbedding(self.config.template, | |||
global_config.template_embedding.slice_num, | |||
global_config=global_config) | |||
self.extra_msa_stack_iteration = EvoformerIteration(self.config.evoformer, | |||
msa_act_dim=64, | |||
pair_act_dim=128, | |||
is_extra_msa=True, | |||
batch_size=self.config.extra_msa_stack_num_block, | |||
global_config=global_config) | |||
self.evoformer_iteration = EvoformerIteration(self.config.evoformer, | |||
msa_act_dim=256, | |||
pair_act_dim=128, | |||
is_extra_msa=False, | |||
batch_size=self.config.evoformer_num_block, | |||
global_config=global_config) | |||
self.structure_module = StructureModule(config.model.heads.structure_module, | |||
self.config.seq_channel, | |||
self.config.pair_channel, | |||
global_config=global_config) | |||
self.module_lddt = PredictedLDDTHead(config.model.heads.predicted_lddt, | |||
global_config, | |||
self.config.seq_channel) | |||
self._init_tensor(global_config) | |||
def _init_tensor(self, global_config): | |||
"initialization of tensors and parameters" | |||
self.chi_atom_indices = Tensor(get_chi_atom_indices(), mstype.int32) | |||
chi_angles_mask = list(residue_constants.chi_angles_mask) | |||
chi_angles_mask.append([0.0, 0.0, 0.0, 0.0]) | |||
self.chi_angles_mask = Tensor(chi_angles_mask, mstype.float32) | |||
self.mirror_psi_mask = Tensor(np.asarray([1., 1., -1., 1., 1., 1., 1.])[None, None, :, None], mstype.float32) | |||
self.chi_pi_periodic = Tensor(residue_constants.chi_pi_periodic, mstype.float32) | |||
indices0 = np.arange(4).reshape((-1, 1, 1, 1, 1)).astype("int64") # 4 batch | |||
indices0 = indices0.repeat(global_config.seq_length, axis=1) # seq_length sequence length | |||
indices0 = indices0.repeat(4, axis=2) # 4 chis | |||
self.indices0 = Tensor(indices0.repeat(4, axis=3)) # 4 atoms | |||
indices1 = np.arange(global_config.seq_length).reshape((1, -1, 1, 1, 1)).astype("int64") | |||
indices1 = indices1.repeat(4, axis=0) | |||
indices1 = indices1.repeat(4, axis=2) | |||
self.indices1 = Tensor(indices1.repeat(4, axis=3)) | |||
self.idx_extra_msa_stack = Parameter(Tensor(0, mstype.int32), requires_grad=False) | |||
self.extra_msa_stack_num_block = self.config.extra_msa_stack_num_block | |||
self.idx_evoformer_block = Parameter(Tensor(0, mstype.int32), requires_grad=False) | |||
self.evoformer_num_block = Tensor(self.config.evoformer_num_block, mstype.int32) | |||
def construct(self, target_feat, msa_feat, msa_mask, seq_mask, aatype, | |||
template_aatype, template_all_atom_masks, template_all_atom_positions, | |||
template_mask, template_pseudo_beta_mask, template_pseudo_beta, | |||
_, extra_msa, extra_has_deletion, | |||
extra_deletion_value, extra_msa_mask, | |||
atom14_atom_exists, atom37_atom_exists, residue_index, | |||
prev_pos, prev_msa_first_row, prev_pair): | |||
"""construct""" | |||
preprocess_1d = self.preprocess_1d(target_feat) | |||
preprocess_msa = self.preprocess_msa(msa_feat) | |||
msa_activations1 = mnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa | |||
left_single = self.left_single(target_feat) | |||
right_single = self.right_single(target_feat) | |||
pair_activations = left_single[:, None] + right_single[None] | |||
mask_2d = seq_mask[:, None] * seq_mask[None, :] | |||
if self.recycle_pos: | |||
prev_pseudo_beta = pseudo_beta_fn(aatype, prev_pos, None) | |||
dgram = dgram_from_positions(prev_pseudo_beta, self.num_bins, self.min_bin, self.max_bin) | |||
pair_activations += self.prev_pos_linear(dgram) | |||
# return pair_activations, msa_activations1 | |||
prev_msa_first_row = F.depend(prev_msa_first_row, pair_activations) | |||
if self.recycle_features: | |||
prev_msa_first_row = self.prev_msa_first_row_norm(prev_msa_first_row) | |||
msa_activations1 = mnp.concatenate( | |||
(mnp.expand_dims(prev_msa_first_row + msa_activations1[0, ...], 0), | |||
msa_activations1[1:, ...]), 0) | |||
pair_activations += self.prev_pair_norm(prev_pair.astype(mstype.float32)) | |||
if self.max_relative_feature: | |||
offset = residue_index[:, None] - residue_index[None, :] | |||
rel_pos = self.one_hot(mnp.clip(offset + self.max_relative_feature, 0, 2 * self.max_relative_feature)) | |||
pair_activations += self.pair_activations(rel_pos) | |||
template_pair_representation = 0 | |||
if self.template_enable: | |||
template_pair_representation = self.template_embedding(pair_activations, template_aatype, | |||
template_all_atom_masks, template_all_atom_positions, | |||
template_mask, template_pseudo_beta_mask, | |||
template_pseudo_beta, mask_2d) | |||
pair_activations += template_pair_representation | |||
msa_1hot = self.extra_msa_one_hot(extra_msa) | |||
extra_msa_feat = mnp.concatenate((msa_1hot, extra_has_deletion[..., None], extra_deletion_value[..., None]), | |||
axis=-1) | |||
extra_msa_activations = self.extra_msa_activations(extra_msa_feat) | |||
msa_act = extra_msa_activations | |||
pair_act = pair_activations | |||
msa_act = msa_act.astype(mstype.float32) | |||
pair_act = pair_act.astype(mstype.float32) | |||
extra_msa_mask = extra_msa_mask.astype(mstype.float32) | |||
mask_2d = mask_2d.astype(mstype.float32) | |||
self.idx_extra_msa_stack = 0 | |||
idx_extra_msa_stack_int = 0 | |||
while idx_extra_msa_stack_int < self.extra_msa_stack_num_block: | |||
msa_act, pair_act = \ | |||
self.extra_msa_stack_iteration(msa_act, pair_act, extra_msa_mask, mask_2d, self.idx_extra_msa_stack) | |||
self.idx_extra_msa_stack += 1 | |||
idx_extra_msa_stack_int += 1 | |||
msa_act = F.depend(msa_act, self.idx_extra_msa_stack) | |||
pair_act = F.depend(pair_act, self.idx_extra_msa_stack) | |||
msa_activations2 = None | |||
if self.template_enabled and self.template_embed_torsion_angles: | |||
num_templ, num_res = template_aatype.shape | |||
aatype_one_hot = self.template_aatype_one_hot(template_aatype) | |||
torsion_angles_sin_cos, alt_torsion_angles_sin_cos, torsion_angles_mask = atom37_to_torsion_angles( | |||
template_aatype, template_all_atom_positions, template_all_atom_masks, self.chi_atom_indices, | |||
self.chi_angles_mask, self.mirror_psi_mask, self.chi_pi_periodic, self.indices0, self.indices1) | |||
template_features = mnp.concatenate([aatype_one_hot, | |||
mnp.reshape(torsion_angles_sin_cos, [num_templ, num_res, 14]), | |||
mnp.reshape(alt_torsion_angles_sin_cos, [num_templ, num_res, 14]), | |||
torsion_angles_mask], axis=-1) | |||
template_activations = self.template_single_embedding(template_features) | |||
template_activations = self.relu(template_activations.astype(mstype.float32)) | |||
template_activations = self.template_projection(template_activations) | |||
msa_activations2 = mnp.concatenate([msa_activations1, template_activations], axis=0) | |||
torsion_angle_mask = torsion_angles_mask[:, :, 2] | |||
torsion_angle_mask = torsion_angle_mask.astype(msa_mask.dtype) | |||
msa_mask = mnp.concatenate([msa_mask, torsion_angle_mask], axis=0) | |||
msa_activations2 = msa_activations2.astype(mstype.float16) | |||
pair_activations = pair_act.astype(mstype.float16) | |||
msa_mask = msa_mask.astype(mstype.float16) | |||
mask_2d = mask_2d.astype(mstype.float16) | |||
# return msa_activations2, pair_activations, msa_mask, mask_2d | |||
self.idx_evoformer_block = self.idx_evoformer_block * 0 | |||
while self.idx_evoformer_block < self.evoformer_num_block: | |||
msa_activations2, pair_activations = \ | |||
self.evoformer_iteration(msa_activations2, | |||
pair_activations, | |||
msa_mask, | |||
mask_2d, | |||
self.idx_evoformer_block) | |||
self.idx_evoformer_block += 1 | |||
single_activations = self.single_activations(msa_activations2[0]) | |||
msa_first_row = msa_activations2[0] | |||
# return single_activations, msa, msa_first_row | |||
final_atom_positions, final_atom_mask, rp_structure_module = \ | |||
self.structure_module(single_activations, | |||
pair_activations, | |||
seq_mask, | |||
aatype, | |||
atom14_atom_exists, | |||
atom37_atom_exists) | |||
predicted_lddt_logits = self.module_lddt(rp_structure_module) | |||
prev_pos = final_atom_positions.astype(mstype.float16) | |||
prev_msa_first_row = msa_first_row.astype(mstype.float16) | |||
prev_pair = pair_activations.astype(mstype.float16) | |||
final_atom_positions = final_atom_positions.astype(mstype.float16) | |||
final_atom_mask = final_atom_mask.astype(mstype.float16) | |||
return final_atom_positions, final_atom_mask, predicted_lddt_logits, prev_pos, prev_msa_first_row, prev_pair |
@@ -0,0 +1,443 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""structure module""" | |||
import numpy as np | |||
import mindspore.ops as ops | |||
import mindspore.common.dtype as mstype | |||
import mindspore.numpy as mnp | |||
from mindspore import Parameter, ms_function, Tensor | |||
from mindspore import nn | |||
from commons import residue_constants | |||
from commons.utils import generate_new_affine, to_tensor, from_tensor, vecs_to_tensor, atom14_to_atom37, \ | |||
get_exp_atom_pos, get_exp_frames, pre_compose, scale_translation, to_tensor_new, l2_normalize, \ | |||
torsion_angles_to_frames, frames_and_literature_positions_to_atom14_pos, apply_to_point, _invert_point | |||
class InvariantPointAttention(nn.Cell): | |||
"""Invariant Point attention module.""" | |||
def __init__(self, config, global_config, pair_dim): | |||
"""Initialize. | |||
Args: | |||
config: Structure Module Config | |||
global_config: Global Config of Model. | |||
pair_dim: pair representation dimension. | |||
""" | |||
super().__init__() | |||
self._dist_epsilon = 1e-8 | |||
self.config = config | |||
self.num_head = config.num_head | |||
self.num_scalar_qk = config.num_scalar_qk | |||
self.num_scalar_v = config.num_scalar_v | |||
self.num_point_v = config.num_point_v | |||
self.num_point_qk = config.num_point_qk | |||
self.num_channel = config.num_channel | |||
self.projection_num = self.num_head * self.num_scalar_v + self.num_head * self.num_point_v * 4 +\ | |||
self.num_head * pair_dim | |||
self.global_config = global_config | |||
self.q_scalar = nn.Dense(config.num_channel, self.num_head*self.num_scalar_qk).to_float(mstype.float16) | |||
self.kv_scalar = nn.Dense(config.num_channel, self.num_head*(self.num_scalar_qk + self.num_scalar_v) | |||
).to_float(mstype.float16) | |||
self.q_point_local = nn.Dense(config.num_channel, self.num_head * 3 * self.num_point_qk | |||
).to_float(mstype.float16) | |||
self.kv_point_local = nn.Dense(config.num_channel, self.num_head * 3 * (self.num_point_qk + self.num_point_v) | |||
).to_float(mstype.float16) | |||
self.soft_max = nn.Softmax() | |||
self.soft_plus = ops.Softplus() | |||
self.trainable_point_weights = Parameter(Tensor(np.ones((12,)), mstype.float32), name="trainable_point_weights") | |||
self.attention_2d = nn.Dense(pair_dim, self.num_head).to_float(mstype.float16) | |||
self.output_projection = nn.Dense(self.projection_num, self.num_channel, weight_init='zeros' | |||
).to_float(mstype.float16) | |||
self.scalar_weights = np.sqrt(1.0 / (3 * 16)) | |||
self.point_weights = np.sqrt(1.0 / (3 * 18)) | |||
self.attention_2d_weights = np.sqrt(1.0 / 3) | |||
def construct(self, inputs_1d, inputs_2d, mask, rotation, translation): | |||
"""Compute geometry-aware attention. | |||
Args: | |||
inputs_1d: (N, C) 1D input embedding that is the basis for the | |||
scalar queries. | |||
inputs_2d: (N, M, C') 2D input embedding, used for biases and values. | |||
mask: (N, 1) mask to indicate which elements of inputs_1d participate | |||
in the attention. | |||
rotation: describe the orientation of every element in inputs_1d | |||
translation: describe the position of every element in inputs_1d | |||
Returns: | |||
Transformation of the input embedding. | |||
""" | |||
num_residues, _ = inputs_1d.shape | |||
# Improve readability by removing a large number of 'self's. | |||
num_head = self.num_head | |||
num_scalar_qk = self.num_scalar_qk | |||
num_point_qk = self.num_point_qk | |||
num_scalar_v = self.num_scalar_v | |||
num_point_v = self.num_point_v | |||
# Construct scalar queries of shape: | |||
q_scalar = self.q_scalar(inputs_1d) | |||
q_scalar = mnp.reshape(q_scalar, [num_residues, num_head, num_scalar_qk]) | |||
# Construct scalar keys/values of shape: | |||
# [num_target_residues, num_head, num_points] | |||
kv_scalar = self.kv_scalar(inputs_1d) | |||
kv_scalar = mnp.reshape(kv_scalar, [num_residues, num_head, num_scalar_v + num_scalar_qk]) | |||
k_scalar, v_scalar = mnp.split(kv_scalar, [num_scalar_qk], axis=-1) | |||
# Construct query points of shape: | |||
# [num_residues, num_head, num_point_qk] | |||
# First construct query points in local frame. | |||
q_point_local = self.q_point_local(inputs_1d) | |||
q_point_local = mnp.stack(mnp.split(q_point_local, 3, axis=-1), axis=0) | |||
# Project query points into global frame. | |||
q_point_global = apply_to_point(rotation, translation, q_point_local) | |||
# Reshape query point for later use. | |||
q_point0 = mnp.reshape(q_point_global[0], (num_residues, num_head, num_point_qk)) | |||
q_point1 = mnp.reshape(q_point_global[1], (num_residues, num_head, num_point_qk)) | |||
q_point2 = mnp.reshape(q_point_global[2], (num_residues, num_head, num_point_qk)) | |||
# Construct key and value points. | |||
# Key points have shape [num_residues, num_head, num_point_qk] | |||
# Value points have shape [num_residues, num_head, num_point_v] | |||
# Construct key and value points in local frame. | |||
kv_point_local = self.kv_point_local(inputs_1d) | |||
kv_point_local = mnp.split(kv_point_local, 3, axis=-1) | |||
# Project key and value points into global frame. | |||
kv_point_global = apply_to_point(rotation, translation, kv_point_local) | |||
kv_point_global0 = mnp.reshape(kv_point_global[0], (num_residues, num_head, (num_point_qk + num_point_v))) | |||
kv_point_global1 = mnp.reshape(kv_point_global[1], (num_residues, num_head, (num_point_qk + num_point_v))) | |||
kv_point_global2 = mnp.reshape(kv_point_global[2], (num_residues, num_head, (num_point_qk + num_point_v))) | |||
# Split key and value points. | |||
k_point0, v_point0 = mnp.split(kv_point_global0, [num_point_qk,], axis=-1) | |||
k_point1, v_point1 = mnp.split(kv_point_global1, [num_point_qk,], axis=-1) | |||
k_point2, v_point2 = mnp.split(kv_point_global2, [num_point_qk,], axis=-1) | |||
trainable_point_weights = self.soft_plus(self.trainable_point_weights) | |||
point_weights = self.point_weights * mnp.expand_dims(trainable_point_weights, axis=1) | |||
v_point = [mnp.swapaxes(v_point0, -2, -3), mnp.swapaxes(v_point1, -2, -3), mnp.swapaxes(v_point2, -2, -3)] | |||
q_point = [mnp.swapaxes(q_point0, -2, -3), mnp.swapaxes(q_point1, -2, -3), mnp.swapaxes(q_point2, -2, -3)] | |||
k_point = [mnp.swapaxes(k_point0, -2, -3), mnp.swapaxes(k_point1, -2, -3), mnp.swapaxes(k_point2, -2, -3)] | |||
dist2 = mnp.square(q_point[0][:, :, None, :] - k_point[0][:, None, :, :]) + \ | |||
mnp.square(q_point[1][:, :, None, :] - k_point[1][:, None, :, :]) + \ | |||
mnp.square(q_point[2][:, :, None, :] - k_point[2][:, None, :, :]) | |||
attn_qk_point = -0.5 * mnp.sum( | |||
point_weights[:, None, None, :] * dist2, axis=-1) | |||
v = mnp.swapaxes(v_scalar, -2, -3) | |||
q = mnp.swapaxes(self.scalar_weights * q_scalar, -2, -3) | |||
k = mnp.swapaxes(k_scalar, -2, -3) | |||
attn_qk_scalar = ops.matmul(q, mnp.swapaxes(k, -2, -1)) | |||
attn_logits = attn_qk_scalar + attn_qk_point | |||
attention_2d = self.attention_2d(inputs_2d) | |||
attention_2d = mnp.transpose(attention_2d, [2, 0, 1]) | |||
attention_2d = self.attention_2d_weights * attention_2d | |||
attn_logits += attention_2d | |||
mask_2d = mask * mnp.swapaxes(mask, -1, -2) | |||
attn_logits -= 1e5 * (1. - mask_2d) | |||
# [num_head, num_query_residues, num_target_residues] | |||
attn = self.soft_max(attn_logits) | |||
# [num_head, num_query_residues, num_head * num_scalar_v] | |||
result_scalar = ops.matmul(attn, v) | |||
result_point_global = [mnp.swapaxes(mnp.sum(attn[:, :, :, None] * v_point[0][:, None, :, :], axis=-2), -2, -3), | |||
mnp.swapaxes(mnp.sum(attn[:, :, :, None] * v_point[1][:, None, :, :], axis=-2), -2, -3), | |||
mnp.swapaxes(mnp.sum(attn[:, :, :, None] * v_point[2][:, None, :, :], axis=-2), -2, -3) | |||
] | |||
result_point_global = [mnp.reshape(result_point_global[0], [num_residues, num_head * num_point_v]), | |||
mnp.reshape(result_point_global[1], [num_residues, num_head * num_point_v]), | |||
mnp.reshape(result_point_global[2], [num_residues, num_head * num_point_v])] | |||
result_scalar = mnp.swapaxes(result_scalar, -2, -3) | |||
result_scalar = mnp.reshape(result_scalar, [num_residues, num_head * num_scalar_v]) | |||
result_point_local = _invert_point(result_point_global, rotation, translation) | |||
output_feature1 = result_scalar | |||
output_feature20 = result_point_local[0] | |||
output_feature21 = result_point_local[1] | |||
output_feature22 = result_point_local[2] | |||
output_feature3 = mnp.sqrt(self._dist_epsilon + | |||
mnp.square(result_point_local[0]) + | |||
mnp.square(result_point_local[1]) + | |||
mnp.square(result_point_local[2])) | |||
result_attention_over_2d = ops.matmul(mnp.swapaxes(attn, 0, 1), inputs_2d) | |||
num_out = num_head * result_attention_over_2d.shape[-1] | |||
output_feature4 = mnp.reshape(result_attention_over_2d, [num_residues, num_out]) | |||
final_act = mnp.concatenate([output_feature1, output_feature20, output_feature21, | |||
output_feature22, output_feature3, output_feature4], axis=-1) | |||
final_result = self.output_projection(final_act) | |||
return final_result | |||
class MultiRigidSidechain(nn.Cell): | |||
"""Class to make side chain atoms.""" | |||
def __init__(self, config, global_config, single_repr_dim): | |||
super().__init__() | |||
self.config = config | |||
self.global_config = global_config | |||
self.input_projection = nn.Dense(single_repr_dim, config.num_channel, weight_init='normal' | |||
).to_float(mstype.float16) | |||
self.input_projection_1 = nn.Dense(single_repr_dim, config.num_channel, weight_init='normal' | |||
).to_float(mstype.float16) | |||
self.relu = nn.ReLU() | |||
self.resblock1 = nn.Dense(config.num_channel, config.num_channel, weight_init='normal').to_float(mstype.float16) | |||
self.resblock2 = nn.Dense(config.num_channel, config.num_channel, weight_init='zeros').to_float(mstype.float16) | |||
self.resblock1_1 = nn.Dense(config.num_channel, config.num_channel, weight_init='normal' | |||
).to_float(mstype.float16) | |||
self.resblock2_1 = nn.Dense(config.num_channel, config.num_channel, weight_init='zeros' | |||
).to_float(mstype.float16) | |||
self.unnormalized_angles = nn.Dense(config.num_channel, 14, weight_init='normal').to_float(mstype.float16) | |||
self.print = ops.Print() | |||
self.restype_atom14_to_rigid_group = Tensor(residue_constants.restype_atom14_to_rigid_group) | |||
self.restype_atom14_rigid_group_positions = Tensor(residue_constants.restype_atom14_rigid_group_positions) | |||
self.restype_atom14_mask = Tensor(residue_constants.restype_atom14_mask) | |||
self.restype_rigid_group_default_frame = Tensor(residue_constants.restype_rigid_group_default_frame) | |||
def construct(self, rotation, translation, act, initial_act, aatype): | |||
"""Predict side chains using rotation and translation representations. | |||
Args: | |||
rotation: The rotation matrices. | |||
translation: A translation matrices. | |||
act: updated pair activations from structure module | |||
initial_act: initial act representations (input of structure module) | |||
aatype: Amino acid type representations | |||
Returns: | |||
angles, positions and new frames | |||
""" | |||
act1 = self.input_projection(self.relu(act.astype(mstype.float32))) | |||
init_act1 = self.input_projection_1(self.relu(initial_act.astype(mstype.float32))) | |||
# Sum the activation list (equivalent to concat then Linear). | |||
act = act1 + init_act1 | |||
# Mapping with some residual blocks. | |||
# for _ in range(self.config.num_residual_block): | |||
# resblock1 | |||
old_act = act | |||
act = self.resblock1(self.relu(act.astype(mstype.float32))) | |||
act = self.resblock2(self.relu(act.astype(mstype.float32))) | |||
act += old_act | |||
# resblock2 | |||
old_act = act | |||
act = self.resblock1_1(self.relu(act.astype(mstype.float32))) | |||
act = self.resblock2_1(self.relu(act.astype(mstype.float32))) | |||
act += old_act | |||
# Map activations to torsion angles. Shape: (num_res, 14). | |||
num_res = act.shape[0] | |||
unnormalized_angles = self.unnormalized_angles(self.relu(act.astype(mstype.float32))) | |||
unnormalized_angles = mnp.reshape(unnormalized_angles, [num_res, 7, 2]) | |||
angles = l2_normalize(unnormalized_angles, axis=-1) | |||
backb_to_global = [rotation[0][0], rotation[0][1], rotation[0][2], | |||
rotation[1][0], rotation[1][1], rotation[1][2], | |||
rotation[2][0], rotation[2][1], rotation[2][2], | |||
translation[0], translation[1], translation[2]] | |||
all_frames_to_global = torsion_angles_to_frames(aatype, backb_to_global, angles, | |||
self.restype_rigid_group_default_frame) | |||
pred_positions = frames_and_literature_positions_to_atom14_pos(aatype, all_frames_to_global, | |||
self.restype_atom14_to_rigid_group, | |||
self.restype_atom14_rigid_group_positions, | |||
self.restype_atom14_mask) | |||
atom_pos = pred_positions | |||
frames = all_frames_to_global | |||
return angles, unnormalized_angles, atom_pos, frames | |||
class FoldIteration(nn.Cell): | |||
"""A single iteration of the main structure module loop.""" | |||
def __init__(self, config, global_config, pair_dim, single_repr_dim): | |||
super().__init__() | |||
self.config = config | |||
self.global_config = global_config | |||
self.drop_out = nn.Dropout(keep_prob=0.9) | |||
self.attention_layer_norm = nn.LayerNorm([config.num_channel,], epsilon=1e-5) | |||
self.transition_layer_norm = nn.LayerNorm([config.num_channel,], epsilon=1e-5) | |||
self.transition = nn.Dense(config.num_channel, config.num_channel, weight_init='normal' | |||
).to_float(mstype.float16) | |||
self.transition_1 = nn.Dense(config.num_channel, config.num_channel, weight_init='normal' | |||
).to_float(mstype.float16) | |||
self.transition_2 = nn.Dense(config.num_channel, config.num_channel, weight_init='normal' | |||
).to_float(mstype.float16) | |||
self.relu = nn.ReLU() | |||
self.affine_update = nn.Dense(config.num_channel, 6, weight_init='zeros').to_float(mstype.float16) | |||
self.attention_module = InvariantPointAttention(self.config, self.global_config, pair_dim) | |||
self.mu_side_chain = MultiRigidSidechain(config.sidechain, global_config, single_repr_dim) | |||
self.print = ops.Print() | |||
def construct(self, act, static_feat_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype): | |||
'''constuct''' | |||
# Attention | |||
attn = self.attention_module(act, static_feat_2d, sequence_mask, rotation, translation) | |||
act += attn | |||
act = self.drop_out(act) | |||
act = self.attention_layer_norm(act.astype(mstype.float32)) | |||
# Transition | |||
input_act = act | |||
act = self.transition(act) | |||
act = self.relu(act.astype(mstype.float32)) | |||
act = self.transition_1(act) | |||
act = self.relu(act.astype(mstype.float32)) | |||
act = self.transition_2(act) | |||
act += input_act | |||
act = self.drop_out(act) | |||
act = self.transition_layer_norm(act.astype(mstype.float32)) | |||
# This block corresponds to | |||
# Jumper et al. (2021) Alg. 23 "Backbone update" | |||
# Affine update | |||
affine_update = self.affine_update(act) | |||
quaternion, rotation, translation = pre_compose(quaternion, rotation, translation, affine_update) | |||
_, rotation1, translation1 = scale_translation(quaternion, translation, rotation, 10.0) | |||
angles_sin_cos, unnormalized_angles_sin_cos, atom_pos, frames =\ | |||
self.mu_side_chain(rotation1, translation1, act, initial_act, aatype) | |||
affine_output = to_tensor_new(quaternion, translation) | |||
return act, quaternion, translation, rotation, affine_output, angles_sin_cos, unnormalized_angles_sin_cos, \ | |||
atom_pos, frames | |||
class StructureModule(nn.Cell): | |||
"""StructureModule as a network head.""" | |||
def __init__(self, config, single_repr_dim, pair_dim, global_config=None, compute_loss=True): | |||
super(StructureModule, self).__init__() | |||
self.config = config | |||
self.global_config = global_config | |||
self.compute_loss = compute_loss | |||
self.fold_iteration = FoldIteration(self.config, global_config, pair_dim, single_repr_dim) | |||
self.single_layer_norm = nn.LayerNorm([single_repr_dim,], epsilon=1e-5) | |||
self.initial_projection = nn.Dense(single_repr_dim, self.config.num_channel).to_float(mstype.float16) | |||
self.pair_layer_norm = nn.LayerNorm([pair_dim,], epsilon=1e-5) | |||
self.num_layer = config.num_layer | |||
self.indice0 = Tensor( | |||
np.arange(global_config.seq_length).reshape((-1, 1, 1)).repeat(37, axis=1).astype("int32")) | |||
@ms_function | |||
def construct(self, single, pair, seq_mask, aatype, residx_atom37_to_atom14=None, atom37_atom_exists=None): | |||
'''construct''' | |||
sequence_mask = seq_mask[:, None] | |||
act = self.single_layer_norm(single.astype(mstype.float32)) | |||
initial_act = act | |||
act = self.initial_projection(act) | |||
quaternion, rotation, translation = generate_new_affine(sequence_mask) | |||
aff_to_tensor = to_tensor(quaternion, mnp.transpose(translation)) | |||
act_2d = self.pair_layer_norm(pair.astype(mstype.float32)) | |||
# folder iteration | |||
quaternion, rotation, translation = from_tensor(aff_to_tensor) | |||
act_new, atom_pos, _, _, _, _ =\ | |||
self.iteration_operation(act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype) | |||
atom14_pred_positions = vecs_to_tensor(atom_pos)[-1] | |||
atom37_pred_positions = atom14_to_atom37(atom14_pred_positions, | |||
residx_atom37_to_atom14, | |||
atom37_atom_exists, | |||
self.indice0) | |||
final_atom_positions = atom37_pred_positions | |||
final_atom_mask = atom37_atom_exists | |||
rp_structure_module = act_new | |||
return final_atom_positions, final_atom_mask, rp_structure_module | |||
def iteration_operation(self, act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act, | |||
aatype): | |||
'''iteration operation''' | |||
affine_init = () | |||
angles_sin_cos_init = () | |||
um_angles_sin_cos_init = () | |||
atom_pos = () | |||
frames = () | |||
for _ in range(self.num_layer): | |||
act, quaternion, translation, rotation, affine_output, angles_sin_cos, unnormalized_angles_sin_cos, \ | |||
atom_pos, frames = \ | |||
self.fold_iteration(act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype) | |||
affine_init = affine_init + (affine_output[None, ...],) | |||
angles_sin_cos_init = angles_sin_cos_init + (angles_sin_cos[None, ...],) | |||
um_angles_sin_cos_init = um_angles_sin_cos_init + (unnormalized_angles_sin_cos[None, ...],) | |||
atom_pos = get_exp_atom_pos(atom_pos) | |||
frames = get_exp_frames(frames) | |||
affine_output_new = mnp.concatenate(affine_init, axis=0) | |||
angles_sin_cos_new = mnp.concatenate(angles_sin_cos_init, axis=0) | |||
um_angles_sin_cos_new = mnp.concatenate(um_angles_sin_cos_init, axis=0) | |||
return act, atom_pos, affine_output_new, angles_sin_cos_new, um_angles_sin_cos_new, frames | |||
class PredictedLDDTHead(nn.Cell): | |||
"""Head to predict the per-residue LDDT to be used as a confidence measure.""" | |||
def __init__(self, config, global_config, seq_channel): | |||
super().__init__() | |||
self.config = config | |||
self.global_config = global_config | |||
self.input_layer_norm = nn.LayerNorm([seq_channel,], epsilon=1e-5) | |||
self.act_0 = nn.Dense(seq_channel, self.config.num_channels, weight_init='zeros' | |||
).to_float(mstype.float16) | |||
self.act_1 = nn.Dense(self.config.num_channels, self.config.num_channels, weight_init='zeros' | |||
).to_float(mstype.float16) | |||
self.logits = nn.Dense(self.config.num_channels, self.config.num_bins, weight_init='zeros' | |||
).to_float(mstype.float16) | |||
self.relu = nn.ReLU() | |||
def construct(self, rp_structure_module): | |||
"""Builds ExperimentallyResolvedHead module.""" | |||
act = rp_structure_module | |||
act = self.input_layer_norm(act.astype(mstype.float32)) | |||
act = self.act_0(act) | |||
act = self.relu(act.astype(mstype.float32)) | |||
act = self.act_1(act) | |||
act = self.relu(act.astype(mstype.float32)) | |||
logits = self.logits(act) | |||
return logits |
@@ -0,0 +1,5 @@ | |||
absl-py==0.13.0 | |||
biopython=1.79 | |||
ml-collections==0.1.0 | |||
scipy==1.7.0 | |||
numpy==1.19.5 |
@@ -0,0 +1,34 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""config for serving mode""" | |||
import ml_collections | |||
config = ml_collections.ConfigDict({ | |||
"seq_length": 256, | |||
"device_id": 0, | |||
"port": 5500, | |||
"ckpt_path": "/CHECKPOINT_PATH", | |||
"input_fasta_path": "INPUT_FASTA_PATH", | |||
"msa_result_path": "MSA_RESULT_PATH", | |||
"database_dir": "DATABASE_DIR", | |||
"database_envdb_dir": "DATABASE_ENVDB_DIR", | |||
"hhsearch_binary_path": "HHSEARCH_BINARY_PATH", | |||
"pdb70_database_path": 'PDB&)_DATABASE_PATH', | |||
"template_mmcif_dir": 'TEMPLATE_MMCIF_DIR', | |||
"max_template_date": "MAX_TEMPLATE_DATE", | |||
"kalign_binary_path": 'KALIGN_BINARY_PATH', | |||
"obsolete_pdbs_path": 'OBSOLETE_PDBS_PATH', | |||
}) |
@@ -0,0 +1,104 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""serving config for mindspore serving""" | |||
import time | |||
import os | |||
import json | |||
import numpy as np | |||
from mindspore_serving.server import register | |||
import mindspore.context as context | |||
from mindspore.common.tensor import Tensor | |||
from mindspore import load_checkpoint | |||
from data.feature.feature_extraction import process_features | |||
from data.tools.data_process import data_process | |||
from commons.utils import compute_confidence | |||
from commons.generate_pdb import to_pdb, from_prediction | |||
from model import AlphaFold | |||
from config import config, global_config | |||
from fold_service.config import config as serving_config | |||
context.set_context(mode=context.GRAPH_MODE, | |||
device_target="Ascend", | |||
variable_memory_max_size="31GB", | |||
device_id=serving_config.device_id, | |||
save_graphs=False) | |||
model_name = "model_1" | |||
model_config = config.model_config(model_name) | |||
num_recycle = model_config.model.num_recycle | |||
global_config = global_config.global_config(serving_config.seq_length) | |||
extra_msa_length = global_config.extra_msa_length | |||
fold_net = AlphaFold(model_config, global_config) | |||
load_checkpoint(serving_config.ckpt_path, fold_net) | |||
def fold_model(input_fasta_path): | |||
"""defining fold model""" | |||
seq_files = os.listdir(input_fasta_path) | |||
for seq_file in seq_files: | |||
print(seq_file) | |||
t1 = time.time() | |||
seq_name = seq_file.split('.')[0] | |||
input_features = data_process(seq_name, serving_config) | |||
tensors, aatype, residue_index, ori_res_length = process_features( | |||
raw_features=input_features, config=model_config, global_config=global_config) | |||
prev_pos = Tensor(np.zeros([global_config.seq_length, 37, 3]).astype(np.float16)) | |||
prev_msa_first_row = Tensor(np.zeros([global_config.seq_length, 256]).astype(np.float16)) | |||
prev_pair = Tensor(np.zeros([global_config.seq_length, global_config.seq_length, 128]).astype(np.float16)) | |||
t2 = time.time() | |||
for i in range(num_recycle+1): | |||
tensors_i = [tensor[i] for tensor in tensors] | |||
input_feats = [Tensor(tensor) for tensor in tensors_i] | |||
final_atom_positions, final_atom_mask, predicted_lddt_logits,\ | |||
prev_pos, prev_msa_first_row, prev_pair = fold_net(*input_feats, | |||
prev_pos, | |||
prev_msa_first_row, | |||
prev_pair) | |||
t3 = time.time() | |||
final_atom_positions = final_atom_positions.asnumpy()[:ori_res_length] | |||
final_atom_mask = final_atom_mask.asnumpy()[:ori_res_length] | |||
predicted_lddt_logits = predicted_lddt_logits.asnumpy()[:ori_res_length] | |||
confidence = compute_confidence(predicted_lddt_logits) | |||
unrelaxed_protein = from_prediction(final_atom_mask, aatype[0], final_atom_positions, residue_index[0]) | |||
pdb_file = to_pdb(unrelaxed_protein) | |||
seq_length = aatype.shape[-1] | |||
os.makedirs(f'./result/seq_{seq_name}_{seq_length}', exist_ok=True) | |||
with open(os.path.join(f'./result/seq_{seq_name}_{seq_length}/', f'unrelaxed_model_{seq_name}.pdb'), 'w') as f: | |||
f.write(pdb_file) | |||
t4 = time.time() | |||
timings = {"pre_process_time": round(t2 - t1, 2), | |||
"model_time": round(t3 - t2, 2), | |||
"pos_process_time": round(t4 - t3, 2), | |||
"all_time": round(t4 - t1, 2), | |||
"confidence": confidence} | |||
print(timings) | |||
with open(f'./result/seq_{seq_name}_{seq_length}/timings', 'w') as f: | |||
f.write(json.dumps(timings)) | |||
return True | |||
@register.register_method(output_names=["res"]) | |||
def folding(input_fasta_path): | |||
res = register.add_stage(fold_model, input_fasta_path, outputs_count=1) | |||
return res |
@@ -0,0 +1,32 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""client script for serving mode""" | |||
import time | |||
from mindspore_serving.client import Client | |||
from fold_service.config import config as serving_config | |||
if __name__ == "__main__": | |||
client = Client("127.0.0.1:" + str(serving_config.port), "fold_service", "folding") | |||
instances = [{"input_fasta_path": serving_config.input_fasta_path}] | |||
print("inferring...") | |||
t1 = time.time() | |||
result = client.infer(instances) | |||
t2 = time.time() | |||
print("finish inferring! Time costed:", t2 - t1) | |||
print(result) |
@@ -0,0 +1,31 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""server script for serving mode""" | |||
import os | |||
from mindspore_serving import server | |||
from fold_service.config import config as serving_config | |||
def start(): | |||
servable_dir = os.path.dirname(os.path.realpath(__file__)) | |||
servable_config = server.ServableStartConfig(servable_directory=servable_dir, servable_name="fold_service", | |||
device_ids=serving_config.device_id) | |||
server.start_servables(servable_configs=servable_config) | |||
server.start_grpc_server(address="127.0.0.1:" + str(serving_config.port)) | |||
if __name__ == "__main__": | |||
start() |
@@ -0,0 +1,14 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ |
@@ -0,0 +1,15 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""init""" |
@@ -0,0 +1,85 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" test Activations """ | |||
import pytest | |||
import numpy as np | |||
from mindspore import nn | |||
from mindspore import Tensor | |||
from mindspore import context | |||
from mindelec.architecture import get_activation | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
class Net(nn.Cell): | |||
def __init__(self): | |||
super(Net, self).__init__() | |||
self.srelu = get_activation("srelu") | |||
def construct(self, x): | |||
return self.srelu(x) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_srelu(): | |||
"""test srelu activation""" | |||
net = Net() | |||
input_tensor = Tensor(np.array([[1.2, 0.1], [0.2, 3.2]], dtype=np.float32)) | |||
output = net(input_tensor) | |||
print(input_tensor.asnumpy()) | |||
print(output.asnumpy()) | |||
class Net1(nn.Cell): | |||
"""net""" | |||
def __init__(self): | |||
super(Net1, self).__init__() | |||
self.sin = get_activation("sin") | |||
def construct(self, x): | |||
return self.sin(x) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_sin(): | |||
"""test sin activation""" | |||
net = Net1() | |||
input_tensor = Tensor(np.array([[1.2, 0.1], [0.2, 3.2]], dtype=np.float32)) | |||
output = net(input_tensor) | |||
print(input_tensor.asnumpy()) | |||
print(output.asnumpy()) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_activation_type_error(): | |||
with pytest.raises(TypeError): | |||
get_activation(1) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_get_activation(): | |||
activation = get_activation("softshrink") | |||
assert isinstance(activation, nn.Cell) |
@@ -0,0 +1,224 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" test block """ | |||
import pytest | |||
import numpy as np | |||
from mindspore import nn, context | |||
from mindspore import Tensor, Parameter | |||
from mindelec.architecture import LinearBlock, ResBlock | |||
from mindelec.architecture import InputScaleNet, FCSequential, MultiScaleFCCell | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
class Net(nn.Cell): | |||
""" Net definition """ | |||
def __init__(self, | |||
input_channels, | |||
output_channels, | |||
weight='normal', | |||
bias='zeros', | |||
has_bias=True): | |||
super(Net, self).__init__() | |||
self.fc = LinearBlock(input_channels, output_channels, weight, bias, has_bias) | |||
def construct(self, input_x): | |||
return self.fc(input_x) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_linear(): | |||
"""test linear block""" | |||
weight = Tensor(np.random.randint(0, 255, [8, 64]).astype(np.float32)) | |||
bias = Tensor(np.random.randint(0, 255, [8]).astype(np.float32)) | |||
net = Net(64, 8, weight=weight, bias=bias) | |||
input_data = Tensor(np.random.randint(0, 255, [128, 64]).astype(np.float32)) | |||
output = net(input_data) | |||
print(output.asnumpy()) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_linear_nobias(): | |||
"""test linear block with no bias""" | |||
weight = Tensor(np.random.randint(0, 255, [8, 64]).astype(np.float32)) | |||
net = Net(64, 8, weight=weight, has_bias=False) | |||
input_data = Tensor(np.random.randint(0, 255, [128, 64]).astype(np.float32)) | |||
output = net(input_data) | |||
print(output.asnumpy()) | |||
class Net1(nn.Cell): | |||
""" Net definition """ | |||
def __init__(self, | |||
input_channels, | |||
output_channels, | |||
weight='normal', | |||
bias='zeros', | |||
has_bias=True, | |||
activation=None): | |||
super(Net1, self).__init__() | |||
self.fc = ResBlock(input_channels, output_channels, weight, bias, has_bias, activation) | |||
def construct(self, input_x): | |||
return self.fc(input_x) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_res(): | |||
"""test res block""" | |||
weight = Tensor(np.random.randint(0, 255, [8, 8]).astype(np.float32)) | |||
bias = Tensor(np.random.randint(0, 255, [8]).astype(np.float32)) | |||
net = Net1(8, 8, weight=weight, bias=bias) | |||
input_data = Tensor(np.random.randint(0, 255, [128, 8]).astype(np.float32)) | |||
output = net(input_data) | |||
print(output.asnumpy()) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_res_nobias(): | |||
"""test res block with no bias""" | |||
weight = Tensor(np.random.randint(0, 255, [8, 8]).astype(np.float32)) | |||
net = Net1(8, 8, weight=weight, has_bias=False) | |||
input_data = Tensor(np.random.randint(0, 255, [128, 8]).astype(np.float32)) | |||
output = net(input_data) | |||
print(output.asnumpy()) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_res_activation(): | |||
"""test res block with activation""" | |||
weight = Tensor(np.random.randint(0, 255, [8, 8]).astype(np.float32)) | |||
bias = Tensor(np.random.randint(0, 255, [8]).astype(np.float32)) | |||
net = Net1(8, 8, weight=weight, bias=bias, activation='sin') | |||
input_data = Tensor(np.random.randint(0, 255, [128, 8]).astype(np.float32)) | |||
output = net(input_data) | |||
print(output.asnumpy()) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_res_channel_error(): | |||
with pytest.raises(ValueError): | |||
ResBlock(3, 6) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_input_scale(): | |||
"""test input scale cell""" | |||
inputs = np.random.uniform(size=(16, 3)) + 3.0 | |||
inputs = Tensor(inputs.astype(np.float32)) | |||
input_scale = [1.0, 2.0, 4.0] | |||
input_center = [3.5, 3.5, 3.5] | |||
net = InputScaleNet(input_scale, input_center) | |||
output = net(inputs).asnumpy() | |||
assert np.all(output[:, 0] <= 0.5) and np.all(output[:, 0] >= -0.5) | |||
assert np.all(output[:, 1] <= 1.0) and np.all(output[:, 0] >= -1.0) | |||
assert np.all(output[:, 2] <= 2.0) and np.all(output[:, 0] >= -2.0) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_fc_sequential(): | |||
"""test fc sequential cell""" | |||
inputs = np.ones((16, 3)) | |||
inputs = Tensor(inputs.astype(np.float32)) | |||
net = FCSequential(3, 3, 5, 32, weight_init="ones", bias_init="zeros") | |||
output = net(inputs).asnumpy() | |||
target = np.ones((16, 3)) * -31.998459 | |||
assert np.allclose(output, target, rtol=5e-2) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_mulscale_without_latent(): | |||
"""test multi-scale net without latent vector""" | |||
inputs = np.ones((16, 3)) + 3.0 | |||
inputs = Tensor(inputs.astype(np.float32)) | |||
input_scale = [1.0, 2.0, 4.0] | |||
input_center = [3.5, 3.5, 3.5] | |||
net = MultiScaleFCCell(3, 3, 5, 32, | |||
weight_init="ones", bias_init="zeros", | |||
input_scale=input_scale, input_center=input_center) | |||
output = net(inputs).asnumpy() | |||
target = np.ones((16, 3)) * -61.669254 | |||
assert np.allclose(output, target, rtol=5e-2) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_mulscale_with_latent(): | |||
"""test multi-scale net with latent vector and input scale""" | |||
inputs = np.ones((64, 3)) + 3.0 | |||
inputs = Tensor(inputs.astype(np.float32)) | |||
num_scenarios = 4 | |||
latent_size = 16 | |||
latent_init = np.ones((num_scenarios, latent_size)).astype(np.float32) | |||
latent_vector = Parameter(Tensor(latent_init), requires_grad=True) | |||
input_scale = [1.0, 2.0, 4.0] | |||
input_center = [3.5, 3.5, 3.5] | |||
net = MultiScaleFCCell(3, 3, 5, 32, | |||
weight_init="ones", bias_init="zeros", | |||
input_scale=input_scale, input_center=input_center, latent_vector=latent_vector) | |||
output = net(inputs).asnumpy() | |||
target = np.ones((64, 3)) * -57.8849 | |||
assert np.allclose(output, target, rtol=5e-2) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_mulscale_with_latent_noscale(): | |||
"""test multi-scale net with latent vector""" | |||
inputs = np.ones((64, 3)) | |||
inputs = Tensor(inputs.astype(np.float32)) | |||
num_scenarios = 4 | |||
latent_size = 16 | |||
latent_init = np.ones((num_scenarios, latent_size)).astype(np.float32) | |||
latent_vector = Parameter(Tensor(latent_init), requires_grad=True) | |||
net = MultiScaleFCCell(3, 3, 5, 32, | |||
weight_init="ones", bias_init="zeros", latent_vector=latent_vector) | |||
output = net(inputs).asnumpy() | |||
target = np.ones((64, 3)) * -105.62799 | |||
assert np.allclose(output, target, rtol=5e-2) |
@@ -0,0 +1,43 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" test block """ | |||
import pytest | |||
import numpy as np | |||
from mindspore import context | |||
from mindspore import Tensor | |||
from mindelec.architecture import MTLWeightedLossCell | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_linear(): | |||
net = MTLWeightedLossCell(num_losses=2) | |||
input_data = Tensor(np.array([1.0, 1.0]).astype(np.float32)) | |||
output = net(input_data) | |||
print(output.asnumpy()) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_mlt_num_losses_error(): | |||
with pytest.raises(TypeError): | |||
MTLWeightedLossCell('a') |
@@ -0,0 +1,92 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" test metrics """ | |||
import pytest | |||
from mindspore import context | |||
from mindspore.common.tensor import Tensor | |||
from mindspore.common import dtype as mstype | |||
from mindelec.common import LearningRate, get_poly_lr | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_learning_rate(): | |||
"""test LearningRate""" | |||
learning_rate = LearningRate(0.1, 0.001, 0, 10, 0.5) | |||
res = learning_rate(Tensor(10000, mstype.int32)) | |||
assert res == 0.001 | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_learning_rate_power_value_error(): | |||
with pytest.raises(ValueError): | |||
LearningRate(0.1, 0.001, 0, 10, -0.5) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_learning_rate_warmup_steps_type_error(): | |||
"""test TypeError cases""" | |||
with pytest.raises(TypeError): | |||
LearningRate(0.1, 0.001, 0.1, 10, 0.5) | |||
with pytest.raises(ValueError): | |||
LearningRate(0.0, 0.001, 0, 10, 0.5) | |||
with pytest.raises(ValueError): | |||
LearningRate(0.1, -0.001, 0, 10, 0.5) | |||
with pytest.raises(ValueError): | |||
LearningRate(0.1, 0.001, 0, 0, 0.5) | |||
with pytest.raises(ValueError): | |||
LearningRate(0.1, 0.001, -1, 10, 0.5) | |||
with pytest.raises(ValueError): | |||
LearningRate(0.1, 0.001, 0, -10, 0.5) | |||
with pytest.raises(ValueError): | |||
LearningRate(0.1, 0.001, 1, -10, 0.5) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_get_poly_lr(): | |||
"""test get_poly_lr""" | |||
res = get_poly_lr(100, 0.001, 0.1, 0.0001, 1000, 10000, 0.5) | |||
assert res.shape == (9900,) | |||
with pytest.raises(ValueError): | |||
get_poly_lr(-1, 0.001, 0.1, 0.0001, 1000, 10000, 0.5) | |||
with pytest.raises(ValueError): | |||
get_poly_lr(100, 0.0, 0.1, 0.0001, 1000, 10000, 0.5) | |||
with pytest.raises(ValueError): | |||
get_poly_lr(100, 0.001, 0.1, 0.0, 1000, 10000, 0.5) | |||
with pytest.raises(ValueError): | |||
get_poly_lr(100, 0.001, 0.1, 0.0001, 1000, 0, 0.5) | |||
with pytest.raises(ValueError): | |||
get_poly_lr(100, 0.001, 0.1, 0.0001, 1000, 10000, -0.5) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_get_poly_lr1(): | |||
"""test get_poly_lr""" | |||
res = get_poly_lr(100, 0.001, 0.1, 0.0001, 0, 10000, 0.5) | |||
assert res.shape == (9900,) |
@@ -0,0 +1,39 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" test metrics """ | |||
import pytest | |||
import numpy as np | |||
import mindspore | |||
from mindspore import Tensor | |||
from mindspore import context | |||
from mindelec.common import L2 | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_l2(): | |||
"""test l2""" | |||
x = Tensor(np.array([0.1, 0.2, 0.6, 0.9]), mindspore.float32) | |||
y = Tensor(np.array([0.1, 0.25, 0.7, 0.9]), mindspore.float32) | |||
metric = L2() | |||
metric.clear() | |||
metric.update(x, y) | |||
result = metric.eval() | |||
assert result == 0.09543302997807275 |
@@ -0,0 +1,82 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" | |||
sampling and dataset settings | |||
""" | |||
from easydict import EasyDict as edict | |||
ds_config = edict({ | |||
'train': edict({ | |||
'batch_size': 100, | |||
'shuffle': True, | |||
'drop_remainder': True, | |||
}), | |||
'eval': edict({ | |||
'batch_size': 100, | |||
'shuffle': False, | |||
'drop_remainder': False, | |||
}), | |||
}) | |||
src_sampling_config = edict({ | |||
'domain': edict({ | |||
'random_sampling': True, | |||
'size': 400, | |||
'sampler': 'uniform' | |||
}), | |||
'IC': edict({ | |||
'random_sampling': True, | |||
'size': 100, | |||
'sampler': 'uniform', | |||
}), | |||
'time': edict({ | |||
'random_sampling': True, | |||
'size': 10, | |||
'sampler': 'uniform', | |||
}), | |||
}) | |||
no_src_sampling_config = edict({ | |||
'domain': edict({ | |||
'random_sampling': True, | |||
'size': 100, | |||
'sampler': 'uniform' | |||
}), | |||
'IC': edict({ | |||
'random_sampling': True, | |||
'size': 100, | |||
'sampler': 'uniform', | |||
}), | |||
'time': edict({ | |||
'random_sampling': True, | |||
'size': 10, | |||
'sampler': 'uniform', | |||
}), | |||
}) | |||
bc_sampling_config = edict({ | |||
'BC': edict({ | |||
'random_sampling': True, | |||
'size': 100, | |||
'sampler': 'uniform', | |||
'with_normal': True | |||
}), | |||
'time': edict({ | |||
'random_sampling': True, | |||
'size': 10, | |||
'sampler': 'uniform', | |||
}), | |||
}) |
@@ -0,0 +1,110 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""test geometry module: geometry with time cases""" | |||
import copy | |||
import pytest | |||
from easydict import EasyDict as edict | |||
from mindelec.geometry import create_config_from_edict | |||
from mindelec.geometry import Rectangle, GeometryWithTime, TimeDomain | |||
from mindelec.data import Boundary, BoundaryBC, BoundaryIC | |||
reset_geom_time_config = edict({ | |||
'domain': edict({ | |||
'random_sampling': False, | |||
'size': [10, 20], | |||
}), | |||
'BC': edict({ | |||
'random_sampling': True, | |||
'size': 10, | |||
'with_normal': True, | |||
}), | |||
'IC': edict({ | |||
'random_sampling': True, | |||
'size': 10, | |||
}), | |||
'time': edict({ | |||
'random_sampling': False, | |||
'size': 10, | |||
}) | |||
}) | |||
reset_geom_time_config2 = edict({ | |||
'domain': edict({ | |||
'random_sampling': False, | |||
'size': [10, 20], | |||
}), | |||
'BC': edict({ | |||
'random_sampling': False, | |||
'size': 10, | |||
'with_normal': False, | |||
}), | |||
'IC': edict({ | |||
'random_sampling': False, | |||
'size': 10, | |||
}), | |||
'time': edict({ | |||
'random_sampling': False, | |||
'size': 10, | |||
}) | |||
}) | |||
def check_rect_with_time_set_config(config): | |||
"""check_rect_with_time_set_config""" | |||
rect = Rectangle("rect", [-1.0, -0.5], [1.0, 0.5]) | |||
time = TimeDomain("time", 0.0, 1.0) | |||
rect_with_time = GeometryWithTime(rect, time) | |||
with pytest.raises(TypeError): | |||
Boundary(0) | |||
with pytest.raises(ValueError): | |||
Boundary(rect_with_time) | |||
config1 = copy.deepcopy(config) | |||
config1.pop('BC') | |||
rect_with_time.set_sampling_config(create_config_from_edict(config1)) | |||
with pytest.raises(ValueError): | |||
BoundaryBC(rect_with_time) | |||
config2 = copy.deepcopy(config) | |||
config2.pop('IC') | |||
rect_with_time.set_sampling_config(create_config_from_edict(config2)) | |||
with pytest.raises(ValueError): | |||
BoundaryIC(rect_with_time) | |||
rect_with_time.set_sampling_config(create_config_from_edict(config)) | |||
bc = BoundaryBC(rect_with_time) | |||
for i in range(20): | |||
print(bc[i]) | |||
config3 = copy.deepcopy(config) | |||
if not config3.IC.random_sampling: | |||
config3.IC.size = [4, 4] | |||
rect_with_time.set_sampling_config(create_config_from_edict(config3)) | |||
ic = BoundaryIC(rect_with_time) | |||
for i in range(20): | |||
print(ic[i]) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_rect_with_time_set_config(): | |||
"""test_check_rect_with_time_set_config""" | |||
check_rect_with_time_set_config(reset_geom_time_config) | |||
check_rect_with_time_set_config(reset_geom_time_config2) |
@@ -0,0 +1,195 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""Test data_base.""" | |||
import pytest | |||
import numpy as np | |||
from mindelec.data import ExistedDataConfig | |||
from mindelec.data.data_base import Data | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_existed_data_config_name_error(): | |||
with pytest.raises(TypeError): | |||
ExistedDataConfig(1, "/home/a.npy", "data") | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_existed_data_config_data_dir_error(): | |||
with pytest.raises(TypeError): | |||
ExistedDataConfig("a", 1, "data") | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_existed_data_config_column_list_error(): | |||
with pytest.raises(TypeError): | |||
input_path = "./input_data.npy" | |||
input_data = np.random.randn(10, 3) | |||
np.save(input_path, input_data) | |||
ExistedDataConfig("a", input_path, 1) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_existed_data_config_constraint_type_error(): | |||
with pytest.raises(TypeError): | |||
input_path = "./input_data.npy" | |||
input_data = np.random.randn(10, 3) | |||
np.save(input_path, input_data) | |||
ExistedDataConfig("a", input_path, "data", contraint_type="a") | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_existed_data_config_data_format_typeerror(): | |||
with pytest.raises(TypeError): | |||
input_path = "./input_data.npy" | |||
input_data = np.random.randn(10, 3) | |||
np.save(input_path, input_data) | |||
ExistedDataConfig("a", input_path, "data", 1) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_existed_data_config_data_format_valueerror(): | |||
with pytest.raises(ValueError): | |||
input_path = "./input_data.npy" | |||
input_data = np.random.randn(10, 3) | |||
np.save(input_path, input_data) | |||
ExistedDataConfig("a", input_path, "data", "csv") | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_existed_data_config_data_constraint_type_valueerror(): | |||
with pytest.raises(TypeError): | |||
input_path = "./input_data.npy" | |||
input_data = np.random.randn(10, 3) | |||
np.save(input_path, input_data) | |||
ExistedDataConfig("a", input_path, "data", constraint_type='1') | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_existed_data_config_data_constraint_type_valueerror1(): | |||
with pytest.raises(TypeError): | |||
input_path = "./input_data.npy" | |||
input_data = np.random.randn(10, 3) | |||
np.save(input_path, input_data) | |||
ExistedDataConfig("a", input_path, "data", constraint_type='test') | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_existed_data_config_random_merge_error(): | |||
with pytest.raises(TypeError): | |||
input_path = "./input_data.npy" | |||
input_data = np.random.randn(10, 3) | |||
np.save(input_path, input_data) | |||
ExistedDataConfig("a", input_path, "data", random_merge=1) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_data_name_type_error(): | |||
with pytest.raises(TypeError): | |||
Data(1) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_data_columns_list_type_error(): | |||
with pytest.raises(TypeError): | |||
Data(columns_list=1) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_data_constraint_type_type_error(): | |||
with pytest.raises(TypeError): | |||
Data(constraint_type=1) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_data_constraint_type_type_error1(): | |||
with pytest.raises(TypeError): | |||
Data(constraint_type="labe") | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_data_set_constraint_type_type_error(): | |||
with pytest.raises(TypeError): | |||
data = Data() | |||
data.set_constraint_type("test") | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_data_create_dataset_nie_error(): | |||
with pytest.raises(NotImplementedError): | |||
data = Data() | |||
data.create_dataset() | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_data_get_item_nie_error(): | |||
with pytest.raises(NotImplementedError): | |||
data = Data() | |||
print(data[0]) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_data_len_nie_error(): | |||
with pytest.raises(NotImplementedError): | |||
data = Data() | |||
len(data) |
@@ -0,0 +1,112 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""test dataset module""" | |||
import pytest | |||
from easydict import EasyDict as edict | |||
from mindelec.data import Dataset | |||
from mindelec.geometry import create_config_from_edict | |||
from mindelec.geometry import Disk, Rectangle, TimeDomain, GeometryWithTime | |||
from mindelec.data import BoundaryBC, BoundaryIC | |||
from config import ds_config, src_sampling_config, no_src_sampling_config, bc_sampling_config | |||
ic_bc_config = edict({ | |||
'domain': edict({ | |||
'random_sampling': False, | |||
'size': [10, 20], | |||
}), | |||
'BC': edict({ | |||
'random_sampling': True, | |||
'size': 10, | |||
'with_normal': True, | |||
}), | |||
'IC': edict({ | |||
'random_sampling': True, | |||
'size': 10, | |||
}), | |||
'time': edict({ | |||
'random_sampling': False, | |||
'size': 10, | |||
}) | |||
}) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_dataset_allnone(): | |||
with pytest.raises(ValueError): | |||
Dataset() | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_dataset(): | |||
"""test dataset""" | |||
disk = Disk("src", (0.0, 0.0), 0.2) | |||
rectangle = Rectangle("rect", (-1, -1), (1, 1)) | |||
diff = rectangle - disk | |||
time = TimeDomain("time", 0.0, 1.0) | |||
# check datalist | |||
rect_with_time = GeometryWithTime(rectangle, time) | |||
rect_with_time.set_sampling_config(create_config_from_edict(ic_bc_config)) | |||
bc = BoundaryBC(rect_with_time) | |||
ic = BoundaryIC(rect_with_time) | |||
dataset = Dataset(dataset_list=bc) | |||
dataset.set_constraint_type("Equation") | |||
c_type1 = {bc: "Equation", ic: "Equation"} | |||
with pytest.raises(ValueError): | |||
dataset.set_constraint_type(c_type1) | |||
no_src_region = GeometryWithTime(diff, time) | |||
no_src_region.set_name("no_src") | |||
no_src_region.set_sampling_config(create_config_from_edict(no_src_sampling_config)) | |||
src_region = GeometryWithTime(disk, time) | |||
src_region.set_name("src") | |||
src_region.set_sampling_config(create_config_from_edict(src_sampling_config)) | |||
boundary = GeometryWithTime(rectangle, time) | |||
boundary.set_name("bc") | |||
boundary.set_sampling_config(create_config_from_edict(bc_sampling_config)) | |||
geom_dict = ['1', '2'] | |||
with pytest.raises(TypeError): | |||
Dataset(geom_dict) | |||
geom_dict = {src_region: ["test"]} | |||
with pytest.raises(KeyError): | |||
Dataset(geom_dict) | |||
geom_dict = {src_region: ["domain", "IC"], | |||
no_src_region: ["domain", "IC"], | |||
boundary: ["BC"]} | |||
dataset = Dataset(geom_dict) | |||
with pytest.raises(ValueError): | |||
print(dataset[0]) | |||
with pytest.raises(ValueError): | |||
len(dataset) | |||
with pytest.raises(ValueError): | |||
dataset.get_columns_list() | |||
with pytest.raises(ValueError): | |||
dataset.create_dataset(batch_size=ds_config.train.batch_size, | |||
shuffle=ds_config.train.shuffle, | |||
prebatched_data=True, | |||
drop_remainder=False) |
@@ -0,0 +1,77 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""test geometry module: geometry with time cases""" | |||
import copy | |||
import pytest | |||
from easydict import EasyDict as edict | |||
from mindelec.geometry import create_config_from_edict | |||
from mindelec.geometry import Rectangle, GeometryWithTime, TimeDomain | |||
from mindelec.data import Equation | |||
reset_geom_time_config = edict({ | |||
'domain': edict({ | |||
'random_sampling': False, | |||
'size': [5, 2], | |||
}), | |||
'time': edict({ | |||
'random_sampling': False, | |||
'size': 10, | |||
}) | |||
}) | |||
reset_geom_time_config2 = edict({ | |||
'domain': edict({ | |||
'random_sampling': True, | |||
'size': 10, | |||
}), | |||
'time': edict({ | |||
'random_sampling': False, | |||
'size': 10, | |||
}) | |||
}) | |||
def check_rect_with_time_set_config(config): | |||
"""check_rect_with_time_set_config""" | |||
rect = Rectangle("rect", [-1.0, -0.5], [1.0, 0.5]) | |||
time = TimeDomain("time", 0.0, 1.0) | |||
rect_with_time = GeometryWithTime(rect, time) | |||
with pytest.raises(TypeError): | |||
Equation(0) | |||
with pytest.raises(ValueError): | |||
Equation(rect_with_time) | |||
config1 = copy.deepcopy(config) | |||
config1.pop('domain') | |||
rect_with_time.set_sampling_config(create_config_from_edict(config1)) | |||
with pytest.raises(KeyError): | |||
Equation(rect_with_time) | |||
rect_with_time.set_sampling_config(create_config_from_edict(config)) | |||
eq = Equation(rect_with_time) | |||
for i in range(20): | |||
print(eq[i]) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_rect_with_time_set_config(): | |||
"""test_check_rect_with_time_set_config""" | |||
check_rect_with_time_set_config(reset_geom_time_config) | |||
check_rect_with_time_set_config(reset_geom_time_config2) |
@@ -0,0 +1,77 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""Test existed_data.""" | |||
import pytest | |||
import numpy as np | |||
from mindelec.data import ExistedDataset, ExistedDataConfig | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_existed_data_value_error(): | |||
with pytest.raises(ValueError): | |||
ExistedDataset() | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_existed_data_cnfig_type_error(): | |||
with pytest.raises(TypeError): | |||
ExistedDataset(data_config=1) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_existed_config(): | |||
"""test various errors""" | |||
with pytest.raises(ValueError): | |||
input_path = "./input_data.npy" | |||
label_path = "./label.npy" | |||
input_data = np.random.randn(10, 3) | |||
output = np.random.randn(10, 3) | |||
np.save(input_path, input_data) | |||
np.save(label_path, output) | |||
config = ExistedDataConfig(name="existed_data", | |||
data_dir=[input_path, label_path], | |||
columns_list=["inputs", "label"], | |||
constraint_type="Equation", | |||
data_format="npz") | |||
dataset = ExistedDataset(data_config=config) | |||
input_path = "./input_data.npy" | |||
input_data = np.random.randn(10, 3) | |||
np.save(input_path, input_data) | |||
dataset = ExistedDataset(name="existed_data", | |||
data_dir=[input_path], | |||
columns_list=["inputs"], | |||
constraint_type="Equation", | |||
data_format="npy") | |||
for i in range(20): | |||
print(dataset[i]) | |||
dataset = ExistedDataset(name="existed_data", | |||
data_dir=[input_path], | |||
columns_list=["inputs"], | |||
constraint_type="Equation", | |||
data_format="npy", | |||
random_merge=False) | |||
for i in range(20): | |||
print(dataset[i]) |
@@ -0,0 +1,84 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""test dataset module, call CSG ant GeometryWithTime""" | |||
import pytest | |||
import numpy as np | |||
from easydict import EasyDict as edict | |||
from mindelec.data import Dataset | |||
from mindelec.geometry import Rectangle, TimeDomain, GeometryWithTime, create_config_from_edict | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_dataset(): | |||
"""check dataset""" | |||
sampling_config = edict({ | |||
'domain': edict({ | |||
'random_sampling': False, | |||
'size': [10, 10], | |||
}), | |||
'BC': edict({ | |||
'random_sampling': False, | |||
'size': 100, | |||
'with_normal': True, | |||
}), | |||
'IC': edict({ | |||
'random_sampling': False, | |||
'size': [10, 10], | |||
}), | |||
'time': edict({ | |||
'random_sampling': False, | |||
'size': 40, | |||
}), | |||
}) | |||
rect_space = Rectangle("rectangle", coord_min=[0, 0], coord_max=[10, 10]) | |||
time = TimeDomain("time", 0.0, 20) | |||
grid = GeometryWithTime(rect_space, time) | |||
grid.set_sampling_config(create_config_from_edict(sampling_config)) | |||
grid.set_name("grid") | |||
geom_dict = {grid: ["domain", "BC", "IC"]} | |||
dataset = Dataset(geom_dict) | |||
def preprocess_fn(*data): | |||
bc_data = data[1] | |||
bc_normal = np.ones(bc_data.shape) | |||
return data[0], data[1], bc_normal, data[2], data[3] | |||
colunms_map = {"grid_domain_points": ["grid_domain_points"], | |||
"grid_BC_points": ["grid_BC_points", "grid_BC_tangent"], | |||
"grid_BC_normal": "grid_BC_normal", | |||
"grid_IC_points": "grid_IC_points"} | |||
train_data = dataset.create_dataset(batch_size=8192, shuffle=False, drop_remainder=False, | |||
preprocess_fn=preprocess_fn, | |||
input_output_columns_map=colunms_map) | |||
dataset.set_constraint_type({dataset.all_datasets[0]: "Equation", | |||
dataset.all_datasets[1]: "BC", | |||
dataset.all_datasets[2]: "IC"}) | |||
print("get merged data: {}".format(dataset[5])) | |||
for sub_data in dataset.all_datasets: | |||
print("get data: {}".format(sub_data[5])) | |||
dataset_iter = train_data.create_dict_iterator(num_epochs=1) | |||
np.set_printoptions(threshold=np.inf) | |||
for _ in range(1): | |||
for data in dataset_iter: | |||
for k, v in data.items(): | |||
print("key: ", k) | |||
print(v) | |||
break |
@@ -0,0 +1,133 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""test geometry module: 1d cases""" | |||
import pytest | |||
from easydict import EasyDict as edict | |||
from mindelec.geometry import create_config_from_edict | |||
from mindelec.geometry import Interval, TimeDomain | |||
line_config_out = edict({ | |||
'domain': edict({ | |||
'random_sampling': True, | |||
'size': 100, | |||
'sampler': 'uniform' | |||
}), | |||
'BC': edict({ | |||
'random_sampling': True, | |||
'size': 10, | |||
'sampler': 'uniform', | |||
}), | |||
}) | |||
line_config_out2 = edict({ | |||
'domain': edict({ | |||
'random_sampling': True, | |||
'size': 100, | |||
'sampler': 'uniform' | |||
}), | |||
'BC': edict({ | |||
'random_sampling': False, | |||
'size': 10, | |||
'sampler': 'uniform', | |||
}), | |||
}) | |||
def check_line_interval_case1(line_config): | |||
"""check_line_interval_case1""" | |||
try: | |||
Interval("line", 'test', 1.0, sampling_config=create_config_from_edict(line_config)) | |||
except ValueError: | |||
return | |||
line = Interval("line", -1.0, 1.0, sampling_config=create_config_from_edict(line_config)) | |||
domain = line.sampling(geom_type="domain") | |||
bc = line.sampling(geom_type="BC") | |||
try: | |||
line.sampling(geom_type="other") | |||
except ValueError: | |||
print("get ValueError when sampling other data") | |||
# uniform sampling | |||
if "BC" in line_config.keys(): | |||
line_config.BC = None | |||
if "domain" in line_config.keys(): | |||
line_config.domain.random_sampling = False | |||
line.set_sampling_config(create_config_from_edict(line_config)) | |||
domain = line.sampling(geom_type="domain") | |||
try: | |||
line.sampling(geom_type="BC") | |||
except KeyError: | |||
print("get ValueError when sampling BC data") | |||
# lhs, halton, sobol | |||
for samplers in ["lhs", "halton", "sobol"]: | |||
if "domain" in line_config.keys(): | |||
line_config.domain.random_sampling = True | |||
line_config.domain.sampler = samplers | |||
line.set_sampling_config(create_config_from_edict(line_config)) | |||
domain = line.sampling(geom_type="domain") | |||
print(domain, bc) | |||
time_config = edict({ | |||
'domain': edict({ | |||
'random_sampling': True, | |||
'size': 100, | |||
'sampler': 'lhs' | |||
}) | |||
}) | |||
def check_time_interval(line_config): | |||
"""check_time_interval""" | |||
try: | |||
create_config_from_edict({"test": "test"}) | |||
except ValueError: | |||
return | |||
line = TimeDomain("time", 0.0, 1.0, sampling_config=create_config_from_edict(line_config)) | |||
domain = line.sampling(geom_type="domain") | |||
try: | |||
line.sampling(geom_type="BC") | |||
except KeyError: | |||
print("get ValueError when sampling BC data") | |||
print(domain) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_time_interval(): | |||
"""test_check_time_interval""" | |||
check_time_interval(time_config) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_line_interval_case1(): | |||
"""test_check_time_interval""" | |||
check_line_interval_case1(line_config_out) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_line_interval_case2(): | |||
"""test_check_time_interval""" | |||
check_line_interval_case1(line_config_out2) |
@@ -0,0 +1,247 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""test geometry module: 2d cases""" | |||
import pytest | |||
from easydict import EasyDict as edict | |||
from mindelec.geometry import create_config_from_edict | |||
from mindelec.geometry import Rectangle, Disk | |||
disk_random = edict({ | |||
'domain': edict({ | |||
'random_sampling': True, | |||
'size': 1000, | |||
'sampler': 'uniform' | |||
}), | |||
'BC': edict({ | |||
'random_sampling': True, | |||
'size': 200, | |||
'sampler': 'uniform', | |||
'with_normal': False, | |||
}), | |||
}) | |||
disk_mesh = edict({ | |||
'domain': edict({ | |||
'random_sampling': False, | |||
'size': [100, 180], | |||
}), | |||
'BC': edict({ | |||
'random_sampling': False, | |||
'size': [20, 10], | |||
'with_normal': False, | |||
}), | |||
}) | |||
disk_mesh_wrong_meshsize = edict({ | |||
'domain': edict({ | |||
'random_sampling': False, | |||
'size': [100, 180, 200], | |||
}), | |||
'BC': edict({ | |||
'random_sampling': False, | |||
'size': 200, | |||
'with_normal': False, | |||
}), | |||
}) | |||
disk_mesh_nodomain = edict({ | |||
'BC': edict({ | |||
'random_sampling': False, | |||
'size': 200, | |||
'with_normal': False, | |||
}), | |||
}) | |||
disk_mesh_nobc = edict({ | |||
'domain': edict({ | |||
'random_sampling': True, | |||
'size': 1000, | |||
'sampler': 'uniform' | |||
}), | |||
}) | |||
def check_disk_random(disk_config): | |||
"""check_disk_random""" | |||
with pytest.raises(ValueError): | |||
disk = Disk("disk", (-1.0, 0), -2.0) | |||
with pytest.raises(ValueError): | |||
disk = Disk("disk", (-1.0, 0, 3), 2.0) | |||
disk = Disk("disk", (-1.0, 0), 2.0) | |||
for samplers in ["uniform", "lhs", "halton", "sobol"]: | |||
print("check random sampler: {}".format(samplers)) | |||
if "domain" in disk_config.keys(): | |||
disk_config.domain.sampler = samplers | |||
if "BC" in disk_config.keys(): | |||
disk_config.BC.sampler = samplers | |||
try: | |||
disk.sampling(geom_type="domain") | |||
except ValueError: | |||
return | |||
disk.set_sampling_config(create_config_from_edict(disk_config)) | |||
with pytest.raises(ValueError): | |||
disk.sampling(geom_type="test") | |||
domain = disk.sampling(geom_type="domain") | |||
bc = disk.sampling(geom_type="BC") | |||
disk.sampling_config.bc.with_normal = True | |||
bc, bc_normal = disk.sampling(geom_type="BC") | |||
print(bc, bc_normal, domain) | |||
def check_disk_mesh(disk_config): | |||
"""check_disk_mesh""" | |||
disk = Disk("disk", (-1.0, 0), 2.0, sampling_config=create_config_from_edict(disk_config)) | |||
domain = disk.sampling(geom_type="domain") | |||
bc = disk.sampling(geom_type="BC") | |||
disk.sampling_config.bc.with_normal = True | |||
bc, bc_normal = disk.sampling(geom_type="BC") | |||
print(bc, bc_normal, domain) | |||
rectangle_random = edict({ | |||
'domain': edict({ | |||
'random_sampling': True, | |||
'size': 1000, | |||
'sampler': 'uniform' | |||
}), | |||
'BC': edict({ | |||
'random_sampling': True, | |||
'size': 200, | |||
'sampler': 'uniform', | |||
'with_normal': True, | |||
}), | |||
}) | |||
rectangle_mesh = edict({ | |||
'domain': edict({ | |||
'random_sampling': False, | |||
'size': [50, 25], | |||
}), | |||
'BC': edict({ | |||
'random_sampling': False, | |||
'size': 300, | |||
'with_normal': True, | |||
}), | |||
}) | |||
def check_rectangle_random(config): | |||
"""check_rectangle_random""" | |||
rectangle = Rectangle("rectangle", (-3.0, 1), (1, 2)) | |||
for samplers in ["uniform", "lhs", "halton", "sobol"]: | |||
print("check random sampler: {}".format(samplers)) | |||
if "domain" in config.keys(): | |||
config.domain.sampler = samplers | |||
if "BC" in config.keys(): | |||
config.BC.sampler = samplers | |||
config.BC.with_normal = True | |||
rectangle.set_sampling_config(create_config_from_edict(config)) | |||
domain = rectangle.sampling(geom_type="domain") | |||
bc, bc_normal = rectangle.sampling(geom_type="BC") | |||
if "BC" in config.keys(): | |||
config.BC.with_normal = False | |||
rectangle.set_sampling_config(create_config_from_edict(config)) | |||
bc = rectangle.sampling(geom_type="BC") | |||
print(bc, bc_normal, domain) | |||
def check_rectangle_mesh(config): | |||
"""check_rectangle_mesh""" | |||
rectangle = Rectangle("rectangle", (-3.0, 1), (1, 2)) | |||
if "BC" in config.keys(): | |||
config.BC.with_normal = True | |||
rectangle.set_sampling_config(create_config_from_edict(config)) | |||
domain = rectangle.sampling(geom_type="domain") | |||
bc, bc_normal = rectangle.sampling(geom_type="BC") | |||
if "BC" in config.keys(): | |||
config.BC.with_normal = False | |||
rectangle.set_sampling_config(create_config_from_edict(config)) | |||
bc = rectangle.sampling(geom_type="BC") | |||
print(bc, bc_normal, domain) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_disk_random(): | |||
"""test_check_disk_random""" | |||
check_disk_random(disk_random) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_disk_mesh(): | |||
"""test_check_disk_mesh""" | |||
check_disk_mesh(disk_mesh) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_disk_mesh_wrong_meshsize_error(): | |||
"""test_check_disk_mesh""" | |||
with pytest.raises(ValueError): | |||
check_disk_mesh(disk_mesh_wrong_meshsize) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_disk_mesh_nodomain_error(): | |||
"""test_check_disk_mesh""" | |||
with pytest.raises(KeyError): | |||
check_disk_mesh(disk_mesh_nodomain) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_disk_mesh_nobc_error(): | |||
"""test_check_disk_mesh""" | |||
with pytest.raises(KeyError): | |||
check_disk_mesh(disk_mesh_nobc) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_rectangle_random(): | |||
"""test_check_rectangle_random""" | |||
check_rectangle_random(rectangle_random) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_rectangle_mesh(): | |||
"""test_check_rectangle_mesh""" | |||
check_rectangle_mesh(rectangle_mesh) |
@@ -0,0 +1,264 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""test geometry module: base classes""" | |||
import pytest | |||
import numpy as np | |||
from easydict import EasyDict as edict | |||
from mindelec.geometry import Geometry, PartSamplingConfig, SamplingConfig, create_config_from_edict | |||
def check_create_config_from_edict(): | |||
try: | |||
sampling_config = ["geom", "IC", "BC"] | |||
config = create_config_from_edict(sampling_config) | |||
print("check config: {}".format(config)) | |||
except TypeError: | |||
print("get sampling config type error") | |||
def check_part_sampling_config(size, random_sampling, sampler, random_merge, with_normal): | |||
try: | |||
config = PartSamplingConfig(size, random_sampling, sampler, random_merge, with_normal) | |||
print("check config: {}".format(config.__dict__)) | |||
except TypeError: | |||
print("get TypeError") | |||
temp_config = edict({ | |||
'domain': edict({ | |||
'random_sampling': True, | |||
'size': 100, | |||
}), | |||
'BC': edict({ | |||
'random_sampling': True, | |||
'size': 100, | |||
'sampler': 'uniform', | |||
'random_merge': True, | |||
}), | |||
}) | |||
def check_sampling_config_case1(): | |||
"""check_sampling_config_case1""" | |||
with pytest.raises(TypeError): | |||
SamplingConfig("test") | |||
with pytest.raises(KeyError): | |||
SamplingConfig({"test": "test"}) | |||
with pytest.raises(TypeError): | |||
SamplingConfig({"domain": "test"}) | |||
with pytest.raises(TypeError): | |||
part_sampling_config_dict = {"domain": PartSamplingConfig("test", False, True)} | |||
SamplingConfig(part_sampling_config_dict) | |||
part_sampling_config_dict = {"domain": PartSamplingConfig([100, 100], False, "uniform", True, True), | |||
"BC": PartSamplingConfig(100, True, "uniform", True, True)} | |||
sampling_config_tmp = SamplingConfig(part_sampling_config_dict) | |||
for attr, config in sampling_config_tmp.__dict__.items(): | |||
if config is not None: | |||
print("check sampling config: {}: {}".format(attr, config.__dict__)) | |||
def check_sampling_config_case2(config_in): | |||
"""check_sampling_config_case2""" | |||
sampling_config_tmp = create_config_from_edict(config_in) | |||
for attr, config in sampling_config_tmp.__dict__.items(): | |||
if config is not None: | |||
print("check sampling config: {}: {}".format(attr, config.__dict__)) | |||
def check_sampling_config_case3(config_in): | |||
"""check_sampling_config_case3""" | |||
config_in.ttime = config_in.domain | |||
sampling_config_tmp = create_config_from_edict(config_in) | |||
for attr, config in sampling_config_tmp.__dict__.items(): | |||
if config is not None: | |||
print("check sampling config: {}: {}".format(attr, config.__dict__)) | |||
def check_geometry_case1(): | |||
"""check_geometry_case1""" | |||
with pytest.raises(ValueError): | |||
Geometry("geom", 2, 0.0, 1.0) | |||
with pytest.raises(TypeError): | |||
Geometry("geom", 2, [0, 0], [1, 1], sampling_config="test") | |||
geom = Geometry("geom", 1, 0.0, 1.0) | |||
with pytest.raises(TypeError): | |||
geom.set_sampling_config('test') | |||
with pytest.raises(NotImplementedError): | |||
geom.sampling() | |||
for attr, config in geom.__dict__.items(): | |||
if config is not None: | |||
print("check sampling config: {}: {}".format(attr, config)) | |||
try: | |||
geom = Geometry("geom", 1, 1.0, 0.0) | |||
for attr, config in geom.__dict__.items(): | |||
if config is not None: | |||
print("check sampling config: {}: {}".format(attr, config)) | |||
except ValueError: | |||
print("get ValueError") | |||
sampling_config2 = create_config_from_edict(temp_config) | |||
def check_geometry_case2(sampling_config_tmp): | |||
"""check_geometry_case2""" | |||
geom = Geometry("geom", 1, 0.0, 1.0, sampling_config=sampling_config_tmp) | |||
geom.set_name("geom_name") | |||
for attr, config in geom.__dict__.items(): | |||
if config is not None: | |||
print("check sampling config: {}: {}".format(attr, config)) | |||
def check_geometry_case3(config_in): | |||
"""check_geometry_case3""" | |||
geom = Geometry("geom", 1, 0.0, 1.0) | |||
for attr, configs in geom.__dict__.items(): | |||
if configs is not None: | |||
print("check sampling config: {}: {}".format(attr, configs)) | |||
geom.set_name("geom_name") | |||
geom.set_sampling_config(create_config_from_edict(config_in)) | |||
for attr, config in geom.__dict__.items(): | |||
if attr == "sampling_config" and config is not None: | |||
print("check sampling config after set: {}: {}".format(attr, config.__dict__)) | |||
for attrs, configs in config.__dict__.items(): | |||
if configs is not None: | |||
print("check sampling config: {}: {}".format(attrs, configs.__dict__)) | |||
else: | |||
if config is not None: | |||
print("check sampling config after set: {}: {}".format(attr, config)) | |||
def check_geometry_case4(): | |||
"""check_geometry_case4""" | |||
try: | |||
geom = Geometry(10, 1, 1.0, 0.0) | |||
for attr, config in geom.__dict__.items(): | |||
if config is not None: | |||
print("check sampling config: {}: {}".format(attr, config)) | |||
except TypeError: | |||
print("get geom name type error") | |||
geom = Geometry("geom", 1, 0.0, 1.0) | |||
try: | |||
geom.set_name("geom_name") | |||
except TypeError: | |||
print("get set geom name type error") | |||
geom.set_name("geom_name") | |||
try: | |||
geom = Geometry("geom", 1.0, 1.0, 2.0) | |||
for attr, config in geom.__dict__.items(): | |||
if config is not None: | |||
print("check sampling config: {}: {}".format(attr, config)) | |||
except TypeError: | |||
print("get geom dim type error") | |||
try: | |||
geom = Geometry("geom", 1, 1.0, 0.0) | |||
except ValueError: | |||
print("get geom coord value error") | |||
try: | |||
geom = Geometry("geom", 1, {"min": 0.0}, {"max": 1.0}) | |||
except TypeError: | |||
print("get geom coord type error") | |||
try: | |||
geom = Geometry("geom", 1, 0.0, 1.0, dtype=np.uint32) | |||
except TypeError: | |||
print("get geom data type error") | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_part_sampling_config(): | |||
"""test_check_part_sampling_config""" | |||
check_part_sampling_config(100, True, "uniform", True, True) | |||
check_part_sampling_config(100, False, "sobol", True, True) | |||
check_part_sampling_config([100, 100], False, "sobol", True, True) | |||
check_part_sampling_config([100, 100], True, "uniform", True, True) | |||
check_part_sampling_config(100, False, "lhs", True, True) | |||
check_part_sampling_config(100, False, "halton", True, True) | |||
check_create_config_from_edict() | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_sampling_config_case1(): | |||
"""test_check_sampling_config_case1""" | |||
check_sampling_config_case1() | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_sampling_config_case2(): | |||
"""test_check_sampling_config_case2""" | |||
check_sampling_config_case2(temp_config) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_sampling_config_case3(): | |||
"""test_check_sampling_config_case3""" | |||
check_sampling_config_case3(temp_config) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_geometry_case1(): | |||
"""test_check_geometry_case1""" | |||
check_geometry_case1() | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_geometry_case2(): | |||
"""test_check_geometry_case2""" | |||
check_geometry_case2(sampling_config2) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_geometry_case3(): | |||
"""test_check_geometry_case3""" | |||
check_geometry_case3(temp_config) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_geometry_case4(): | |||
"""test_check_geometry_case4""" | |||
check_geometry_case4() |
@@ -0,0 +1,240 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""test geometry module: CSG classes""" | |||
import pytest | |||
from easydict import EasyDict as edict | |||
from mindelec.geometry import create_config_from_edict | |||
from mindelec.geometry import Rectangle, Disk, Interval | |||
from mindelec.geometry import CSGIntersection, CSGDifference, CSGUnion, CSGXOR, CSG | |||
sampling_config_csg = edict({ | |||
'domain': edict({ | |||
'random_sampling': True, | |||
'size': 1000, | |||
'sampler': 'uniform' | |||
}), | |||
'BC': edict({ | |||
'random_sampling': True, | |||
'size': 200, | |||
'sampler': 'uniform', | |||
'with_normal': True, | |||
}), | |||
}) | |||
sampling_config_csg2 = edict({ | |||
'domain': edict({ | |||
'random_sampling': True, | |||
'size': 1000, | |||
'sampler': 'uniform' | |||
}), | |||
'BC': edict({ | |||
'random_sampling': True, | |||
'size': 200, | |||
'sampler': 'uniform', | |||
'with_normal': False, | |||
}), | |||
}) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_csg_union(): | |||
"""test check union""" | |||
disk = Disk("disk", (1.2, 0.5), 0.8) | |||
rect = Rectangle("rect", (-1.0, 0), (1, 1)) | |||
union = CSGUnion(rect, disk, create_config_from_edict(sampling_config_csg)) | |||
union = rect | disk | |||
for samplers in ["uniform", "lhs", "halton", "sobol"]: | |||
print("check random sampler: {}".format(samplers)) | |||
if "domain" in sampling_config_csg.keys(): | |||
sampling_config_csg.domain.sampler = samplers | |||
if "BC" in sampling_config_csg.keys(): | |||
sampling_config_csg.BC.sampler = samplers | |||
union.set_sampling_config(create_config_from_edict(sampling_config_csg2)) | |||
bc = union.sampling(geom_type="BC") | |||
union.set_sampling_config(create_config_from_edict(sampling_config_csg)) | |||
domain = union.sampling(geom_type="domain") | |||
bc, bc_normal = union.sampling(geom_type="BC") | |||
print(bc, bc_normal, domain) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_csg_difference(): | |||
"""test_check_csg_difference""" | |||
disk = Disk("disk", (1.2, 0.5), 0.8) | |||
rect = Rectangle("rect", (-1.0, 0), (1, 1)) | |||
difference = CSGDifference(rect, disk, create_config_from_edict(sampling_config_csg)) | |||
difference = rect - disk | |||
for samplers in ["uniform", "lhs", "halton", "sobol"]: | |||
print("check random sampler: {}".format(samplers)) | |||
if "domain" in sampling_config_csg.keys(): | |||
sampling_config_csg.domain.sampler = samplers | |||
if "BC" in sampling_config_csg.keys(): | |||
sampling_config_csg.BC.sampler = samplers | |||
difference.set_sampling_config(create_config_from_edict(sampling_config_csg2)) | |||
bc = difference.sampling(geom_type="BC") | |||
difference.set_sampling_config(create_config_from_edict(sampling_config_csg)) | |||
domain = difference.sampling(geom_type="domain") | |||
bc, bc_normal = difference.sampling(geom_type="BC") | |||
print(bc, bc_normal, domain) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_csg_intersection(): | |||
"""test_check_csg_intersection""" | |||
disk = Disk("disk", (1.2, 0.5), 0.8) | |||
rect = Rectangle("rect", (-1.0, 0), (1, 1)) | |||
intersec = CSGIntersection(rect, disk, create_config_from_edict(sampling_config_csg)) | |||
intersec = rect & disk | |||
for samplers in ["uniform", "lhs", "halton", "sobol"]: | |||
print("check random sampler: {}".format(samplers)) | |||
if "domain" in sampling_config_csg.keys(): | |||
sampling_config_csg.domain.sampler = samplers | |||
if "BC" in sampling_config_csg.keys(): | |||
sampling_config_csg.BC.sampler = samplers | |||
intersec.set_sampling_config(create_config_from_edict(sampling_config_csg2)) | |||
bc = intersec.sampling(geom_type="BC") | |||
intersec.set_sampling_config(create_config_from_edict(sampling_config_csg)) | |||
domain = intersec.sampling(geom_type="domain") | |||
bc, bc_normal = intersec.sampling(geom_type="BC") | |||
print(bc, bc_normal, domain) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_csg_xor(): | |||
"""test_check_csg_xor""" | |||
disk = Disk("disk", (1.2, 0.5), 0.8) | |||
rect = Rectangle("rect", (-1.0, 0), (1, 1)) | |||
xor = CSGXOR(rect, disk, create_config_from_edict(sampling_config_csg)) | |||
xor = rect ^ disk | |||
for samplers in ["uniform", "lhs", "halton", "sobol"]: | |||
print("check random sampler: {}".format(samplers)) | |||
if "domain" in sampling_config_csg.keys(): | |||
sampling_config_csg.domain.sampler = samplers | |||
if "BC" in sampling_config_csg.keys(): | |||
sampling_config_csg.BC.sampler = samplers | |||
xor.set_sampling_config(create_config_from_edict(sampling_config_csg2)) | |||
bc = xor.sampling(geom_type="BC") | |||
xor.set_sampling_config(create_config_from_edict(sampling_config_csg)) | |||
domain = xor.sampling(geom_type="domain") | |||
bc, bc_normal = xor.sampling(geom_type="BC") | |||
print(bc, bc_normal, domain) | |||
no_src_sampling_config = edict({ | |||
'domain': edict({ | |||
'random_sampling': True, | |||
'size': 200, | |||
'sampler': 'uniform' | |||
}), | |||
}) | |||
no_src_sampling_config1 = edict({ | |||
'BC': edict({ | |||
'random_sampling': True, | |||
'size': 200, | |||
'sampler': 'uniform', | |||
'with_normal': False | |||
}), | |||
}) | |||
no_src_sampling_config2 = edict({ | |||
'domain': edict({ | |||
'random_sampling': False, | |||
'size': [10, 10] | |||
}), | |||
}) | |||
no_src_sampling_config3 = edict({ | |||
'BC': edict({ | |||
'random_sampling': False, | |||
'size': [10, 10] | |||
}), | |||
}) | |||
no_src_sampling_config4 = edict({ | |||
'IC': edict({ | |||
'random_sampling': False, | |||
'size': [10, 10] | |||
}), | |||
}) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_point_src_csg(): | |||
"""test_check_point_src_csg""" | |||
src_region = Disk("src", (0.0, 0.0), 0.2) | |||
rectangle = Rectangle("rect", (-1, -1), (1, 1)) | |||
line = Interval("line", -1, 1) | |||
with pytest.raises(TypeError): | |||
no_src_region = CSG("test", rectangle, 2, [0, 0], [1, 1]) | |||
with pytest.raises(TypeError): | |||
no_src_region = CSG("test", 2, rectangle, [0, 0], [1, 1]) | |||
with pytest.raises(ValueError): | |||
no_src_region = CSG("test", rectangle, line, [0, 0], [1, 1]) | |||
no_src_region = rectangle - src_region | |||
no_src_region.set_name("no_src") | |||
with pytest.raises(ValueError): | |||
no_src_region.set_sampling_config(None) | |||
with pytest.raises(TypeError): | |||
no_src_region.set_sampling_config("test") | |||
with pytest.raises(ValueError): | |||
no_src_region.set_sampling_config(create_config_from_edict(no_src_sampling_config2)) | |||
with pytest.raises(ValueError): | |||
no_src_region.set_sampling_config(create_config_from_edict(no_src_sampling_config3)) | |||
with pytest.raises(ValueError): | |||
no_src_region.set_sampling_config(create_config_from_edict(no_src_sampling_config4)) | |||
no_src_region.set_sampling_config(create_config_from_edict(no_src_sampling_config)) | |||
with pytest.raises(KeyError): | |||
no_src_region.sampling(geom_type="BC") | |||
no_src_region.set_sampling_config(create_config_from_edict(no_src_sampling_config1)) | |||
with pytest.raises(KeyError): | |||
no_src_region.sampling(geom_type="domain") | |||
with pytest.raises(ValueError): | |||
no_src_region.sampling(geom_type="test") | |||
no_src_region.sampling(geom_type="BC") |
@@ -0,0 +1,142 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""test geometry module: nd cases""" | |||
import pytest | |||
from easydict import EasyDict as edict | |||
from mindelec.geometry import create_config_from_edict | |||
from mindelec.geometry import Cuboid | |||
cuboid_random = edict({ | |||
'domain': edict({ | |||
'random_sampling': True, | |||
'size': 1000, | |||
'sampler': 'uniform' | |||
}), | |||
'BC': edict({ | |||
'random_sampling': True, | |||
'size': 200, | |||
'sampler': 'uniform', | |||
'with_normal': False, | |||
}), | |||
}) | |||
cuboid_random2 = edict({ | |||
'BC': edict({ | |||
'random_sampling': True, | |||
'size': 200, | |||
'sampler': 'uniform', | |||
'with_normal': False, | |||
}), | |||
}) | |||
cuboid_mesh = edict({ | |||
'domain': edict({ | |||
'random_sampling': False, | |||
'size': [20, 30, 10], | |||
}), | |||
'BC': edict({ | |||
'random_sampling': False, | |||
'size': 900, | |||
'with_normal': True, | |||
}), | |||
}) | |||
cuboid_mesh2 = edict({ | |||
'domain': edict({ | |||
'random_sampling': False, | |||
'size': [20, 30], | |||
}), | |||
'BC': edict({ | |||
'random_sampling': False, | |||
'size': 900, | |||
'with_normal': True, | |||
}), | |||
}) | |||
def check_cuboid_random(cuboid_config): | |||
"""check_cuboid_random""" | |||
cuboid = Cuboid("Cuboid", (-3, -1, 0), [-1, 2, 1]) | |||
for samplers in ["uniform", "lhs", "halton", "sobol"]: | |||
print("check random sampler: {}".format(samplers)) | |||
if "domain" in cuboid_config.keys(): | |||
cuboid_config.domain.sampler = samplers | |||
if "BC" in cuboid_config.keys(): | |||
cuboid_config.BC.sampler = samplers | |||
try: | |||
cuboid.set_sampling_config("test") | |||
except TypeError: | |||
print("set_sampling_config TypeError") | |||
try: | |||
cuboid.sampling(geom_type="domain") | |||
except ValueError: | |||
print("sampling ValueError") | |||
cuboid.set_sampling_config(create_config_from_edict(cuboid_config)) | |||
domain = cuboid.sampling(geom_type="domain") | |||
bc = cuboid.sampling(geom_type="BC") | |||
assert domain.shape == (1000, 3) | |||
assert bc.shape == (199, 3) | |||
def check_cuboid_mesh(cuboid_config): | |||
"""check_cuboid_mesh""" | |||
cuboid = Cuboid("Cuboid", (-3, -1, 0), [-1, 2, 1], sampling_config=create_config_from_edict(cuboid_config)) | |||
domain = cuboid.sampling(geom_type="domain") | |||
bc, bc_normal = cuboid.sampling(geom_type="BC") | |||
assert domain.shape == (6000, 3) | |||
assert bc.shape == (556, 3) | |||
assert bc_normal.shape == (556, 3) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_cuboid_random(): | |||
"""test_check_cuboid_random""" | |||
check_cuboid_random(cuboid_random) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_cuboid_random_nodomain_error(): | |||
"""test_check_cuboid_random""" | |||
with pytest.raises(KeyError): | |||
check_cuboid_random(cuboid_random2) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_cuboid_mesh(): | |||
"""test_check_cuboid_mesh""" | |||
check_cuboid_mesh(cuboid_mesh) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_cuboid_mesh_meshsize_error(): | |||
"""test_check_cuboid_mesh""" | |||
with pytest.raises(ValueError): | |||
check_cuboid_mesh(cuboid_mesh2) |
@@ -0,0 +1,293 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""test geometry module: geometry with time cases""" | |||
import copy | |||
import pytest | |||
from easydict import EasyDict as edict | |||
from mindelec.geometry import create_config_from_edict | |||
from mindelec.geometry import Rectangle, GeometryWithTime, TimeDomain | |||
rectangle_random = edict({ | |||
'domain': edict({ | |||
'random_sampling': True, | |||
'size': 1000, | |||
'sampler': 'uniform' | |||
}), | |||
'BC': edict({ | |||
'random_sampling': True, | |||
'size': 200, | |||
'sampler': 'uniform', | |||
'with_normal': False, | |||
}), | |||
}) | |||
rectangle_mesh = edict({ | |||
'domain': edict({ | |||
'random_sampling': False, | |||
'size': [10, 20], | |||
}), | |||
'BC': edict({ | |||
'random_sampling': False, | |||
'size': 50, | |||
'with_normal': False, | |||
}), | |||
}) | |||
time_random = edict({ | |||
'domain': edict({ | |||
'random_sampling': True, | |||
'size': 10, | |||
'sampler': 'lhs' | |||
}) | |||
}) | |||
time_mesh = edict({ | |||
'domain': edict({ | |||
'random_sampling': False, | |||
'size': 10, | |||
}) | |||
}) | |||
time_mesh2 = edict({ | |||
'domain': edict({ | |||
'random_sampling': False, | |||
'size': 10000, | |||
}) | |||
}) | |||
time_mesh3 = edict({ | |||
'IC': edict({ | |||
'random_sampling': False, | |||
'size': 10000, | |||
}) | |||
}) | |||
reset_geom_time_config = edict({ | |||
'domain': edict({ | |||
'random_sampling': False, | |||
'size': [10, 20], | |||
}), | |||
'BC': edict({ | |||
'random_sampling': False, | |||
'size': 50, | |||
'with_normal': True, | |||
}), | |||
'IC': edict({ | |||
'random_sampling': False, | |||
'size': [10, 10], | |||
}), | |||
'time': edict({ | |||
'random_sampling': False, | |||
'size': 10, | |||
}) | |||
}) | |||
reset_geom_time_config2 = edict({ | |||
'domain': edict({ | |||
'random_sampling': False, | |||
'size': [10, 20], | |||
}), | |||
'BC': edict({ | |||
'random_sampling': False, | |||
'size': 50, | |||
'with_normal': True, | |||
}), | |||
'IC': edict({ | |||
'random_sampling': False, | |||
'size': [10, 10], | |||
}), | |||
'time': edict({ | |||
'random_sampling': True, | |||
'size': 10, | |||
}) | |||
}) | |||
reset_geom_time_config3 = edict({ | |||
'domain': edict({ | |||
'random_sampling': True, | |||
'size': 200, | |||
}), | |||
'BC': edict({ | |||
'random_sampling': True, | |||
'size': 50, | |||
'with_normal': True, | |||
}), | |||
'IC': edict({ | |||
'random_sampling': False, | |||
'size': [10, 10], | |||
}), | |||
'time': edict({ | |||
'random_sampling': True, | |||
'size': 10, | |||
}) | |||
}) | |||
reset_geom_time_config4 = edict({ | |||
'domain': edict({ | |||
'random_sampling': True, | |||
'size': 200, | |||
}), | |||
'BC': edict({ | |||
'random_sampling': True, | |||
'size': 50, | |||
'with_normal': True, | |||
}), | |||
'IC': edict({ | |||
'random_sampling': True, | |||
'size': 100, | |||
}), | |||
'time': edict({ | |||
'random_sampling': False, | |||
'size': 10, | |||
}) | |||
}) | |||
reset_geom_time_config5 = edict({ | |||
'time': edict({ | |||
'random_sampling': False, | |||
'size': 10, | |||
}) | |||
}) | |||
def check_rect_with_time_init_config(rect_config, time_config): | |||
"""check_rect_with_time_init_config""" | |||
rect = Rectangle("rect", [-1.0, -0.5], [1.0, 0.5], sampling_config=create_config_from_edict(rect_config)) | |||
time = TimeDomain("time", 0.0, 1.0, sampling_config=create_config_from_edict(time_config)) | |||
rect_with_time = GeometryWithTime(rect, time) | |||
# check info | |||
print("check rect_with_time initial config: {}".format(rect_with_time.__dict__)) | |||
if rect_with_time.sampling_config is not None: | |||
for key, value in rect_with_time.sampling_config.__dict__.items(): | |||
if value is not None: | |||
print(" get attr: {}, value: {}".format(key, value.__dict__)) | |||
# sampling | |||
config = rect_with_time.sampling_config | |||
if config is None: | |||
raise ValueError | |||
if config.domain is not None: | |||
domain = rect_with_time.sampling(geom_type="domain") | |||
print("check domain points: {}".format(domain.shape)) | |||
if config.bc is not None: | |||
if config.bc.with_normal: | |||
bc, bc_normal = rect_with_time.sampling(geom_type="BC") | |||
print("check bc points: {}, bc_normal: {}".format(bc.shape, bc_normal.shape)) | |||
else: | |||
bc = rect_with_time.sampling(geom_type="BC") | |||
print("check bc points: {}".format(bc.shape)) | |||
if config.ic is not None: | |||
ic = rect_with_time.sampling(geom_type="IC") | |||
print("check ic points: {}".format(ic.shape)) | |||
def check_rect_with_time_set_config(config): | |||
"""check_rect_with_time_set_config""" | |||
rect = Rectangle("rect", [-1.0, -0.5], [1.0, 0.5]) | |||
time = TimeDomain("time", 0.0, 1.0) | |||
try: | |||
GeometryWithTime(rect, time, create_config_from_edict(config)) | |||
except ValueError: | |||
print("create_config_from_edict ValueError") | |||
rect_with_time = GeometryWithTime(rect, time) | |||
try: | |||
rect_with_time.sampling(geom_type="domain") | |||
except ValueError: | |||
print("sampling ValueError") | |||
rect_with_time.set_sampling_config(create_config_from_edict(config)) | |||
try: | |||
rect_with_time.sampling(geom_type="test") | |||
except ValueError: | |||
print("sampling ValueError") | |||
# sampling | |||
config = rect_with_time.sampling_config | |||
if config.domain is not None: | |||
domain = rect_with_time.sampling(geom_type="domain") | |||
print("check domain points: {}".format(domain.shape)) | |||
else: | |||
try: | |||
rect_with_time.sampling(geom_type="domain") | |||
except ValueError: | |||
print("sampling KeyError") | |||
if config.bc is not None: | |||
if config.bc.with_normal: | |||
bc, bc_normal = rect_with_time.sampling(geom_type="BC") | |||
print("check bc points: {}, bc_normal: {}".format(bc.shape, bc_normal.shape)) | |||
normal = copy.deepcopy(bc) | |||
normal[:, :2] = bc[:, :2] + bc_normal[:, :] | |||
else: | |||
bc = rect_with_time.sampling(geom_type="BC") | |||
print("check bc points: {}".format(bc.shape)) | |||
else: | |||
try: | |||
rect_with_time.sampling(geom_type="BC") | |||
except ValueError: | |||
print("sampling KeyError") | |||
if config.ic is not None: | |||
ic = rect_with_time.sampling(geom_type="IC") | |||
print("check ic points: {}".format(ic.shape)) | |||
else: | |||
try: | |||
rect_with_time.sampling(geom_type="IC") | |||
except ValueError: | |||
print("sampling KeyError") | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_rect_with_time_init_config(): | |||
"""test_check_rect_with_time_init_config""" | |||
check_rect_with_time_init_config(rectangle_random, time_random) | |||
check_rect_with_time_init_config(rectangle_mesh, time_random) | |||
check_rect_with_time_init_config(rectangle_mesh, time_mesh2) | |||
check_rect_with_time_init_config(rectangle_random, time_mesh) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_rect_with_time_init_config_error(): | |||
"""test_check_rect_with_time_init_config""" | |||
with pytest.raises(ValueError): | |||
check_rect_with_time_init_config(rectangle_mesh, time_mesh3) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_rect_with_time_set_config(): | |||
"""test_check_rect_with_time_set_config""" | |||
check_rect_with_time_set_config(reset_geom_time_config) | |||
check_rect_with_time_set_config(reset_geom_time_config2) | |||
check_rect_with_time_set_config(reset_geom_time_config3) | |||
check_rect_with_time_set_config(reset_geom_time_config4) | |||
check_rect_with_time_set_config(reset_geom_time_config4) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_check_rect_with_time_set_config2(): | |||
"""test_check_rect_with_time_set_config""" | |||
with pytest.raises(ValueError): | |||
check_rect_with_time_set_config(rectangle_random) |
@@ -0,0 +1,39 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""Test constraints""" | |||
import pytest | |||
from mindelec.data import Dataset | |||
from mindelec.loss import Constraints | |||
from mindelec.geometry import Rectangle | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_constraints_dataset_type_error(): | |||
with pytest.raises(TypeError): | |||
Constraints(1, 1) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_existed_data_config_type_error(): | |||
rectangle = Rectangle("rect", (-1, -1), (1, 1)) | |||
geom_dict = {rectangle: ["domain"]} | |||
dataset = Dataset(geom_dict) | |||
with pytest.raises(TypeError): | |||
Constraints(dataset, 1) |
@@ -0,0 +1,217 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================== | |||
""" | |||
test net_with_loss | |||
""" | |||
import os | |||
from easydict import EasyDict as edict | |||
import numpy as np | |||
import pytest | |||
import mindspore.nn as nn | |||
import mindspore.ops as ops | |||
from mindspore import Tensor | |||
from mindspore import context, ms_function | |||
from mindspore.common import set_seed | |||
from mindelec.geometry import Rectangle, create_config_from_edict | |||
from mindelec.data import Dataset, ExistedDataConfig | |||
from mindelec.loss import Constraints, NetWithLoss, NetWithEval | |||
from mindelec.architecture import ResBlock, LinearBlock | |||
from mindelec.solver import Problem | |||
from mindelec.common import L2 | |||
from mindelec.operators import Grad | |||
set_seed(1) | |||
np.random.seed(1) | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
path = os.getcwd() | |||
data_config = edict({ | |||
'data_dir': [path+'/inputs.npy', path+'/label.npy'], # absolute dir | |||
'columns_list': ['input_data', 'label'], | |||
'data_format': 'npy', | |||
'constraint_type': 'Equation', | |||
'name': 'exist' | |||
}) | |||
if not os.path.exists(data_config['data_dir'][0]): | |||
data_in = np.ones((32, 2), dtype=np.float32) | |||
np.save(path+"/inputs.npy", data_in) | |||
if not os.path.exists(data_config['data_dir'][1]): | |||
data_label = np.ones((32, 1), dtype=np.float32) | |||
np.save(path+"/label.npy", data_label) | |||
rectangle_config = edict({ | |||
'domain': edict({ | |||
'random_sampling': True, | |||
'size': 2500, | |||
'sampler': 'uniform' | |||
}), | |||
'BC': edict({ | |||
'random_sampling': True, | |||
'size': 200, | |||
'sampler': 'uniform' | |||
}) | |||
}) | |||
rectangle_config1 = edict({ | |||
'domain': edict({ | |||
'random_sampling': True, | |||
'size': 2500, | |||
'sampler': 'uniform' | |||
}), | |||
}) | |||
class Net(nn.Cell): | |||
"""net definition""" | |||
def __init__(self, input_dim, output_dim, hidden_layer=128, activation="sin"): | |||
super(Net, self).__init__() | |||
self.resblock = ResBlock(hidden_layer, hidden_layer, activation=activation) | |||
self.fc1 = LinearBlock(input_dim, hidden_layer) | |||
self.fc2 = LinearBlock(hidden_layer, output_dim) | |||
def construct(self, *inputs): | |||
x = inputs[0] | |||
out = self.fc1(x) | |||
out = self.resblock(out) | |||
out = self.fc2(out) | |||
return out | |||
class RectPde(Problem): | |||
"""rectangle pde problem""" | |||
def __init__(self, domain_name=None, bc_name=None, label_name=None, net=None): | |||
super(RectPde, self).__init__() | |||
self.domain_name = domain_name | |||
self.bc_name = bc_name | |||
self.label_name = label_name | |||
self.type = "Equation" | |||
self.jacobian = Grad(net) | |||
@ms_function | |||
def governing_equation(self, *output, **kwargs): | |||
u = output[0] | |||
data = kwargs[self.domain_name] | |||
u_x = self.jacobian(data, 0, 0, u) | |||
return u_x | |||
@ms_function | |||
def boundary_condition(self, *output, **kwargs): | |||
u = output[0] | |||
x = kwargs[self.bc_name][:, 0] | |||
y = kwargs[self.bc_name][:, 1] | |||
return u - ops.sin(x) * ops.cos(y) | |||
@ms_function | |||
def constraint_function(self, *output, **kwargs): | |||
u = output[0] | |||
label = kwargs[self.label_name] | |||
return u - label | |||
class RectPde1(Problem): | |||
"""rectangle pde problem with no boundary condition""" | |||
def __init__(self, domain_name, net): | |||
super(RectPde1, self).__init__() | |||
self.domain_name = domain_name | |||
self.type = "Equation" | |||
self.jacobian = Grad(net) | |||
@ms_function | |||
def governing_equation(self, *output, **kwargs): | |||
u = output[0] | |||
data = kwargs[self.domain_name] | |||
u_x = self.jacobian(data, 0, 0, u) | |||
return u_x | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_netwithloss(): | |||
"""test netwithloss function""" | |||
model = Net(2, 1, 128, "sin") | |||
rect_space = Rectangle("rectangle", coord_min=[-1.0, -1.0], coord_max=[1.0, 1.0], | |||
sampling_config=create_config_from_edict(rectangle_config)) | |||
geom_dict = {rect_space: ["domain", "BC"]} | |||
dataset = Dataset(geom_dict) | |||
dataset.create_dataset(batch_size=4, shuffle=True) | |||
prob_dict = {rect_space.name: RectPde(domain_name="rectangle_domain_points", bc_name="rectangle_BC_points", | |||
net=model)} | |||
train_constraints = Constraints(dataset, prob_dict) | |||
metrics = {'l2': L2(), 'distance': nn.MAE()} | |||
train_input_map = {'rectangle_domain': ['rectangle_domain_points'], 'rectangle_BC': ['rectangle_BC_points']} | |||
loss_network = NetWithLoss(model, train_constraints, metrics, train_input_map) | |||
domain_points = Tensor(np.ones([32, 2]).astype(np.float32)) | |||
bc_points = Tensor(np.ones([32, 2]).astype(np.float32)) | |||
bc_normal = Tensor(np.ones([32, 2]).astype(np.float32)) | |||
out = loss_network(domain_points, bc_points, bc_normal) | |||
print(out) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_netwithloss1(): | |||
"""test netwithloss function1""" | |||
model = Net(2, 1, 128, "sin") | |||
rect_space = Rectangle("rectangle", coord_min=[-1.0, -1.0], coord_max=[1.0, 1.0], | |||
sampling_config=create_config_from_edict(rectangle_config1)) | |||
geom_dict = {rect_space: ["domain"]} | |||
dataset = Dataset(geom_dict) | |||
dataset.create_dataset(batch_size=4, shuffle=True) | |||
prob_dict = {rect_space.name: RectPde1(domain_name="rectangle_domain_points", net=model)} | |||
train_constraints = Constraints(dataset, prob_dict) | |||
metrics = {'l2': L2(), 'distance': nn.MAE()} | |||
train_input_map = {'rectangle_domain': ['rectangle_domain_points']} | |||
loss_network = NetWithLoss(model, train_constraints, metrics, train_input_map) | |||
domain_points = Tensor(np.ones([32, 2]).astype(np.float32)) | |||
out = loss_network(domain_points) | |||
print(out) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_netwitheval(): | |||
"""test netwitheval function""" | |||
model = Net(2, 1, 128, "sin") | |||
src_domain = ExistedDataConfig(name="src_domain", | |||
data_dir=[path+"/inputs.npy", path+"/label.npy"], | |||
columns_list=["inputs", "label"], | |||
data_format="npy", | |||
constraint_type="Label", | |||
random_merge=False) | |||
test_prob_dict = {src_domain.name: RectPde(domain_name=src_domain.name+"_inputs", | |||
label_name=src_domain.name+"_label", net=model)} | |||
test_dataset = Dataset(existed_data_list=[src_domain]) | |||
test_dataset.create_dataset(batch_size=4, shuffle=False) | |||
test_constraints = Constraints(test_dataset, test_prob_dict) | |||
metrics = {'l2': L2(), 'distance': nn.MAE()} | |||
loss_network = NetWithEval(model, test_constraints, metrics) | |||
data = Tensor(np.ones([32, 2]).astype(np.float32)) | |||
label = Tensor(np.ones([32, 1]).astype(np.float32)) | |||
out = loss_network(data, label) | |||
print(out) |
@@ -0,0 +1,30 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" | |||
network config setting, will be used in train.py and eval.py | |||
""" | |||
# config | |||
config = { | |||
'base_channels': 8, | |||
'input_channels': 4, | |||
'epochs': 2000, | |||
'batch_size': 8, | |||
'save_epoch': 100, | |||
'lr': 0.01, | |||
'lr_decay_milestones': 5, | |||
'eval_interval': 20, | |||
'patch_shape': [25, 50, 25], | |||
} |
@@ -0,0 +1,82 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""dataset generation and loading""" | |||
import numpy as np | |||
from mindspore.common import set_seed | |||
from mindelec.data import Dataset, ExistedDataConfig | |||
np.random.seed(0) | |||
set_seed(0) | |||
PATCH_DIM = [25, 50, 25] | |||
NUM_SAMPLE = 10000 | |||
INPUT_PATH = "" | |||
DATA_CONFIG_PATH = "./data_config.npy" | |||
SAVE_DATA_PATH = "./" | |||
def generate_data(input_path): | |||
"""generate training data and data configuration""" | |||
space_temp = np.load(input_path) | |||
print("data load finish") | |||
print("random cropping...") | |||
space_data = np.ones((NUM_SAMPLE, | |||
PATCH_DIM[0], | |||
PATCH_DIM[1], | |||
PATCH_DIM[2], | |||
space_temp.shape[-1])).astype(np.float32) | |||
rand_pos = np.random.randint(low=0, high=(space_temp.shape[0] - PATCH_DIM[0])*\ | |||
(space_temp.shape[1] - PATCH_DIM[1])*\ | |||
(space_temp.shape[2] - PATCH_DIM[2]), size=NUM_SAMPLE) | |||
for i, pos in enumerate(rand_pos): | |||
z = pos % (space_temp.shape[2] - PATCH_DIM[2]) | |||
y = (pos // (space_temp.shape[2] - PATCH_DIM[2])) % (space_temp.shape[1] - PATCH_DIM[1]) | |||
x = (pos // (space_temp.shape[2] - PATCH_DIM[2])) // (space_temp.shape[1] - PATCH_DIM[1]) | |||
space_data[i] = space_temp[x : x+PATCH_DIM[0], y : y+PATCH_DIM[1], z : z+PATCH_DIM[2], :] | |||
print("random crop finished") | |||
space_data[:, :, :, :, 2] = np.log10(space_data[:, :, :, :, 2] + 1.0) | |||
data_config = np.ones(4) | |||
for i in range(4): | |||
data_config[i] = np.max(np.abs(space_data[:, :, :, :, i])) | |||
space_data[:, :, :, :, i] = space_data[:, :, :, :, i] / data_config[i] | |||
length = space_data.shape[0] // 10 | |||
test_data = space_data[:length] | |||
train_data = space_data[length:] | |||
space_data = space_data.transpose((0, 4, 1, 2, 3)) | |||
np.save(DATA_CONFIG_PATH, data_config) | |||
np.save(SAVE_DATA_PATH+"/train_data.npy", train_data) | |||
np.save(SAVE_DATA_PATH+"/test_data.npy", test_data) | |||
print("AE train and test data and data config is saved") | |||
def create_dataset(input_path, label_path, batch_size=8, shuffle=True): | |||
electromagnetic = ExistedDataConfig(name="electromagnetic", | |||
data_dir=[input_path, label_path], | |||
columns_list=["inputs", "label"], | |||
data_format=input_path.split('.')[-1]) | |||
dataset = Dataset(existed_data_list=[electromagnetic]) | |||
data_loader = dataset.create_dataset(batch_size=batch_size, shuffle=shuffle) | |||
return data_loader | |||
if __name__ == "__main__": | |||
generate_data(INPUT_PATH) |
@@ -0,0 +1,30 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""learning rate generator""" | |||
def step_lr_generator(step_size, epochs, lr, lr_decay_milestones): | |||
"""generate step decayed learnig rate""" | |||
total_steps = epochs * step_size | |||
milestones = [int(total_steps * i / lr_decay_milestones) for i in range(1, lr_decay_milestones)] | |||
milestones.append(total_steps) | |||
learning_rates = [lr*0.5**i for i in range(0, lr_decay_milestones - 1)] | |||
learning_rates.append(lr*0.5**(lr_decay_milestones - 1)) | |||
print("total_steps: %s, milestones: %s, learning_rates: %s " %(total_steps, milestones, learning_rates)) | |||
return milestones, learning_rates |
@@ -0,0 +1,49 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""metrics""" | |||
import numpy as np | |||
import mindspore.nn as nn | |||
from mindspore.ops import functional as F | |||
class MyMSELoss(nn.LossBase): | |||
"""mse loss function""" | |||
def construct(self, base, target): | |||
bs, _, _, _, _ = F.shape(target) | |||
x = F.square(base - target) | |||
return 2*bs*self.get_loss(x) | |||
class EvalMetric(nn.Metric): | |||
"""eval metric""" | |||
def __init__(self, length): | |||
super(EvalMetric, self).__init__() | |||
self.clear() | |||
self.length = length | |||
def clear(self): | |||
self.error_sum_l2_error = 0 | |||
self.error_mean_l1_error = 0 | |||
def update(self, *inputs): | |||
test_predict = self._convert_data(inputs[0]) | |||
test_label = self._convert_data(inputs[1]) | |||
for i in range(len(test_label)): | |||
self.error_mean_l1_error += np.mean(np.abs(test_label[i] - test_predict[i])) | |||
def eval(self): | |||
return {'mean_l1_error': self.error_mean_l1_error / self.length} |
@@ -0,0 +1,298 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""EncoderDecoder for MindElec.""" | |||
import mindspore.nn as nn | |||
import mindspore.ops as ops | |||
class EncoderDecoder(nn.Cell): | |||
""" | |||
EncoderDecoder architecture for MindElec. | |||
Args: | |||
input_dim (int): input channel. | |||
target_shape (list or tuple): Output DWH shape. | |||
base_channels (int): base channel, all intermediate layers' channels are multiple of this value. | |||
decoding (bool): Enable Decoder, True if the reconstructed input is need, | |||
for example, the training. Default: False | |||
Returns: | |||
Tensor, output tensor, compressed encodings (encoding=False) or reconstructed input (encoding=True). | |||
Examples: | |||
>>> training encoder: Encoder_Decoder(input_dim=4, target_shape=[25, 50, 25], base_channels=8, decoding=True) | |||
>>> applying encoder(data compression): Encoder_Decoder(input_dim=4, | |||
... target_shape=[25, 50, 25], | |||
... base_channels=8, | |||
... decoding=False) | |||
""" | |||
def __init__(self, input_dim, target_shape, base_channels=8, decoding=False): | |||
super(EncoderDecoder, self).__init__() | |||
self.decoding = decoding | |||
self.encoder = Encoder(input_dim, base_channels) | |||
if self.decoding: | |||
self.decoder = Decoder(input_dim, target_shape, base_channels) | |||
def construct(self, x): | |||
encoding = self.encoder(x) | |||
if self.decoding: | |||
output = self.decoder(encoding) | |||
else: | |||
output = encoding | |||
return output | |||
class Encoder(nn.Cell): | |||
""" | |||
Encoder architecture. | |||
Args: | |||
input_dim (int): input channel. | |||
base_channels (int): base channel, all intermediate layers' channels are multiple of this value. | |||
Returns: | |||
Tensor, output tensor. | |||
Examples: | |||
>>> Encoder(input_dim=4, base_channels=8) | |||
""" | |||
def __init__(self, input_dim, base_channels): | |||
super(Encoder, self).__init__() | |||
print("BASE_CHANNELS: %d" %base_channels) | |||
self.input_dim = input_dim | |||
self.channels = base_channels | |||
self.conv0 = nn.Conv3d(self.input_dim, self.channels*2, kernel_size=3, pad_mode='pad', padding=1) | |||
self.conv0_1 = nn.Conv3d(self.channels*2, self.channels*2, kernel_size=3, pad_mode='pad', padding=1) | |||
self.conv1 = nn.Conv3d(self.channels*2, self.channels*4, kernel_size=3, pad_mode='pad', padding=1) | |||
self.conv1_1 = nn.Conv3d(self.channels*4, self.channels*4, kernel_size=3, pad_mode='pad', padding=1) | |||
self.conv2 = nn.Conv3d(self.channels*4, self.channels*8, kernel_size=3, pad_mode='pad', padding=1) | |||
self.conv2_1 = nn.Conv3d(self.channels*8, self.channels*8, kernel_size=3, pad_mode='pad', padding=1) | |||
self.conv3 = nn.Conv3d(self.channels*8, self.channels*16, kernel_size=(2, 3, 2)) | |||
self.conv4 = nn.Conv2d(self.channels*16, self.channels*32, kernel_size=(1, 3), pad_mode='pad', padding=0) | |||
self.bn0 = nn.BatchNorm3d(self.channels*2) | |||
self.bn1 = nn.BatchNorm3d(self.channels*4) | |||
self.bn2 = nn.BatchNorm3d(self.channels*8) | |||
self.bn3 = nn.BatchNorm3d(self.channels*16) | |||
self.bn4 = nn.BatchNorm2d(self.channels*32) | |||
self.down1 = ops.MaxPool3D(kernel_size=(2, 2, 2), strides=(2, 2, 2)) | |||
self.down2 = ops.MaxPool3D(kernel_size=(2, 2, 2), strides=(2, 2, 2)) | |||
self.down3 = ops.MaxPool3D(kernel_size=(2, 2, 2), strides=(2, 2, 2)) | |||
self.down4 = nn.MaxPool2d(kernel_size=(2, 6), stride=(2, 6)) | |||
self.down_1_1 = ops.MaxPool3D(kernel_size=(4, 5, 4), strides=(4, 5, 4)) | |||
self.down_1 = nn.MaxPool2d(kernel_size=(3, 5*3)) | |||
self.down_2_1 = ops.MaxPool3D(kernel_size=(3, 4, 3), strides=(3, 4, 3)) | |||
self.down_2 = nn.MaxPool2d(kernel_size=(2, 3*2)) | |||
self.down_3 = nn.MaxPool2d(kernel_size=(3, 18)) | |||
self.act = nn.Sigmoid() | |||
self.concat = ops.Concat(axis=1) | |||
self.expand_dims = ops.ExpandDims() | |||
def construct(self, x): | |||
"""forward""" | |||
bs = x.shape[0] | |||
x = self.conv0(x) | |||
x = self.conv0_1(x) | |||
x = self.bn0(x) | |||
x = self.act(x) | |||
x = self.down1(x) | |||
x_1 = self.down_1_1(x) | |||
x_1 = self.down_1(x_1.view(bs, x_1.shape[1], x_1.shape[2], -1)) | |||
x = self.conv1(x) | |||
x = self.conv1_1(x) | |||
x = self.bn1(x) | |||
x = self.act(x) | |||
x = self.down2(x) | |||
x_2 = self.down_2_1(x) | |||
x_2 = self.down_2(x_2.view(bs, x_2.shape[1], x_2.shape[2], -1)) | |||
x = self.conv2(x) | |||
x = self.conv2_1(x) | |||
x = self.bn2(x) | |||
x = self.act(x) | |||
x = self.down3(x) | |||
x_3 = self.down_3(x.view(bs, x.shape[1], x.shape[2], -1)) | |||
x = self.act(self.bn3(self.conv3(x))) | |||
x = x.view((bs, x.shape[1], x.shape[2], -1)) | |||
x = self.down4(x) | |||
x = self.act(self.bn4(self.conv4(x))) | |||
x = self.concat((x, x_1, x_2, x_3)) | |||
x = self.expand_dims(x, 3) | |||
return x | |||
class Decoder(nn.Cell): | |||
""" | |||
Decoder architecture. | |||
Args: | |||
output_dim (int): output channel. | |||
target_shape (list or tuple): Output DWH shape. | |||
base_channels (int): base channel, all intermediate layers' channels are multiple of this value. | |||
Returns: | |||
Tensor, output tensor. | |||
Examples: | |||
>>> Decoder(output_dim=4, target_shape=[25, 50, 25], base_channels=8) | |||
""" | |||
def __init__(self, output_dim, target_shape, base_channels): | |||
super(Decoder, self).__init__() | |||
self.output_dim = output_dim | |||
self.base_channels = base_channels | |||
self.up0 = Up((32 + 8 + 4 + 2) * self.base_channels, | |||
self.base_channels * 32, | |||
[1, 1, 1], | |||
[x // 8 for x in target_shape], | |||
pad=True) | |||
self.up1 = Up(self.base_channels * 32, | |||
self.base_channels * 16, | |||
[x // 8 for x in target_shape], | |||
[x // 4 for x in target_shape], | |||
pad=False) | |||
self.up2 = Up(self.base_channels * 16, | |||
self.base_channels* 4, | |||
[x // 4 for x in target_shape], | |||
[x // 2 for x in target_shape], | |||
pad=False) | |||
self.up3 = Up(self.base_channels * 4, | |||
self.output_dim, | |||
[x // 2 for x in target_shape], | |||
target_shape, | |||
pad=False) | |||
def construct(self, x): | |||
x = self.up0(x) | |||
x = self.up1(x) | |||
x = self.up2(x) | |||
x = self.up3(x) | |||
return x | |||
class DoubleConvTranspose(nn.Cell): | |||
""" | |||
DoubleConvTranspose architecture | |||
Args: | |||
input_dim (int): Input channel. | |||
out_channel (int): Output channel. | |||
mid_channels (int): Mid channels. Default: None. | |||
Returns: | |||
Tensor, output tensor. | |||
Examples: | |||
>>> DoubleConvTranspose(input_dim=4, out_channels=8) | |||
""" | |||
def __init__(self, input_dim, out_channels, mid_channels=None): | |||
super(DoubleConvTranspose, self).__init__() | |||
if not mid_channels: | |||
mid_channels = out_channels | |||
self.conv = nn.Conv3dTranspose(input_dim, out_channels, kernel_size=3) | |||
self.conv1 = nn.Conv3dTranspose(out_channels, out_channels, kernel_size=3) | |||
self.bn = nn.BatchNorm3d(out_channels) | |||
self.relu = nn.ReLU() | |||
self.act = nn.Sigmoid() | |||
def construct(self, x): | |||
x = self.conv(x) | |||
x = self.conv1(x) | |||
x = self.bn(x) | |||
x = self.act(x) | |||
return x | |||
def calculate_difference(input_shape, target_shape): | |||
"""calculate difference of the target shape and the output shape of previous Conv3dTranspose | |||
and the according padding sizes""" | |||
target_shape = target_shape | |||
# calculating output shape of Conv3dTranspose | |||
input_shape = [(dim - 1)*2 - 2*0 + 1*(2 - 1) + 0 + 1 for dim in input_shape] | |||
diff = [x - y for x, y in zip(target_shape, input_shape)] | |||
paddings = [(0, 0), (0, 0)] | |||
for i in range(3): | |||
paddings.append((diff[i], 0)) | |||
return tuple(diff), tuple(paddings) | |||
class Up(nn.Cell): | |||
""" | |||
Upscaling then apply Conv3dTranspose to decode | |||
Args: | |||
in_channel (int): Input channel. | |||
out_channel (int): Output channel. | |||
input_shape (list or tuple): Input DWH shape. Default: None. | |||
target_shape (list or tuple): Output DWH shape. Default: None. | |||
pad (bool): Enable manual padding, only needed in the first layer of the Decoder. Default: True. | |||
Returns: | |||
Tensor, output tensor. | |||
Examples: | |||
>>> Up(8,4, input_shape=[10, 10, 10], target_shape=[20, 20, 20], pad=False) | |||
""" | |||
def __init__(self, input_dim, out_channels, input_shape=None, target_shape=None, pad=True): | |||
super(Up, self).__init__() | |||
self.pad = pad | |||
diff, paddings = calculate_difference(input_shape, target_shape) | |||
if self.pad: | |||
self.pad = ops.Pad(paddings) | |||
self.up = nn.Conv3dTranspose(input_dim, input_dim // 2, kernel_size=2, stride=2, pad_mode='pad', padding=0) | |||
else: | |||
self.up = nn.Conv3dTranspose(input_dim, | |||
input_dim // 2, | |||
kernel_size=2, | |||
stride=2, | |||
pad_mode='pad', | |||
padding=0, | |||
output_padding=diff) | |||
self.conv = DoubleConvTranspose(input_dim // 2, out_channels) | |||
def construct(self, x): | |||
x = self.up(x) | |||
if self.pad: | |||
x = self.pad(x) | |||
x = self.conv(x) | |||
return x |
@@ -0,0 +1,137 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""ae_train""" | |||
import os | |||
import time | |||
import pytest | |||
import mindspore.nn as nn | |||
import mindspore.common.initializer as weight_init | |||
from mindspore import context | |||
from mindspore.common import set_seed | |||
from mindspore.train.callback import LossMonitor, Callback | |||
from mindelec.solver import Solver | |||
from src.dataset import create_dataset | |||
from src.model import EncoderDecoder | |||
from src.lr_generator import step_lr_generator | |||
from src.metric import MyMSELoss, EvalMetric | |||
from src.config import config | |||
train_data_path = "/home/workspace/mindspore_dataset/mindelec_data/ae_data/train_data.npy" | |||
test_data_path = "/home/workspace/mindspore_dataset/mindelec_data/ae_data/test_data.npy" | |||
set_seed(0) | |||
print("pid:", os.getpid()) | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
class TimeMonitor(Callback): | |||
""" | |||
Monitor the time in training. | |||
""" | |||
def __init__(self, data_size=None): | |||
super(TimeMonitor, self).__init__() | |||
self.data_size = data_size | |||
self.epoch_time = time.time() | |||
self.per_step_time = 0 | |||
self._tmp = None | |||
def epoch_begin(self, run_context): | |||
""" | |||
Record time at the begin of epoch. | |||
""" | |||
self.epoch_time = time.time() | |||
self._tmp = run_context | |||
def epoch_end(self, run_context): | |||
""" | |||
Print process cost time at the end of epoch. | |||
""" | |||
epoch_seconds = (time.time() - self.epoch_time) * 1000 | |||
step_size = self.data_size | |||
cb_params = run_context.original_args() | |||
if hasattr(cb_params, "batch_num"): | |||
batch_num = cb_params.batch_num | |||
if isinstance(batch_num, int) and batch_num > 0: | |||
step_size = cb_params.batch_num | |||
self.per_step_time = epoch_seconds / step_size | |||
print("epoch time: {:5.3f} ms, per step time: {:5.3f} ms".format(epoch_seconds, self.per_step_time), flush=True) | |||
def get_step_time(self,): | |||
return self.per_step_time | |||
def init_weight(net): | |||
for _, cell in net.cells_and_names(): | |||
if isinstance(cell, (nn.Conv3d, nn.Dense)): | |||
cell.weight.set_data(weight_init.initializer(weight_init.HeNormal(), | |||
cell.weight.shape, | |||
cell.weight.dtype)) | |||
@pytest.mark.level1 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_auto_encoder(): | |||
"""training""" | |||
model_net = EncoderDecoder(config["input_channels"], config["patch_shape"], config["base_channels"], decoding=True) | |||
init_weight(net=model_net) | |||
train_dataset = create_dataset(input_path=train_data_path, | |||
label_path=train_data_path, | |||
batch_size=config["batch_size"], | |||
shuffle=True) | |||
eval_dataset = create_dataset(input_path=test_data_path, | |||
label_path=test_data_path, | |||
batch_size=config["batch_size"], | |||
shuffle=False) | |||
step_size = train_dataset.get_dataset_size() | |||
milestones, learning_rates = step_lr_generator(step_size, | |||
config["epochs"], | |||
config["lr"], | |||
config["lr_decay_milestones"]) | |||
optimizer = nn.Adam(model_net.trainable_params(), | |||
learning_rate=nn.piecewise_constant_lr(milestones, learning_rates)) | |||
loss_net = MyMSELoss() | |||
eval_step_size = eval_dataset.get_dataset_size() * config["batch_size"] | |||
evl_error_mrc = EvalMetric(eval_step_size) | |||
solver = Solver(model_net, | |||
train_input_map={'train': ['train_input_data']}, | |||
test_input_map={'test': ['test_input_data']}, | |||
optimizer=optimizer, | |||
metrics={'evl_mrc': evl_error_mrc,}, | |||
amp_level="O2", | |||
loss_fn=loss_net) | |||
time_cb = TimeMonitor() | |||
solver.model.train(20, train_dataset, callbacks=[LossMonitor(), time_cb], dataset_sink_mode=False) | |||
res = solver.model.eval(eval_dataset, dataset_sink_mode=False) | |||
per_step_time = time_cb.get_step_time() | |||
l1_error = res['evl_mrc']['mean_l1_error'] | |||
print('test_res:', f'l1_error: {l1_error:.10f} ') | |||
print(f'per step time: {per_step_time:.10f} ') | |||
assert l1_error <= 0.03 | |||
assert per_step_time <= 30 |
@@ -0,0 +1,124 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" | |||
call back functions | |||
""" | |||
import time | |||
import numpy as np | |||
from mindspore.train.callback import Callback | |||
from mindspore import Tensor | |||
import mindspore.common.dtype as mstype | |||
class PredictCallback(Callback): | |||
""" | |||
Evaluate the model during training. | |||
Args: | |||
model (Cell): A testing network. | |||
predict_ds (Dataset): Dataset to predictuate the model. | |||
predict_interval (int): Specifies how many epochs to train before prediction. | |||
input_data (Array): Input test dataset. | |||
label (Array): Label data. | |||
batch_size (int): batch size for prediction | |||
""" | |||
def __init__(self, model, predict_interval, input_data, label, batch_size=8192): | |||
super(PredictCallback, self).__init__() | |||
self.model = model | |||
self.input_data = input_data | |||
self.label = label | |||
self.predict_interval = predict_interval | |||
self.batch_size = min(batch_size, len(input_data)) | |||
self.l2_error = 1.0 | |||
print("check test dataset shape: {}, {}".format(self.input_data.shape, self.label.shape)) | |||
def epoch_end(self, run_context): | |||
""" | |||
Evaluate the model at the end of epoch. | |||
Args: | |||
run_context (RunContext): Context of the train running. | |||
""" | |||
cb_params = run_context.original_args() | |||
if cb_params.cur_epoch_num % self.predict_interval == 0: | |||
print("================================Start Evaluation================================") | |||
test_input_data = self.input_data.reshape(-1, 2) | |||
label = self.label.reshape(-1, 1) | |||
index = 0 | |||
prediction = np.zeros(label.shape) | |||
time_beg = time.time() | |||
while index < len(test_input_data): | |||
index_end = min(index + self.batch_size, len(test_input_data)) | |||
test_batch = Tensor(test_input_data[index: index_end, :], dtype=mstype.float32) | |||
predict = self.model(test_batch) | |||
predict = predict.asnumpy() | |||
prediction[index: index_end, :] = predict[:, :] | |||
index = index_end | |||
print("Total prediction time: {} s".format(time.time() - time_beg)) | |||
error = label - prediction | |||
l2_error = np.sqrt(np.sum(np.square(error[:, 0]))) / np.sqrt(np.sum(np.square(label[:, 0]))) | |||
print("l2_error: ", l2_error) | |||
print("=================================End Evaluation=================================") | |||
self.l2_error = l2_error | |||
def get_l2_error(self): | |||
return self.l2_error | |||
class TimeMonitor(Callback): | |||
""" | |||
Monitor the time in training. | |||
Args: | |||
data_size (int): Iteration steps to run one epoch of the whole dataset. | |||
""" | |||
def __init__(self, data_size=None): | |||
super(TimeMonitor, self).__init__() | |||
self.data_size = data_size | |||
self.epoch_time = time.time() | |||
self.per_step_time = 0 | |||
def epoch_begin(self, run_context): | |||
""" | |||
Set begin time at the beginning of epoch. | |||
Args: | |||
run_context (RunContext): Context of the train running. | |||
""" | |||
run_context.original_args() | |||
self.epoch_time = time.time() | |||
def epoch_end(self, run_context): | |||
""" | |||
Print process cost time at the end of epoch. | |||
""" | |||
epoch_seconds = (time.time() - self.epoch_time) * 1000 | |||
step_size = self.data_size | |||
cb_params = run_context.original_args() | |||
if hasattr(cb_params, "batch_num"): | |||
batch_num = cb_params.batch_num | |||
if isinstance(batch_num, int) and batch_num > 0: | |||
step_size = cb_params.batch_num | |||
self.per_step_time = epoch_seconds / step_size | |||
print("epoch time: {:5.3f} ms, per step time: {:5.3f} ms".format(epoch_seconds, self.per_step_time), flush=True) | |||
def get_step_time(self): | |||
return self.per_step_time |
@@ -0,0 +1,44 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" | |||
network config setting, will be used in train.py and eval.py | |||
""" | |||
from easydict import EasyDict as ed | |||
# config | |||
rectangle_sampling_config = ed({ | |||
'domain': ed({ | |||
'random_sampling': False, | |||
'size': [100, 100], | |||
}), | |||
'BC': ed({ | |||
'random_sampling': True, | |||
'size': 128, | |||
'with_normal': False, | |||
}) | |||
}) | |||
# config | |||
helmholtz_2d_config = ed({ | |||
"name": "Helmholtz2D", | |||
"columns_list": ["input", "label"], | |||
"epochs": 10, | |||
"batch_size": 128, | |||
"lr": 0.001, | |||
"coord_min": [0.0, 0.0], | |||
"coord_max": [1.0, 1.0], | |||
"axis_size": 101, | |||
"wave_number": 2 | |||
}) |
@@ -0,0 +1,42 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================== | |||
""" | |||
dataset | |||
""" | |||
import numpy as np | |||
# prepare test input and label | |||
def test_data_prepare(config): | |||
"""create test dataset""" | |||
coord_min = config["coord_min"] | |||
coord_max = config["coord_max"] | |||
axis_size = config["axis_size"] | |||
wave_number = config.get("wave_number", 2.0) | |||
# input | |||
axis_x = np.linspace(coord_min[0], coord_max[0], num=axis_size, endpoint=True) | |||
axis_y = np.linspace(coord_min[1], coord_max[1], num=axis_size, endpoint=True) | |||
mesh_x, mesh_y = np.meshgrid(axis_y, axis_x) | |||
input_data = np.hstack((mesh_x.flatten()[:, None], mesh_y.flatten()[:, None])).astype(np.float32) | |||
# label | |||
label = np.zeros((axis_size, axis_size, 1)) | |||
for i in range(axis_size): | |||
for j in range(axis_size): | |||
label[i, j, 0] = np.sin(wave_number * axis_x[j]) | |||
label = label.reshape(-1, 1).astype(np.float32) | |||
return input_data, label |
@@ -0,0 +1,55 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================== | |||
""" | |||
feedforward neural network | |||
""" | |||
import mindspore.nn as nn | |||
from mindelec.architecture import get_activation, LinearBlock | |||
class FFNN(nn.Cell): | |||
""" | |||
Full-connect networks. | |||
Args: | |||
input_dim (int): the input dimensions. | |||
output_dim (int): the output dimensions. | |||
hidden_layer (int): number of hidden layers. | |||
activation (str or Cell): activation functions. | |||
""" | |||
def __init__(self, input_dim, output_dim, hidden_layer=64, activation="sin"): | |||
super(FFNN, self).__init__() | |||
self.activation = get_activation(activation) | |||
self.fc1 = LinearBlock(input_dim, hidden_layer) | |||
self.fc2 = LinearBlock(hidden_layer, hidden_layer) | |||
self.fc3 = LinearBlock(hidden_layer, hidden_layer) | |||
self.fc4 = LinearBlock(hidden_layer, hidden_layer) | |||
self.fc5 = LinearBlock(hidden_layer, output_dim) | |||
def construct(self, *inputs): | |||
"""fc network""" | |||
x = inputs[0] | |||
out = self.fc1(x) | |||
out = self.activation(out) | |||
out = self.fc2(out) | |||
out = self.activation(out) | |||
out = self.fc3(out) | |||
out = self.activation(out) | |||
out = self.fc4(out) | |||
out = self.activation(out) | |||
out = self.fc5(out) | |||
return out |
@@ -0,0 +1,144 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================== | |||
""" | |||
train | |||
""" | |||
import os | |||
import pytest | |||
import numpy as np | |||
import mindspore.nn as nn | |||
import mindspore.ops as ops | |||
from mindspore import context, ms_function | |||
from mindspore.common import set_seed | |||
from mindspore.train.callback import LossMonitor | |||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager | |||
from mindelec.solver import Solver, Problem | |||
from mindelec.geometry import Rectangle, create_config_from_edict | |||
from mindelec.common import L2 | |||
from mindelec.data import Dataset | |||
from mindelec.operators import SecondOrderGrad as Hessian | |||
from mindelec.loss import Constraints | |||
from src.config import rectangle_sampling_config, helmholtz_2d_config | |||
from src.model import FFNN | |||
from src.dataset import test_data_prepare | |||
from src.callback import PredictCallback, TimeMonitor | |||
set_seed(0) | |||
np.random.seed(0) | |||
print("pid:", os.getpid()) | |||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend") | |||
# define problem | |||
class Helmholtz2D(Problem): | |||
"""2D Helmholtz equation""" | |||
def __init__(self, domain_name, bc_name, net, wavenumber=2): | |||
super(Helmholtz2D, self).__init__() | |||
self.domain_name = domain_name | |||
self.bc_name = bc_name | |||
self.type = "Equation" | |||
self.wave_number = wavenumber | |||
self.grad_xx = Hessian(net, input_idx1=0, input_idx2=0, output_idx=0) | |||
self.grad_yy = Hessian(net, input_idx1=1, input_idx2=1, output_idx=0) | |||
self.reshape = ops.Reshape() | |||
@ms_function | |||
def governing_equation(self, *output, **kwargs): | |||
"""governing equation""" | |||
u = output[0] | |||
x = kwargs[self.domain_name][:, 0] | |||
y = kwargs[self.domain_name][:, 1] | |||
x = self.reshape(x, (-1, 1)) | |||
y = self.reshape(y, (-1, 1)) | |||
u_xx = self.grad_xx(kwargs[self.domain_name]) | |||
u_yy = self.grad_yy(kwargs[self.domain_name]) | |||
return u_xx + u_yy + self.wave_number**2 * u | |||
@ms_function | |||
def boundary_condition(self, *output, **kwargs): | |||
"""boundary condition""" | |||
u = output[0] | |||
x = kwargs[self.bc_name][:, 0] | |||
y = kwargs[self.bc_name][:, 1] | |||
x = self.reshape(x, (-1, 1)) | |||
y = self.reshape(y, (-1, 1)) | |||
test_label = ops.sin(self.wave_number * x) | |||
return 100 * (u - test_label) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_frequency_domain_maxwell(): | |||
"""train process""" | |||
net = FFNN(input_dim=2, output_dim=1, hidden_layer=64) | |||
# define geometry | |||
geom_name = "rectangle" | |||
rect_space = Rectangle(geom_name, | |||
coord_min=helmholtz_2d_config["coord_min"], | |||
coord_max=helmholtz_2d_config["coord_max"], | |||
sampling_config=create_config_from_edict(rectangle_sampling_config)) | |||
geom_dict = {rect_space: ["domain", "BC"]} | |||
# create dataset for train and test | |||
train_dataset = Dataset(geom_dict) | |||
train_data = train_dataset.create_dataset(batch_size=helmholtz_2d_config.get("batch_size", 128), | |||
shuffle=True, drop_remainder=False) | |||
test_input, test_label = test_data_prepare(helmholtz_2d_config) | |||
# define problem and constraints | |||
train_prob_dict = {geom_name: Helmholtz2D(domain_name=geom_name + "_domain_points", | |||
bc_name=geom_name + "_BC_points", | |||
net=net, | |||
wavenumber=helmholtz_2d_config.get("wavenumber", 2)), | |||
} | |||
train_constraints = Constraints(train_dataset, train_prob_dict) | |||
# optimizer | |||
optim = nn.Adam(net.trainable_params(), learning_rate=helmholtz_2d_config.get("lr", 1e-4)) | |||
# solver | |||
solver = Solver(net, | |||
optimizer=optim, | |||
mode="PINNs", | |||
train_constraints=train_constraints, | |||
test_constraints=None, | |||
amp_level="O2", | |||
metrics={'l2': L2(), 'distance': nn.MAE()}, | |||
loss_scale_manager=DynamicLossScaleManager() | |||
) | |||
# train | |||
time_cb = TimeMonitor() | |||
loss_cb = PredictCallback(model=net, predict_interval=10, input_data=test_input, label=test_label) | |||
solver.train(epoch=helmholtz_2d_config.get("epochs", 10), | |||
train_dataset=train_data, | |||
callbacks=[time_cb, LossMonitor(), loss_cb]) | |||
per_step_time = time_cb.get_step_time() | |||
l2_error = loss_cb.get_l2_error() | |||
print(f'l2 error: {l2_error:.10f}') | |||
print(f'per step time: {per_step_time:.10f}') | |||
assert l2_error <= 0.05 | |||
assert per_step_time <= 10.0 |
@@ -0,0 +1,31 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" | |||
network config setting, will be used in train.py and eval.py | |||
""" | |||
from easydict import EasyDict as ed | |||
# config | |||
config = ed({ | |||
"epochs": 500, | |||
"batch_size": 8, | |||
"lr": 0.0001, | |||
"t_solution": 162, | |||
"x_solution": 50, | |||
"y_solution": 50, | |||
"z_solution": 8, | |||
"save_checkpoint_epochs": 5, | |||
"keep_checkpoint_max": 20 | |||
}) |
@@ -0,0 +1,35 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================== | |||
""" | |||
dataset | |||
""" | |||
import numpy as np | |||
from mindelec.data import Dataset, ExistedDataConfig | |||
def create_dataset(data_path, batch_size=8, shuffle=True, drop_remainder=True, is_train=True): | |||
"""create dataset""" | |||
input_path = data_path + "inputs.npy" | |||
label_path = data_path + "label.npy" | |||
electromagnetic = ExistedDataConfig(name="electromagnetic", | |||
data_dir=[input_path, label_path], | |||
columns_list=["inputs", "label"], | |||
data_format="npy") | |||
dataset = Dataset(existed_data_list=[electromagnetic]) | |||
data_loader = dataset.create_dataset(batch_size=batch_size, shuffle=shuffle, drop_remainder=drop_remainder) | |||
scale = None | |||
if not is_train: | |||
scale = np.load(data_path+"scale.npy") | |||
return data_loader, scale |
@@ -0,0 +1,80 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================== | |||
""" | |||
loss | |||
""" | |||
import numpy as np | |||
import mindspore.nn as nn | |||
from mindspore.ops import functional as F | |||
from src.config import config | |||
class MyMSELoss(nn.LossBase): | |||
"""mse loss function""" | |||
def construct(self, base, target): | |||
bs, _, _, _, _ = F.shape(target) | |||
x = F.square(base - target) | |||
return 2*bs*self.get_loss(x) | |||
class EvaLMetric(nn.Metric): | |||
"""eval metric""" | |||
def __init__(self, length, scale, batch_size): | |||
super(EvaLMetric, self).__init__() | |||
self.clear() | |||
self.length = length | |||
self.batch_size = batch_size | |||
self.t = config.t_solution | |||
self.x = config.x_solution | |||
self.y = config.y_solution | |||
self.z = config.z_solution | |||
self.predict_real = np.zeros((self.length*self.t, self.x, self.y, self.z, 6), dtype=np.float32) | |||
self.label_real = np.zeros((self.length*self.t, self.x, self.y, self.z, 6), dtype=np.float32) | |||
self.scale = scale | |||
self.iter_idx = 0 | |||
def clear(self): | |||
"""clear""" | |||
self.iter_idx = 0 | |||
def update(self, *inputs): | |||
"""update""" | |||
y_pred = self._convert_data(inputs[0]) | |||
y = self._convert_data(inputs[1]) | |||
predict, label = y_pred, y | |||
self.predict_real[self.iter_idx*self.batch_size: self.iter_idx*self.batch_size + label.shape[0]] = predict | |||
self.label_real[self.iter_idx*self.batch_size: self.iter_idx*self.batch_size + label.shape[0]] = label | |||
self.iter_idx += 1 | |||
def eval(self): | |||
"""eval""" | |||
predict_real = np.reshape(self.predict_real, (self.length, self.t, self.x, self.y, self.z, 6)) | |||
label_real = np.reshape(self.label_real, (self.length, self.t, self.x, self.y, self.z, 6)) | |||
l2_time = 0.0 | |||
for i in range(self.length): | |||
predict_real_temp = predict_real[i:i+1] | |||
label_real_temp = label_real[i:i+1] | |||
for j in range(self.t): | |||
predict_real_temp[0, j, :, :, :, 0] = predict_real_temp[0, j, :, :, :, 0] * self.scale[0][j] | |||
predict_real_temp[0, j, :, :, :, 1] = predict_real_temp[0, j, :, :, :, 1] * self.scale[1][j] | |||
predict_real_temp[0, j, :, :, :, 2] = predict_real_temp[0, j, :, :, :, 2] * self.scale[2][j] | |||
predict_real_temp[0, j, :, :, :, 3] = predict_real_temp[0, j, :, :, :, 3] * self.scale[3][j] | |||
predict_real_temp[0, j, :, :, :, 4] = predict_real_temp[0, j, :, :, :, 4] * self.scale[4][j] | |||
predict_real_temp[0, j, :, :, :, 5] = predict_real_temp[0, j, :, :, :, 5] * self.scale[5][j] | |||
l2_time += (np.sqrt(np.sum(np.square(label_real_temp - predict_real_temp))) / | |||
np.sqrt(np.sum(np.square(label_real_temp)))) | |||
return {'l2_error': l2_time / self.length} |
@@ -0,0 +1,176 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================== | |||
""" | |||
maxwell 3d model | |||
""" | |||
import mindspore.nn as nn | |||
from mindspore.ops import operations as P | |||
import mindspore.ops.functional as F | |||
from mindelec.architecture import get_activation | |||
class Maxwell3D(nn.Cell): | |||
"""maxwell3d""" | |||
def __init__(self, output_dim): | |||
super(Maxwell3D, self).__init__() | |||
self.output_dim = output_dim | |||
width = 64 | |||
self.net0 = ModelHead(4, width) | |||
self.net1 = ModelHead(4, width) | |||
self.net2 = ModelHead(4, width) | |||
self.net3 = ModelHead(4, width) | |||
self.net4 = ModelHead(4, width) | |||
self.fc0 = nn.Dense(width+33, 128) | |||
self.net = ModelOut(128, output_dim, (2, 2, 1), (2, 2, 1)) | |||
self.cat = P.Concat(axis=-1) | |||
def construct(self, x): | |||
"""forward""" | |||
x_location = x[..., :4] | |||
x_media = x[..., 4:] | |||
out1 = self.net0(x_location) | |||
out2 = self.net1(2*x_location) | |||
out3 = self.net2(4*x_location) | |||
out4 = self.net3(8*x_location) | |||
out5 = self.net4(16.0*x_location) | |||
out = out1 + out2 + out3 + out4 + out5 | |||
out = self.cat((out, x_media)) | |||
out = self.fc0(out) | |||
out = self.net(out) | |||
return out | |||
class ModelHead(nn.Cell): | |||
"""model_head""" | |||
def __init__(self, input_dim, output_dim): | |||
super(ModelHead, self).__init__() | |||
self.output_dim = output_dim | |||
self.fc0 = nn.Dense(input_dim, output_dim) | |||
self.fc1 = nn.Dense(output_dim, output_dim) | |||
self.act0 = get_activation('srelu') | |||
self.act1 = get_activation('srelu') | |||
def construct(self, x): | |||
"""forward""" | |||
x = self.fc0(x) | |||
x = self.act0(x) | |||
x = self.fc1(x) | |||
x = self.act1(x) | |||
return x | |||
class ModelOut(nn.Cell): | |||
"""model out""" | |||
def __init__(self, input_dim, output_dim, kernel_size=2, strides=2): | |||
super(ModelOut, self).__init__() | |||
self.input_dim = input_dim | |||
self.output_dim = output_dim | |||
self.base_channels = 64 | |||
self.inc = DoubleConv(self.input_dim, self.base_channels) | |||
self.down1 = Down(self.base_channels, self.base_channels * 2, kernel_size, strides) | |||
self.down2 = Down(self.base_channels * 2, self.base_channels * 4, kernel_size, strides) | |||
self.down3 = Down(self.base_channels * 4, self.base_channels * 8, kernel_size, strides) | |||
self.down4 = Down(self.base_channels * 8, self.base_channels * 16, kernel_size, strides) | |||
self.up1 = Up(self.base_channels * 16, self.base_channels * 8, kernel_size, strides) | |||
self.up2 = Up(self.base_channels * 8, self.base_channels * 4, kernel_size, strides) | |||
self.up3 = Up(self.base_channels * 4, self.base_channels * 2, kernel_size, strides) | |||
self.up4 = Up(self.base_channels * 2, self.base_channels, kernel_size, strides) | |||
self.fc1 = nn.Dense(self.base_channels+128, 64) | |||
self.fc2 = nn.Dense(64, output_dim) | |||
self.relu = nn.ReLU() | |||
self.transpose = P.Transpose() | |||
self.cat = P.Concat(axis=1) | |||
def construct(self, x): | |||
"""forward""" | |||
x0 = self.transpose(x, (0, 4, 1, 2, 3)) | |||
x1 = self.inc(x0) | |||
x2 = self.down1(x1) | |||
x3 = self.down2(x2) | |||
x4 = self.down3(x3) | |||
x5 = self.down4(x4) | |||
x = self.up1(x5, x4) | |||
x = self.up2(x, x3) | |||
x = self.up3(x, x2) | |||
x = self.up4(x, x1) | |||
x = self.cat((x, x0)) | |||
x = self.transpose(x, (0, 2, 3, 4, 1)) | |||
x = self.fc1(x) | |||
x = self.relu(x) | |||
x = self.fc2(x) | |||
return x | |||
class DoubleConv(nn.Cell): | |||
"""double conv""" | |||
def __init__(self, input_dim, out_channels, mid_channels=None): | |||
super().__init__() | |||
if not mid_channels: | |||
mid_channels = out_channels | |||
self.double_conv = nn.SequentialCell( | |||
nn.Conv3d(input_dim, mid_channels, kernel_size=3), | |||
nn.BatchNorm3d(mid_channels), | |||
nn.ReLU(), | |||
nn.Conv3d(mid_channels, out_channels, kernel_size=3), | |||
nn.BatchNorm3d(out_channels), | |||
nn.ReLU() | |||
) | |||
def construct(self, x): | |||
"""forward""" | |||
return self.double_conv(x) | |||
class Down(nn.Cell): | |||
"""down""" | |||
def __init__(self, input_dim, out_channels, kernel_size=2, strides=2): | |||
super().__init__() | |||
self.conv = DoubleConv(input_dim, out_channels) | |||
self.maxpool = P.MaxPool3D(kernel_size=kernel_size, strides=strides) | |||
def construct(self, x): | |||
"""forward""" | |||
x = self.maxpool(x) | |||
return self.conv(x) | |||
class Up(nn.Cell): | |||
"""up""" | |||
def __init__(self, input_dim, out_channels, kernel_size=2, strides=2): | |||
super().__init__() | |||
self.up = nn.Conv3dTranspose(input_dim, input_dim // 2, kernel_size=kernel_size, stride=strides) | |||
self.conv = DoubleConv(input_dim, out_channels) | |||
self.cat = P.Concat(axis=1) | |||
def construct(self, x1, x2): | |||
"""forward""" | |||
x1 = self.up(x1) | |||
_, _, h1, w1, c1 = F.shape(x1) | |||
_, _, h2, w2, c2 = F.shape(x2) | |||
diff_z = c2 - c1 | |||
diff_y = w2 - w1 | |||
diff_x = h2 - h1 | |||
x1 = P.Pad(((0, 0), (0, 0), (diff_x // 2, diff_x - diff_x // 2), (diff_y // 2, diff_y - diff_y // 2), | |||
(diff_z // 2, diff_z - diff_z // 2)))(x1) | |||
x = self.cat((x2, x1)) | |||
return self.conv(x) |
@@ -0,0 +1,37 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================== | |||
""" | |||
sample fake data for train and test | |||
""" | |||
import os | |||
import numpy as np | |||
def generate_data(train_path, test_path): | |||
"""generate fake data""" | |||
if not os.path.exists(train_path): | |||
os.makedirs(train_path) | |||
if not os.path.exists(test_path): | |||
os.makedirs(test_path) | |||
inputs = np.ones((162, 50, 50, 8, 37), dtype=np.float32) | |||
label = np.ones((162, 50, 50, 8, 6), dtype=np.float32) | |||
np.save(train_path + "inputs.npy", inputs) | |||
np.save(train_path + "label.npy", label) | |||
scale = np.ones((6, 162), dtype=np.float32) | |||
np.save(test_path + "inputs.npy", inputs) | |||
np.save(test_path + "label.npy", label) | |||
np.save(test_path + "scale.npy", scale) |
@@ -0,0 +1,152 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================== | |||
""" | |||
train | |||
""" | |||
import os | |||
import time | |||
import numpy as np | |||
import pytest | |||
import mindspore.nn as nn | |||
import mindspore.common.initializer as weight_init | |||
from mindspore.common import set_seed | |||
from mindspore import Tensor | |||
from mindspore import context | |||
from mindspore.train.callback import Callback, LossMonitor | |||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager | |||
from mindelec.solver import Solver | |||
from src.dataset import create_dataset | |||
from src.loss import MyMSELoss, EvaLMetric | |||
from src.maxwell_model import Maxwell3D | |||
from src.config import config | |||
from src.sample import generate_data | |||
set_seed(0) | |||
np.random.seed(0) | |||
train_data_path = "./train_data_em/" | |||
test_data_path = "./test_data_em/" | |||
print("pid:", os.getpid()) | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
class TimeMonitor(Callback): | |||
""" | |||
Monitor the time in training. | |||
""" | |||
def __init__(self, data_size=None): | |||
super(TimeMonitor, self).__init__() | |||
self.data_size = data_size | |||
self.epoch_time = time.time() | |||
self.per_step_time = 0 | |||
def epoch_begin(self, run_context): | |||
""" | |||
Record time at the begin of epoch. | |||
""" | |||
run_context.original_args() | |||
self.epoch_time = time.time() | |||
def epoch_end(self, run_context): | |||
""" | |||
Print process cost time at the end of epoch. | |||
""" | |||
epoch_seconds = (time.time() - self.epoch_time) * 1000 | |||
step_size = self.data_size | |||
cb_params = run_context.original_args() | |||
if hasattr(cb_params, "batch_num"): | |||
batch_num = cb_params.batch_num | |||
if isinstance(batch_num, int) and batch_num > 0: | |||
step_size = cb_params.batch_num | |||
self.per_step_time = epoch_seconds / step_size | |||
print("epoch time: {:5.3f} ms, per step time: {:5.3f} ms".format(epoch_seconds, self.per_step_time), flush=True) | |||
def get_step_time(self,): | |||
return self.per_step_time | |||
def get_lr(lr_init, steps_per_epoch, total_epochs): | |||
"""get lr""" | |||
lr_each_step = [] | |||
total_steps = steps_per_epoch * total_epochs | |||
for i in range(total_steps): | |||
epoch = i // steps_per_epoch | |||
lr_local = lr_init | |||
if epoch <= 15: | |||
lr_local = lr_init | |||
elif epoch <= 45: | |||
lr_local = lr_init * 0.5 | |||
elif epoch <= 300: | |||
lr_local = lr_init * 0.25 | |||
elif epoch <= 600: | |||
lr_local = lr_init * 0.125 | |||
lr_each_step.append(lr_local) | |||
learning_rate = np.array(lr_each_step).astype(np.float32) | |||
print(learning_rate) | |||
return learning_rate | |||
def init_weight(net): | |||
"""init weight""" | |||
for _, cell in net.cells_and_names(): | |||
if isinstance(cell, (nn.Conv3d, nn.Dense)): | |||
cell.weight.set_data(weight_init.initializer(weight_init.HeNormal(), | |||
cell.weight.shape, | |||
cell.weight.dtype)) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_full_em(): | |||
"""train""" | |||
generate_data(train_data_path, test_data_path) | |||
train_dataset, _ = create_dataset(train_data_path, batch_size=config.batch_size, shuffle=True) | |||
test_dataset, config_scale = create_dataset(test_data_path, batch_size=config.batch_size, | |||
shuffle=False, drop_remainder=False, is_train=False) | |||
model_net = Maxwell3D(6) | |||
init_weight(net=model_net) | |||
train_step_size = train_dataset.get_dataset_size() | |||
lr = get_lr(config.lr, train_step_size, config.epochs) | |||
optimizer = nn.Adam(model_net.trainable_params(), learning_rate=Tensor(lr)) | |||
loss_net = MyMSELoss() | |||
loss_scale = DynamicLossScaleManager() | |||
tets_step_size = test_dataset.get_dataset_size() | |||
test_batch_size = test_dataset.get_batch_size() | |||
data_length = tets_step_size * test_batch_size // config.t_solution | |||
evl_error_mrc = EvaLMetric(data_length, config_scale, test_batch_size) | |||
solver = Solver(model_net, | |||
optimizer=optimizer, | |||
loss_scale_manager=loss_scale, | |||
amp_level="O2", | |||
keep_batchnorm_fp32=False, | |||
loss_fn=loss_net, | |||
metrics={"evl_mrc": evl_error_mrc}) | |||
time_cb = TimeMonitor() | |||
solver.model.train(5, train_dataset, callbacks=[LossMonitor(), time_cb], dataset_sink_mode=False) | |||
res = solver.model.eval(test_dataset, dataset_sink_mode=False) | |||
per_step_time = time_cb.get_step_time() | |||
l2_s11 = res['evl_mrc']['l2_error'] | |||
print('test_res:', f'l2_error: {l2_s11:.10f} ') | |||
print(f'per step time: {per_step_time:.10f} ') | |||
assert l2_s11 <= 0.05 | |||
assert per_step_time <= 150 |
@@ -0,0 +1,44 @@ | |||
{ | |||
"Description" : [ "PINNs for solve Maxwell's equations" ], | |||
"Case" : "2D_Mur_Src_Gauss_Mscale_MTL_PIAD", | |||
"coord_min" : [0, 0], | |||
"coord_max" : [1, 1], | |||
"src_pos" : [0.4975, 0.4975], | |||
"SrcFrq": 1e+9, | |||
"range_t" : 4e-9, | |||
"input_center": [0.5, 0.5, 2.0e-9], | |||
"input_scale": [2.0, 2.0, 5.0e+8], | |||
"output_scale": [37.67303, 37.67303, 0.1], | |||
"src_radius": 0.01, | |||
"input_size" : 3, | |||
"output_size" : 3, | |||
"residual" : true, | |||
"num_scales" : 4, | |||
"layers" : 7, | |||
"neurons" : 64, | |||
"amp_factor" : 2, | |||
"scale_factor" : 2, | |||
"save_ckpt" : true, | |||
"load_ckpt" : false, | |||
"save_ckpt_path" : "./ckpt", | |||
"load_ckpt_path" : "", | |||
"train_with_eval": true, | |||
"test_data_path" : "./benchmark/", | |||
"lr" : 0.001, | |||
"milestones" : [2000], | |||
"lr_gamma" : 0.25, | |||
"train_epoch" : 50, | |||
"train_batch_size" : 1024, | |||
"test_batch_size" : 8192, | |||
"predict_interval" : 500, | |||
"vision_path" : "./vision", | |||
"summary_path" : "./summary", | |||
"EPS_candidates": [1, 3, 5], | |||
"MU_candidates": [1, 3, 5], | |||
"num_scenarios": 9, | |||
"latent_vector_size": 16, | |||
"latent_reg": 1.0, | |||
"latent_init_std": 1.0 | |||
} |
@@ -0,0 +1,28 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" | |||
init | |||
""" | |||
from .dataset import get_test_data, create_random_dataset | |||
from .maxwell import Maxwell2DMur | |||
from .lr_scheduler import MultiStepLR | |||
__all__ = [ | |||
"create_random_dataset", | |||
"get_test_data", | |||
"Maxwell2DMur", | |||
"MultiStepLR", | |||
] |
@@ -0,0 +1,57 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" | |||
create dataset | |||
""" | |||
import numpy as np | |||
from mindelec.data import Dataset | |||
from mindelec.geometry import Disk, Rectangle, TimeDomain, GeometryWithTime | |||
from mindelec.geometry import create_config_from_edict | |||
from .sampling_config import no_src_sampling_config, src_sampling_config, bc_sampling_config | |||
def get_test_data(test_data_path): | |||
"""load label_dataed data for evaluation""" | |||
# check data | |||
paths = [test_data_path + '/input.npy', test_data_path + '/output.npy'] | |||
input_data = np.load(paths[0]) | |||
label_data = np.load(paths[1]) | |||
return input_data, label_data | |||
def create_random_dataset(config): | |||
"""create training dataset by online sampling""" | |||
radius = config["src_radius"] | |||
origin = config["src_pos"] | |||
disk = Disk("src", origin, radius) | |||
rect = Rectangle("rect", config["coord_min"], config["coord_max"]) | |||
diff = rect - disk | |||
interval = TimeDomain("time", 0.0, config["range_t"]) | |||
no_src = GeometryWithTime(diff, interval) | |||
no_src.set_name("no_src") | |||
no_src.set_sampling_config(create_config_from_edict(no_src_sampling_config)) | |||
src = GeometryWithTime(disk, interval) | |||
src.set_name("src") | |||
src.set_sampling_config(create_config_from_edict(src_sampling_config)) | |||
bc = GeometryWithTime(rect, interval) | |||
bc.set_name("bc") | |||
bc.set_sampling_config(create_config_from_edict(bc_sampling_config)) | |||
geom_dict = {src: ["domain", "IC"], | |||
no_src: ["domain", "IC"], | |||
bc: ["BC"]} | |||
dataset = Dataset(geom_dict) | |||
return dataset |
@@ -0,0 +1,73 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# =========================================================================== | |||
"""Learning rate scheduler.""" | |||
from collections import Counter | |||
import numpy as np | |||
class _LRScheduler: | |||
""" | |||
Basic class for learning rate scheduler | |||
""" | |||
def __init__(self, lr, max_epoch, steps_per_epoch): | |||
self.base_lr = lr | |||
self.steps_per_epoch = steps_per_epoch | |||
self.total_steps = int(max_epoch * steps_per_epoch) | |||
def get_lr(self): | |||
# Compute learning rate using chainable form of the scheduler | |||
raise NotImplementedError | |||
class MultiStepLR(_LRScheduler): | |||
""" | |||
Multi-step learning rate scheduler | |||
Decays the learning rate by gamma once the number of epoch reaches one of the milestones. | |||
Args: | |||
lr (float): Initial learning rate which is the lower boundary in the cycle. | |||
milestones (list): List of epoch indices. Must be increasing. | |||
gamma (float): Multiplicative factor of learning rate decay. | |||
steps_per_epoch (int): The number of steps per epoch to train for. | |||
max_epoch (int): The number of epochs to train for. | |||
Outputs: | |||
numpy.ndarray, shape=(1, steps_per_epoch*max_epoch) | |||
Example: | |||
>>> # Assuming optimizer uses lr = 0.05 for all groups | |||
>>> # lr = 0.05 if epoch < 30 | |||
>>> # lr = 0.005 if 30 <= epoch < 80 | |||
>>> # lr = 0.0005 if epoch >= 80 | |||
>>> scheduler = MultiStepLR(lr=0.1, milestones=[30,80], gamma=0.1, steps_per_epoch=5000, max_epoch=90) | |||
>>> lr = scheduler.get_lr() | |||
""" | |||
def __init__(self, lr, milestones, gamma, steps_per_epoch, max_epoch): | |||
self.milestones = Counter(milestones) | |||
self.gamma = gamma | |||
super(MultiStepLR, self).__init__(lr, max_epoch, steps_per_epoch) | |||
def get_lr(self): | |||
lr_each_step = [] | |||
current_lr = self.base_lr | |||
for i in range(self.total_steps): | |||
cur_ep = i // self.steps_per_epoch | |||
if i % self.steps_per_epoch == 0 and cur_ep in self.milestones: | |||
current_lr = current_lr * self.gamma | |||
lr = current_lr | |||
lr_each_step.append(lr) | |||
return np.array(lr_each_step).astype(np.float32) |
@@ -0,0 +1,184 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
#pylint: disable=W0613 | |||
""" | |||
2D maxwell problem with Mur bc | |||
""" | |||
import numpy as np | |||
import mindspore.numpy as ms_np | |||
from mindspore import ms_function | |||
from mindspore import ops | |||
from mindspore import Tensor | |||
import mindspore.common.dtype as ms_type | |||
from mindelec.solver import Problem | |||
from mindelec.common import MU, EPS, PI | |||
from mindelec.operators import Grad | |||
class Maxwell2DMur(Problem): | |||
r""" | |||
The 2D Maxwell's equations with 2nd-order Mur absorbed boundary condition. | |||
Args: | |||
network (Cell): The solving network. | |||
config (dict): Setting information. | |||
domain_column (str): The corresponding column name of data which governed by maxwell's equation. | |||
bc_column (str): The corresponding column name of data which governed by boundary condition. | |||
bc_normal (str): The column name of normal direction vector corresponding to specified boundary. | |||
ic_column (str): The corresponding column name of data which governed by initial condition. | |||
""" | |||
def __init__(self, network, config, domain_column=None, bc_column=None, ic_column=None): | |||
super(Maxwell2DMur, self).__init__() | |||
self.domain_column = domain_column | |||
self.bc_column = bc_column | |||
self.ic_column = ic_column | |||
self.network = network | |||
# operations | |||
self.gradient = Grad(self.network) | |||
self.reshape = ops.Reshape() | |||
self.cast = ops.Cast() | |||
self.mul = ops.Mul() | |||
self.cast = ops.Cast() | |||
self.split = ops.Split(1, 3) | |||
self.concat = ops.Concat(1) | |||
self.sqrt = ops.Sqrt() | |||
# gauss-type pulse source | |||
self.pi = Tensor(PI, ms_type.float32) | |||
self.src_frq = config.get("src_frq", 1e+9) | |||
self.tau = Tensor((2.3 ** 0.5) / (PI * self.src_frq), ms_type.float32) | |||
self.amp = Tensor(1.0, ms_type.float32) | |||
self.t0 = Tensor(3.65 * self.tau, ms_type.float32) | |||
# src space | |||
self.src_x0 = Tensor(config["src_pos"][0], ms_type.float32) | |||
self.src_y0 = Tensor(config["src_pos"][1], ms_type.float32) | |||
self.src_sigma = Tensor(config["src_radius"] / 4.0, ms_type.float32) | |||
self.src_coord_min = config["coord_min"] | |||
self.src_coord_max = config["coord_max"] | |||
input_scales = config.get("input_scales", [1.0, 1.0, 2.5e+8]) | |||
output_scales = config.get("output_scales", [37.67303, 37.67303, 0.1]) | |||
self.s_x = Tensor(input_scales[0], ms_type.float32) | |||
self.s_y = Tensor(input_scales[1], ms_type.float32) | |||
self.s_t = Tensor(input_scales[2], ms_type.float32) | |||
self.s_ex = Tensor(output_scales[0], ms_type.float32) | |||
self.s_ey = Tensor(output_scales[1], ms_type.float32) | |||
self.s_hz = Tensor(output_scales[2], ms_type.float32) | |||
# set up eps, mu candidates | |||
eps_candidates = np.array(config["eps_list"], dtype=np.float32) * EPS | |||
mu_candidates = np.array(config["mu_list"], dtype=np.float32) * MU | |||
self.epsilon_x = Tensor(eps_candidates, ms_type.float32).view((-1, 1)) | |||
self.epsilon_y = Tensor(eps_candidates, ms_type.float32).view((-1, 1)) | |||
self.mu_z = Tensor(mu_candidates, ms_type.float32).view((-1, 1)) | |||
self.light_speed = 1.0 / ops.Sqrt()(ops.Mul()(self.epsilon_x, self.mu_z)) | |||
def smooth_src(self, x, y, t): | |||
source = self.amp * ops.exp(- ((t - self.t0) / self.tau)**2) | |||
gauss = 1 / (2 * self.pi * self.src_sigma**2) * \ | |||
ops.exp(- ((x - self.src_x0)**2 + (y - self.src_y0)**2) / (2 * (self.src_sigma**2))) | |||
return self.mul(source, gauss) | |||
@ms_function | |||
def governing_equation(self, *output, **kwargs): | |||
"""maxwell equation of TE mode wave""" | |||
out = output[0] | |||
data = kwargs[self.domain_column] | |||
x = self.reshape(data[:, 0], (-1, 1)) | |||
y = self.reshape(data[:, 1], (-1, 1)) | |||
t = self.reshape(data[:, 2], (-1, 1)) | |||
dex_dxyt = self.gradient(data, None, 0, out) | |||
_, dex_dy, dex_dt = self.split(dex_dxyt) | |||
dey_dxyt = self.gradient(data, None, 1, out) | |||
dey_dx, _, dey_dt = self.split(dey_dxyt) | |||
dhz_dxyt = self.gradient(data, None, 2, out) | |||
dhz_dx, dhz_dy, dhz_dt = self.split(dhz_dxyt) | |||
dex_dy = self.cast(dex_dy, ms_type.float32) | |||
dex_dt = self.cast(dex_dt, ms_type.float32) | |||
dey_dx = self.cast(dey_dx, ms_type.float32) | |||
dey_dt = self.cast(dey_dt, ms_type.float32) | |||
dhz_dx = self.cast(dhz_dx, ms_type.float32) | |||
dhz_dy = self.cast(dhz_dy, ms_type.float32) | |||
dhz_dt = self.cast(dhz_dt, ms_type.float32) | |||
loss_a1 = (self.s_hz * dhz_dy) / (self.s_ex * self.s_t * self.epsilon_x) | |||
loss_a2 = dex_dt / self.s_t | |||
loss_b1 = -(self.s_hz * dhz_dx) / (self.s_ey * self.s_t * self.epsilon_y) | |||
loss_b2 = dey_dt / self.s_t | |||
loss_c1 = (self.s_ey * dey_dx - self.s_ex * dex_dy) / (self.s_hz * self.s_t * self.mu_z) | |||
loss_c2 = - dhz_dt / self.s_t | |||
source = self.smooth_src(x, y, t) / (self.s_hz * self.s_t * self.mu_z) | |||
pde_res1 = loss_a1 - loss_a2 | |||
pde_res2 = loss_b1 - loss_b2 | |||
pde_res3 = loss_c1 - loss_c2 - source | |||
pde_r = ops.Concat(1)((pde_res1, pde_res2, pde_res3)) | |||
return pde_r | |||
@ms_function | |||
def boundary_condition(self, *output, **kwargs): | |||
"""2nd-order mur boundary condition""" | |||
u = output[0] | |||
data = kwargs[self.bc_column] | |||
coord_min = self.src_coord_min | |||
coord_max = self.src_coord_max | |||
batch_size, _ = data.shape | |||
bc_attr = ms_np.zeros(shape=(batch_size, 4)) | |||
bc_attr[:, 0] = ms_np.where(ms_np.isclose(data[:, 0], coord_min[0]), 1.0, 0.0) | |||
bc_attr[:, 1] = ms_np.where(ms_np.isclose(data[:, 0], coord_max[0]), 1.0, 0.0) | |||
bc_attr[:, 2] = ms_np.where(ms_np.isclose(data[:, 1], coord_min[1]), 1.0, 0.0) | |||
bc_attr[:, 3] = ms_np.where(ms_np.isclose(data[:, 1], coord_max[1]), 1.0, 0.0) | |||
dex_dxyt = self.gradient(data, None, 0, u) | |||
_, dex_dy, _ = self.split(dex_dxyt) | |||
dey_dxyt = self.gradient(data, None, 1, u) | |||
dey_dx, _, _ = self.split(dey_dxyt) | |||
dhz_dxyt = self.gradient(data, None, 2, u) | |||
dhz_dx, dhz_dy, dhz_dt = self.split(dhz_dxyt) | |||
dex_dy = self.cast(dex_dy, ms_type.float32) | |||
dey_dx = self.cast(dey_dx, ms_type.float32) | |||
dhz_dx = self.cast(dhz_dx, ms_type.float32) | |||
dhz_dy = self.cast(dhz_dy, ms_type.float32) | |||
dhz_dt = self.cast(dhz_dt, ms_type.float32) | |||
bc_r1 = dhz_dx / self.s_x - dhz_dt / (self.light_speed * self.s_x) + \ | |||
self.s_ex * self.light_speed * self.epsilon_x / (2 * self.s_hz * self.s_x) * dex_dy # x=0 | |||
bc_r2 = dhz_dx / self.s_x + dhz_dt / (self.light_speed * self.s_x) - \ | |||
self.s_ex * self.light_speed * self.epsilon_x / (2 * self.s_hz * self.s_x) * dex_dy # x=L | |||
bc_r3 = dhz_dy / self.s_y - dhz_dt / (self.light_speed * self.s_y) - \ | |||
self.s_ey * self.light_speed * self.epsilon_y / (2 * self.s_hz * self.s_y) * dey_dx # y=0 | |||
bc_r4 = dhz_dy / self.s_y + dhz_dt / (self.light_speed * self.s_y) + \ | |||
self.s_ey * self.light_speed * self.epsilon_y / (2 * self.s_hz * self.s_y) * dey_dx # y=L | |||
bc_r_all = self.concat((bc_r1, bc_r2, bc_r3, bc_r4)) | |||
bc_r = self.mul(bc_r_all, bc_attr) | |||
return bc_r | |||
@ms_function | |||
def initial_condition(self, *output, **kwargs): | |||
"""initial condition: u = 0""" | |||
net_out = output[0] | |||
return net_out |
@@ -0,0 +1,52 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================== | |||
"""sampling information""" | |||
import copy | |||
from easydict import EasyDict as edict | |||
src_sampling_config = edict({ | |||
'domain': edict({ | |||
'random_sampling': True, | |||
'size': 262144, | |||
'sampler': 'uniform' | |||
}), | |||
'IC': edict({ | |||
'random_sampling': True, | |||
'size': 262144, | |||
'sampler': 'uniform', | |||
}), | |||
'time': edict({ | |||
'random_sampling': True, | |||
'size': 262144, | |||
'sampler': 'uniform', | |||
}), | |||
}) | |||
no_src_sampling_config = copy.deepcopy(src_sampling_config) | |||
bc_sampling_config = edict({ | |||
'BC': edict({ | |||
'random_sampling': True, | |||
'size': 262144, | |||
'sampler': 'uniform', | |||
'with_normal': False | |||
}), | |||
'time': edict({ | |||
'random_sampling': True, | |||
'size': 262144, | |||
'sampler': 'uniform', | |||
}), | |||
}) |
@@ -0,0 +1,192 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================== | |||
"""reconstruct process.""" | |||
import os | |||
import json | |||
import math | |||
import pytest | |||
import numpy as np | |||
from mindspore.common import set_seed | |||
from mindspore import context, Tensor, nn, Parameter | |||
from mindspore.train import DynamicLossScaleManager | |||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | |||
from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
import mindspore.common.dtype as ms_type | |||
from mindspore.common.initializer import HeUniform | |||
from mindelec.loss import Constraints | |||
from mindelec.solver import Solver, LossAndTimeMonitor | |||
from mindelec.common import L2 | |||
from mindelec.architecture import MultiScaleFCCell, MTLWeightedLossCell | |||
from src.dataset import create_random_dataset | |||
from src.lr_scheduler import MultiStepLR | |||
from src.maxwell import Maxwell2DMur | |||
set_seed(123456) | |||
np.random.seed(123456) | |||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend", save_graphs_path="./solver") | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_incremental_learning(): | |||
"""pretraining process""" | |||
print("pid:", os.getpid()) | |||
mode = "pretrain" | |||
config = json.load(open("./pretrain.json")) | |||
preprocess_config(config) | |||
elec_train_dataset = create_random_dataset(config) | |||
train_dataset = elec_train_dataset.create_dataset(batch_size=config["batch_size"], | |||
shuffle=True, | |||
prebatched_data=True, | |||
drop_remainder=True) | |||
epoch_steps = len(elec_train_dataset) | |||
print("check train dataset size: ", len(elec_train_dataset)) | |||
# load ckpt | |||
if config.get("load_ckpt", False): | |||
param_dict = load_checkpoint(config["load_ckpt_path"]) | |||
if mode == "pretrain": | |||
loaded_ckpt_dict = param_dict | |||
else: | |||
loaded_ckpt_dict = {} | |||
latent_vector_ckpt = 0 | |||
for name in param_dict: | |||
if name == "model.latent_vector": | |||
latent_vector_ckpt = param_dict[name].data.asnumpy() | |||
elif "network" in name and "moment" not in name: | |||
loaded_ckpt_dict[name] = param_dict[name] | |||
# initialize latent vector | |||
num_scenarios = config["num_scenarios"] | |||
latent_size = config["latent_vector_size"] | |||
if mode == "pretrain": | |||
latent_init = np.random.randn(num_scenarios, latent_size) / np.sqrt(latent_size) | |||
else: | |||
latent_norm = np.mean(np.linalg.norm(latent_vector_ckpt, axis=1)) | |||
print("check mean latent vector norm: ", latent_norm) | |||
latent_init = np.zeros((num_scenarios, latent_size)) | |||
latent_vector = Parameter(Tensor(latent_init, ms_type.float32), requires_grad=True) | |||
network = MultiScaleFCCell(config["input_size"], | |||
config["output_size"], | |||
layers=config["layers"], | |||
neurons=config["neurons"], | |||
residual=config["residual"], | |||
weight_init=HeUniform(negative_slope=math.sqrt(5)), | |||
act="sin", | |||
num_scales=config["num_scales"], | |||
amp_factor=config["amp_factor"], | |||
scale_factor=config["scale_factor"], | |||
input_scale=config["input_scale"], | |||
input_center=config["input_center"], | |||
latent_vector=latent_vector | |||
) | |||
network = network.to_float(ms_type.float16) | |||
network.input_scale.to_float(ms_type.float32) | |||
if config.get("enable_mtl", True): | |||
mtl_cell = MTLWeightedLossCell(num_losses=elec_train_dataset.num_dataset) | |||
else: | |||
mtl_cell = None | |||
# define problem | |||
train_prob = {} | |||
for dataset in elec_train_dataset.all_datasets: | |||
train_prob[dataset.name] = Maxwell2DMur(network=network, config=config, | |||
domain_column=dataset.name + "_points", | |||
ic_column=dataset.name + "_points", | |||
bc_column=dataset.name + "_points") | |||
print("check problem: ", train_prob) | |||
train_constraints = Constraints(elec_train_dataset, train_prob) | |||
# optimizer | |||
if mode == "pretrain": | |||
params = network.trainable_params() + mtl_cell.trainable_params() | |||
if config.get("load_ckpt", False): | |||
load_param_into_net(network, loaded_ckpt_dict) | |||
load_param_into_net(mtl_cell, loaded_ckpt_dict) | |||
else: | |||
if config.get("finetune_model"): | |||
model_params = network.trainable_params() | |||
else: | |||
model_params = [param for param in network.trainable_params() | |||
if ("bias" not in param.name and "weight" not in param.name)] | |||
params = model_params + mtl_cell.trainable_params() if mtl_cell else model_params | |||
load_param_into_net(network, loaded_ckpt_dict) | |||
lr_scheduler = MultiStepLR(config["lr"], config["milestones"], config["lr_gamma"], | |||
epoch_steps, config["train_epoch"]) | |||
optimizer = nn.Adam(params, learning_rate=Tensor(lr_scheduler.get_lr())) | |||
# problem solver | |||
solver = Solver(network, | |||
optimizer=optimizer, | |||
mode="PINNs", | |||
train_constraints=train_constraints, | |||
test_constraints=None, | |||
metrics={'l2': L2(), 'distance': nn.MAE()}, | |||
loss_fn='smooth_l1_loss', | |||
loss_scale_manager=DynamicLossScaleManager(), | |||
mtl_weighted_cell=mtl_cell, | |||
latent_vector=latent_vector, | |||
latent_reg=config["latent_reg"] | |||
) | |||
loss_time_callback = LossAndTimeMonitor(epoch_steps) | |||
callbacks = [loss_time_callback] | |||
if config["save_ckpt"]: | |||
config_ck = CheckpointConfig(save_checkpoint_steps=10, keep_checkpoint_max=2) | |||
prefix = 'pretrain_maxwell_frq1e9' if mode == "pretrain" else 'reconstruct_maxwell_frq1e9' | |||
ckpoint_cb = ModelCheckpoint(prefix=prefix, directory=config["save_ckpt_path"], config=config_ck) | |||
callbacks += [ckpoint_cb] | |||
solver.train(config["train_epoch"], train_dataset, callbacks=callbacks, dataset_sink_mode=True) | |||
assert loss_time_callback.get_loss() <= 2.0 | |||
assert loss_time_callback.get_step_time() <= 115.0 | |||
def preprocess_config(config): | |||
"""preprocess to get the coefficients of electromagnetic field for each scenario""" | |||
eps_candidates = config["EPS_candidates"] | |||
mu_candidates = config["MU_candidates"] | |||
config["num_scenarios"] = len(eps_candidates) * len(mu_candidates) | |||
batch_size_single_scenario = config["train_batch_size"] | |||
config["batch_size"] = batch_size_single_scenario * config["num_scenarios"] | |||
eps_list = [] | |||
for eps in eps_candidates: | |||
eps_list.extend([eps] * (batch_size_single_scenario * len(mu_candidates))) | |||
mu_list = [] | |||
for mu in mu_candidates: | |||
mu_list.extend([mu] * batch_size_single_scenario) | |||
mu_list = mu_list * (len(eps_candidates)) | |||
exp_name = "_" + config["Case"] + '_num_scenarios_' + str(config["num_scenarios"]) \ | |||
+ "_latent_reg_" + str(config["latent_reg"]) | |||
if config["save_ckpt"]: | |||
config["save_ckpt_path"] += exp_name | |||
config["vision_path"] += exp_name | |||
config["summary_path"] += exp_name | |||
print("check config: {}".format(config)) | |||
config["eps_list"] = eps_list | |||
config["mu_list"] = mu_list |
@@ -0,0 +1,130 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================== | |||
""" | |||
dataset | |||
""" | |||
import os | |||
import shutil | |||
import numpy as np | |||
from mindelec.data import Dataset, ExistedDataConfig | |||
def custom_normalize(data): | |||
""" | |||
get normalize data | |||
""" | |||
print("Custom normalization is called") | |||
ori_shape = data.shape | |||
data = data.reshape(ori_shape[0], -1) | |||
data = np.transpose(data) | |||
mean = np.mean(data, axis=1) | |||
data = data - mean[:, None] | |||
std = np.std(data, axis=1) | |||
std += (np.abs(std) < 0.0000001) | |||
data = data / std[:, None] | |||
data = np.transpose(data) | |||
data = data.reshape(ori_shape) | |||
return data | |||
def create_dataset(opt): | |||
""" | |||
load data | |||
""" | |||
data_input_path = opt.input_path | |||
data_label_path = opt.label_path | |||
data_input = np.load(data_input_path) | |||
data_label = np.load(data_label_path) | |||
frequency = data_label[0, :, 0] | |||
data_label = data_label[:, :, 1] | |||
print(data_input.shape) | |||
print(data_label.shape) | |||
print("data load finish") | |||
data_input = custom_normalize(data_input) | |||
config_data_prepare = {} | |||
config_data_prepare["scale_input"] = 0.5 * np.max(np.abs(data_input), axis=0) | |||
config_data_prepare["scale_S11"] = 0.5 * np.max(np.abs(data_label)) | |||
data_input[:, :] = data_input[:, :] / config_data_prepare["scale_input"] | |||
data_label[:, :] = data_label[:, :] / config_data_prepare["scale_S11"] | |||
permutation = np.random.permutation(data_input.shape[0]) | |||
data_input = data_input[permutation] | |||
data_label = data_label[permutation] | |||
length = data_input.shape[0] // 10 | |||
train_input, train_label = data_input[length:], data_label[length:] | |||
eval_input, eval_label = data_input[:length], data_label[:length] | |||
print(np.shape(train_input)) | |||
print(np.shape(train_label)) | |||
print(np.shape(eval_input)) | |||
print(np.shape(eval_label)) | |||
if not os.path.exists('./data_prepare'): | |||
os.mkdir('./data_prepare') | |||
else: | |||
shutil.rmtree('./data_prepare') | |||
os.mkdir('./data_prepare') | |||
train_input = train_input.astype(np.float32) | |||
np.save('./data_prepare/train_input', train_input) | |||
train_label = train_label.astype(np.float32) | |||
np.save('./data_prepare/train_label', train_label) | |||
eval_input = eval_input.astype(np.float32) | |||
np.save('./data_prepare/eval_input', eval_input) | |||
eval_label = eval_label.astype(np.float32) | |||
np.save('./data_prepare/eval_label', eval_label) | |||
electromagnetic_train = ExistedDataConfig(name="electromagnetic_train", | |||
data_dir=['./data_prepare/train_input.npy', | |||
'./data_prepare/train_label.npy'], | |||
columns_list=["inputs", "label"], | |||
data_format="npy") | |||
electromagnetic_eval = ExistedDataConfig(name="electromagnetic_eval", | |||
data_dir=['./data_prepare/eval_input.npy', | |||
'./data_prepare/eval_label.npy'], | |||
columns_list=["inputs", "label"], | |||
data_format="npy") | |||
train_batch_size = opt.batch_size | |||
eval_batch_size = len(eval_input) | |||
train_dataset = Dataset(existed_data_list=[electromagnetic_train]) | |||
train_loader = train_dataset.create_dataset(batch_size=train_batch_size, shuffle=True) | |||
eval_dataset = Dataset(existed_data_list=[electromagnetic_eval]) | |||
eval_loader = eval_dataset.create_dataset(batch_size=eval_batch_size, shuffle=False) | |||
data = { | |||
"train_loader": train_loader, | |||
"eval_loader": eval_loader, | |||
"train_data": train_input, | |||
"train_label": train_label, | |||
"eval_data": eval_input, | |||
"eval_label": eval_label, | |||
"train_data_length": len(train_label), | |||
"eval_data_length": len(eval_label), | |||
"frequency": frequency, | |||
} | |||
return data, config_data_prepare |
@@ -0,0 +1,98 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================== | |||
""" | |||
loss | |||
""" | |||
import os | |||
import shutil | |||
import mindspore.nn as nn | |||
import matplotlib.pyplot as plt | |||
import numpy as np | |||
import cv2 | |||
class EvalMetric(nn.Metric): | |||
""" | |||
eval metric | |||
""" | |||
def __init__(self, scale_s11, length, frequency, show_pic_number, file_path): | |||
super(EvalMetric, self).__init__() | |||
self.clear() | |||
self.scale_s11 = scale_s11 | |||
self.length = length | |||
self.frequency = frequency | |||
self.show_pic_number = show_pic_number | |||
self.file_path = file_path | |||
self.show_pic_id = np.random.choice(length, self.show_pic_number, replace=False) | |||
def clear(self): | |||
""" | |||
clear error | |||
""" | |||
self.error_sum_l2_error = 0 | |||
self.error_sum_loss_error = 0 | |||
self.pic_res = None | |||
def update(self, *inputs): | |||
""" | |||
update error | |||
""" | |||
if not os.path.exists(self.file_path): | |||
os.mkdir(self.file_path) | |||
else: | |||
shutil.rmtree(self.file_path) | |||
os.mkdir(self.file_path) | |||
y_pred = self._convert_data(inputs[0]) | |||
y_label = self._convert_data(inputs[1]) | |||
test_predict, test_label = y_pred, y_label | |||
test_predict[:, :] = test_predict[:, :] * self.scale_s11 | |||
test_label[:, :] = test_label[:, :] * self.scale_s11 | |||
self.pic_res = [] | |||
for i in range(len(test_label)): | |||
predict_real_temp = test_predict[i] | |||
label_real_temp = test_label[i] | |||
l2_error_temp = np.sqrt(np.sum(np.square(label_real_temp - predict_real_temp))) / \ | |||
np.sqrt(np.sum(np.square(label_real_temp))) | |||
self.error_sum_l2_error += l2_error_temp | |||
self.error_sum_loss_error += np.mean((label_real_temp - predict_real_temp) ** 2) | |||
s11_label, s11_predict = label_real_temp, predict_real_temp | |||
plt.figure(dpi=250) | |||
plt.plot(self.frequency, s11_predict, '-', label='AI Model', linewidth=2) | |||
plt.plot(self.frequency, s11_label, '--', label='CST', linewidth=1) | |||
plt.title('s11(dB)') | |||
plt.xlabel('frequency(GHz) l2_s11:' + str(l2_error_temp)[:10]) | |||
plt.ylabel('dB') | |||
plt.legend() | |||
plt.savefig(self.file_path + '/' + str(i) + '_' + str(l2_error_temp)[:10] + '.jpg') | |||
plt.close() | |||
if i in self.show_pic_id: | |||
self.pic_res.append(cv2.imread( | |||
self.file_path + '/' + str(i) + '_' + str(l2_error_temp)[:10] + '.jpg')) | |||
self.pic_res = np.array(self.pic_res).astype(np.float32) | |||
def eval(self): | |||
""" | |||
compute final error | |||
""" | |||
return {'l2_error': self.error_sum_l2_error / self.length, | |||
'loss_error': self.error_sum_loss_error / self.length, | |||
'pic_res': self.pic_res} |
@@ -0,0 +1,47 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================== | |||
""" | |||
maxwell S11 model | |||
""" | |||
import mindspore.nn as nn | |||
class S11Predictor(nn.Cell): | |||
""" | |||
maxwell S11 model define | |||
""" | |||
def __init__(self, input_dimension): | |||
super(S11Predictor, self).__init__() | |||
self.fc1 = nn.Dense(input_dimension, 128) | |||
self.fc2 = nn.Dense(128, 128) | |||
self.fc3 = nn.Dense(128, 128) | |||
self.fc4 = nn.Dense(128, 128) | |||
self.fc5 = nn.Dense(128, 128) | |||
self.fc6 = nn.Dense(128, 128) | |||
self.fc7 = nn.Dense(128, 1001) | |||
self.relu = nn.ReLU() | |||
def construct(self, x): | |||
"""forward""" | |||
x0 = x | |||
x1 = self.relu(self.fc1(x0)) | |||
x2 = self.relu(self.fc2(x1)) | |||
x3 = self.relu(self.fc3(x1 + x2)) | |||
x4 = self.relu(self.fc4(x1 + x2 + x3)) | |||
x5 = self.relu(self.fc5(x1 + x2 + x3 + x4)) | |||
x6 = self.relu(self.fc6(x1 + x2 + x3 + x4 + x5)) | |||
x = self.fc7(x1 + x2 + x3 + x4 + x5 + x6) | |||
return x |
@@ -0,0 +1,166 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================== | |||
""" | |||
test parameterization | |||
""" | |||
import time | |||
import pytest | |||
import numpy as np | |||
import mindspore.nn as nn | |||
from mindspore.common import set_seed | |||
import mindspore.common.dtype as mstype | |||
from mindspore import context | |||
from mindspore.train.callback import Callback | |||
from easydict import EasyDict as edict | |||
from mindelec.solver import Solver | |||
from mindelec.vision import MonitorTrain, MonitorEval | |||
from src.dataset import create_dataset | |||
from src.maxwell_model import S11Predictor | |||
from src.loss import EvalMetric | |||
set_seed(123456) | |||
np.random.seed(123456) | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
opt = edict({ | |||
'epochs': 100, | |||
'print_interval': 50, | |||
'batch_size': 8, | |||
'save_epoch': 10, | |||
'lr': 0.0001, | |||
'input_dim': 3, | |||
'device_num': 1, | |||
'device_target': "Ascend", | |||
'checkpoint_dir': './ckpt/', | |||
'save_graphs_path': './graph_result/', | |||
'input_path': './dataset/Butterfly_antenna/data_input.npy', | |||
'label_path': './dataset/Butterfly_antenna/data_label.npy', | |||
}) | |||
class TimeMonitor(Callback): | |||
""" | |||
Monitor the time in training. | |||
""" | |||
def __init__(self, data_size=None): | |||
super(TimeMonitor, self).__init__() | |||
self.data_size = data_size | |||
self.epoch_time = time.time() | |||
self.per_step_time = 0 | |||
self.t = 0 | |||
def epoch_begin(self, run_context): | |||
""" | |||
Record time at the begin of epoch. | |||
""" | |||
self.t = run_context | |||
self.epoch_time = time.time() | |||
def epoch_end(self, run_context): | |||
""" | |||
Print process cost time at the end of epoch. | |||
""" | |||
epoch_seconds = (time.time() - self.epoch_time) * 1000 | |||
step_size = self.data_size | |||
cb_params = run_context.original_args() | |||
if hasattr(cb_params, "batch_num"): | |||
batch_num = cb_params.batch_num | |||
if isinstance(batch_num, int) and batch_num > 0: | |||
step_size = cb_params.batch_num | |||
self.per_step_time = epoch_seconds / step_size | |||
def get_step_time(self,): | |||
return self.per_step_time | |||
def get_lr(data): | |||
""" | |||
get_lr | |||
""" | |||
num_milestones = 10 | |||
if data['train_data_length'] % opt.batch_size == 0: | |||
iter_number = int(data['train_data_length'] / opt.batch_size) | |||
else: | |||
iter_number = int(data['train_data_length'] / opt.batch_size) + 1 | |||
iter_number = opt.epochs * iter_number | |||
milestones = [int(iter_number * i / num_milestones) for i in range(1, num_milestones)] | |||
milestones.append(iter_number) | |||
learning_rates = [opt.lr * 0.5 ** i for i in range(0, num_milestones - 1)] | |||
learning_rates.append(opt.lr * 0.5 ** (num_milestones - 1)) | |||
return milestones, learning_rates | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_parameterization(): | |||
""" | |||
test parameterization | |||
""" | |||
data, config_data = create_dataset(opt) | |||
model_net = S11Predictor(opt.input_dim) | |||
model_net.to_float(mstype.float16) | |||
milestones, learning_rates = get_lr(data) | |||
optim = nn.Adam(model_net.trainable_params(), | |||
learning_rate=nn.piecewise_constant_lr(milestones, learning_rates)) | |||
eval_error_mrc = EvalMetric(scale_s11=config_data["scale_S11"], | |||
length=data["eval_data_length"], | |||
frequency=data["frequency"], | |||
show_pic_number=4, | |||
file_path='./eval_res') | |||
solver = Solver(network=model_net, | |||
mode="Data", | |||
optimizer=optim, | |||
metrics={'eval_mrc': eval_error_mrc}, | |||
loss_fn=nn.MSELoss()) | |||
monitor_train = MonitorTrain(per_print_times=1, | |||
summary_dir='./summary_dir_train') | |||
monitor_eval = MonitorEval(summary_dir='./summary_dir_eval', | |||
model=solver, | |||
eval_ds=data["eval_loader"], | |||
eval_interval=opt.print_interval, | |||
draw_flag=True) | |||
time_monitor = TimeMonitor() | |||
callbacks_train = [monitor_train, time_monitor, monitor_eval] | |||
solver.model.train(epoch=opt.epochs, | |||
train_dataset=data["train_loader"], | |||
callbacks=callbacks_train, | |||
dataset_sink_mode=True) | |||
loss_print, l2_s11_print = monitor_eval.loss_final, monitor_eval.l2_s11_final | |||
per_step_time = time_monitor.get_step_time() | |||
print('loss_mse:', loss_print) | |||
print('l2_s11:', l2_s11_print) | |||
print('per_step_time:', per_step_time) | |||
assert l2_s11_print <= 1.0 | |||
assert per_step_time <= 10.0 |
@@ -0,0 +1,26 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" | |||
network config setting, will be used in train.py and eval.py | |||
""" | |||
# config | |||
config = { | |||
'input_channels': 496, | |||
'epochs': 20, | |||
'batch_size': 8, | |||
'lr': 0.0001, | |||
'lr_decay_milestones': 10, | |||
} |
@@ -0,0 +1,79 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""dataset utilities""" | |||
import os | |||
import numpy as np | |||
from mindelec.data import Dataset, ExistedDataConfig | |||
INPUT_PATH = "" | |||
LABEL_PATH = "" | |||
DATA_CONFIG_PATH = "./data_config.npz" | |||
SAVE_DATA_PATH = "./" | |||
def custom_normalize(dataset, mean=None, std=None): | |||
""" custom normalization """ | |||
ori_shape = dataset.shape | |||
dataset = dataset.reshape(ori_shape[0], -1) | |||
dataset = np.transpose(dataset) | |||
if mean is None: | |||
mean = np.mean(dataset, axis=1) | |||
dataset = dataset - mean[:, None] | |||
if std is None: | |||
std = np.std(dataset, axis=1) | |||
std += (np.abs(std) < 0.0000001) | |||
dataset = dataset / std[:, None] | |||
dataset = np.transpose(dataset) | |||
dataset = dataset.reshape(ori_shape) | |||
return dataset, mean, std | |||
def generate_data(input_path, label_path): | |||
"""generate dataset for s11 parameter prediction""" | |||
data_input = np.load(input_path) | |||
if os.path.exists(DATA_CONFIG_PATH): | |||
data_config = np.load(DATA_CONFIG_PATH) | |||
mean = data_config["mean"] | |||
std = data_config["std"] | |||
data_input, mean, std = custom_normalize(data_input) | |||
data_label = np.load(label_path) | |||
print(data_input.shape) | |||
print(data_label.shape) | |||
data_input = data_input.transpose((0, 4, 1, 2, 3)) | |||
data_label[:, :] = np.log10(-data_label[:, :] + 1.0) | |||
scale_s11 = 0.5 * np.max(np.abs(data_label[:, :])) | |||
data_label[:, :] = data_label[:, :] / scale_s11 | |||
np.savez(DATA_CONFIG_PATH, scale_s11=scale_s11, mean=mean, std=std) | |||
np.save(os.path.join(SAVE_DATA_PATH, 'data_input.npy'), data_input) | |||
np.save(os.path.join(SAVE_DATA_PATH, 'data_label.npy'), data_label) | |||
print("data saved in target path") | |||
def create_dataset(input_path, label_path, batch_size=8, shuffle=True): | |||
electromagnetic = ExistedDataConfig(name="electromagnetic", | |||
data_dir=[input_path, label_path], | |||
columns_list=["inputs", "label"], | |||
data_format=input_path.split('.')[-1]) | |||
dataset = Dataset(existed_data_list=[electromagnetic]) | |||
data_loader = dataset.create_dataset(batch_size=batch_size, shuffle=shuffle) | |||
return data_loader | |||
if __name__ == "__main__": | |||
generate_data(input_path=INPUT_PATH, label_path=LABEL_PATH) |
@@ -0,0 +1,29 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""learning rate generator""" | |||
def step_lr_generator(step_size, epochs, lr, lr_decay_milestones): | |||
"""generate step decayed learning rate""" | |||
total_steps = epochs * step_size | |||
milestones = [int(total_steps * i / lr_decay_milestones) for i in range(1, lr_decay_milestones)] | |||
milestones.append(total_steps) | |||
learning_rates = [lr*0.5**i for i in range(0, lr_decay_milestones - 1)] | |||
learning_rates.append(lr*0.5**(lr_decay_milestones - 1)) | |||
print("total_steps: %s, milestones: %s, learning_rates: %s " %(total_steps, milestones, learning_rates)) | |||
return milestones, learning_rates |
@@ -0,0 +1,108 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""metrics""" | |||
import os | |||
import shutil | |||
import mindspore.nn as nn | |||
from mindspore.ops import functional as F | |||
import matplotlib.pyplot as plt | |||
import numpy as np | |||
import cv2 | |||
class MyMSELoss(nn.LossBase): | |||
"""mse loss function""" | |||
def construct(self, base, target): | |||
x = F.square(base - target) | |||
return self.get_loss(x) | |||
class EvalMetric(nn.Metric): | |||
""" | |||
eval metric | |||
""" | |||
def __init__(self, scale_s11, length, frequency, show_pic_number, file_path): | |||
super(EvalMetric, self).__init__() | |||
self.clear() | |||
self.scale_s11 = scale_s11 | |||
self.length = length | |||
self.frequency = frequency | |||
self.show_pic_number = show_pic_number | |||
self.file_path = file_path | |||
self.show_pic_id = np.random.choice(length, self.show_pic_number, replace=False) | |||
if not os.path.exists(self.file_path): | |||
os.mkdir(self.file_path) | |||
else: | |||
shutil.rmtree(self.file_path) | |||
os.mkdir(self.file_path) | |||
def clear(self): | |||
""" | |||
clear error | |||
""" | |||
self.error_sum_l2_error = 0 | |||
self.error_sum_loss_error = 0 | |||
self.pic_res = None | |||
self.index = 0 | |||
def update(self, *inputs): | |||
""" | |||
update error | |||
""" | |||
y_pred = self._convert_data(inputs[0]) | |||
y_label = self._convert_data(inputs[1]) | |||
test_predict, test_label = y_pred, y_label | |||
test_predict[:, :] = test_predict[:, :] * self.scale_s11 | |||
test_label[:, :] = test_label[:, :] * self.scale_s11 | |||
test_predict[:, :] = 1.0 - np.power(10, test_predict[:, :]) | |||
test_label[:, :] = 1.0 - np.power(10, test_label[:, :]) | |||
self.pic_res = [] | |||
for i in range(len(test_label)): | |||
self.index += 1 | |||
predict_real_temp = test_predict[i] | |||
label_real_temp = test_label[i] | |||
l2_error_temp = np.sqrt(np.sum(np.square(label_real_temp - predict_real_temp))) / \ | |||
np.sqrt(np.sum(np.square(label_real_temp))) | |||
self.error_sum_l2_error += l2_error_temp | |||
self.error_sum_loss_error += np.mean((label_real_temp - predict_real_temp) ** 2) | |||
s11_label, s11_predict = label_real_temp, predict_real_temp | |||
plt.figure(dpi=250) | |||
plt.plot(self.frequency, s11_predict, '-', label='AI Model', linewidth=2) | |||
plt.plot(self.frequency, s11_label, '--', label='CST', linewidth=1) | |||
plt.title('s11(dB)') | |||
plt.xlabel('frequency(GHz) l2_s11:' + str(l2_error_temp)[:10]) | |||
plt.ylabel('dB') | |||
plt.legend() | |||
plt.savefig(self.file_path + '/' + str(self.index) + '_' + str(l2_error_temp)[:10] + '.jpg') | |||
plt.close() | |||
if i in self.show_pic_id: | |||
self.pic_res.append(cv2.imread( | |||
self.file_path + '/' + str(i) + '_' + str(l2_error_temp)[:10] + '.jpg')) | |||
self.pic_res = np.array(self.pic_res).astype(np.float32) | |||
def eval(self): | |||
""" | |||
compute final error | |||
""" | |||
return {'l2_error': self.error_sum_l2_error / self.length, | |||
'loss_error': self.error_sum_loss_error / self.length, | |||
'pic_res': self.pic_res} |
@@ -0,0 +1,89 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""S parameter prediction model""" | |||
import mindspore.nn as nn | |||
import mindspore.ops as ops | |||
class S11Predictor(nn.Cell): | |||
""" | |||
S11Predictor architecture for MindElec. | |||
Args: | |||
input_dim (int): input channel. | |||
Returns: | |||
Tensor, output tensor. | |||
Examples: | |||
>>> S11Predictor(input_dim=128) | |||
""" | |||
def __init__(self, input_dim): | |||
super(S11Predictor, self).__init__() | |||
# input shape is [20, 40, 3] | |||
self.conv1 = nn.Conv3d(input_dim, 512, kernel_size=(3, 3, 1)) | |||
self.conv2 = nn.Conv3d(512, 512, kernel_size=(3, 3, 1)) | |||
self.conv3 = nn.Conv3d(512, 512, kernel_size=(3, 3, 1)) | |||
self.conv4 = nn.Conv3d(512, 512, kernel_size=(2, 1, 3), pad_mode='pad', padding=0) | |||
self.down1 = ops.MaxPool3D(kernel_size=(2, 3, 1), strides=(2, 3, 1)) | |||
self.down2 = ops.MaxPool3D(kernel_size=(2, 3, 1), strides=(2, 3, 1)) | |||
self.down3 = ops.MaxPool3D(kernel_size=(2, 3, 1), strides=(2, 3, 1)) | |||
self.down_1_1 = ops.MaxPool3D(kernel_size=(1, 13, 1), strides=(1, 13, 1)) | |||
self.down_1_2 = nn.MaxPool2d(kernel_size=(10, 3)) | |||
self.down_2 = nn.MaxPool2d((5, 4*3)) | |||
self.fc1 = nn.Dense(1536, 2048) | |||
self.fc2 = nn.Dense(2048, 2048) | |||
self.fc3 = nn.Dense(2048, 1001) | |||
self.concat = ops.Concat(axis=1) | |||
self.relu = nn.ReLU() | |||
def construct(self, x): | |||
"""forward""" | |||
bs = x.shape[0] | |||
x = self.conv1(x) | |||
x = self.relu(x) | |||
x = self.down1(x) | |||
x_1 = self.down_1_1(x) | |||
x_1 = self.down_1_2(x_1.view(bs, x_1.shape[1], x_1.shape[2], -1)).view((bs, -1)) | |||
x = self.conv2(x) | |||
x = self.relu(x) | |||
x = self.down2(x) | |||
x_2 = self.down_2(x.view(bs, x.shape[1], x.shape[2], -1)).view((bs, -1)) | |||
x = self.conv3(x) | |||
x = self.relu(x) | |||
x = self.down3(x) | |||
x = self.conv4(x) | |||
x = self.relu(x).view((bs, -1)) | |||
x = self.concat([x, x_1, x_2]) | |||
x = self.relu(x).view(bs, -1) | |||
x = self.relu(self.fc1(x)) | |||
x = self.relu(self.fc2(x)) | |||
x = self.fc3(x) | |||
return x |
@@ -0,0 +1,144 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""s parameter prediction model training test""" | |||
import time | |||
import pytest | |||
import numpy as np | |||
import mindspore.nn as nn | |||
from mindspore import context | |||
from mindspore.common import set_seed | |||
import mindspore.common.initializer as weight_init | |||
from mindspore.train.callback import LossMonitor, Callback | |||
from mindelec.solver import Solver | |||
from src.dataset import create_dataset | |||
from src.model import S11Predictor | |||
from src.lr_generator import step_lr_generator | |||
from src.config import config | |||
from src.metric import MyMSELoss, EvalMetric | |||
set_seed(0) | |||
np.random.seed(0) | |||
context.set_context(mode=context.GRAPH_MODE, | |||
save_graphs=False, | |||
device_target="Ascend") | |||
class TimeMonitor(Callback): | |||
""" | |||
Monitor the time in training. | |||
""" | |||
def __init__(self, data_size=None): | |||
super(TimeMonitor, self).__init__() | |||
self.data_size = data_size | |||
self.epoch_time = time.time() | |||
self.per_step_time = 0 | |||
self._tmp = None | |||
def epoch_begin(self, run_context): | |||
""" | |||
Record time at the begin of epoch. | |||
""" | |||
self.epoch_time = time.time() | |||
self._tmp = run_context | |||
def epoch_end(self, run_context): | |||
""" | |||
Print process cost time at the end of epoch. | |||
""" | |||
epoch_seconds = (time.time() - self.epoch_time) * 1000 | |||
step_size = self.data_size | |||
cb_params = run_context.original_args() | |||
if hasattr(cb_params, "batch_num"): | |||
batch_num = cb_params.batch_num | |||
if isinstance(batch_num, int) and batch_num > 0: | |||
step_size = cb_params.batch_num | |||
self.per_step_time = epoch_seconds / step_size | |||
print("epoch time: {:5.3f} ms, per step time: {:5.3f} ms".format(epoch_seconds, self.per_step_time), flush=True) | |||
def get_step_time(self,): | |||
return self.per_step_time | |||
def init_weight(net): | |||
"""init_weight""" | |||
for _, cell in net.cells_and_names(): | |||
if isinstance(cell, (nn.Conv3d, nn.Dense)): | |||
cell.weight.set_data(weight_init.initializer(weight_init.HeNormal(), | |||
cell.weight.shape, | |||
cell.weight.dtype)) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_onecard | |||
def test_s_predictor_train(): | |||
"""training""" | |||
train_input_path = "input.npy" | |||
train_label_path = "label.npy" | |||
train_input = np.ones((100, 496, 20, 40, 3), np.float32) | |||
train_label = np.ones((100, 1001), np.float32) | |||
np.save(train_input_path, train_input) | |||
np.save(train_label_path, train_label) | |||
model_net = S11Predictor(input_dim=config["input_channels"]) | |||
init_weight(net=model_net) | |||
train_dataset = create_dataset(input_path=train_input_path, | |||
label_path=train_label_path, | |||
batch_size=config["batch_size"], | |||
shuffle=True) | |||
step_size = train_dataset.get_dataset_size() | |||
milestones, learning_rates = step_lr_generator(step_size, | |||
config["epochs"], | |||
config["lr"], | |||
config["lr_decay_milestones"]) | |||
optimizer = nn.Adam(model_net.trainable_params(), | |||
learning_rate=nn.piecewise_constant_lr(milestones, learning_rates)) | |||
loss_net = MyMSELoss() | |||
eval_step_size = train_dataset.get_dataset_size() * config["batch_size"] | |||
evl_error_mrc = EvalMetric(scale_s11=1, | |||
length=eval_step_size, | |||
frequency=np.linspace(0, 4*10**8, 1001), | |||
show_pic_number=4, | |||
file_path='./eval_result') | |||
solver = Solver(model_net, | |||
train_input_map={'train': ['train_input_data']}, | |||
test_input_map={'test': ['test_input_data']}, | |||
optimizer=optimizer, | |||
metrics={'evl_mrc': evl_error_mrc,}, | |||
amp_level="O2", | |||
loss_fn=loss_net) | |||
time_cb = TimeMonitor() | |||
solver.model.train(config["epochs"], | |||
train_dataset, | |||
callbacks=[LossMonitor(), time_cb], | |||
dataset_sink_mode=False) | |||
res = solver.model.eval(train_dataset, dataset_sink_mode=False) | |||
per_step_time = time_cb.get_step_time() | |||
l1_error = res['evl_mrc']['l2_error'] | |||
print('test_res:', f'l1_error: {l1_error:.10f} ') | |||
print(f'per step time: {per_step_time:.10f} ') | |||
assert l1_error <= 0.01 | |||
assert per_step_time <= 100 |
@@ -0,0 +1,37 @@ | |||
{ | |||
"Description" : [ "PINNs for solve Maxwell's equations" ], | |||
"Case" : "2D_Mur_Src_Gauss_Mscale_MTL", | |||
"random_sampling" : true, | |||
"coord_min" : [0, 0], | |||
"coord_max" : [1, 1], | |||
"src_pos" : [0.4975, 0.4975], | |||
"src_frq": 1e+9, | |||
"range_t" : 4e-9, | |||
"input_scale": [1.0, 1.0, 2.5e+8], | |||
"output_scale": [37.67303, 37.67303, 0.1], | |||
"src_radius": 0.01, | |||
"input_size" : 3, | |||
"output_size" : 3, | |||
"residual" : true, | |||
"num_scales" : 4, | |||
"layers" : 7, | |||
"neurons" : 64, | |||
"amp_factor" : 10.0, | |||
"scale_factor" : 2.0, | |||
"save_ckpt" : true, | |||
"load_ckpt" : false, | |||
"save_ckpt_path" : "./ckpt", | |||
"load_ckpt_path" : "", | |||
"train_data_path" : "./input/", | |||
"test_data_path" : "./input/benchmark/", | |||
"lr" : 0.002, | |||
"milestones" : [300], | |||
"lr_gamma" : 0.1, | |||
"train_epoch" : 300, | |||
"train_batch_size" : 8192, | |||
"test_batch_size" : 32768, | |||
"predict_interval" : 6000, | |||
"vision_path" : "./vision", | |||
"summary_path" : "./summary" | |||
} |
@@ -0,0 +1,32 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" | |||
init | |||
""" | |||
from .dataset import create_train_dataset, get_test_data, create_random_dataset | |||
from .maxwell import Maxwell2DMur | |||
from .lr_scheduler import MultiStepLR | |||
from .callback import PredictCallback | |||
from .utils import visual_result | |||
__all__ = [ | |||
"create_train_dataset", | |||
"create_random_dataset", | |||
"get_test_data", | |||
"Maxwell2DMur", | |||
"MultiStepLR", | |||
"PredictCallback", | |||
"visual_result" | |||
] |
@@ -0,0 +1,127 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" | |||
call back functions | |||
""" | |||
import time | |||
import copy | |||
import numpy as np | |||
from mindspore.train.callback import Callback | |||
from mindspore.train.summary import SummaryRecord | |||
from mindspore import Tensor | |||
import mindspore.common.dtype as mstype | |||
class PredictCallback(Callback): | |||
""" | |||
Monitor the prediction accuracy in training. | |||
Args: | |||
model (Cell): Prediction network cell. | |||
inputs (Array): Input data of prediction. | |||
label (Array): Label data of prediction. | |||
config (dict): config info of prediction. | |||
visual_fn (dict): Visualization function. Default: None. | |||
""" | |||
def __init__(self, model, inputs, label, config, visual_fn=None): | |||
super(PredictCallback, self).__init__() | |||
self.model = model | |||
self.inputs = inputs | |||
self.label = label | |||
self.label_shape = label.shape | |||
self.visual_fn = visual_fn | |||
self.vision_path = config.get("vision_path", "./vision") | |||
self.summary_dir = config.get("summary_path", "./summary") | |||
self.output_size = config.get("output_size", 3) | |||
self.input_size = config.get("input_size", 3) | |||
self.output_scale = np.array(config["output_scale"], dtype=np.float32) | |||
self.predict_interval = config.get("predict_interval", 10) | |||
self.batch_size = config.get("test_batch_size", 8192*4) | |||
self.dx = inputs[0, 1, 0, 0] - inputs[0, 0, 0, 0] | |||
self.dy = inputs[0, 0, 1, 1] - inputs[0, 0, 0, 1] | |||
self.dt = inputs[1, 0, 0, 2] - inputs[0, 0, 0, 2] | |||
print("check yee delta: {}, {}, {}".format(self.dx, self.dy, self.dt)) | |||
self.ex_inputs = copy.deepcopy(inputs) | |||
self.ey_inputs = copy.deepcopy(inputs) | |||
self.hz_inputs = copy.deepcopy(inputs) | |||
self.ex_inputs = self.ex_inputs.reshape(-1, self.input_size) | |||
self.ey_inputs = self.ey_inputs.reshape(-1, self.input_size) | |||
self.hz_inputs = self.hz_inputs.reshape(-1, self.input_size) | |||
self.ex_inputs[:, 1] += self.dy / 2.0 | |||
self.ex_inputs[:, 2] += self.dt / 2.0 | |||
self.ey_inputs[:, 0] += self.dx / 2.0 | |||
self.ey_inputs[:, 2] += self.dt / 2.0 | |||
self.inputs_each = [self.ex_inputs, self.ey_inputs, self.hz_inputs] | |||
self._step_counter = 0 | |||
self.l2_error = (1.0, 1.0, 1.0) | |||
def __enter__(self): | |||
self.summary_record = SummaryRecord(self.summary_dir) | |||
return self | |||
def __exit__(self, *exc_args): | |||
self.summary_record.close() | |||
def epoch_end(self, run_context): | |||
""" | |||
Evaluate the model at the end of epoch. | |||
Args: | |||
run_context (RunContext): Context of the train running. | |||
""" | |||
cb_params = run_context.original_args() | |||
if cb_params.cur_epoch_num % self.predict_interval == 0: | |||
# predict each quantity | |||
index = 0 | |||
prediction_each = np.zeros(self.label_shape) | |||
prediction_each = prediction_each.reshape((-1, self.output_size)) | |||
time_beg = time.time() | |||
while index < len(self.inputs_each[0]): | |||
index_end = min(index + self.batch_size, len(self.inputs_each[0])) | |||
for i in range(self.output_size): | |||
test_batch = Tensor(self.inputs_each[i][index: index_end, :], mstype.float32) | |||
predict = self.model(test_batch) | |||
predict = predict.asnumpy() | |||
prediction_each[index: index_end, i] = predict[:, i] * self.output_scale[i] | |||
index = index_end | |||
print("==================================================================================================") | |||
print("predict total time: {} s".format(time.time() - time_beg)) | |||
prediction = prediction_each.reshape(self.label_shape) | |||
if self.visual_fn is not None: | |||
self.visual_fn(self.inputs, self.label, prediction, path=self.vision_path, | |||
name="epoch" + str(cb_params.cur_epoch_num)) | |||
label = self.label.reshape((-1, self.output_size)) | |||
prediction = prediction.reshape((-1, self.output_size)) | |||
self.l2_error = self._calculate_error(label, prediction) | |||
def _calculate_error(self, label, prediction): | |||
"""calculate l2-error to evaluate accuracy""" | |||
self._step_counter += 1 | |||
error = label - prediction | |||
l2_error_ex = np.sqrt(np.sum(np.square(error[..., 0]))) / np.sqrt(np.sum(np.square(label[..., 0]))) | |||
l2_error_ey = np.sqrt(np.sum(np.square(error[..., 1]))) / np.sqrt(np.sum(np.square(label[..., 1]))) | |||
l2_error_hz = np.sqrt(np.sum(np.square(error[..., 2]))) / np.sqrt(np.sum(np.square(label[..., 2]))) | |||
print("l2_error, Ex: ", l2_error_ex, ", Ey: ", l2_error_ey, ", Hz: ", l2_error_hz) | |||
self.summary_record.add_value('scalar', 'l2_ex', Tensor(l2_error_ex)) | |||
self.summary_record.add_value('scalar', 'l2_ey', Tensor(l2_error_ey)) | |||
self.summary_record.add_value('scalar', 'l2_hz', Tensor(l2_error_hz)) | |||
return l2_error_ex, l2_error_ey, l2_error_hz | |||
def get_l2_error(self): | |||
return self.l2_error |
@@ -0,0 +1,95 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" | |||
create dataset | |||
""" | |||
import numpy as np | |||
from mindelec.data import Dataset, ExistedDataConfig | |||
from mindelec.geometry import Disk, Rectangle, TimeDomain, GeometryWithTime | |||
from mindelec.geometry import create_config_from_edict | |||
from .sampling_config import no_src_sampling_config, src_sampling_config, bc_sampling_config | |||
def get_test_data(test_data_path): | |||
"""load labeled data for evaluation""" | |||
# check data | |||
paths = [test_data_path + '/input.npy', test_data_path + '/output.npy'] | |||
inputs = np.load(paths[0]) | |||
label = np.load(paths[1]) | |||
return inputs, label | |||
def create_train_dataset(train_data_path): | |||
"""create training dataset from existed data files""" | |||
src_ic = ExistedDataConfig(name="src_ic", | |||
data_dir=[train_data_path + "elec_src_ic.npy"], | |||
columns_list=["points"], | |||
data_format="npy", | |||
constraint_type="IC", | |||
random_merge=False) | |||
boundary = ExistedDataConfig(name="boundary", | |||
data_dir=[train_data_path + "elec_no_src_bc.npy"], | |||
columns_list=["points"], | |||
data_format="npy", | |||
constraint_type="BC", | |||
random_merge=True) | |||
no_src_ic = ExistedDataConfig(name="no_src_ic", | |||
data_dir=[train_data_path + "elec_no_src_ic.npy"], | |||
columns_list=["points"], | |||
data_format="npy", | |||
constraint_type="IC", | |||
random_merge=True) | |||
src_domain = ExistedDataConfig(name="src_domain", | |||
data_dir=[train_data_path + "elec_src_domain.npy"], | |||
columns_list=["points"], | |||
data_format="npy", | |||
constraint_type="Equation", | |||
random_merge=True) | |||
no_src_domain = ExistedDataConfig(name="no_src_domain", | |||
data_dir=[train_data_path + "elec_no_src_domain.npy"], | |||
columns_list=["points"], | |||
data_format="npy", | |||
constraint_type="Equation", | |||
random_merge=True) | |||
dataset = Dataset(existed_data_list=[no_src_domain, no_src_ic, boundary, src_domain, src_ic]) | |||
return dataset | |||
def create_random_dataset(config): | |||
"""create training dataset by online sampling""" | |||
disk_radius = config["src_radius"] | |||
disk_origin = config["src_pos"] | |||
coord_min = config["coord_min"] | |||
coord_max = config["coord_max"] | |||
disk = Disk("src", disk_origin, disk_radius) | |||
rectangle = Rectangle("rect", coord_min, coord_max) | |||
diff = rectangle - disk | |||
time_interval = TimeDomain("time", 0.0, config["range_t"]) | |||
no_src_region = GeometryWithTime(diff, time_interval) | |||
no_src_region.set_name("no_src") | |||
no_src_region.set_sampling_config(create_config_from_edict(no_src_sampling_config)) | |||
src_region = GeometryWithTime(disk, time_interval) | |||
src_region.set_name("src") | |||
src_region.set_sampling_config(create_config_from_edict(src_sampling_config)) | |||
boundary = GeometryWithTime(rectangle, time_interval) | |||
boundary.set_name("bc") | |||
boundary.set_sampling_config(create_config_from_edict(bc_sampling_config)) | |||
geom_dict = {src_region: ["domain", "IC"], | |||
no_src_region: ["domain", "IC"], | |||
boundary: ["BC"]} | |||
dataset = Dataset(geom_dict) | |||
return dataset |
@@ -0,0 +1,73 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# =========================================================================== | |||
"""Learning rate scheduler.""" | |||
from collections import Counter | |||
import numpy as np | |||
class _LRScheduler(): | |||
""" | |||
Basic class for learning rate scheduler | |||
""" | |||
def __init__(self, lr, max_epoch, steps_per_epoch): | |||
self.base_lr = lr | |||
self.steps_per_epoch = steps_per_epoch | |||
self.total_steps = int(max_epoch * steps_per_epoch) | |||
def get_lr(self): | |||
# Compute learning rate using chainable form of the scheduler | |||
raise NotImplementedError | |||
class MultiStepLR(_LRScheduler): | |||
""" | |||
Multi-step learning rate scheduler | |||
Decays the learning rate by gamma once the number of epoch reaches one of the milestones. | |||
Args: | |||
lr (float): Initial learning rate which is the lower boundary in the cycle. | |||
milestones (list): List of epoch indices. Must be increasing. | |||
gamma (float): Multiplicative factor of learning rate decay. | |||
steps_per_epoch (int): The number of steps per epoch to train for. | |||
max_epoch (int): The number of epochs to train for. | |||
Outputs: | |||
numpy.ndarray, shape=(1, steps_per_epoch*max_epoch) | |||
Example: | |||
>>> # Assuming optimizer uses lr = 0.05 for all groups | |||
>>> # lr = 0.05 if epoch < 30 | |||
>>> # lr = 0.005 if 30 <= epoch < 80 | |||
>>> # lr = 0.0005 if epoch >= 80 | |||
>>> scheduler = MultiStepLR(lr=0.1, milestones=[30,80], gamma=0.1, steps_per_epoch=5000, max_epoch=90) | |||
>>> lr = scheduler.get_lr() | |||
""" | |||
def __init__(self, lr, milestones, gamma, steps_per_epoch, max_epoch): | |||
self.milestones = Counter(milestones) | |||
self.gamma = gamma | |||
super(MultiStepLR, self).__init__(lr, max_epoch, steps_per_epoch) | |||
def get_lr(self): | |||
lr_each_step = [] | |||
current_lr = self.base_lr | |||
for i in range(self.total_steps): | |||
cur_ep = i // self.steps_per_epoch | |||
if i % self.steps_per_epoch == 0 and cur_ep in self.milestones: | |||
current_lr = current_lr * self.gamma | |||
lr = current_lr | |||
lr_each_step.append(lr) | |||
return np.array(lr_each_step).astype(np.float32) |