You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ge_attr_value.h 11 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef INC_GRAPH_GE_ATTR_VALUE_H_
  17. #define INC_GRAPH_GE_ATTR_VALUE_H_
  18. #include <iostream>
  19. #include <map>
  20. #include <memory>
  21. #include <string>
  22. #include <utility>
  23. #include <vector>
  24. #include "graph/buffer.h"
  25. #include "detail/attributes_holder.h"
  26. #include "graph/ge_error_codes.h"
  27. #include "graph/ge_tensor.h"
  28. using std::map;
  29. using std::string;
  30. using std::vector;
  31. namespace ge {
  32. class GeTensor;
  33. using GeTensorPtr = std::shared_ptr<GeTensor>;
  34. using ConstGeTensorPtr = std::shared_ptr<const GeTensor>;
  35. class ComputeGraph;
  36. using ComputeGraphPtr = std::shared_ptr<ComputeGraph>;
  37. using ConstComputeGraphPtr = std::shared_ptr<const ComputeGraph>;
  38. class GeTensorDesc;
  39. class GeAttrValue;
  40. class GeAttrValueImp;
  41. class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NamedAttrs : public AttrHolder {
  42. public:
  43. NamedAttrs();
  44. virtual ~NamedAttrs() = default;
  45. void SetName(const std::string &name);
  46. string GetName() const;
  47. GeAttrValue GetItem(const string &key) const;
  48. protected:
  49. ProtoAttrMapHelper MutableAttrMap() override;
  50. ConstProtoAttrMapHelper GetAttrMap() const override;
  51. private:
  52. // Create namedAttrs from protobuf obj
  53. NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *protoMsg);
  54. GeIrProtoHelper<proto::NamedAttrs> named_attrs_;
  55. friend class GeAttrValueImp;
  56. friend class GeAttrValue;
  57. };
  58. class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue {
  59. public:
  60. using INT = int64_t;
  61. using FLOAT = float;
  62. using BOOL = bool;
  63. using STR = std::string;
  64. using TENSOR = GeTensorPtr;
  65. using TENSOR_DESC = GeTensorDesc;
  66. using GRAPH = ComputeGraphPtr;
  67. using BYTES = Buffer;
  68. using NAMED_ATTRS = ge::NamedAttrs;
  69. using DATA_TYPE = ge::DataType;
  70. using LIST_INT = vector<INT>;
  71. using LIST_FLOAT = vector<FLOAT>;
  72. using LIST_BOOL = vector<BOOL>;
  73. using LIST_STR = vector<STR>;
  74. using LIST_TENSOR = vector<TENSOR>;
  75. using LIST_TENSOR_DESC = vector<TENSOR_DESC>;
  76. using LIST_GRAPH = vector<GRAPH>;
  77. using LIST_BYTES = vector<BYTES>;
  78. using LIST_NAMED_ATTRS = vector<NAMED_ATTRS>;
  79. using LIST_LIST_INT = vector<vector<int64_t>>;
  80. using LIST_DATA_TYPE = vector<ge::DataType>;
  81. using NamedAttrs = ge::NamedAttrs; // for cce use (ge::GeAttrValue::NamedAttrs).
  82. enum ValueType {
  83. VT_NONE = 0,
  84. VT_STRING,
  85. VT_FLOAT,
  86. VT_BOOL,
  87. VT_INT,
  88. VT_TENSOR_DESC,
  89. VT_TENSOR,
  90. VT_BYTES,
  91. VT_GRAPH,
  92. VT_NAMED_ATTRS,
  93. VT_LIST_LIST_INT,
  94. VT_DATA_TYPE,
  95. VT_LIST_BASE = 1000,
  96. VT_LIST_STRING = VT_LIST_BASE + VT_STRING,
  97. VT_LIST_FLOAT = VT_LIST_BASE + VT_FLOAT,
  98. VT_LIST_BOOL = VT_LIST_BASE + VT_BOOL,
  99. VT_LIST_INT = VT_LIST_BASE + VT_INT,
  100. VT_LIST_TENSOR_DESC = VT_LIST_BASE + VT_TENSOR_DESC,
  101. VT_LIST_TENSOR = VT_LIST_BASE + VT_TENSOR,
  102. VT_LIST_BYTES = VT_LIST_BASE + VT_BYTES,
  103. VT_LIST_GRAPH = VT_LIST_BASE + VT_GRAPH,
  104. VT_LIST_NAMED_ATTRS = VT_LIST_BASE + VT_NAMED_ATTRS,
  105. VT_LIST_DATA_TYPE = VT_LIST_BASE + VT_DATA_TYPE,
  106. };
  107. template <class T>
  108. struct IsAttrTypeEnable {
  109. using DT = typename std::remove_cv<T>::type;
  110. static bool const VALUE = std::is_same<INT, DT>::value || std::is_same<FLOAT, DT>::value ||
  111. std::is_same<BOOL, DT>::value || std::is_same<STR, DT>::value ||
  112. std::is_same<GRAPH, DT>::value || std::is_same<TENSOR, DT>::value ||
  113. std::is_same<TENSOR_DESC, DT>::value || std::is_same<BYTES, DT>::value ||
  114. std::is_same<NAMED_ATTRS, DT>::value || std::is_same<DATA_TYPE, DT>::value;
  115. // Not has list type of NamedAttrs
  116. static bool const LIST_VALUE = std::is_same<LIST_INT, DT>::value || std::is_same<LIST_FLOAT, DT>::value ||
  117. std::is_same<LIST_BOOL, DT>::value || std::is_same<LIST_STR, DT>::value ||
  118. std::is_same<LIST_GRAPH, DT>::value || std::is_same<LIST_TENSOR, DT>::value ||
  119. std::is_same<LIST_TENSOR_DESC, DT>::value || std::is_same<LIST_BYTES, DT>::value ||
  120. std::is_same<LIST_NAMED_ATTRS, DT>::value ||
  121. std::is_same<LIST_LIST_INT, DT>::value || std::is_same<LIST_DATA_TYPE, DT>::value;
  122. };
  123. template <typename vector_type>
  124. // To cols
  125. using enable_if_vector_type_valid_t = typename std::enable_if<IsAttrTypeEnable<vector_type>::LIST_VALUE, int>::type;
  126. template <typename one_type>
  127. using enable_if_one_type_valid_t = typename std::enable_if<IsAttrTypeEnable<one_type>::VALUE, int>::type;
  128. template <typename val_type>
  129. using enable_if_type_valid_t =
  130. typename std::enable_if<IsAttrTypeEnable<val_type>::VALUE || IsAttrTypeEnable<val_type>::LIST_VALUE, int>::type;
  131. template <typename seriliable_type>
  132. using enable_if_seriliable_type_valid_t = typename seriliable_type::__ge_serializable;
  133. GeAttrValue();
  134. ~GeAttrValue() = default;
  135. // SetValue, Set initializer_list
  136. template <typename T, typename DT, enable_if_vector_type_valid_t<T> = 0>
  137. graphStatus SetValue(std::initializer_list<DT> &&val) {
  138. T vectorVal;
  139. for (auto &item : val) {
  140. vectorVal.push_back(item);
  141. }
  142. return SetValue(vectorVal);
  143. }
  144. // SetValue, Set vector
  145. template <typename T, typename DT, enable_if_vector_type_valid_t<T> = 0>
  146. graphStatus SetValue(const std::vector<DT> &val) {
  147. T vectorVal;
  148. for (auto item : val) {
  149. vectorVal.push_back(item);
  150. }
  151. return SetValue(vectorVal);
  152. }
  153. // SetValue, not list type
  154. template <typename T, typename DT, enable_if_one_type_valid_t<T> = 0>
  155. graphStatus SetValue(DT &&val) {
  156. return SetValue(T(std::forward<DT>(val)));
  157. }
  158. // GE_SERIALIZABLE
  159. template <typename T, enable_if_seriliable_type_valid_t<T> = 0>
  160. graphStatus SetValue(const T &t) {
  161. return t.Save(*this);
  162. }
  163. template <typename T, enable_if_seriliable_type_valid_t<T> = 0>
  164. graphStatus SetValue(const vector<T> &t) {
  165. vector<NamedAttrs> attrs;
  166. for (auto &item : t) {
  167. GeAttrValue val;
  168. item.Save(val);
  169. NamedAttrs attrsItem;
  170. (void)val.GetValue<NamedAttrs>(attrsItem);
  171. attrs.push_back(attrsItem);
  172. }
  173. return SetValue(attrs);
  174. }
  175. // GetValue, list value
  176. template <typename T, typename DT, enable_if_vector_type_valid_t<T> = 0,
  177. typename std::enable_if<!std::is_same<DT, GeTensorPtr>::value, int>::type = 0>
  178. graphStatus GetValue(std::vector<DT> &val) const {
  179. T valGet;
  180. val.clear();
  181. auto status = GetValue(valGet);
  182. if (status != GRAPH_SUCCESS) {
  183. return status;
  184. }
  185. for (auto item : valGet) {
  186. val.push_back(item);
  187. }
  188. return GRAPH_SUCCESS;
  189. }
  190. // GetValue, not list type
  191. template <typename T, typename DT, enable_if_one_type_valid_t<T> = 0,
  192. typename std::enable_if<!std::is_same<DT, GeTensorPtr>::value, int>::type = 0>
  193. graphStatus GetValue(DT &val) const {
  194. T valGet;
  195. auto status = GetValue(valGet);
  196. if (status != GRAPH_SUCCESS) {
  197. return status;
  198. }
  199. val = DT(valGet);
  200. return GRAPH_SUCCESS;
  201. }
  202. // GE_SERIALIZABLE
  203. template <typename T, enable_if_seriliable_type_valid_t<T> = 0>
  204. graphStatus GetValue(T &t) {
  205. return t.Load(*this);
  206. }
  207. template <typename T, enable_if_seriliable_type_valid_t<T> = 0>
  208. graphStatus GetValue(vector<T> &t) {
  209. graphStatus status;
  210. t.clear();
  211. vector<NamedAttrs> attrs;
  212. status = this->GetValue(attrs);
  213. if (status != GRAPH_SUCCESS) {
  214. return status;
  215. }
  216. for (auto &attr : attrs) {
  217. T item;
  218. GeAttrValue val;
  219. (void)val.SetValue(attr);
  220. status = item.Load(val);
  221. if (status != GRAPH_SUCCESS) {
  222. return status;
  223. }
  224. t.push_back(item);
  225. }
  226. return GRAPH_SUCCESS;
  227. }
  228. template <typename T, typename DT, enable_if_type_valid_t<T> = 0>
  229. static GeAttrValue CreateFrom(DT &&val) {
  230. GeAttrValue valRet;
  231. (void)valRet.SetValue<T>(std::forward<DT>(val));
  232. return valRet;
  233. }
  234. template <typename T, typename DT, enable_if_vector_type_valid_t<T> = 0>
  235. static GeAttrValue CreateFrom(std::initializer_list<DT> &&val) {
  236. GeAttrValue valRet;
  237. (void)valRet.SetValue<T>(std::move(val));
  238. return valRet;
  239. }
  240. template <typename T, enable_if_seriliable_type_valid_t<T> = 0>
  241. static GeAttrValue CreateFrom(const T &val) {
  242. GeAttrValue valRet;
  243. (void)valRet.SetValue(val);
  244. return valRet;
  245. }
  246. template <typename T, enable_if_seriliable_type_valid_t<T> = 0>
  247. static GeAttrValue CreateFrom(const vector<T> &val) {
  248. GeAttrValue valRet;
  249. (void)valRet.SetValue(val);
  250. return valRet;
  251. }
  252. ValueType GetValueType() const;
  253. bool IsEmpty() const;
  254. GeAttrValue Copy() const;
  255. // For map key
  256. bool operator==(const GeAttrValue &other) const { return value_ == other.value_; }
  257. graphStatus MutableTensor(GeTensorPtr &tensor);
  258. graphStatus MutableListTensor(vector<GeTensorPtr> &list_tensor);
  259. private:
  260. #define VALUE_SET_GET_DEC(DT) \
  261. graphStatus SetValue(const DT &val); \
  262. graphStatus GetValue(DT &val) const;
  263. VALUE_SET_GET_DEC(GeAttrValue::STR)
  264. VALUE_SET_GET_DEC(GeAttrValue::INT)
  265. VALUE_SET_GET_DEC(GeAttrValue::FLOAT)
  266. VALUE_SET_GET_DEC(GeAttrValue::BOOL)
  267. VALUE_SET_GET_DEC(GeTensorDesc)
  268. VALUE_SET_GET_DEC(GeAttrValue::TENSOR)
  269. VALUE_SET_GET_DEC(GeAttrValue::GRAPH)
  270. VALUE_SET_GET_DEC(BYTES)
  271. VALUE_SET_GET_DEC(NamedAttrs)
  272. VALUE_SET_GET_DEC(ge::DataType)
  273. VALUE_SET_GET_DEC(vector<GeAttrValue::STR>)
  274. VALUE_SET_GET_DEC(vector<GeAttrValue::INT>)
  275. VALUE_SET_GET_DEC(vector<GeAttrValue::FLOAT>)
  276. VALUE_SET_GET_DEC(vector<GeAttrValue::BOOL>)
  277. VALUE_SET_GET_DEC(vector<GeTensorDesc>)
  278. VALUE_SET_GET_DEC(vector<GeAttrValue::TENSOR>)
  279. VALUE_SET_GET_DEC(vector<GeAttrValue::GRAPH>)
  280. VALUE_SET_GET_DEC(vector<GeAttrValue::BYTES>)
  281. VALUE_SET_GET_DEC(vector<NamedAttrs>)
  282. VALUE_SET_GET_DEC(vector<vector<int64_t>>)
  283. VALUE_SET_GET_DEC(vector<ge::DataType>)
  284. #undef VALUE_SET_GET_DEC
  285. GeIrProtoHelper<proto::AttrDef> value_;
  286. GeAttrValue(const ProtoMsgOwner &proto_owner, ge::proto::AttrDef *val);
  287. friend class AttrHolder;
  288. friend class ModelSerializeImp;
  289. friend class OnnxUtils;
  290. };
  291. class AttrValueImpl {
  292. public:
  293. AttrValueImpl() = default;
  294. ~AttrValueImpl() = default;
  295. GeAttrValue geAttrValue_;
  296. };
  297. } // namespace ge
  298. #endif // INC_GRAPH_GE_ATTR_VALUE_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示