You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

attr_value_serializable.h 10 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_
  17. #define INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_
  18. #include <string>
  19. #include <vector>
  20. #include "graph/ge_attr_value.h"
  21. #include "graph/compiler_options.h"
  22. namespace ge {
  23. class GeAttrValue;
  24. class _GeSerializable {
  25. public:
  26. template <typename T>
  27. struct ge_serializable_int64_t_support_type {
  28. using DT = typename std::remove_cv<T>::type;
  29. static const bool value = std::is_same<DT, uint64_t>::value // by cast
  30. || std::is_same<DT, int32_t>::value || std::is_same<DT, uint32_t>::value ||
  31. std::is_same<DT, int16_t>::value || std::is_same<DT, uint16_t>::value ||
  32. std::is_same<DT, int8_t>::value || std::is_same<DT, uint8_t>::value;
  33. };
  34. template <typename T, typename T::__ge_serializable = 0>
  35. static GeAttrValue SaveItemAsAttrValue(const T &t) {
  36. return GeAttrValue::CreateFrom(t);
  37. }
  38. template <typename T, typename T::__ge_serializable = 0>
  39. static GeAttrValue SaveItemAsAttrValue(const vector<T> &t) {
  40. return GeAttrValue::CreateFrom(t);
  41. }
  42. template <typename T, GeAttrValue::enable_if_type_valid_t<T> = 0, typename DT = typename std::remove_cv<T>::type>
  43. static GeAttrValue SaveItemAsAttrValue(const T &t) {
  44. return GeAttrValue::CreateFrom<DT>(t);
  45. }
  46. // int64_t support type
  47. template <typename T, typename std::enable_if<ge_serializable_int64_t_support_type<T>::value, int>::type = 0>
  48. static GeAttrValue SaveItemAsAttrValue(const T &t) {
  49. return GeAttrValue::CreateFrom<GeAttrValue::INT>(t);
  50. }
  51. // vector int64_t support type
  52. template <typename T, typename std::enable_if<ge_serializable_int64_t_support_type<T>::value, int>::type = 0>
  53. static GeAttrValue SaveItemAsAttrValue(const vector<T> &t) {
  54. return GeAttrValue::CreateFrom<GeAttrValue::LIST_INT>(t);
  55. }
  56. template <typename T, typename T::__ge_serializable = 0>
  57. static graphStatus LoadItemFromAttrValue(T &t, GeAttrValue &attrVal) {
  58. return attrVal.GetValue(t);
  59. }
  60. template <typename T, typename T::__ge_serializable = 0>
  61. static graphStatus LoadItemFromAttrValue(vector<T> &t, GeAttrValue &attrVal) {
  62. return attrVal.GetValue(t);
  63. }
  64. template <typename T, GeAttrValue::enable_if_type_valid_t<T> = 0, typename DT = typename std::remove_cv<T>::type>
  65. static graphStatus LoadItemFromAttrValue(T &t, GeAttrValue &attrVal) {
  66. return attrVal.GetValue<DT>(t);
  67. }
  68. template <typename T, typename std::enable_if<ge_serializable_int64_t_support_type<T>::value, int>::type = 0>
  69. static graphStatus LoadItemFromAttrValue(T &t, GeAttrValue &attrVal) {
  70. return attrVal.GetValue<GeAttrValue::INT>(t);
  71. }
  72. template <typename T, typename std::enable_if<ge_serializable_int64_t_support_type<T>::value, int>::type = 0>
  73. static graphStatus LoadItemFromAttrValue(vector<T> &t, GeAttrValue &attrVal) {
  74. return attrVal.GetValue<GeAttrValue::LIST_INT>(t);
  75. }
  76. template <class T, class... Args>
  77. static void SaveItem(GeAttrValue::NAMED_ATTRS &namedAttrs, string itemName, T &item, Args &... args) {
  78. GeAttrValue itemVal = SaveItemAsAttrValue(item);
  79. (void)namedAttrs.SetAttr(itemName, itemVal);
  80. SaveItem(namedAttrs, args...);
  81. }
  82. static void SaveItem(GeAttrValue::NAMED_ATTRS &namedAttrs METADEF_ATTRIBUTE_UNUSED) {}
  83. template <class T, class... Args>
  84. static graphStatus LoadItem(GeAttrValue::NAMED_ATTRS &namedAttrs, string itemName, T &item, Args &... args) {
  85. auto itemVal = namedAttrs.GetItem(itemName);
  86. auto status = LoadItemFromAttrValue(item, itemVal);
  87. if (status != GRAPH_SUCCESS) {
  88. return status;
  89. }
  90. return LoadItem(namedAttrs, args...);
  91. }
  92. static graphStatus LoadItem(GeAttrValue::NAMED_ATTRS &namedAttrs METADEF_ATTRIBUTE_UNUSED) { return GRAPH_SUCCESS; }
  93. };
  94. #define _GE_FI(a) #a, a
  95. #define _GE_MAP_FIELDS1(a1) _GE_FI(a1)
  96. #define _GE_MAP_FIELDS2(a1, a2) _GE_FI(a1), _GE_FI(a2)
  97. #define _GE_MAP_FIELDS3(a1, a2, a3) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3)
  98. #define _GE_MAP_FIELDS4(a1, a2, a3, a4) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3), _GE_FI(a4)
  99. #define _GE_MAP_FIELDS5(a1, a2, a3, a4, a5) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5)
  100. #define _GE_MAP_FIELDS6(a1, a2, a3, a4, a5, a6) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6)
  101. #define _GE_MAP_FIELDS7(a1, a2, a3, a4, a5, a6, a7) \
  102. _GE_FI(a1) \
  103. , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7)
  104. #define _GE_MAP_FIELDS8(a1, a2, a3, a4, a5, a6, a7, a8) \
  105. _GE_FI(a1) \
  106. , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8)
  107. #define _GE_MAP_FIELDS9(a1, a2, a3, a4, a5, a6, a7, a8, a9) \
  108. _GE_FI(a1) \
  109. , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9)
  110. #define _GE_MAP_FIELDS10(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10) \
  111. _GE_FI(a1) \
  112. , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10)
  113. #define _GE_MAP_FIELDS11(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11) \
  114. _GE_FI(a1) \
  115. , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \
  116. _GE_FI(a11)
  117. #define _GE_MAP_FIELDS12(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12) \
  118. _GE_FI(a1) \
  119. , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \
  120. _GE_FI(a11), _GE_FI(a12)
  121. #define _GE_MAP_FIELDS13(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13) \
  122. _GE_FI(a1) \
  123. , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \
  124. _GE_FI(a11), _GE_FI(a12), _GE_FI(a13)
  125. #define _GE_MAP_FIELDS14(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14) \
  126. _GE_FI(a1) \
  127. , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \
  128. _GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14)
  129. #define _GE_MAP_FIELDS15(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15) \
  130. _GE_FI(a1) \
  131. , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \
  132. _GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14), _GE_FI(a15)
  133. #define _GE_PRIVATE_ARGS_GLUE(x, y) x y
  134. #define _GE_PRIVATE_MACRO_VAR_ARGS_IMPL_COUNT(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, N, \
  135. ...) \
  136. N
  137. #define _GE_PRIVATE_MACRO_VAR_ARGS_IMPL(args) _GE_PRIVATE_MACRO_VAR_ARGS_IMPL_COUNT args
  138. #define _GE_COUNT_MACRO_VAR_ARGS(...) \
  139. _GE_PRIVATE_MACRO_VAR_ARGS_IMPL((__VA_ARGS__, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0))
  140. #define _GE_PRIVATE_MACRO_CHOOSE_HELPER2(M, count) M##count
  141. #define _GE_PRIVATE_MACRO_CHOOSE_HELPER1(M, count) _GE_PRIVATE_MACRO_CHOOSE_HELPER2(M, count)
  142. #define _GE_PRIVATE_MACRO_CHOOSE_HELPER(M, count) _GE_PRIVATE_MACRO_CHOOSE_HELPER1(M, count)
  143. #define _GE_INVOKE_VAR_MACRO(...) \
  144. _GE_PRIVATE_ARGS_GLUE(_GE_PRIVATE_MACRO_CHOOSE_HELPER(_GE_MAP_FIELDS, _GE_COUNT_MACRO_VAR_ARGS(__VA_ARGS__)), \
  145. (__VA_ARGS__))
  146. #define GE_SERIALIZABLE(...) \
  147. public: \
  148. friend class ge::GeAttrValue; \
  149. using __ge_serializable = int; \
  150. \
  151. private: \
  152. ge::graphStatus Save(GeAttrValue &ar) const { \
  153. GeAttrValue::NAMED_ATTRS named_attrs; \
  154. _GeSerializable::SaveItem(named_attrs, _GE_INVOKE_VAR_MACRO(__VA_ARGS__)); \
  155. return ar.SetValue<GeAttrValue::NAMED_ATTRS>(named_attrs); \
  156. } \
  157. ge::graphStatus Load(const GeAttrValue &ar) { \
  158. GeAttrValue::NAMED_ATTRS named_attrs; \
  159. ge::graphStatus status = ar.GetValue<GeAttrValue::NAMED_ATTRS>(named_attrs); \
  160. if (status != GRAPH_SUCCESS) { \
  161. return status; \
  162. } \
  163. return _GeSerializable::LoadItem(named_attrs, _GE_INVOKE_VAR_MACRO(__VA_ARGS__)); \
  164. }
  165. // end NamedAttrs Helper: GE_SERIALIZABLE
  166. } // namespace ge
  167. #endif // INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示