You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

attr_value_serializable.h 10 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_
  17. #define INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_
  18. #include <string>
  19. #include <vector>
  20. #include "graph/ge_attr_value.h"
  21. namespace ge {
  22. class GeAttrValue;
  23. class _GeSerializable {
  24. public:
  25. template <typename T>
  26. struct ge_serializable_int64_t_support_type {
  27. using DT = typename std::remove_cv<T>::type;
  28. static const bool value = std::is_same<DT, uint64_t>::value // by cast
  29. || std::is_same<DT, int32_t>::value || std::is_same<DT, uint32_t>::value ||
  30. std::is_same<DT, int16_t>::value || std::is_same<DT, uint16_t>::value ||
  31. std::is_same<DT, int8_t>::value || std::is_same<DT, uint8_t>::value;
  32. };
  33. template <typename T, typename T::__ge_serializable = 0>
  34. static GeAttrValue SaveItemAsAttrValue(const T &t) {
  35. return GeAttrValue::CreateFrom(t);
  36. }
  37. template <typename T, typename T::__ge_serializable = 0>
  38. static GeAttrValue SaveItemAsAttrValue(const vector<T> &t) {
  39. return GeAttrValue::CreateFrom(t);
  40. }
  41. template <typename T, GeAttrValue::enable_if_type_valid_t<T> = 0, typename DT = typename std::remove_cv<T>::type>
  42. static GeAttrValue SaveItemAsAttrValue(const T &t) {
  43. return GeAttrValue::CreateFrom<DT>(t);
  44. }
  45. // int64_t support type
  46. template <typename T, typename std::enable_if<ge_serializable_int64_t_support_type<T>::value, int>::type = 0>
  47. static GeAttrValue SaveItemAsAttrValue(const T &t) {
  48. return GeAttrValue::CreateFrom<GeAttrValue::INT>(t);
  49. }
  50. // vector int64_t support type
  51. template <typename T, typename std::enable_if<ge_serializable_int64_t_support_type<T>::value, int>::type = 0>
  52. static GeAttrValue SaveItemAsAttrValue(const vector<T> &t) {
  53. return GeAttrValue::CreateFrom<GeAttrValue::LIST_INT>(t);
  54. }
  55. template <typename T, typename T::__ge_serializable = 0>
  56. static graphStatus LoadItemFromAttrValue(T &t, GeAttrValue &attrVal) {
  57. return attrVal.GetValue(t);
  58. }
  59. template <typename T, typename T::__ge_serializable = 0>
  60. static graphStatus LoadItemFromAttrValue(vector<T> &t, GeAttrValue &attrVal) {
  61. return attrVal.GetValue(t);
  62. }
  63. template <typename T, GeAttrValue::enable_if_type_valid_t<T> = 0, typename DT = typename std::remove_cv<T>::type>
  64. static graphStatus LoadItemFromAttrValue(T &t, GeAttrValue &attrVal) {
  65. return attrVal.GetValue<DT>(t);
  66. }
  67. template <typename T, typename std::enable_if<ge_serializable_int64_t_support_type<T>::value, int>::type = 0>
  68. static graphStatus LoadItemFromAttrValue(T &t, GeAttrValue &attrVal) {
  69. return attrVal.GetValue<GeAttrValue::INT>(t);
  70. }
  71. template <typename T, typename std::enable_if<ge_serializable_int64_t_support_type<T>::value, int>::type = 0>
  72. static graphStatus LoadItemFromAttrValue(vector<T> &t, GeAttrValue &attrVal) {
  73. return attrVal.GetValue<GeAttrValue::LIST_INT>(t);
  74. }
  75. template <class T, class... Args>
  76. static void SaveItem(GeAttrValue::NamedAttrs &namedAttrs, string itemName, T &item, Args &... args) {
  77. GeAttrValue itemVal = SaveItemAsAttrValue(item);
  78. (void)namedAttrs.SetAttr(itemName, itemVal);
  79. SaveItem(namedAttrs, args...);
  80. }
  81. static void SaveItem(GeAttrValue::NamedAttrs &namedAttrs __attribute__((__unused__))) {}
  82. template <class T, class... Args>
  83. static graphStatus LoadItem(GeAttrValue::NamedAttrs &namedAttrs, string itemName, T &item, Args &... args) {
  84. auto itemVal = namedAttrs.GetItem(itemName);
  85. auto status = LoadItemFromAttrValue(item, itemVal);
  86. if (status != GRAPH_SUCCESS) {
  87. return status;
  88. }
  89. return LoadItem(namedAttrs, args...);
  90. }
  91. static graphStatus LoadItem(GeAttrValue::NamedAttrs &namedAttrs __attribute__((__unused__))) { return GRAPH_SUCCESS; }
  92. };
  93. #define _GE_FI(a) #a, a
  94. #define _GE_MAP_FIELDS1(a1) _GE_FI(a1)
  95. #define _GE_MAP_FIELDS2(a1, a2) _GE_FI(a1), _GE_FI(a2)
  96. #define _GE_MAP_FIELDS3(a1, a2, a3) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3)
  97. #define _GE_MAP_FIELDS4(a1, a2, a3, a4) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3), _GE_FI(a4)
  98. #define _GE_MAP_FIELDS5(a1, a2, a3, a4, a5) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5)
  99. #define _GE_MAP_FIELDS6(a1, a2, a3, a4, a5, a6) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6)
  100. #define _GE_MAP_FIELDS7(a1, a2, a3, a4, a5, a6, a7) \
  101. _GE_FI(a1) \
  102. , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7)
  103. #define _GE_MAP_FIELDS8(a1, a2, a3, a4, a5, a6, a7, a8) \
  104. _GE_FI(a1) \
  105. , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8)
  106. #define _GE_MAP_FIELDS9(a1, a2, a3, a4, a5, a6, a7, a8, a9) \
  107. _GE_FI(a1) \
  108. , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9)
  109. #define _GE_MAP_FIELDS10(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10) \
  110. _GE_FI(a1) \
  111. , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10)
  112. #define _GE_MAP_FIELDS11(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11) \
  113. _GE_FI(a1) \
  114. , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \
  115. _GE_FI(a11)
  116. #define _GE_MAP_FIELDS12(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12) \
  117. _GE_FI(a1) \
  118. , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \
  119. _GE_FI(a11), _GE_FI(a12)
  120. #define _GE_MAP_FIELDS13(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13) \
  121. _GE_FI(a1) \
  122. , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \
  123. _GE_FI(a11), _GE_FI(a12), _GE_FI(a13)
  124. #define _GE_MAP_FIELDS14(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14) \
  125. _GE_FI(a1) \
  126. , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \
  127. _GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14)
  128. #define _GE_MAP_FIELDS15(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15) \
  129. _GE_FI(a1) \
  130. , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \
  131. _GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14), _GE_FI(a15)
  132. #define _GE_PRIVATE_ARGS_GLUE(x, y) x y
  133. #define _GE_PRIVATE_MACRO_VAR_ARGS_IMPL_COUNT(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, N, \
  134. ...) \
  135. N
  136. #define _GE_PRIVATE_MACRO_VAR_ARGS_IMPL(args) _GE_PRIVATE_MACRO_VAR_ARGS_IMPL_COUNT args
  137. #define _GE_COUNT_MACRO_VAR_ARGS(...) \
  138. _GE_PRIVATE_MACRO_VAR_ARGS_IMPL((__VA_ARGS__, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0))
  139. #define _GE_PRIVATE_MACRO_CHOOSE_HELPER2(M, count) M##count
  140. #define _GE_PRIVATE_MACRO_CHOOSE_HELPER1(M, count) _GE_PRIVATE_MACRO_CHOOSE_HELPER2(M, count)
  141. #define _GE_PRIVATE_MACRO_CHOOSE_HELPER(M, count) _GE_PRIVATE_MACRO_CHOOSE_HELPER1(M, count)
  142. #define _GE_INVOKE_VAR_MACRO(...) \
  143. _GE_PRIVATE_ARGS_GLUE(_GE_PRIVATE_MACRO_CHOOSE_HELPER(_GE_MAP_FIELDS, _GE_COUNT_MACRO_VAR_ARGS(__VA_ARGS__)), \
  144. (__VA_ARGS__))
  145. #define GE_SERIALIZABLE(...) \
  146. public: \
  147. friend class ge::GeAttrValue; \
  148. using __ge_serializable = int; \
  149. \
  150. private: \
  151. ge::graphStatus Save(GeAttrValue &ar) const { \
  152. GeAttrValue::NamedAttrs named_attrs; \
  153. _GeSerializable::SaveItem(named_attrs, _GE_INVOKE_VAR_MACRO(__VA_ARGS__)); \
  154. return ar.SetValue<GeAttrValue::NamedAttrs>(named_attrs); \
  155. } \
  156. ge::graphStatus Load(const GeAttrValue &ar) { \
  157. GeAttrValue::NamedAttrs named_attrs; \
  158. ge::graphStatus status = ar.GetValue<GeAttrValue::NamedAttrs>(named_attrs); \
  159. if (status != GRAPH_SUCCESS) { \
  160. return status; \
  161. } \
  162. return _GeSerializable::LoadItem(named_attrs, _GE_INVOKE_VAR_MACRO(__VA_ARGS__)); \
  163. }
  164. // end NamedAttrs Helper: GE_SERIALIZABLE
  165. } // namespace ge
  166. #endif // INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示

Contributors (1)