You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

op_parser_util.h 15 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef INC_FRAMEWORK_COMMON_OP_OP_PARSER_UTIL_H_
  17. #define INC_FRAMEWORK_COMMON_OP_OP_PARSER_UTIL_H_
  18. #include <cce/dnn.h>
  19. #include <limits.h>
  20. #include <math.h>
  21. #include <stdint.h>
  22. namespace domi {
  23. // General
  24. const float DEFAULT_ALPHA_VALUE = 1.0;
  25. const float DEFAULT_BETA_VALUE = 0.0;
  26. const uint32_t NORMAL_INPUT_NUM = 1;
  27. const uint32_t NORMAL_OUTPUT_NUM = 1;
  28. const uint32_t NORMAL_WORKSPACE_NUM = 0;
  29. const int32_t NORMAL_1D_DIM_NUM = 1;
  30. const int32_t NORMAL_SCALE_DIM_NUM = 0;
  31. const int NORMAL_TENSOR_FORMAT = static_cast<const int>(cce::CC_TENSOR_NC1HWC0);
  32. const int NORMAL_TENSOR_SIZE = 4;
  33. const int NORMAL_DEVICE_DATA_TYPE = static_cast<const int>(cce::CC_DATA_HALF);
  34. const int DEFAULT_POOLING_MODE = static_cast<const int>(cce::CC_POOLING_MAX);
  35. const uint32_t DEFAULT_REAL_DIM_CNT = 4;
  36. // Const
  37. const uint32_t CONST_OP_INPUT_NUM = 0;
  38. const uint32_t CONST_OP_NORMAL_WEIGHT_SIZE = 1;
  39. // MatMul
  40. const uint32_t MATMUL_INPUT_NUM = 2;
  41. // ActivationGrad
  42. const int32_t ACTIVATIONGRAD_INPUT_NUM = 2;
  43. // FusedBatchNorm
  44. const int32_t FUSED_BATCH_NORM_WORKSPACE_NUM = 1;
  45. const int32_t FUSED_BATCH_NORM_INPUT_NUM = 5;
  46. const int32_t FUSED_BATCH_NORM_OUTPUT_NUM = 5;
  47. // FusedBatchNormGrad
  48. const int32_t FUSEDBATCHNORMGRAD_WORKSPACE_NUM = 1;
  49. const int32_t FUSEDBATCHNORMGRAD_INPUT_NUM = 5;
  50. const int32_t FUSEDBATCHNORMGRAD_OUTPUT_NUM = 3;
  51. // Conv
  52. const uint32_t CONVOLUTION_WORKSPACE_NUM = 1;
  53. const uint32_t CONVOLUTION_PAD_SIZE = 4;
  54. const uint32_t CONVOLUTION_STRIDE_SIZE = 2;
  55. const uint32_t CONVOLUTION_DILATION_SIZE = 2;
  56. const int32_t CONVOLUTION_ADJ_SIZE = 2;
  57. const int32_t CONVOLUTION_TARGET_SHAPE_SIZE = 2;
  58. // ConvGradFilter
  59. const uint32_t CONVGRADFILTER_WORKSPACE_NUM = 1;
  60. const uint32_t CONVGRADFILTER_INPUT_NUM = 3;
  61. // Pooling
  62. const uint32_t POOLING_WINDOW_SIZE = 2;
  63. const uint32_t POOLING_STRIDE_SIZE = 2;
  64. const uint32_t POOLING_PAD_SIZE = 4;
  65. // Add Sub Mul
  66. const uint32_t ADD_INPUT_NUM = 2;
  67. const uint32_t SUB_INPUT_NUM = 2;
  68. const uint32_t MUL_INPUT_NUM = 2;
  69. const uint32_t DIV_INPUT_NUM = 2;
  70. const uint32_t ADD_WORKSPACE_NUM = 1;
  71. const uint32_t SUB_WORKSPACE_NUM = 1;
  72. const uint32_t MUL_WORKSPACE_NUM = 1;
  73. const uint32_t DIV_WORKSPACE_NUM = 1;
  74. const int32_t DEFAULT_AXIS_VALUE = -1;
  75. const int32_t RESHAPE_AXIS_DEFAULT_VALUE = 0;
  76. const int32_t RESHAPE_NUM_AXES_DEFAULT_VALUE = -1;
  77. const uint32_t RESHAPE_WORKSPACE_NUM = 1;
  78. const uint32_t FLATTEN_WORKSPACE_NUM = 1;
  79. const int32_t CONCAT_MIN_INPUT_SIZE = 1;
  80. const int32_t CONCAT_DEFAULT_AXIS = 1;
  81. const uint32_t CONCAT_WORKSPACE_NUM = 1;
  82. // The value for LRN parameters
  83. const uint32_t LRN_DEFAULT_NORM_REGION = 0;
  84. const float LRN_DEFAULT_K = 1.0;
  85. const uint32_t LRN_DEFAULT_LOCAL_SIZE = 5;
  86. const float LRN_DEFAULT_ALPHA = 1.0;
  87. const float LRN_DEFAULT_BETA = 0.75;
  88. ///
  89. /// @ingroup domi_common
  90. /// @brief default value of roipooling
  91. ///
  92. const uint32_t ROIPOOLING_DEFAULT_POOLED_H = 0;
  93. const uint32_t ROIPOOLING_DEFAULT_POOLED_W = 0;
  94. const float ROIPOOLING_DEFAULT_SPATIAL_SCALE = 1;
  95. const int32_t ROIPOOLING_DEFAULT_SAMPLING_RATIO = -1;
  96. // DetectionOutput
  97. const int32_t DETECTIONOUTPUT_INPUT_SIZE = 3;
  98. const int32_t DETECTIONOUTPUT_OUTPUT_SIZE = 2;
  99. const int32_t DETECTIONOUTPUT_WORKSPACE_NUM = 1;
  100. const int DETECTIONOUTPUT_CLASS_NUM = 20;
  101. const int DETECTIONOUTPUT_NUM_CLASSES_DEFAULT_VALUE = 21;
  102. const float DETECTIONOUTPUT_NMS_THRESHOLD_DEFAULT_VALUE = 0.3;
  103. const float DETECTIONOUTPUT_CONFIDENCE_THRESHOLD_DEFAULT_VALUE = 0.8;
  104. // Proposal
  105. const int32_t PROPOSAL_INPUT_SIZE = 3;
  106. const int32_t PROPOSAL_OUTPUT_MAX_SIZE = 2;
  107. const int32_t PROPOSAL_WORKSPACE_NUM = 1;
  108. const float PROPOSAL_BASE_SIZE_DEFAULT_VALUE = 16;
  109. const float PROPOSAL_RATIO_DIM_0_DEFAULT_VALUE = 0.5;
  110. const float PROPOSAL_RATIO_DIM_1_DEFAULT_VALUE = 1;
  111. const float PROPOSAL_RATIO_DIM_2_DEFAULT_VALUE = 2;
  112. const float PROPOSAL_SCALE_DIM_0_DEFAULT_VALUE = 8;
  113. const float PROPOSAL_SCALE_DIM_1_DEFAULT_VALUE = 16;
  114. const float PROPOSAL_SCALE_DIM_2_DEFAULT_VALUE = 32;
  115. const float PROPOSAL_MIN_SIZE_DEFAULT_VALUE = 16;
  116. const int PROPOSAL_PRE_NMS_TOPN_DEFAULT_VALUE = 6000;
  117. const int PROPOSAL_POST_NMS_TOPN_DEFAULT_VALUE = 304;
  118. const float PROPOSAL_NMS_THRESH_DEFAULT_VALUE = 0.7;
  119. const float PROPOSAL_FILTER_THRESH_DEFAULT_VALUE = 0;
  120. // TVM OP
  121. const uint32_t DEFAULT_KERNEL_BLOCK_DIM = 1;
  122. // Softmax
  123. const int32_t SOFTMAX_WORKSPACE_NUM = 1;
  124. // SoftmaxCrossEntropy
  125. const int32_t SOFTMAXCROSSENTROPY_INPUT_NUM = 2;
  126. const int32_t SOFTMAXCROSSENTROPY_OUTPUT_NUM = 2;
  127. // Permute
  128. const int32_t PERMUTE_INPUT_NUM = 1;
  129. const int32_t PERMUTE_OUTPUT_NUM = 1;
  130. const int32_t PERMUTE_WORKSPACE_NUM = 1;
  131. const int32_t PERMUTE_ORDER_NUM = 4;
  132. // Ssd normalize
  133. const int SSD_NORMALIZE_INPUT_SIZE = 1;
  134. const float SSD_NORMALIZE_EPS_DEFAULT_VALUE = 2e-7;
  135. // SsdPriroBox
  136. const int32_t SSD_PRIOR_BOX_WORKSPACE_NUM = 1;
  137. const int32_t SSD_PRIOR_BOX_INPUT_NUM = 2;
  138. const bool SSD_PRIOR_BOX_FLIP_VALUE = true;
  139. const bool SSD_PRIOR_BOX_CLIP_VALUE = false;
  140. const double SSD_PRIOR_BOX_ASPECT_OFFSET_VALUE = 0.5;
  141. const double SSD_PRIORBOX_VARIANCE_VALUE = 0.1;
  142. const double SSD_PRIORBOX_VARIANCE_SIZE_ONE = 1;
  143. const double SSD_PRIORBOX_VARIANCE_SIZE_FOUR = 4;
  144. const double SSD_PRIORBOX_ASPECT_RATIO_VALUE = 1.0;
  145. const int SSD_PRIOR_BOX_CODETYPE_CORNER_VALUE = 1;
  146. const int SSD_PRIOR_BOX_CODETYPE_CENTER_SIZE_VALUE = 2;
  147. const int SSD_PRIOR_BOX_CODETYPE_CORNER_SIZE_VALUE = 3;
  148. // Ssd DetectionOutput
  149. const int32_t SSD_DETECTIONOUTPUT_INPUT_SIZE = 3;
  150. const int32_t SSD_DETECTIONOUTPUT_INPUT_SIZE_AFTER_FUSION = 2;
  151. const int32_t SSD_DETECTIONOUTPUT_OUTPUT_SIZE = 2;
  152. const int32_t SSD_DETECTIONOUTPUT_OUTPUT_SIZE_AFTER_FUSION = 3;
  153. const int32_t SSD_DETECTIONOUTPUT_WORKSPACE_NUM = 1;
  154. const int32_t SSD_DETECTIONOUTPUT_WORKSPACE_NUM_AFTER_FUSION = 0;
  155. const bool SSD_DETECTIONOUTPUT_SHARED_LOCATION_DEFAULT_VALUE = true;
  156. const int32_t SSD_DETECTIONOUTPUT_BACKGROUND_LABEL_ID_DEFAULT_VALUE = 0;
  157. const float SSD_DETECTIONOUTPUT_NMS_THRESHOLD_DEFAULT_VALUE = 0.3;
  158. const int32_t SSD_DETECTIONOUTPUT_TOP_K_DEFAULT_VALUE = 200;
  159. const float SSD_DETECTIONOUTPUT_ETA_DEFAULT_VALUE = 1.0;
  160. const int SSD_DETECTIONOUTPUT_CODE_TYPE_DEFAULT_VALUE = static_cast<const int>(cce::CC_BOX_CENTER_SIZE);
  161. const int32_t SSD_DETECTIONOUTPUT_KEEP_TOP_K_DEFAULT_VALUE = 200;
  162. const bool SSD_DETECTIONOUTPUT_VARIANCE_ENCODED_IN_TARGET_DEFAULT_VALUE = false;
  163. const float SSD_DETECTIONOUTPUT_CONFIDENCE_THRESHOLD_DEFAULT_VALUE = 0.1;
  164. // Refinedet DetectionOutput
  165. const int32_t REFINEDET_DETECTIONOUTPUT_INPUT_SIZE = 5;
  166. const int32_t REFINEDET_DETECTIONOUTPUT_INPUT_SIZE_AFTER_FUSION = 2;
  167. const int32_t REFINEDET_DETECTIONOUTPUT_OUTPUT_SIZE = 2;
  168. const int32_t REFINEDET_DETECTIONOUTPUT_OUTPUT_SIZE_AFTER_FUSION = 3;
  169. const int32_t REFINEDET_DETECTIONOUTPUT_WORKSPACE_NUM = 1;
  170. const bool REFINEDET_DETECTIONOUTPUT_SHARED_LOCATION_DEFAULT_VALUE = true;
  171. const int32_t REFINEDET_DETECTIONOUTPUT_BACKGROUND_LABEL_ID_DEFAULT_VALUE = 0;
  172. const float REFINEDET_DETECTIONOUTPUT_NMS_THRESHOLD_DEFAULT_VALUE = 0.3;
  173. const int32_t REFINEDET_DETECTIONOUTPUT_TOP_K_DEFAULT_VALUE = 200;
  174. const float REFINEDET_DETECTIONOUTPUT_ETA_DEFAULT_VALUE = 1.0;
  175. const bool REFINEDET_DETECTIONOUTPUT_VARIANCE_ENCODED_IN_TARGET_DEFAULT_VALUE = false;
  176. const int REFINEDET_DETECTIONOUTPUT_CODE_TYPE_DEFAULT_VALUE = static_cast<const int>(cce::CC_BOX_CENTER_SIZE);
  177. const int32_t REFINEDET_DETECTIONOUTPUT_KEEP_TOP_K_DEFAULT_VALUE = 200;
  178. const float REFINEDET_DETECTIONOUTPUT_CONFIDENCE_THRESHOLD_DEFAULT_VALUE = 0.1;
  179. const float REFINEDET_DETECTIONOUTPUT_OBJECTNESS_SCORE_DEFAULT_VALUE = 0;
  180. // Channel axpy
  181. const int32_t CHANNEL_AXPY_INPUT_NUM = 3;
  182. const int32_t CHANNEL_AXPY_INPUT_DIM_SIZE = 4;
  183. const int32_t CHANNEL_AXPY_WORKSPACE_NUM = 1;
  184. // Psroi pooling
  185. const int PSROI_POOLING_INPUT_COUNT = 2;
  186. const int PSROI_POOLING_WORKSPACE_NUM = 1;
  187. // MaxPoolWithArgmax
  188. const uint32_t MAX_POOL_WITH_ARGMAX_OUTPUT_NUM = 2;
  189. const uint32_t MAX_POOL_GRAD_WITH_ARGMAX_INPUT_NUM = 3;
  190. // AvgPoolGrad
  191. const uint32_t AVG_POOL_GRAD_INPUT_NUM = 2;
  192. // ROIAlign
  193. const int32_t ROIALIGN_INPUT_SIZE = 2;
  194. const int32_t ROIALIGN_WORKSPACE_NUM = 1;
  195. const int32_t ROIALIGN_DEFAULT_POOLED_H = 1;
  196. const int32_t ROIALIGN_DEFAULT_POOLED_W = 1;
  197. // Correlation
  198. const uint32_t CORRELATION_INPUT_NUM = 2;
  199. const int CORRELATION_WORKSPACE_NUM = 1;
  200. // Detectionpostprocess
  201. const int32_t POSTPROCESS_INPUT_SIZE = 4;
  202. const int32_t POSTPROCESS_OUTPUT_SIZE = 2;
  203. const int32_t POSTPROCESS_WORKSPACE_NUM = 1;
  204. const uint32_t POSTPROCESS_CLS_NUM_DEFAULT_VALUE = 12;
  205. const uint32_t POSTPROCESS_POST_NMS_TOPN_DEFAULT_VALUE = 100;
  206. const float POSTPROCESS_NMS_THRESH_DEFAULT_VALUE = 0.3;
  207. const float POSTPROCESS_CONF_THRESH_DEFAULT_VALUE = 0.5;
  208. const float POSTPROCESS_BBOX_REG_WEIGHT_DIM_DEFAULT_VALUE = 1.0;
  209. const int32_t POSTPROCESS_BBOX_REG_WEIGHT_SIZE_DEFAULT_VALUE = 4;
  210. // Split
  211. const int32_t SPLIT_INPUT_NUM = 2;
  212. const int32_t SPLIT_DEFAULT_AXIS_VALUE = 1;
  213. const int32_t SPLIT_MIN_OUTPUT_SIZE = 1;
  214. const uint32_t STRIDEDSLICE_INPUT_NUM = 4;
  215. // Slice
  216. const int32_t SLICE_INPUT_NUM = 3;
  217. const int32_t SLICE_WEIGHT_NUM = 2;
  218. // GatherNd
  219. const int32_t GATHERND_INPUT_NUM = 2;
  220. // ArgMax
  221. const int32_t ARGMAX_INPUT_NUM = 2;
  222. const int32_t ARGMAX_REAL_INPUT_NUM = 1;
  223. // HighWay
  224. const int32_t HIGHWAY_INPUT_NUM = 4;
  225. const int32_t HIGHWAY_WORKSPACE_NUM = 1;
  226. // RealDiv
  227. const int32_t REALDIV_INPUT_NUM = 2;
  228. // Range
  229. const int32_t RANGE_INPUT_NUM = 3;
  230. const int32_t RANGE_OUTPUT_NUM = 1;
  231. const int32_t RANGE_INPUT_DIM_SIZE = 0;
  232. // Pad
  233. const int32_t PAD_WEIGHT_NUM = 1;
  234. const int32_t PAD_DIM_SIZE = 2;
  235. const int32_t PAD_DIM0 = 4;
  236. const int32_t PAD_DIM1 = 2;
  237. const int32_t PAD_WEIGHT_WITH_CONSTANT_NUM = 2;
  238. const int32_t PAD_CONSTATNT_DEFAULT_VALUE = 0;
  239. const int32_t PAD_PADDINGS_SIZE = 8;
  240. // Tile
  241. const int32_t TILE_WEIGHT_NUM = 1;
  242. const int32_t TILE_MULTIPLES_DIM_SIZE = 1;
  243. // DecodeBbox
  244. const int32_t DECODE_BBOX_INPUT_NUM = 2;
  245. // GenerateRpnProposals
  246. const int32_t GENERATE_RPN_PROPOSAL_INPUT_SIZE = 2;
  247. const int32_t GENERATE_RPN_PROPOSAL_OUTPUT_SIZE = 3;
  248. // Decode_BBox
  249. const int32_t DECODE_BBOX_INPUT_SIZE = 2;
  250. const int32_t DEFAULT_DECODE_CLIP_VALUE = 0;
  251. // FastRcnnPredictions
  252. const int32_t FASTRCNN_PREDICTIONS_INPUT_SIZE = 2;
  253. const int32_t FASTRCNN_PREDICTIONS_OUTPUT_SIZE = 4;
  254. const int32_t CLIP_BOXES_INPUT_NUM = 1;
  255. const int32_t CLIP_BOXES_WEIGHT_SIZE = 1;
  256. const int32_t CLIP_BOXES_WEIGHT_ITEM_SIZE = 2;
  257. const int32_t CLIP_BOXES_OUTPUT_NUM = 1;
  258. const int32_t FLOORDIV_INPUT_NUM = 2;
  259. // Mean
  260. const int32_t MEAN_WEIGHT_SIZE = 1;
  261. const int32_t MEAN_WEIGHT_DIM_SIZE = 1;
  262. const int32_t MEAN_WEIGHT_DIM = 2;
  263. const int32_t MEAN_FIRST_AXIS = 2;
  264. const int32_t MEAN_SECOND_AXIS = 3;
  265. const int32_t MEAN_STRIDE_PLACE_HOLD = 1;
  266. // Switch
  267. const uint32_t SWITCH_INPUT_NUM = 2;
  268. const uint32_t SWITCH_OUTPUT_NUM = 2;
  269. // Merge
  270. const uint32_t MERGE_INPUT_NUM = 2;
  271. // Greater
  272. const uint32_t GREATER_OUTPUT_NUM = 1;
  273. const uint32_t GREATER_INPUT_NUM = 0;
  274. const uint32_t GREATER_WEIGHT_NUM = 2;
  275. // Yolo region
  276. const uint32_t YOLO_REGION_OUTPUT_NUM = 3;
  277. const uint32_t YOLO_REGION_WORKSPACE_NUM = 1;
  278. const uint32_t YOLO_REGION_COORDS = 4;
  279. const uint32_t YOLO_REGION_CLASSES = 20;
  280. const uint32_t YOLO_REGION_BOXES = 1;
  281. const bool YOLO_REGION_BACKGROUND = false;
  282. const bool YOLO_REGION_SOFTMAX = false;
  283. const bool YOLO_REGION_SOFTMAX_TREE = false;
  284. // Yolo detectionoutput
  285. const uint32_t YOLO_DETECTIONOUTPUT_INPUT_SIZE = 4;
  286. const uint32_t YOLO_DETECTIONOUTPUT_OUTPUT_SIZE = 2;
  287. const uint32_t YOLO_DETECTION_OUTPUT_WORKSPACE_NUM = 1;
  288. const uint32_t YOLO_DETECTION_OUTPUT_CLASSES = 20;
  289. const uint32_t YOLO_DETECTION_OUTPUT_BOXES_V2 = 5;
  290. const uint32_t YOLO_DETECTION_OUTPUT_BOXES_V3 = 3;
  291. const bool YOLO_DETECTION_OUTPUT_RELATIVE = true;
  292. const float YOLO_DETECTION_OUTPUT_OBJECTNESS_THRESHOLD = 0.5;
  293. const float YOLO_DETECTION_OUTPUT_CLASS_THRESHOLD = 0.5;
  294. const uint32_t YOLO_DETECTION_OUTPUT_POST_TOP_K = UINT_MAX;
  295. const float YOLO_DETECTION_OUTPUT_NMS_THRESHOLD = 0;
  296. const float YOLO_DETECTION_OUTPUT_IOU_THRESHOLD_DECAY = 1.0;
  297. const float YOLO_DETECTION_OUTPUT_COOR_SCALE_FACTOR = 1.0;
  298. // Reorg
  299. const int32_t REORG_DEFAULT_STRIDE = 2;
  300. const uint32_t REORG_INPUT_COUNT = 1;
  301. // Reshape
  302. const int32_t RESHAPE_INPUT_NUM = 2;
  303. // Maximum
  304. const int32_t MAXIMUM_INPUT_NUM = 2;
  305. // Spatialtf
  306. const int32_t SPATIALTF_WORKSPACE_NUM = 1;
  307. const int32_t REVERSE_DEFAULT_AXIS = 1;
  308. // Crop
  309. const int32_t CROP_AXIS = 2;
  310. const int32_t CROP_INPUT_NUM = 2;
  311. // ConvGradInput
  312. const uint32_t CONVGRADINPUT_WORKSPACE_NUM = 1;
  313. const uint32_t CONVGRADINPUT_INPUT_NUM = 3;
  314. // RNN
  315. const uint32_t RNN_WORKSPACE_NUM = 1;
  316. // Cropandresize
  317. const int32_t CROPANDRESIZE_WEIGHT_NUM = 1;
  318. const int32_t CROPANDRESIZE_CROP_DIM_SIZE = 1;
  319. const int32_t CROP_DIM0 = 2;
  320. // Attention decoder weight index
  321. const uint32_t ATTENTION_DECODER_WEIGHT_ATTENW0 = 0;
  322. const uint32_t ATTENTION_DECODER_WEIGHT_ATTENTION0_KERNEL = 1;
  323. const uint32_t ATTENTION_DECODER_WEIGHT_ATTNOUTPUTPROJECTION_KERNEL = 2;
  324. const uint32_t ATTENTION_DECODER_WEIGHT_ATTENTION_DECODER_KERNEL = 3;
  325. const uint32_t ATTENTION_DECODER_WEIGHT_CELL0_GATES_KERNEL = 4;
  326. const uint32_t ATTENTION_DECODER_WEIGHT_CELL0_CANDIDATE_KERNEL = 5;
  327. const uint32_t ATTENTION_DECODER_WEIGHT_CELL1_GATES_KERNEL = 6;
  328. const uint32_t ATTENTION_DECODER_WEIGHT_CELL1_CANDIDATE_KERNEL = 7;
  329. const uint32_t ATTENTION_DECODER_WEIGHT_ATTENTION0_BIAS = 8;
  330. const uint32_t ATTENTION_DECODER_WEIGHT_ATTNOUTPUTPROJECTION_BIAS = 9;
  331. const uint32_t ATTENTION_DECODER_WEIGHT_ATTENTION_DECODER_BIAS = 10;
  332. const uint32_t ATTENTION_DECODER_WEIGHT_CELL0_GATES_BIAS = 11;
  333. const uint32_t ATTENTION_DECODER_WEIGHT_CELL0_CANDIDATE_BIAS = 12;
  334. const uint32_t ATTENTION_DECODER_WEIGHT_CELL1_GATES_BIAS = 13;
  335. const uint32_t ATTENTION_DECODER_WEIGHT_CELL1_CANDIDATE_BIAS = 14;
  336. const uint32_t ATTENTION_DECODER_WEIGHT_EMBEDDING = 15;
  337. const uint32_t ATTENTION_DECODER_WEIGHT_ATTENVA = 16;
  338. const uint32_t ATTENTION_DECODER_WEIGHT_DECODER_INITIAL = 17;
  339. // Attention decoder weight size
  340. const uint32_t ATTENTION_DECODER_WEIGHT_SIZE = 18;
  341. const uint32_t ATTENTION_DECODER_INPUT_SIZE = 2;
  342. const uint32_t ATTENTION_DECODER_WORKSPACE_NUM = 1;
  343. const uint32_t ATTENTION_DECODER_INPUT_DECODER_INPUTS = 0;
  344. const uint32_t ATTENTION_DECODER_INPUT_DECODER_INITIAL_HIDDEN = 1;
  345. const int ATTENTION_DECODER_ALGO_NORMAL = 0;
  346. const int ATTENTION_DECODER_SYMBOLS = 10000;
  347. const int ATTENTION_DECODER_EMBEDDING_SIZE = 128;
  348. const int ATTENTION_DECODER_ATTENTION_NUM_HIDDEN = 256;
  349. const int ATTENTION_DECODER_DECODER_NUM_HIDDEN = 128;
  350. const int ATTENTION_DECODER_DECODER_NUM_LAYERS = 2;
  351. const int ATTENTION_DECODER_RNN_UNBIDIRECTIONAL = 0;
  352. const int ATTENTION_DECODER_SEQLEN_VALUE = 57;
  353. const int ATTENTION_DECODER_GRU = 3;
  354. // Logicaland
  355. const int32_t LOGICAL_AND_INPUT_NUM = 2;
  356. const int32_t EQUAL_INPUT_NUM = 2;
  357. static const int32_t OP_WEIGHT_MEM_BASE_OFFSET = 512;
  358. // MultiShape
  359. const uint32_t MULTI_SHAPE_INPUT_NUM = 2;
  360. // Shufflechannel
  361. const uint32_t SHUFFLECHANNEL_DEFAULT_GROUP = 1;
  362. } // namespace domi
  363. #endif // INC_FRAMEWORK_COMMON_OP_OP_PARSER_UTIL_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示

Contributors (1)