You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hcom_ops.h 15 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. /*!
  17. * \file hcom_ops.h
  18. * \brief huawei collective communication library ops.
  19. */
  20. #ifndef OPS_BUILT_IN_OP_PROTO_INC_HCOM_OPS_H_
  21. #define OPS_BUILT_IN_OP_PROTO_INC_HCOM_OPS_H_
  22. #include "graph/operator_reg.h"
  23. namespace ge {
  24. /**
  25. * @brief Outputs a tensor gathering all input tensors.
  26. * @par Inputs:
  27. * x: A tensor. Must be one of the following types: int8, int16, int32, float16,
  28. float32.
  29. * @par Attributes:
  30. * @li rank_size: A required integer identifying the number of ranks
  31. participating in the op.
  32. * @li group: A required string identifying the group name of ranks
  33. participating in the op.
  34. * @par Outputs:
  35. * y: A Tensor. Has the same type as "x".
  36. * @attention Constraints:
  37. "group" is limited to 128 characters. Use "hccl_world_group"
  38. as the name of a world group.
  39. */
  40. REG_OP(HcomAllGather)
  41. .INPUT(x, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16, DT_INT64, DT_UINT64}))
  42. .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16, DT_INT64, DT_UINT64}))
  43. .REQUIRED_ATTR(rank_size, Int)
  44. .REQUIRED_ATTR(group, String)
  45. .OP_END_FACTORY_REG(HcomAllGather)
  46. /**
  47. * @brief Outputs a tensor containing the reduction across all input tensors
  48. passed to op.
  49. * @par Inputs:
  50. * x: A tensor. Must be one of the following types: int8, int16, int32, float16,
  51. float32.
  52. * @par Attributes:
  53. * @li reduction: A required string identifying the reduction operation to
  54. perform.The supported operation are: "sum", "max", "min", "prod".
  55. * @li group: A required string identifying the group name of ranks
  56. participating in the op.
  57. * @li fusion: An optional integer identifying the fusion flag of the op.
  58. 0: no fusion; 1 (default): fusion; 2: fusion the ops by fusion id.
  59. * @li fusion_id: An optional integer identifying the fusion id of the op.
  60. * The HcomAllReduce ops with the same fusion id will be fused.
  61. * @par Outputs:
  62. * y: A Tensor. Has the same type as "x".
  63. * @attention Constraints:
  64. *"group" is limited to 128 characters. Use "hccl_world_group"
  65. as the name of a world group.
  66. */
  67. REG_OP(HcomAllReduce)
  68. .INPUT(x, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16}))
  69. .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16}))
  70. .REQUIRED_ATTR(reduction, String)
  71. .REQUIRED_ATTR(group, String)
  72. .ATTR(fusion, Int, 1)
  73. .ATTR(fusion_id, Int, -1)
  74. .OP_END_FACTORY_REG(HcomAllReduce)
  75. /**
  76. * @brief Broadcasts the input tensor in root rank to all ranks.
  77. * @par Inputs:
  78. * x: A list of dynamic input tensor. Must be one of the following types:
  79. int8, int16, int32, float16, float32. It's a dynamic input.
  80. * @par Attributes:
  81. * @li root_rank: A required integer identifying the root rank in the op
  82. input of this rank will be broadcast to other ranks.
  83. * @li fusion: A required integer identifying if the op need to fusion,the
  84. default value is none fusion
  85. * @li fusion_id: A required integer identifying the fusion id if para fusion
  86. is set.
  87. * @li group: A required string identifying the group name of ranks
  88. participating in the op.
  89. * @par Outputs:
  90. * y: A list of dynamic output tensor. Has the same type and length as "x".
  91. * It's a dynamic output.
  92. * @attention Constraints:
  93. "group" is limited to 128 characters. Use "hccl_world_group"
  94. as the name of a world group.
  95. */
  96. REG_OP(HcomBroadcast)
  97. .DYNAMIC_INPUT(x, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16, DT_INT64, DT_UINT64}))
  98. .DYNAMIC_OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16, DT_INT64, DT_UINT64}))
  99. .REQUIRED_ATTR(root_rank, Int)
  100. .REQUIRED_ATTR(group, String)
  101. .ATTR(fusion, Int, 0)
  102. .ATTR(fusion_id, Int, -1)
  103. .OP_END_FACTORY_REG(HcomBroadcast)
  104. /**
  105. * @brief preforms reduction from others rank to rootrank
  106. * @par Inputs:
  107. * @li root_rank: A required integer identifying the root rank in the op
  108. the reduction result will be on this root rank
  109. * x: A tensor. Must be one of the following types: int8, int16, int32, float16,
  110. float32.
  111. * @par Attributes:
  112. * @li reduction: A required string identifying the reduction operation to
  113. perform.The supported operation are: "sum", "max", "min", "prod".
  114. * @li group: A required string identifying the group name of ranks
  115. participating in the op.
  116. * @li fusion: An optional integer identifying the fusion flag of the op.
  117. 0: no fusion; 1 (default): fusion; 2: fusion the ops by fusion id.
  118. * @li fusion_id: An optional integer identifying the fusion id of the op.
  119. * The HcomReduce ops with the same fusion id will be fused.
  120. * @par Outputs:
  121. * y: A Tensor. Has the same type as "x".
  122. * @attention Constraints:
  123. *"group" is limited to 128 characters. Use "hccl_world_group"
  124. as the name of a world group.
  125. */
  126. REG_OP(HcomReduce)
  127. .INPUT(x, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16}))
  128. .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16}))
  129. .REQUIRED_ATTR(root_rank, Int)
  130. .REQUIRED_ATTR(reduction, String)
  131. .REQUIRED_ATTR(group, String)
  132. .ATTR(fusion, Int, 0)
  133. .ATTR(fusion_id, Int, -1)
  134. .OP_END_FACTORY_REG(HcomReduce)
  135. /**
  136. * @brief Performs reduction across all input tensors, scattering in equal
  137. blocks among ranks, each rank getting a chunk of data based on its rank
  138. index.
  139. * @par Inputs:
  140. * x: A tensor. Must be one of the following types: int8, int16, int32, float16,
  141. float32.
  142. * @par Attributes:
  143. * @li reduction: A required string identifying the reduction operation to
  144. perform. The supported operation are: "sum", "max", "min", "prod".
  145. * @li group: A required string identifying the group name of ranks
  146. participating in the op.
  147. * @li rank_size: A required integer identifying the number of ranks
  148. participating in the op.
  149. * @par Outputs:
  150. * y: A Tensor. Has the same type as "x".
  151. * @attention Constraints:
  152. "group" is limited to 128 characters. Use "hccl_world_group"
  153. as the name of a world group.
  154. */
  155. REG_OP(HcomReduceScatter)
  156. .INPUT(x, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16}))
  157. .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16}))
  158. .REQUIRED_ATTR(reduction, String)
  159. .REQUIRED_ATTR(group, String)
  160. .REQUIRED_ATTR(rank_size, Int)
  161. .OP_END_FACTORY_REG(HcomReduceScatter)
  162. /**
  163. * @brief Sends the input tensor to destination rank.
  164. * @par Inputs:
  165. * x: A tensor. Must be one of the following types: int8, int16, int32, float16,
  166. float32.
  167. * @par Attributes:
  168. * @li sr_tag: A required integer identifying the send/recv message tag. The
  169. message will be received by the HcomReceive op with the same "sr_tag".
  170. * @li dest_rank: A required integer identifying the destination rank.
  171. * @li group: A string identifying the group name of ranks participating in
  172. the op.
  173. * @par Outputs:
  174. * None.
  175. * @attention Constraints:
  176. @li "group" is limited to 128 characters. Use
  177. "hccl_world_group" as the name of a world group.
  178. * @li Operators HcomSend and HcomReceive have the same "sr_tag".
  179. * @see HcomReceive
  180. */
  181. REG_OP(HcomSend)
  182. .INPUT(x, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16, DT_INT64, DT_UINT64}))
  183. .REQUIRED_ATTR(group, String)
  184. .REQUIRED_ATTR(sr_tag, Int)
  185. .REQUIRED_ATTR(dest_rank, Int)
  186. .OP_END_FACTORY_REG(HcomSend)
  187. /**
  188. * @brief Receives the tensor from source rank.
  189. * @par Inputs:
  190. * None.
  191. * @par Attributes:
  192. * @li sr_tag: A required integer identifying the send/recv message tag. The
  193. message will be send by the HcomSend op with the same "sr_tag".
  194. * @li src_rank: A required integer identifying the source rank.
  195. * @li group: A required string identifying the group name of ranks
  196. * participating in the op.
  197. * @li shape: A required list identifying the shape of the tensor to be
  198. received.
  199. * @li dtype: A required integer identifying the type of the tensor to be
  200. received. The supported types are: int8, int16, int32, float16, float32.
  201. * @par Outputs:
  202. * y: A tensor with type identified in "dtype".
  203. * @attention Constraints:
  204. @li "group" is limited to 128 characters. Use
  205. "hccl_world_group" as the name of a world group.
  206. * @li Operators HcomSend and HcomReceive have the same "sr_tag".
  207. * @li "shape" should be same as the input tensor of HcomSend.
  208. * @li "dtype" should be same as the input tensor of HcomSend.
  209. * @see HcomSend
  210. */
  211. REG_OP(HcomReceive)
  212. .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16, DT_INT64, DT_UINT64}))
  213. .REQUIRED_ATTR(group, String)
  214. .REQUIRED_ATTR(sr_tag, Int)
  215. .REQUIRED_ATTR(src_rank, Int)
  216. .REQUIRED_ATTR(shape, ListInt)
  217. .REQUIRED_ATTR(dtype, Type)
  218. .OP_END_FACTORY_REG(HcomReceive)
  219. /**
  220. * @brief Performs Remote Read of input tensors
  221. * @par Inputs:
  222. * remote: A tensor. describing the remote memory address to read: u64 remoteId, u64 addrRemote, u64 length
  223. * @par Outputs:
  224. * local: A Tensor. whose value is length / size_of(Type)
  225. */
  226. REG_OP(HcomRemoteRead)
  227. .INPUT(remote, TensorType({DT_INT64, DT_UINT64}))
  228. .OUTPUT(local, TensorType::ALL())
  229. .REQUIRED_ATTR(dtype, Type)
  230. .OP_END_FACTORY_REG(HcomRemoteRead)
  231. /**
  232. * @brief Performs Remote Ref Read of input tensors
  233. * @par Inputs:
  234. * remote: A tensor. describing the remote memory address to read: u64 remoteId, u64 addrRemote, u64 length
  235. * cache_var: The local base address
  236. * local_offset: Skip step length
  237. * @par Outputs:
  238. * cache_var: The local base address
  239. */
  240. REG_OP(HcomRemoteRefRead)
  241. .INPUT(remote, TensorType({DT_UINT64}))
  242. .INPUT(cache_var, TensorType({DT_UINT64}))
  243. .INPUT(local_offset, TensorType({DT_UINT64}))
  244. .OUTPUT(cache_var, TensorType({DT_UINT64}))
  245. .REQUIRED_ATTR(dtype, Type)
  246. .OP_END_FACTORY_REG(HcomRemoteRefRead)
  247. /**
  248. * @brief Performs Remote Write of input tensors
  249. * @par Inputs:
  250. * remote: A tensor. describing the remote memory address to write: u64 remoteId, u64 addrRemote, u64 length
  251. * @par Inputs:
  252. * local: A Tensor. whose value is length / size_of(Type)
  253. */
  254. REG_OP(HcomRemoteWrite)
  255. .INPUT(remote, TensorType({DT_INT64, DT_UINT64}))
  256. .INPUT(local, TensorType::ALL())
  257. .OP_END_FACTORY_REG(HcomRemoteWrite)
  258. /**
  259. * @brief Performs Remote Write of input tensors
  260. * @par Inputs:
  261. * remote: A tensor. describing the remote memory address to write: u64 remoteId, u64 addrRemote, u64 length
  262. * @par Inputs:
  263. * local: A Tensor. whose value is length / size_of(Type)
  264. */
  265. REG_OP(HcomRemoteScatterWrite)
  266. .INPUT(remote, TensorType({DT_INT64, DT_UINT64}))
  267. .INPUT(local, TensorType::ALL())
  268. .OPTIONAL_INPUT(local_offset, TensorType({DT_UINT64}))
  269. .OP_END_FACTORY_REG(HcomRemoteScatterWrite)
  270. /**
  271. * @brief All ranks send different amount of data to, and receive different
  272. amount of data from, all ranks.
  273. * @par Inputs:
  274. * Five inputs, including:
  275. * @li send_data: A tensor. the memory to send.
  276. * @li send_counts: A list, where entry i specifies the number of elements in
  277. send_data to send to rank i.
  278. * @li send_displacements: A list, where entry i specifies the displacement
  279. (offset from sendbuf) from which to send data to rank i.
  280. * @li recv_counts: A list, where entry i specifies the number of
  281. elements to receive from rank i.
  282. * @li recv_displacements: A list, , where entry i specifies the displacement
  283. (offset from recv_data) to which data from rank i should be written.
  284. * @par Outputs:
  285. * recv_data: A Tensor has same element type as send_data.
  286. * @par Attributes:
  287. * @li group: A string identifying the group name of ranks participating in
  288. the op.
  289. * @attention all ranks participating in the op should be full-mesh networking
  290. using the RDMA.
  291. */
  292. REG_OP(HcomAllToAllV)
  293. .INPUT(send_data, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16, DT_INT64, DT_UINT64}))
  294. .INPUT(send_counts, TensorType({DT_INT64}))
  295. .INPUT(send_displacements, TensorType({DT_INT64}))
  296. .INPUT(recv_counts, TensorType({DT_INT64}))
  297. .INPUT(recv_displacements, TensorType({DT_INT64}))
  298. .OUTPUT(recv_data, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16, DT_INT64, DT_UINT64}))
  299. .REQUIRED_ATTR(group, String)
  300. .OP_END_FACTORY_REG(HcomAllToAllV)
  301. /**
  302. * @brief All ranks send different amount of data to, and receive different
  303. amount of data from, all ranks. And concat all data descripting by addrinfo
  304. togather into output gathered.
  305. * @par Inputs:
  306. * Four inputs, including:
  307. * @li addrinfo: A tensor, descripting the memory info(address, length) to send.
  308. * @li addrinfo_count_per_rank: A list, where entry i specifies the number of
  309. elements in send_data to send to rank i.
  310. * @li recv_counts: A list, where entry i specifies the number of
  311. elements to receive from rank i.
  312. * @li recv_displacements: A list, , where entry i specifies the displacement
  313. (offset from recv_data) to which data from rank i should be written.
  314. * @par Outputs:
  315. * Two outputs, including:
  316. * @li recv_data: A Tensor has same element type as dtype.
  317. * @li gathered: A Tensor has same element type as dtype.
  318. * @par Attributes:
  319. * @li group: A string identifying the group name of ranks participating in
  320. the op.
  321. * @li dtype: Datatype of send buffer elements.
  322. * @li addr_length: descripting the element memory length in the addrinfo.
  323. -2: all element memory length in the addrinfo is the same, but it is unknown.
  324. -1: all element memory length is unknown.
  325. >0: all element memory length in the addrinfo is the same. the attr value is the memory length.
  326. * @attention all ranks participating in the op should be full-mesh networking
  327. using the RDMA.
  328. */
  329. REG_OP(HcomGatherAllToAllV)
  330. .INPUT(addrinfo, TensorType({DT_UINT64}))
  331. .INPUT(addrinfo_count_per_rank, TensorType({DT_INT64}))
  332. .INPUT(recv_counts, TensorType({DT_INT64}))
  333. .INPUT(recv_displacements, TensorType({DT_INT64}))
  334. .OUTPUT(recv_data, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16, DT_INT64, DT_UINT64}))
  335. .OUTPUT(gathered, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16, DT_INT64, DT_UINT64}))
  336. .REQUIRED_ATTR(group, String)
  337. .REQUIRED_ATTR(dtype, Type)
  338. .REQUIRED_ATTR(addr_length, Int)
  339. .OP_END_FACTORY_REG(HcomGatherAllToAllV)
  340. } // namespace ge
  341. #endif // OPS_BUILT_IN_OP_PROTO_INC_HCOM_OPS_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示