You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

deep_md.h 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. /**
  2. * CCopyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. /*!
  17. * \file deep_md.h
  18. * \brief
  19. */
  20. #ifndef OPS_BUILT_IN_OP_PROTO_INC_DEEP_MD_H_
  21. #define OPS_BUILT_IN_OP_PROTO_INC_DEEP_MD_H_
  22. #include "graph/operator_reg.h"
  23. namespace ge {
  24. /**
  25. * @brief Calculate TabulateFusion. \n
  26. *
  27. * @par Inputs:
  28. * Five inputs, including:
  29. * @li table: A Tensor. Must be one of the following types: float16, float32, float64.
  30. * @li table_info: A Tensor. Must be one of the following types: float16, float32, float64.
  31. * @li em_x: A Tensor. Must be one of the following types: float16, float32, float64.
  32. * @li em: A Tensor. Must be one of the following types: float16, float32, float64. \n
  33. *
  34. * @par Outputs:
  35. * descriptor: A Tensor. Must be one of the following types: float16, float32, float64. \n
  36. *
  37. * @par Attributes:
  38. * Three attributes, including:
  39. * @li last_layer_size: int value.
  40. * @li split_count: int value.
  41. * @li split_index: int value. \n
  42. *
  43. * @par Restrictions:
  44. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  45. */
  46. REG_OP(TabulateFusion)
  47. .INPUT(table, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  48. .INPUT(table_info, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  49. .INPUT(em_x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  50. .INPUT(em, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  51. .OUTPUT(descriptor, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  52. .REQUIRED_ATTR(last_layer_size, Int)
  53. .OP_END_FACTORY_REG(TabulateFusion)
  54. /**
  55. * @brief Calculate ProdEnvMatA. \n
  56. *
  57. * @par Inputs:
  58. * @li coord: A Tensor. Must be one of the following types: float32, float64.
  59. * @li type: A Tensor. Must be one of the following types: int32.
  60. * @li natoms: A Tensor. Must be one of the following types: int32.
  61. * @li box: A Tensor. Must be one of the following types: float32, float64.
  62. * @li mesh: A Tensor. Must be one of the following types: int32.
  63. * @li davg: A Tensor. Must be one of the following types: float32, float64.
  64. * @li dstd: A Tensor. Must be one of the following types: float32, float64.
  65. *
  66. * @par Outputs:
  67. * descrpt: A Tensor. Must be one of the following types: float32, float64.
  68. * descrpt_deriv: A Tensor. Must be one of the following types: float32, float64.
  69. * rij: A Tensor. Must be one of the following types: float32, float64.
  70. * nlist: A Tensor. Must be one of the following types: int32. \n
  71. *
  72. * @par Attributes:
  73. * @li rcut_a: A Float.
  74. * @li rcut_r: A Float.
  75. * @li rcut_r_smth: A Float.
  76. * @li sel_a: A ListInt.
  77. * @li split_count: A Int.
  78. * @li split_index: A Int.\n
  79. *
  80. */
  81. REG_OP(ProdEnvMatA)
  82. .INPUT(coord, TensorType({DT_FLOAT, DT_DOUBLE}))
  83. .INPUT(type, TensorType({DT_INT32}))
  84. .INPUT(natoms, TensorType({DT_INT32}))
  85. .INPUT(box, TensorType({DT_FLOAT, DT_DOUBLE}))
  86. .INPUT(mesh, TensorType({DT_INT32}))
  87. .INPUT(davg, TensorType({DT_FLOAT, DT_DOUBLE}))
  88. .INPUT(dstd, TensorType({DT_FLOAT, DT_DOUBLE}))
  89. .OUTPUT(descrpt, TensorType({DT_FLOAT, DT_DOUBLE}))
  90. .OUTPUT(descrpt_deriv, TensorType({DT_FLOAT, DT_DOUBLE}))
  91. .OUTPUT(rij, TensorType({DT_FLOAT, DT_DOUBLE}))
  92. .OUTPUT(nlist, TensorType({DT_INT32}))
  93. .ATTR(rcut_a, Float, 1.0)
  94. .ATTR(rcut_r, Float, 1.0)
  95. .ATTR(rcut_r_smth, Float, 1.0)
  96. .ATTR(sel_a, ListInt, {})
  97. .ATTR(sel_r, ListInt, {})
  98. .OP_END_FACTORY_REG(ProdEnvMatA)
  99. /**
  100. * @brief Calculate ProdEnvMatACalRij.
  101. * Use type, natoms, sel_a, and rcut_r as constraints, find the central element in
  102. * the corresponding coord through mesh, output the index of the central element
  103. * and the distance between the central element and each neighbor. \n
  104. *
  105. * @par Inputs:
  106. * @li coord: A Tensor. Must be one of the following types: float32, float64.
  107. * @li type: A Tensor. Must be one of the following types: int32.
  108. * @li natoms: A Tensor. Must be one of the following types: int32.
  109. * @li box: A Tensor. Must be one of the following types: float32, float64.
  110. * @li mesh: A Tensor. Must be one of the following types: int32.
  111. *
  112. * @par Outputs:
  113. * rij: A Tensor. Must be one of the following types: float32, float64.
  114. * nlist: A Tensor. Must be one of the following types: int32.
  115. * distance: A Tensor. Must be one of the following types: float32, float64.
  116. * rij_x: A Tensor. Must be one of the following types: float32, float64.
  117. * rij_y: A Tensor. Must be one of the following types: float32, float64.
  118. * rij_z: A Tensor. Must be one of the following types: float32, float64. \n
  119. *
  120. * @par Attributes:
  121. * @li rcut_a: A Float.
  122. * @li rcut_r: A Float.
  123. * @li rcut_r_smth: A Float.
  124. * @li sel_a: A ListInt.
  125. * @li sel_r: A ListInt. \n
  126. *
  127. * @par Restrictions:
  128. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  129. */
  130. REG_OP(ProdEnvMatACalcRij)
  131. .INPUT(coord, TensorType({DT_FLOAT, DT_DOUBLE}))
  132. .INPUT(type, TensorType({DT_INT32}))
  133. .INPUT(natoms, TensorType({DT_INT32}))
  134. .INPUT(box, TensorType({DT_FLOAT, DT_DOUBLE}))
  135. .INPUT(mesh, TensorType({DT_INT32}))
  136. .OUTPUT(rij, TensorType({DT_FLOAT, DT_DOUBLE}))
  137. .OUTPUT(nlist, TensorType({DT_INT32}))
  138. .OUTPUT(distance, TensorType({DT_FLOAT, DT_DOUBLE}))
  139. .OUTPUT(rij_x, TensorType({DT_FLOAT, DT_DOUBLE}))
  140. .OUTPUT(rij_y, TensorType({DT_FLOAT, DT_DOUBLE}))
  141. .OUTPUT(rij_z, TensorType({DT_FLOAT, DT_DOUBLE}))
  142. .ATTR(rcut_a, Float, 1.0)
  143. .ATTR(rcut_r, Float, 1.0)
  144. .ATTR(rcut_r_smth, Float, 1.0)
  145. .ATTR(sel_a, ListInt, {})
  146. .ATTR(sel_r, ListInt, {})
  147. .OP_END_FACTORY_REG(ProdEnvMatACalcRij)
  148. /**
  149. * @brief Calculate ProdEnvMatACalcDescrpt. \n
  150. *
  151. * @par Inputs:
  152. * @li distance: A Tensor. Must be one of the following types: float32, float64.
  153. * @li rij_x: A Tensor. Must be one of the following types: float32, float64.
  154. * @li rij_y: A Tensor. Must be one of the following types: float32, float64.
  155. * @li rij_z: A Tensor. Must be one of the following types: float32, float64.
  156. * @li type: A Tensor. Must be one of the following types: int32.
  157. * @li natoms: A Tensor. Must be one of the following types: int32.
  158. * @li mesh: A Tensor. Must be one of the following types: int32.
  159. * @li davg: A Tensor. Must be one of the following types: float32, float64.
  160. * @li dstd: A Tensor. Must be one of the following types: float32, float64. \n
  161. *
  162. * @par Outputs:
  163. * @li descrpt: A Tensor. Must be one of the following types: float32, float64.
  164. * @li descrpt_deriv: A Tensor. Must be one of the following types: float32, float64. \n
  165. *
  166. * @par Attributes:
  167. * @li rcut_a: A Float.
  168. * @li rcut_r: A Float.
  169. * @li rcut_r_smth: A Float.
  170. * @li sel_a: A ListInt.
  171. * @li sel_r: A ListInt. \n
  172. *
  173. * @par Restrictions:
  174. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  175. */
  176. REG_OP(ProdEnvMatACalcDescrpt)
  177. .INPUT(distance, TensorType({DT_FLOAT, DT_DOUBLE}))
  178. .INPUT(rij_x, TensorType({DT_FLOAT, DT_DOUBLE}))
  179. .INPUT(rij_y, TensorType({DT_FLOAT, DT_DOUBLE}))
  180. .INPUT(rij_z, TensorType({DT_FLOAT, DT_DOUBLE}))
  181. .INPUT(type, TensorType({DT_INT32}))
  182. .INPUT(natoms, TensorType({DT_INT32}))
  183. .INPUT(mesh, TensorType({DT_INT32}))
  184. .INPUT(davg, TensorType({DT_FLOAT, DT_DOUBLE}))
  185. .INPUT(dstd, TensorType({DT_FLOAT, DT_DOUBLE}))
  186. .OUTPUT(descrpt, TensorType({DT_FLOAT, DT_DOUBLE}))
  187. .OUTPUT(descrpt_deriv, TensorType({DT_FLOAT, DT_DOUBLE}))
  188. .ATTR(rcut_a, Float, 1.0)
  189. .ATTR(rcut_r, Float, 1.0)
  190. .ATTR(rcut_r_smth, Float, 1.0)
  191. .ATTR(sel_a, ListInt, {})
  192. .ATTR(sel_r, ListInt, {})
  193. .OP_END_FACTORY_REG(ProdEnvMatACalcDescrpt)
  194. /**
  195. * @brief Calculate ProdForceSeA. \n
  196. *
  197. * @par Inputs:
  198. * Five inputs, including:
  199. * @li net_deriv: A Tensor. Must be one of the following types: float16, float32, float64.
  200. * @li in_deriv: A Tensor. Must be one of the following types: float16, float32, float64.
  201. * @li nlist: A Tensor. dtype is int32.
  202. * @li natoms: A Tensor. dtype is int32. \n
  203. *
  204. * @par Outputs:
  205. * atom_force: A Tensor. Must be one of the following types: float16, float32, float64. \n
  206. *
  207. * @par Attributes:
  208. * Two attributes, including:
  209. * @li n_a_sel: A Scalar.
  210. * @li n_r_sel: A Scalar. \n
  211. *
  212. * @par Restrictions:
  213. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  214. */
  215. REG_OP(ProdForceSeA)
  216. .INPUT(net_deriv, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  217. .INPUT(in_deriv, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  218. .INPUT(nlist, TensorType({DT_INT32}))
  219. .INPUT(natoms, TensorType({DT_INT32}))
  220. .OUTPUT(atom_force, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  221. .REQUIRED_ATTR(n_a_sel, Int)
  222. .REQUIRED_ATTR(n_r_sel, Int)
  223. .OP_END_FACTORY_REG(ProdForceSeA)
  224. /**
  225. * @brief Calculate ProdVirialSeA. \n
  226. *
  227. * @par Inputs:
  228. * Five inputs, including:
  229. * @li net_deriv: A Tensor. Must be one of the following types: float16, float32, float64.
  230. * @li in_deriv: A Tensor. Must be one of the following types: float16, float32, float64.
  231. * @li rij: A Tensor. Must be one of the following types: float16, float32, float64.
  232. * @li nlist: A Tensor. dtype is int32.
  233. * @li natoms: A Tensor. dtype is int32. \n
  234. *
  235. * @par Outputs:
  236. * Two outputs, including:
  237. * @li virial: A Tensor. Must be one of the following types: float16, float32, float64.
  238. * @li atom_virial: A Tensor. Must be one of the following types: float16, float32, float64. \n
  239. *
  240. * @par Attributes:
  241. * Two attributes, including:
  242. * @li n_a_sel: Int value.
  243. * @li n_r_sel: Int value.
  244. * @li split_count: Int value.
  245. * @li split_index: Int value. \n
  246. */
  247. REG_OP(ProdVirialSeA)
  248. .INPUT(net_deriv, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  249. .INPUT(in_deriv, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  250. .INPUT(rij, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  251. .INPUT(nlist, TensorType({DT_INT32}))
  252. .INPUT(natoms, TensorType({DT_INT32}))
  253. .OUTPUT(virial, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  254. .OUTPUT(atom_virial, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  255. .REQUIRED_ATTR(n_a_sel, Int)
  256. .REQUIRED_ATTR(n_r_sel, Int)
  257. .OP_END_FACTORY_REG(ProdVirialSeA)
  258. /**
  259. * @brief Calculate TabulateFusionGrad. \n
  260. *
  261. * @par Inputs:
  262. * Five inputs, including:
  263. * @li table: A Tensor. Must be one of the following types: float16, float32, float64.
  264. * @li table_info: A Tensor. Must be one of the following types: float16, float32, float64.
  265. * @li em_x: A Tensor. Must be one of the following types: float16, float32, float64.
  266. * @li em: A Tensor. Must be one of the following types: float16, float32, float64.
  267. * @li dy: A Tensor. Must be one of the following types: float16, float32, float64.
  268. * @li descriptor: A Tensor. Must be one of the following types: float16, float32, float64. \n
  269. *
  270. * @par Outputs:
  271. * @li dy_dem_x: A Tensor. Must be one of the following types: float16, float32, float64.
  272. * @li dy_dem: A Tensor. Must be one of the following types: float16, float32, float64. \n
  273. *
  274. * @par Attributes:
  275. * Two attributes, including:
  276. * @li split_count: A Scalar.
  277. * @li split_index: A Scalar. \n
  278. *
  279. * @par Restrictions:
  280. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  281. */
  282. REG_OP(TabulateFusionGrad)
  283. .INPUT(table, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  284. .INPUT(table_info, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  285. .INPUT(em_x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  286. .INPUT(em, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  287. .INPUT(dy, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  288. .INPUT(descriptor, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  289. .OUTPUT(dy_dem_x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  290. .OUTPUT(dy_dem, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  291. .OP_END_FACTORY_REG(TabulateFusionGrad)
  292. } // namespace ge
  293. #endif // OPS_BUILT_IN_OP_PROTO_INC_DEEP_MD_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示