You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

deep_md.h 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. /**
  2. * CCopyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. /*!
  17. * \file deep_md.h
  18. * \brief
  19. */
  20. #ifndef OPS_BUILT_IN_OP_PROTO_INC_DEEP_MD_H_
  21. #define OPS_BUILT_IN_OP_PROTO_INC_DEEP_MD_H_
  22. #include "graph/operator_reg.h"
  23. namespace ge {
  24. /**
  25. * @brief Calculate TabulateFusion. \n
  26. *
  27. * @par Inputs:
  28. * Five inputs, including:
  29. * @li table: A Tensor. Must be one of the following types: float16, float32, float64.
  30. * @li table_info: A Tensor. Must be one of the following types: float16, float32, float64.
  31. * @li em_x: A Tensor. Must be one of the following types: float16, float32, float64.
  32. * @li em: A Tensor. Must be one of the following types: float16, float32, float64. \n
  33. *
  34. * @par Outputs:
  35. * descriptor: A Tensor. Must be one of the following types: float16, float32, float64. \n
  36. *
  37. * @par Attributes:
  38. * Three attributes, including:
  39. * @li last_layer_size: int value.
  40. * @li split_count: int value.
  41. * @li split_index: int value. \n
  42. *
  43. * @par Restrictions:
  44. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  45. */
  46. REG_OP(TabulateFusion)
  47. .INPUT(table, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  48. .INPUT(table_info, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  49. .INPUT(em_x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  50. .INPUT(em, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  51. .OUTPUT(descriptor, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  52. .REQUIRED_ATTR(last_layer_size, Int)
  53. .ATTR(split_count, Int, 1)
  54. .ATTR(split_index, Int, 0)
  55. .OP_END_FACTORY_REG(TabulateFusion)
  56. /**
  57. * @brief Calculate ProdEnvMatA. \n
  58. *
  59. * @par Inputs:
  60. * @li coord: A Tensor. Must be one of the following types: float32, float64.
  61. * @li type: A Tensor. Must be one of the following types: int32.
  62. * @li natoms: A Tensor. Must be one of the following types: int32.
  63. * @li box: A Tensor. Must be one of the following types: float32, float64.
  64. * @li mesh: A Tensor. Must be one of the following types: int32.
  65. * @li davg: A Tensor. Must be one of the following types: float32, float64.
  66. * @li dstd: A Tensor. Must be one of the following types: float32, float64.
  67. *
  68. * @par Outputs:
  69. * descrpt: A Tensor. Must be one of the following types: float32, float64.
  70. * descrpt_deriv: A Tensor. Must be one of the following types: float32, float64.
  71. * rij: A Tensor. Must be one of the following types: float32, float64.
  72. * nlist: A Tensor. Must be one of the following types: int32. \n
  73. *
  74. * @par Attributes:
  75. * @li rcut_a: A Float.
  76. * @li rcut_r: A Float.
  77. * @li rcut_r_smth: A Float.
  78. * @li sel_a: A ListInt.
  79. * @li split_count: A Int.
  80. * @li split_index: A Int.\n
  81. *
  82. */
  83. REG_OP(ProdEnvMatA)
  84. .INPUT(coord, TensorType({DT_FLOAT, DT_DOUBLE}))
  85. .INPUT(type, TensorType({DT_INT32}))
  86. .INPUT(natoms, TensorType({DT_INT32}))
  87. .INPUT(box, TensorType({DT_FLOAT, DT_DOUBLE}))
  88. .INPUT(mesh, TensorType({DT_INT32}))
  89. .INPUT(davg, TensorType({DT_FLOAT, DT_DOUBLE}))
  90. .INPUT(dstd, TensorType({DT_FLOAT, DT_DOUBLE}))
  91. .OUTPUT(descrpt, TensorType({DT_FLOAT, DT_DOUBLE}))
  92. .OUTPUT(descrpt_deriv, TensorType({DT_FLOAT, DT_DOUBLE}))
  93. .OUTPUT(rij, TensorType({DT_FLOAT, DT_DOUBLE}))
  94. .OUTPUT(nlist, TensorType({DT_INT32}))
  95. .ATTR(rcut_a, Float, 1.0)
  96. .ATTR(rcut_r, Float, 1.0)
  97. .ATTR(rcut_r_smth, Float, 1.0)
  98. .ATTR(sel_a, ListInt, {})
  99. .ATTR(sel_r, ListInt, {})
  100. .ATTR(split_count, Int, 1)
  101. .ATTR(split_index, Int, 0)
  102. .OP_END_FACTORY_REG(ProdEnvMatA)
  103. /**
  104. * @brief Calculate ProdEnvMatACalcDescrpt. \n
  105. *
  106. * @par Inputs:
  107. * @li distance: A Tensor. Must be one of the following types: float32, float64.
  108. * @li rij_x: A Tensor. Must be one of the following types: float32, float64.
  109. * @li rij_y: A Tensor. Must be one of the following types: float32, float64.
  110. * @li rij_z: A Tensor. Must be one of the following types: float32, float64.
  111. * @li type: A Tensor. Must be one of the following types: int32.
  112. * @li natoms: A Tensor. Must be one of the following types: int32.
  113. * @li mesh: A Tensor. Must be one of the following types: int32.
  114. * @li davg: A Tensor. Must be one of the following types: float32, float64.
  115. * @li dstd: A Tensor. Must be one of the following types: float32, float64. \n
  116. *
  117. * @par Outputs:
  118. * @li descrpt: A Tensor. Must be one of the following types: float32, float64.
  119. * @li descrpt_deriv: A Tensor. Must be one of the following types: float32, float64. \n
  120. *
  121. * @par Attributes:
  122. * @li rcut_a: A Float.
  123. * @li rcut_r: A Float.
  124. * @li rcut_r_smth: A Float.
  125. * @li sel_a: A ListInt.
  126. * @li sel_r: A ListInt. \n
  127. *
  128. * @par Restrictions:
  129. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  130. */
  131. REG_OP(ProdEnvMatACalcDescrpt)
  132. .INPUT(distance, TensorType({DT_FLOAT, DT_DOUBLE}))
  133. .INPUT(rij_x, TensorType({DT_FLOAT, DT_DOUBLE}))
  134. .INPUT(rij_y, TensorType({DT_FLOAT, DT_DOUBLE}))
  135. .INPUT(rij_z, TensorType({DT_FLOAT, DT_DOUBLE}))
  136. .INPUT(type, TensorType({DT_INT32}))
  137. .INPUT(natoms, TensorType({DT_INT32}))
  138. .INPUT(mesh, TensorType({DT_INT32}))
  139. .INPUT(davg, TensorType({DT_FLOAT, DT_DOUBLE}))
  140. .INPUT(dstd, TensorType({DT_FLOAT, DT_DOUBLE}))
  141. .OUTPUT(descrpt, TensorType({DT_FLOAT, DT_DOUBLE}))
  142. .OUTPUT(descrpt_deriv, TensorType({DT_FLOAT, DT_DOUBLE}))
  143. .ATTR(rcut_a, Float, 1.0)
  144. .ATTR(rcut_r, Float, 1.0)
  145. .ATTR(rcut_r_smth, Float, 1.0)
  146. .ATTR(sel_a, ListInt, {})
  147. .ATTR(sel_r, ListInt, {})
  148. .OP_END_FACTORY_REG(ProdEnvMatACalcDescrpt)
  149. /**
  150. * @brief Calculate ProdForceSeA. \n
  151. *
  152. * @par Inputs:
  153. * Five inputs, including:
  154. * @li net_deriv: A Tensor. Must be one of the following types: float16, float32, float64.
  155. * @li in_deriv: A Tensor. Must be one of the following types: float16, float32, float64.
  156. * @li nlist: A Tensor. dtype is int32.
  157. * @li natoms: A Tensor. dtype is int32. \n
  158. *
  159. * @par Outputs:
  160. * atom_force: A Tensor. Must be one of the following types: float16, float32, float64. \n
  161. *
  162. * @par Attributes:
  163. * Two attributes, including:
  164. * @li n_a_sel: A Scalar.
  165. * @li n_r_sel: A Scalar. \n
  166. *
  167. * @par Restrictions:
  168. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  169. */
  170. REG_OP(ProdForceSeA)
  171. .INPUT(net_deriv, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  172. .INPUT(in_deriv, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  173. .INPUT(nlist, TensorType({DT_INT32}))
  174. .INPUT(natoms, TensorType({DT_INT32}))
  175. .OUTPUT(atom_force, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  176. .REQUIRED_ATTR(n_a_sel, Int)
  177. .REQUIRED_ATTR(n_r_sel, Int)
  178. .ATTR(split_count, Int, 1)
  179. .ATTR(split_index, Int, 0)
  180. .OP_END_FACTORY_REG(ProdForceSeA)
  181. /**
  182. * @brief Calculate ProdVirialSeA. \n
  183. *
  184. * @par Inputs:
  185. * Five inputs, including:
  186. * @li net_deriv: A Tensor. Must be one of the following types: float16, float32, float64.
  187. * @li in_deriv: A Tensor. Must be one of the following types: float16, float32, float64.
  188. * @li rij: A Tensor. Must be one of the following types: float16, float32, float64.
  189. * @li nlist: A Tensor. dtype is int32.
  190. * @li natoms: A Tensor. dtype is int32. \n
  191. *
  192. * @par Outputs:
  193. * Two outputs, including:
  194. * @li virial: A Tensor. Must be one of the following types: float16, float32, float64.
  195. * @li atom_virial: A Tensor. Must be one of the following types: float16, float32, float64. \n
  196. *
  197. * @par Attributes:
  198. * Two attributes, including:
  199. * @li n_a_sel: Int value.
  200. * @li n_r_sel: Int value.
  201. * @li split_count: Int value.
  202. * @li split_index: Int value. \n
  203. */
  204. REG_OP(ProdVirialSeA)
  205. .INPUT(net_deriv, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  206. .INPUT(in_deriv, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  207. .INPUT(rij, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  208. .INPUT(nlist, TensorType({DT_INT32}))
  209. .INPUT(natoms, TensorType({DT_INT32}))
  210. .OUTPUT(virial, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  211. .OUTPUT(atom_virial, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  212. .REQUIRED_ATTR(n_a_sel, Int)
  213. .REQUIRED_ATTR(n_r_sel, Int)
  214. .ATTR(split_count, Int, 1)
  215. .ATTR(split_index, Int, 0)
  216. .OP_END_FACTORY_REG(ProdVirialSeA)
  217. /**
  218. * @brief Calculate TabulateFusionGrad. \n
  219. *
  220. * @par Inputs:
  221. * Five inputs, including:
  222. * @li table: A Tensor. Must be one of the following types: float16, float32, float64.
  223. * @li table_info: A Tensor. Must be one of the following types: float16, float32, float64.
  224. * @li em_x: A Tensor. Must be one of the following types: float16, float32, float64.
  225. * @li em: A Tensor. Must be one of the following types: float16, float32, float64.
  226. * @li dy: A Tensor. Must be one of the following types: float16, float32, float64.
  227. * @li descriptor: A Tensor. Must be one of the following types: float16, float32, float64. \n
  228. *
  229. * @par Outputs:
  230. * @li dy_dem_x: A Tensor. Must be one of the following types: float16, float32, float64.
  231. * @li dy_dem: A Tensor. Must be one of the following types: float16, float32, float64. \n
  232. *
  233. * @par Attributes:
  234. * Two attributes, including:
  235. * @li split_count: A Scalar.
  236. * @li split_index: A Scalar. \n
  237. *
  238. * @par Restrictions:
  239. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  240. */
  241. REG_OP(TabulateFusionGrad)
  242. .INPUT(table, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  243. .INPUT(table_info, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  244. .INPUT(em_x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  245. .INPUT(em, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  246. .INPUT(dy, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  247. .INPUT(descriptor, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  248. .OUTPUT(dy_dem_x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  249. .OUTPUT(dy_dem, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
  250. .ATTR(split_count, Int, 1)
  251. .ATTR(split_index, Int, 0)
  252. .OP_END_FACTORY_REG(TabulateFusionGrad)
  253. } // namespace ge
  254. #endif // OPS_BUILT_IN_OP_PROTO_INC_DEEP_MD_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示