You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hcom_ops.h 16 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. /*!
  17. * \file hcom_ops.h
  18. * \brief huawei collective communication library ops.
  19. */
  20. #ifndef OPS_BUILT_IN_OP_PROTO_INC_HCOM_OPS_H_
  21. #define OPS_BUILT_IN_OP_PROTO_INC_HCOM_OPS_H_
  22. #include "graph/operator_reg.h"
  23. namespace ge {
  24. /**
  25. * @brief Outputs a tensor gathering all input tensors.
  26. * @par Inputs:
  27. * x: A tensor. Must be one of the following types: int8, int16, int32, float16,
  28. float32, uint8, uint16, uint32, float64.
  29. * @par Attributes:
  30. * @li rank_size: A required integer identifying the number of ranks
  31. participating in the op.
  32. * @li group: A required string identifying the group name of ranks
  33. participating in the op.
  34. * @par Outputs:
  35. * y: A Tensor. Has the same type as "x".
  36. * @attention Constraints:
  37. "group" is limited to 128 characters. Use "hccl_world_group"
  38. as the name of a world group.
  39. */
  40. REG_OP(HcomAllGather)
  41. .INPUT(x, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16, DT_INT64, DT_UINT64,
  42. DT_UINT8, DT_UINT16, DT_UINT32, DT_FLOAT64}))
  43. .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16, DT_INT64, DT_UINT64,
  44. DT_UINT8, DT_UINT16, DT_UINT32, DT_FLOAT64}))
  45. .REQUIRED_ATTR(rank_size, Int)
  46. .REQUIRED_ATTR(group, String)
  47. .OP_END_FACTORY_REG(HcomAllGather)
  48. /**
  49. * @brief Outputs a tensor containing the reduction across all input tensors
  50. passed to op.
  51. * @par Inputs:
  52. * x: A tensor. Must be one of the following types: int8, int16, int32, float16,
  53. float32.
  54. * @par Attributes:
  55. * @li reduction: A required string identifying the reduction operation to
  56. perform.The supported operation are: "sum", "max", "min", "prod".
  57. * @li group: A required string identifying the group name of ranks
  58. participating in the op.
  59. * @li fusion: An optional integer identifying the fusion flag of the op.
  60. 0: no fusion; 1 (default): fusion; 2: fusion the ops by fusion id.
  61. * @li fusion_id: An optional integer identifying the fusion id of the op.
  62. * The HcomAllReduce ops with the same fusion id will be fused.
  63. * @par Outputs:
  64. * y: A Tensor. Has the same type as "x".
  65. * @attention Constraints:
  66. *"group" is limited to 128 characters. Use "hccl_world_group"
  67. as the name of a world group.
  68. */
  69. REG_OP(HcomAllReduce)
  70. .INPUT(x, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16}))
  71. .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16}))
  72. .REQUIRED_ATTR(reduction, String)
  73. .REQUIRED_ATTR(group, String)
  74. .ATTR(fusion, Int, 1)
  75. .ATTR(fusion_id, Int, -1)
  76. .OP_END_FACTORY_REG(HcomAllReduce)
  77. /**
  78. * @brief Broadcasts the input tensor in root rank to all ranks.
  79. * @par Inputs:
  80. * x: A list of dynamic input tensor. Must be one of the following types:
  81. int8, int16, int32, float16, float32. It's a dynamic input.
  82. * @par Attributes:
  83. * @li root_rank: A required integer identifying the root rank in the op
  84. input of this rank will be broadcast to other ranks.
  85. * @li fusion: A required integer identifying if the op need to fusion,the
  86. default value is none fusion
  87. * @li fusion_id: A required integer identifying the fusion id if para fusion
  88. is set.
  89. * @li group: A required string identifying the group name of ranks
  90. participating in the op.
  91. * @par Outputs:
  92. * y: A list of dynamic output tensor. Has the same type and length as "x".
  93. * It's a dynamic output.
  94. * @attention Constraints:
  95. "group" is limited to 128 characters. Use "hccl_world_group"
  96. as the name of a world group.
  97. */
  98. REG_OP(HcomBroadcast)
  99. .DYNAMIC_INPUT(x, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16, DT_INT64, DT_UINT64,
  100. DT_UINT8, DT_UINT16, DT_UINT32, DT_FLOAT64}))
  101. .DYNAMIC_OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16, DT_INT64, DT_UINT64,
  102. DT_UINT8, DT_UINT16, DT_UINT32, DT_FLOAT64}))
  103. .REQUIRED_ATTR(root_rank, Int)
  104. .REQUIRED_ATTR(group, String)
  105. .ATTR(fusion, Int, 0)
  106. .ATTR(fusion_id, Int, -1)
  107. .OP_END_FACTORY_REG(HcomBroadcast)
  108. /**
  109. * @brief preforms reduction from others rank to rootrank
  110. * @par Inputs:
  111. * @li root_rank: A required integer identifying the root rank in the op
  112. the reduction result will be on this root rank
  113. * x: A tensor. Must be one of the following types: int8, int16, int32, float16,
  114. float32.
  115. * @par Attributes:
  116. * @li reduction: A required string identifying the reduction operation to
  117. perform.The supported operation are: "sum", "max", "min", "prod".
  118. * @li group: A required string identifying the group name of ranks
  119. participating in the op.
  120. * @li fusion: An optional integer identifying the fusion flag of the op.
  121. 0: no fusion; 1 (default): fusion; 2: fusion the ops by fusion id.
  122. * @li fusion_id: An optional integer identifying the fusion id of the op.
  123. * The HcomReduce ops with the same fusion id will be fused.
  124. * @par Outputs:
  125. * y: A Tensor. Has the same type as "x".
  126. * @attention Constraints:
  127. *"group" is limited to 128 characters. Use "hccl_world_group"
  128. as the name of a world group.
  129. */
  130. REG_OP(HcomReduce)
  131. .INPUT(x, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16}))
  132. .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16}))
  133. .REQUIRED_ATTR(root_rank, Int)
  134. .REQUIRED_ATTR(reduction, String)
  135. .REQUIRED_ATTR(group, String)
  136. .ATTR(fusion, Int, 0)
  137. .ATTR(fusion_id, Int, -1)
  138. .OP_END_FACTORY_REG(HcomReduce)
  139. /**
  140. * @brief Performs reduction across all input tensors, scattering in equal
  141. blocks among ranks, each rank getting a chunk of data based on its rank
  142. index.
  143. * @par Inputs:
  144. * x: A tensor. Must be one of the following types: int8, int16, int32, float16,
  145. float32.
  146. * @par Attributes:
  147. * @li reduction: A required string identifying the reduction operation to
  148. perform. The supported operation are: "sum", "max", "min", "prod".
  149. * @li group: A required string identifying the group name of ranks
  150. participating in the op.
  151. * @li rank_size: A required integer identifying the number of ranks
  152. participating in the op.
  153. * @par Outputs:
  154. * y: A Tensor. Has the same type as "x".
  155. * @attention Constraints:
  156. "group" is limited to 128 characters. Use "hccl_world_group"
  157. as the name of a world group.
  158. */
  159. REG_OP(HcomReduceScatter)
  160. .INPUT(x, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16}))
  161. .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16}))
  162. .REQUIRED_ATTR(reduction, String)
  163. .REQUIRED_ATTR(group, String)
  164. .REQUIRED_ATTR(rank_size, Int)
  165. .OP_END_FACTORY_REG(HcomReduceScatter)
  166. /**
  167. * @brief Sends the input tensor to destination rank.
  168. * @par Inputs:
  169. * x: A tensor. Must be one of the following types: int8, int16, int32, float16,
  170. float32.
  171. * @par Attributes:
  172. * @li sr_tag: A required integer identifying the send/recv message tag. The
  173. message will be received by the HcomReceive op with the same "sr_tag".
  174. * @li dest_rank: A required integer identifying the destination rank.
  175. * @li group: A string identifying the group name of ranks participating in
  176. the op.
  177. * @par Outputs:
  178. * None.
  179. * @attention Constraints:
  180. @li "group" is limited to 128 characters. Use
  181. "hccl_world_group" as the name of a world group.
  182. * @li Operators HcomSend and HcomReceive have the same "sr_tag".
  183. * @see HcomReceive
  184. */
  185. REG_OP(HcomSend)
  186. .INPUT(x, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16, DT_INT64, DT_UINT64,
  187. DT_UINT8, DT_UINT16, DT_UINT32, DT_FLOAT64}))
  188. .REQUIRED_ATTR(group, String)
  189. .REQUIRED_ATTR(sr_tag, Int)
  190. .REQUIRED_ATTR(dest_rank, Int)
  191. .OP_END_FACTORY_REG(HcomSend)
  192. /**
  193. * @brief Receives the tensor from source rank.
  194. * @par Inputs:
  195. * None.
  196. * @par Attributes:
  197. * @li sr_tag: A required integer identifying the send/recv message tag. The
  198. message will be send by the HcomSend op with the same "sr_tag".
  199. * @li src_rank: A required integer identifying the source rank.
  200. * @li group: A required string identifying the group name of ranks
  201. * participating in the op.
  202. * @li shape: A required list identifying the shape of the tensor to be
  203. received.
  204. * @li dtype: A required integer identifying the type of the tensor to be
  205. received. The supported types are: int8, int16, int32, float16, float32.
  206. * @par Outputs:
  207. * y: A tensor with type identified in "dtype".
  208. * @attention Constraints:
  209. @li "group" is limited to 128 characters. Use
  210. "hccl_world_group" as the name of a world group.
  211. * @li Operators HcomSend and HcomReceive have the same "sr_tag".
  212. * @li "shape" should be same as the input tensor of HcomSend.
  213. * @li "dtype" should be same as the input tensor of HcomSend.
  214. * @see HcomSend
  215. */
  216. REG_OP(HcomReceive)
  217. .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16, DT_INT64, DT_UINT64,
  218. DT_UINT8, DT_UINT16, DT_UINT32, DT_FLOAT64}))
  219. .REQUIRED_ATTR(group, String)
  220. .REQUIRED_ATTR(sr_tag, Int)
  221. .REQUIRED_ATTR(src_rank, Int)
  222. .REQUIRED_ATTR(shape, ListInt)
  223. .REQUIRED_ATTR(dtype, Type)
  224. .OP_END_FACTORY_REG(HcomReceive)
  225. /**
  226. * @brief Performs Remote Read of input tensors
  227. * @par Inputs:
  228. * remote: A tensor. describing the remote memory address to read: u64 remoteId, u64 addrRemote, u64 length
  229. * @par Outputs:
  230. * local: A Tensor. whose value is length / size_of(Type)
  231. */
  232. REG_OP(HcomRemoteRead)
  233. .INPUT(remote, TensorType({DT_INT64, DT_UINT64}))
  234. .OUTPUT(local, TensorType::ALL())
  235. .REQUIRED_ATTR(dtype, Type)
  236. .OP_END_FACTORY_REG(HcomRemoteRead)
  237. /**
  238. * @brief Performs Remote Ref Read of input tensors
  239. * @par Inputs:
  240. * remote: A tensor. describing the remote memory address to read: u64 remoteId, u64 addrRemote, u64 length
  241. * cache_var: The local base address
  242. * local_offset: Skip step length
  243. * @par Outputs:
  244. * cache_var: The local base address
  245. */
  246. REG_OP(HcomRemoteRefRead)
  247. .INPUT(remote, TensorType({DT_UINT64}))
  248. .INPUT(cache_var, TensorType({DT_UINT64}))
  249. .INPUT(local_offset, TensorType({DT_UINT64}))
  250. .OUTPUT(cache_var, TensorType({DT_UINT64}))
  251. .REQUIRED_ATTR(dtype, Type)
  252. .OP_END_FACTORY_REG(HcomRemoteRefRead)
  253. /**
  254. * @brief Performs Remote Write of input tensors
  255. * @par Inputs:
  256. * remote: A tensor. describing the remote memory address to write: u64 remoteId, u64 addrRemote, u64 length
  257. * @par Inputs:
  258. * local: A Tensor. whose value is length / size_of(Type)
  259. */
  260. REG_OP(HcomRemoteWrite)
  261. .INPUT(remote, TensorType({DT_INT64, DT_UINT64}))
  262. .INPUT(local, TensorType::ALL())
  263. .OP_END_FACTORY_REG(HcomRemoteWrite)
  264. /**
  265. * @brief Performs Remote Write of input tensors
  266. * @par Inputs:
  267. * remote: A tensor. describing the remote memory address to write: u64 remoteId, u64 addrRemote, u64 length
  268. * @par Inputs:
  269. * local: A Tensor. whose value is length / size_of(Type)
  270. */
  271. REG_OP(HcomRemoteScatterWrite)
  272. .INPUT(remote, TensorType({DT_INT64, DT_UINT64}))
  273. .INPUT(local, TensorType::ALL())
  274. .OPTIONAL_INPUT(local_offset, TensorType({DT_UINT64}))
  275. .OP_END_FACTORY_REG(HcomRemoteScatterWrite)
  276. /**
  277. * @brief All ranks send different amount of data to, and receive different
  278. amount of data from, all ranks.
  279. * @par Inputs:
  280. * Five inputs, including:
  281. * @li send_data: A tensor. the memory to send.
  282. * @li send_counts: A list, where entry i specifies the number of elements in
  283. send_data to send to rank i.
  284. * @li send_displacements: A list, where entry i specifies the displacement
  285. (offset from sendbuf) from which to send data to rank i.
  286. * @li recv_counts: A list, where entry i specifies the number of
  287. elements to receive from rank i.
  288. * @li recv_displacements: A list, , where entry i specifies the displacement
  289. (offset from recv_data) to which data from rank i should be written.
  290. * @par Outputs:
  291. * recv_data: A Tensor has same element type as send_data.
  292. * @par Attributes:
  293. * @li group: A string identifying the group name of ranks participating in
  294. the op.
  295. * @attention all ranks participating in the op should be full-mesh networking
  296. using the RDMA.
  297. */
  298. REG_OP(HcomAllToAllV)
  299. .INPUT(send_data, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16, DT_INT64, DT_UINT64,
  300. DT_UINT8, DT_UINT16, DT_UINT32, DT_FLOAT64}))
  301. .INPUT(send_counts, TensorType({DT_INT64}))
  302. .INPUT(send_displacements, TensorType({DT_INT64}))
  303. .INPUT(recv_counts, TensorType({DT_INT64}))
  304. .INPUT(recv_displacements, TensorType({DT_INT64}))
  305. .OUTPUT(recv_data, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16, DT_INT64, DT_UINT64,
  306. DT_UINT8, DT_UINT16, DT_UINT32, DT_FLOAT64}))
  307. .REQUIRED_ATTR(group, String)
  308. .OP_END_FACTORY_REG(HcomAllToAllV)
  309. /**
  310. * @brief All ranks send different amount of data to, and receive different
  311. amount of data from, all ranks. And concat all data descripting by addrinfo
  312. togather into output gathered.
  313. * @par Inputs:
  314. * Four inputs, including:
  315. * @li addrinfo: A tensor, descripting the memory info(address, length) to send.
  316. * @li addrinfo_count_per_rank: A list, where entry i specifies the number of
  317. elements in send_data to send to rank i.
  318. * @li recv_counts: A list, where entry i specifies the number of
  319. elements to receive from rank i.
  320. * @li recv_displacements: A list, , where entry i specifies the displacement
  321. (offset from recv_data) to which data from rank i should be written.
  322. * @par Outputs:
  323. * Two outputs, including:
  324. * @li recv_data: A Tensor has same element type as dtype.
  325. * @li gathered: A Tensor has same element type as dtype.
  326. * @par Attributes:
  327. * @li group: A string identifying the group name of ranks participating in
  328. the op.
  329. * @li dtype: Datatype of send buffer elements.
  330. * @li addr_length: descripting the element memory length in the addrinfo.
  331. -2: all element memory length in the addrinfo is the same, but it is unknown.
  332. -1: all element memory length is unknown.
  333. >0: all element memory length in the addrinfo is the same. the attr value is the memory length.
  334. * @attention all ranks participating in the op should be full-mesh networking
  335. using the RDMA.
  336. */
  337. REG_OP(HcomGatherAllToAllV)
  338. .INPUT(addrinfo, TensorType({DT_UINT64}))
  339. .INPUT(addrinfo_count_per_rank, TensorType({DT_INT64}))
  340. .INPUT(recv_counts, TensorType({DT_INT64}))
  341. .INPUT(recv_displacements, TensorType({DT_INT64}))
  342. .OUTPUT(recv_data, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16, DT_INT64, DT_UINT64,
  343. DT_UINT8, DT_UINT16, DT_UINT32, DT_FLOAT64}))
  344. .OUTPUT(gathered, TensorType({DT_FLOAT, DT_INT32, DT_INT8, DT_INT16, DT_FLOAT16, DT_INT64, DT_UINT64,
  345. DT_UINT8, DT_UINT16, DT_UINT32, DT_FLOAT64}))
  346. .REQUIRED_ATTR(group, String)
  347. .REQUIRED_ATTR(dtype, Type)
  348. .REQUIRED_ATTR(addr_length, Int)
  349. .OP_END_FACTORY_REG(HcomGatherAllToAllV)
  350. } // namespace ge
  351. #endif // OPS_BUILT_IN_OP_PROTO_INC_HCOM_OPS_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示