You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

attr_value_util.cc 20 kB

5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "framework/common/op/attr_value_util.h"
  17. #include "framework/common/debug/log.h"
  18. #include "framework/common/util.h"
  19. #include "external/register/register_types.h"
  20. namespace ge {
  21. #define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \
  22. FMK_FUNC_DEV_VISIBILITY void SetAttrDef(ARG_TYPE value, AttrDef *out) { \
  23. GE_CHECK_NOTNULL_JUST_RETURN(out); \
  24. out->set_##FIELD(value); \
  25. }
  26. #define DEFINE_SET_ATTR_VALUE_LIST(ARG_TYPE, FIELD) \
  27. FMK_FUNC_DEV_VISIBILITY void SetAttrList(ARG_TYPE value, AttrDef *out) { \
  28. GE_CHECK_NOTNULL_JUST_RETURN(out); \
  29. GE_CHECK_NOTNULL_JUST_RETURN(out->mutable_list()); \
  30. out->mutable_list()->add_##FIELD(value); \
  31. }
  32. DEFINE_SET_ATTR_VALUE_ONE(const std::string &, s);
  33. DEFINE_SET_ATTR_VALUE_ONE(const char *, s);
  34. DEFINE_SET_ATTR_VALUE_ONE(const uint32_t, u);
  35. DEFINE_SET_ATTR_VALUE_ONE(const int32_t, i);
  36. DEFINE_SET_ATTR_VALUE_ONE(const int64_t, i);
  37. DEFINE_SET_ATTR_VALUE_ONE(const float, f);
  38. DEFINE_SET_ATTR_VALUE_ONE(const double, f);
  39. DEFINE_SET_ATTR_VALUE_ONE(const bool, b);
  40. DEFINE_SET_ATTR_VALUE_LIST(float, f);
  41. DEFINE_SET_ATTR_VALUE_LIST(double, f);
  42. DEFINE_SET_ATTR_VALUE_LIST(uint32_t, u);
  43. DEFINE_SET_ATTR_VALUE_LIST(int32_t, i);
  44. DEFINE_SET_ATTR_VALUE_LIST(bool, b);
  45. DEFINE_SET_ATTR_VALUE_LIST(int64_t, i);
  46. DEFINE_SET_ATTR_VALUE_LIST(const std::string &, s);
  47. #define ADD_TO_ATTR_MAP(KEY, VALUE, ATTR_MAP) \
  48. do { \
  49. GE_CHECK_NOTNULL_JUST_RETURN(ATTR_MAP); \
  50. AttrDef out; \
  51. auto it = ATTR_MAP->find(KEY); \
  52. if (it != ATTR_MAP->end()) { \
  53. auto &attr_value = it->second; \
  54. SetAttrDef(VALUE, &attr_value); \
  55. } else { \
  56. SetAttrDef(VALUE, &out); \
  57. ATTR_MAP->insert(AttrDefPair(KEY, out)); \
  58. } \
  59. } while (0);
  60. #define ADD_TO_ATTR_MAP_LIST(KEY, VALUE, ATTR_MAP) \
  61. do { \
  62. GE_CHECK_NOTNULL_JUST_RETURN(ATTR_MAP); \
  63. AttrDef out; \
  64. auto it = ATTR_MAP->find(KEY); \
  65. if (it != ATTR_MAP->end()) { \
  66. auto &attr_value = it->second; \
  67. SetAttrList(VALUE, &attr_value); \
  68. } else { \
  69. SetAttrList(VALUE, &out); \
  70. ATTR_MAP->insert(AttrDefPair(KEY, out)); \
  71. } \
  72. } while (0);
  73. #define DEFINE_ADD_ATTR_VALUE(KEY_TYPE, VALUE_TYPE) \
  74. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void AddOpAttr(KEY_TYPE map_key, VALUE_TYPE value, OpDef *op_def) { \
  75. GE_CHECK_NOTNULL_JUST_RETURN(op_def); \
  76. auto attr = op_def->mutable_attr(); \
  77. ADD_TO_ATTR_MAP(map_key, value, attr) \
  78. } \
  79. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void AddOpAttr(KEY_TYPE map_key, VALUE_TYPE value, \
  80. AttrDefMap *attr_map) { \
  81. ADD_TO_ATTR_MAP(map_key, value, attr_map) \
  82. } \
  83. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void AddModelAttr(KEY_TYPE map_key, VALUE_TYPE value, \
  84. ModelDef *model_def) { \
  85. GE_CHECK_NOTNULL_JUST_RETURN(model_def); \
  86. auto attr = model_def->mutable_attr(); \
  87. ADD_TO_ATTR_MAP(map_key, value, attr) \
  88. }
  89. #define DEFINE_ADD_ATTR_VALUE_LIST(KEY_TYPE, VALUE_TYPE) \
  90. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void AddOpAttrList(KEY_TYPE map_key, VALUE_TYPE value, \
  91. OpDef *op_def) { \
  92. GE_CHECK_NOTNULL_JUST_RETURN(op_def); \
  93. auto attr = op_def->mutable_attr(); \
  94. ADD_TO_ATTR_MAP_LIST(map_key, value, attr) \
  95. } \
  96. FMK_FUNC_DEV_VISIBILITY void AddOpAttrList(KEY_TYPE map_key, VALUE_TYPE value, AttrDefMap *attr_map) { \
  97. ADD_TO_ATTR_MAP_LIST(map_key, value, attr_map) \
  98. } \
  99. FMK_FUNC_DEV_VISIBILITY void AddModelAttrList(KEY_TYPE map_key, VALUE_TYPE value, ModelDef *model_def) { \
  100. GE_CHECK_NOTNULL_JUST_RETURN(model_def); \
  101. auto attr = model_def->mutable_attr(); \
  102. ADD_TO_ATTR_MAP_LIST(map_key, value, attr) \
  103. }
  104. DEFINE_ADD_ATTR_VALUE(const std::string &, const std::string &);
  105. DEFINE_ADD_ATTR_VALUE(const char *, const char *);
  106. DEFINE_ADD_ATTR_VALUE(const std::string &, const char *);
  107. DEFINE_ADD_ATTR_VALUE(const std::string &, const uint32_t);
  108. DEFINE_ADD_ATTR_VALUE(const std::string &, const int32_t);
  109. DEFINE_ADD_ATTR_VALUE(const std::string &, const int64_t);
  110. DEFINE_ADD_ATTR_VALUE(const std::string &, const float);
  111. DEFINE_ADD_ATTR_VALUE(const std::string &, const double);
  112. DEFINE_ADD_ATTR_VALUE(const std::string &, const bool);
  113. DEFINE_ADD_ATTR_VALUE_LIST(const std::string &, const uint32_t);
  114. DEFINE_ADD_ATTR_VALUE_LIST(const std::string &, const float);
  115. DEFINE_ADD_ATTR_VALUE_LIST(const std::string &, const double);
  116. DEFINE_ADD_ATTR_VALUE_LIST(const std::string &, const int32_t);
  117. DEFINE_ADD_ATTR_VALUE_LIST(const std::string &, const bool);
  118. DEFINE_ADD_ATTR_VALUE_LIST(const std::string &, const int64_t);
  119. DEFINE_ADD_ATTR_VALUE_LIST(const std::string &, const std::string &);
  120. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void AddOpAttr(const std::string &map_key, AttrDef &attr,
  121. OpDef *op_def) {
  122. GE_CHECK_NOTNULL_JUST_RETURN(op_def);
  123. GE_CHECK_NOTNULL_JUST_RETURN(op_def->mutable_attr());
  124. (void)op_def->mutable_attr()->insert(AttrDefPair(map_key, attr));
  125. }
  126. #define DEFINE_GET_ATTR_VALUE(ARG_TYPE_KEY, ARG_TYPE_VALUE, FIELD) \
  127. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool GetAttrDefValue(ARG_TYPE_KEY map_key, ARG_TYPE_VALUE value, \
  128. const AttrDefMap &attr) { \
  129. auto it = attr.find(map_key); \
  130. if (it != attr.end()) { \
  131. *value = it->second.FIELD(); \
  132. return true; \
  133. } \
  134. return false; \
  135. }
  136. #define DEFINE_GET_ATTR_POINT_REF(ARG_TYPE_KEY, ARG_TYPE_VALUE, FIELD) \
  137. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool GetAttrDefValue(ARG_TYPE_KEY map_key, ARG_TYPE_VALUE *&value, \
  138. AttrDefMap *attr) { \
  139. GE_RT_FALSE_CHECK_NOTNULL(attr); \
  140. auto it = attr->find(map_key); \
  141. if (it != attr->end()) { \
  142. value = it->second.mutable_##FIELD(); \
  143. return true; \
  144. } \
  145. return false; \
  146. }
  147. #define DEFINE_GET_ATTR_CONST_POINT_REF(ARG_TYPE_KEY, ARG_TYPE_VALUE, FIELD) \
  148. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool GetAttrDefValue( \
  149. ARG_TYPE_KEY map_key, const ARG_TYPE_VALUE *&value, const AttrDefMap &attr) { \
  150. auto it = attr.find(map_key); \
  151. if (it == attr.end()) { \
  152. return false; \
  153. } \
  154. \
  155. value = &(it->second.FIELD()); \
  156. return true; \
  157. }
  158. #define DEFINE_GET_BYTES_ATTR_VALUE(ARG_TYPE_KEY, ARG_TYPE_VALUE) \
  159. bool GetBytesValue(ARG_TYPE_KEY key, ARG_TYPE_VALUE value, const AttrDefMap &attr) { \
  160. GE_RT_FALSE_CHECK_NOTNULL(value); \
  161. auto it = attr.find(key); \
  162. if (it != attr.end()) { \
  163. *value = it->second.bt(); \
  164. return true; \
  165. } \
  166. return false; \
  167. }
  168. #define DEFINE_GET_ATTR_LIST_VALUE(ARG_TYPE_KEY, ARG_TYPE_VALUE, FIELD) \
  169. FMK_FUNC_DEV_VISIBILITY bool GetAttrDefListValue(ARG_TYPE_KEY map_key, int idx, ARG_TYPE_VALUE value, \
  170. const AttrDefMap &attr) { \
  171. auto it = attr.find(map_key); \
  172. if (it == attr.end()) { \
  173. return false; \
  174. } \
  175. \
  176. const auto &list = it->second.list(); \
  177. if (idx < 0 || idx > list.FIELD##_size() - 1) { \
  178. return false; \
  179. } \
  180. \
  181. *value = list.FIELD(idx); \
  182. return true; \
  183. }
  184. DEFINE_GET_ATTR_VALUE(const std::string &, std::string *, s);
  185. DEFINE_GET_ATTR_VALUE(const std::string &, int32_t *, i);
  186. DEFINE_GET_ATTR_VALUE(const std::string &, int64_t *, i);
  187. DEFINE_GET_ATTR_VALUE(const std::string &, uint32_t *, u);
  188. DEFINE_GET_ATTR_VALUE(const std::string &, float *, f);
  189. DEFINE_GET_ATTR_VALUE(const std::string &, double *, f);
  190. DEFINE_GET_ATTR_VALUE(const std::string &, bool *, b);
  191. DEFINE_GET_ATTR_VALUE(const std::string &, AttrDef_ListValue *, list);
  192. DEFINE_GET_ATTR_LIST_VALUE(const std::string &, int32_t *, i);
  193. DEFINE_GET_ATTR_LIST_VALUE(const std::string &, uint32_t *, u);
  194. DEFINE_GET_ATTR_LIST_VALUE(const std::string &, float *, f);
  195. DEFINE_GET_ATTR_LIST_VALUE(const std::string &, double *, f);
  196. DEFINE_GET_ATTR_POINT_REF(const std::string &, NamedAttrs, func);
  197. DEFINE_GET_ATTR_CONST_POINT_REF(const std::string &, NamedAttrs, func);
  198. DEFINE_GET_BYTES_ATTR_VALUE(const std::string &, std::string *);
  199. #define DEFINE_GET_OP_ATTR(ARG_TYPE_KEY, ARG_TYPE_VALUE) \
  200. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool GetOpAttr(ARG_TYPE_KEY map_key, ARG_TYPE_VALUE value, \
  201. const OpDef *op_def) { \
  202. GE_RT_FALSE_CHECK_NOTNULL(op_def); \
  203. return GetAttrDefValue(map_key, value, op_def->attr()); \
  204. } \
  205. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool GetModelAttr(ARG_TYPE_KEY map_key, ARG_TYPE_VALUE value, \
  206. const ModelDef *model_def) { \
  207. GE_RT_FALSE_CHECK_NOTNULL(model_def); \
  208. return GetAttrDefValue(map_key, value, model_def->attr()); \
  209. }
  210. DEFINE_GET_OP_ATTR(const std::string &, std::string *);
  211. DEFINE_GET_OP_ATTR(const std::string &, int32_t *);
  212. DEFINE_GET_OP_ATTR(const std::string &, int64_t *);
  213. DEFINE_GET_OP_ATTR(const std::string &, uint32_t *);
  214. DEFINE_GET_OP_ATTR(const std::string &, float *);
  215. DEFINE_GET_OP_ATTR(const std::string &, double *);
  216. DEFINE_GET_OP_ATTR(const std::string &, bool *);
  217. DEFINE_GET_OP_ATTR(const std::string &, AttrDef_ListValue *);
  218. #define DEFINE_GET_BT_ATTR(ARG_TYPE_KEY, ARG_TYPE_VALUE) \
  219. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool GetBytesAttr(ARG_TYPE_KEY key, ARG_TYPE_VALUE value, \
  220. const OpDef *op_def) { \
  221. GE_RT_FALSE_CHECK_NOTNULL(op_def); \
  222. return GetBytesValue(key, value, op_def->attr()); \
  223. } \
  224. FMK_FUNC_DEV_VISIBILITY bool GetBytesAttr(ARG_TYPE_KEY key, ARG_TYPE_VALUE value, const ModelDef *model_def) { \
  225. GE_RT_FALSE_CHECK_NOTNULL(model_def); \
  226. return GetBytesValue(key, value, model_def->attr()); \
  227. }
  228. DEFINE_GET_BT_ATTR(const std::string &, std::string *);
  229. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool HasOpAttr(const OpDef *op_def, const std::string &attr_name) {
  230. if (op_def == nullptr) {
  231. return false;
  232. }
  233. const AttrDefMap &attr = op_def->attr();
  234. const AttrDefMap::const_iterator it = attr.find(attr_name);
  235. if (it != attr.end()) {
  236. return true;
  237. }
  238. return false;
  239. }
  240. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void AddModelAttr(const std::string &map_key, const void *value,
  241. size_t size, ModelDef *model_def) {
  242. if (model_def == nullptr) {
  243. return;
  244. }
  245. AttrDef out;
  246. auto attr = model_def->mutable_attr();
  247. auto it = attr->find(map_key);
  248. if (it != attr->end()) {
  249. auto &attr_value = it->second;
  250. attr_value.set_bt(value, size);
  251. } else {
  252. out.set_bt(value, size);
  253. attr->insert(AttrDefPair(map_key, out));
  254. }
  255. }
  256. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void AddOpBytesAttr(const std::string &key, const void *value,
  257. size_t size, OpDef *op_def) {
  258. if (op_def == nullptr) {
  259. return;
  260. }
  261. AttrDef out;
  262. auto attr = op_def->mutable_attr();
  263. auto it = attr->find(key);
  264. if (it != attr->end()) {
  265. auto &attr_value = it->second;
  266. attr_value.set_bt(value, size);
  267. } else {
  268. out.set_bt(value, size);
  269. attr->insert(AttrDefPair(key, out));
  270. }
  271. }
  272. #define DEFINE_GET_ATTR_LIST_SIZE(ARG_TYPE_KEY, ARG_TYPE_VALUE, FIELD) \
  273. FMK_FUNC_DEV_VISIBILITY uint32_t GetOpAttrListSize(ARG_TYPE_KEY key, ARG_TYPE_VALUE value, const OpDef *op_def) { \
  274. GE_CHK_BOOL_RET_STATUS_NOLOG(op_def != nullptr, 0); \
  275. const AttrDefMap &attr_map = op_def->attr(); \
  276. auto it = attr_map.find(key); \
  277. if (it == attr_map.end()) { \
  278. return 0; \
  279. } \
  280. const auto &list = it->second.list(); \
  281. return list.FIELD##_size(); \
  282. }
  283. DEFINE_GET_ATTR_LIST_SIZE(const std::string &, const std::string &, s);
  284. DEFINE_GET_ATTR_LIST_SIZE(const std::string &, int32_t, i);
  285. DEFINE_GET_ATTR_LIST_SIZE(const std::string &, int64_t, i);
  286. DEFINE_GET_ATTR_LIST_SIZE(const std::string &, uint32_t, u);
  287. DEFINE_GET_ATTR_LIST_SIZE(const std::string &, float, f);
  288. DEFINE_GET_ATTR_LIST_SIZE(const std::string &, double, f);
  289. DEFINE_GET_ATTR_LIST_SIZE(const std::string &, bool, b);
  290. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示