You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ge_attr_value.h 11 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef INC_GRAPH_GE_ATTR_VALUE_H_
  17. #define INC_GRAPH_GE_ATTR_VALUE_H_
  18. #include <iostream>
  19. #include <map>
  20. #include <memory>
  21. #include <string>
  22. #include <utility>
  23. #include <vector>
  24. #include "graph/buffer.h"
  25. #include "detail/attributes_holder.h"
  26. #include "graph/ge_error_codes.h"
  27. #include "graph/ge_tensor.h"
  28. using std::map;
  29. using std::string;
  30. using std::vector;
  31. namespace ge {
  32. class GeTensor;
  33. using GeTensorPtr = std::shared_ptr<GeTensor>;
  34. using ConstGeTensorPtr = std::shared_ptr<const GeTensor>;
  35. class ComputeGraph;
  36. using ComputeGraphPtr = std::shared_ptr<ComputeGraph>;
  37. using ConstComputeGraphPtr = std::shared_ptr<const ComputeGraph>;
  38. class GeTensorDesc;
  39. class GeAttrValueImp;
  40. class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue {
  41. public:
  42. class NamedAttrs : public AttrHolder {
  43. public:
  44. NamedAttrs();
  45. virtual ~NamedAttrs() = default;
  46. void SetName(const std::string &name);
  47. string GetName() const;
  48. GeAttrValue GetItem(const string &key) const;
  49. protected:
  50. ProtoAttrMapHelper MutableAttrMap() override;
  51. ConstProtoAttrMapHelper GetAttrMap() const override;
  52. private:
  53. // Create namedAttrs from protobuf obj
  54. NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *protoMsg);
  55. GeIrProtoHelper<proto::NamedAttrs> named_attrs_;
  56. friend class GeAttrValueImp;
  57. };
  58. using INT = int64_t;
  59. using FLOAT = float;
  60. using BOOL = bool;
  61. using STR = std::string;
  62. using TENSOR = GeTensorPtr;
  63. using TENSOR_DESC = GeTensorDesc;
  64. using GRAPH = ComputeGraphPtr;
  65. using BYTES = Buffer;
  66. using NAMED_ATTRS = NamedAttrs;
  67. using DATA_TYPE = ge::DataType;
  68. using LIST_INT = vector<INT>;
  69. using LIST_FLOAT = vector<FLOAT>;
  70. using LIST_BOOL = vector<BOOL>;
  71. using LIST_STR = vector<STR>;
  72. using LIST_TENSOR = vector<TENSOR>;
  73. using LIST_TENSOR_DESC = vector<TENSOR_DESC>;
  74. using LIST_GRAPH = vector<GRAPH>;
  75. using LIST_BYTES = vector<BYTES>;
  76. using LIST_NAMED_ATTRS = vector<NAMED_ATTRS>;
  77. using LIST_LIST_INT = vector<vector<int64_t>>;
  78. using LIST_DATA_TYPE = vector<ge::DataType>;
  79. enum ValueType {
  80. VT_NONE = 0,
  81. VT_STRING,
  82. VT_FLOAT,
  83. VT_BOOL,
  84. VT_INT,
  85. VT_TENSOR_DESC,
  86. VT_TENSOR,
  87. VT_BYTES,
  88. VT_GRAPH,
  89. VT_NAMED_ATTRS,
  90. VT_LIST_LIST_INT,
  91. VT_DATA_TYPE,
  92. VT_LIST_BASE = 1000,
  93. VT_LIST_STRING = VT_LIST_BASE + VT_STRING,
  94. VT_LIST_FLOAT = VT_LIST_BASE + VT_FLOAT,
  95. VT_LIST_BOOL = VT_LIST_BASE + VT_BOOL,
  96. VT_LIST_INT = VT_LIST_BASE + VT_INT,
  97. VT_LIST_TENSOR_DESC = VT_LIST_BASE + VT_TENSOR_DESC,
  98. VT_LIST_TENSOR = VT_LIST_BASE + VT_TENSOR,
  99. VT_LIST_BYTES = VT_LIST_BASE + VT_BYTES,
  100. VT_LIST_GRAPH = VT_LIST_BASE + VT_GRAPH,
  101. VT_LIST_NAMED_ATTRS = VT_LIST_BASE + VT_NAMED_ATTRS,
  102. VT_LIST_DATA_TYPE = VT_LIST_BASE + VT_DATA_TYPE,
  103. };
  104. template <class T>
  105. struct IsAttrTypeEnable {
  106. using DT = typename std::remove_cv<T>::type;
  107. static bool const VALUE = std::is_same<INT, DT>::value || std::is_same<FLOAT, DT>::value ||
  108. std::is_same<BOOL, DT>::value || std::is_same<STR, DT>::value ||
  109. std::is_same<GRAPH, DT>::value || std::is_same<TENSOR, DT>::value ||
  110. std::is_same<TENSOR_DESC, DT>::value || std::is_same<BYTES, DT>::value ||
  111. std::is_same<NAMED_ATTRS, DT>::value || std::is_same<DATA_TYPE, DT>::value;
  112. // Not has list type of NamedAttrs
  113. static bool const LIST_VALUE = std::is_same<LIST_INT, DT>::value || std::is_same<LIST_FLOAT, DT>::value ||
  114. std::is_same<LIST_BOOL, DT>::value || std::is_same<LIST_STR, DT>::value ||
  115. std::is_same<LIST_GRAPH, DT>::value || std::is_same<LIST_TENSOR, DT>::value ||
  116. std::is_same<LIST_TENSOR_DESC, DT>::value || std::is_same<LIST_BYTES, DT>::value ||
  117. std::is_same<LIST_NAMED_ATTRS, DT>::value ||
  118. std::is_same<LIST_LIST_INT, DT>::value || std::is_same<LIST_DATA_TYPE, DT>::value;
  119. };
  120. template <typename vector_type>
  121. // To cols
  122. using enable_if_vector_type_valid_t = typename std::enable_if<IsAttrTypeEnable<vector_type>::LIST_VALUE,
  123. int>::type;
  124. template <typename one_type>
  125. using enable_if_one_type_valid_t = typename std::enable_if<IsAttrTypeEnable<one_type>::VALUE, int>::type;
  126. template <typename val_type>
  127. using enable_if_type_valid_t =
  128. typename std::enable_if<IsAttrTypeEnable<val_type>::VALUE || IsAttrTypeEnable<val_type>::LIST_VALUE, int>::type;
  129. template <typename seriliable_type>
  130. using enable_if_seriliable_type_valid_t = typename seriliable_type::__ge_serializable;
  131. GeAttrValue();
  132. ~GeAttrValue() = default;
  133. // SetValue, Set initializer_list
  134. template <typename T, typename DT, enable_if_vector_type_valid_t<T> = 0>
  135. graphStatus SetValue(std::initializer_list<DT> &&val) {
  136. T vectorVal;
  137. for (auto &item : val) {
  138. vectorVal.push_back(item);
  139. }
  140. return SetValue(vectorVal);
  141. }
  142. // SetValue, Set vector
  143. template <typename T, typename DT, enable_if_vector_type_valid_t<T> = 0>
  144. graphStatus SetValue(const std::vector<DT> &val) {
  145. T vectorVal;
  146. for (auto item : val) {
  147. vectorVal.push_back(item);
  148. }
  149. return SetValue(vectorVal);
  150. }
  151. // SetValue, not list type
  152. template <typename T, typename DT, enable_if_one_type_valid_t<T> = 0>
  153. graphStatus SetValue(DT &&val) {
  154. return SetValue(T(std::forward<DT>(val)));
  155. }
  156. // GE_SERIALIZABLE
  157. template <typename T, enable_if_seriliable_type_valid_t<T> = 0>
  158. graphStatus SetValue(const T &t) {
  159. return t.Save(*this);
  160. }
  161. template <typename T, enable_if_seriliable_type_valid_t<T> = 0>
  162. graphStatus SetValue(const vector<T> &t) {
  163. vector<NamedAttrs> attrs;
  164. for (auto &item : t) {
  165. GeAttrValue val;
  166. item.Save(val);
  167. NamedAttrs attrsItem;
  168. (void)val.GetValue<NamedAttrs>(attrsItem);
  169. attrs.push_back(attrsItem);
  170. }
  171. return SetValue(attrs);
  172. }
  173. // GetValue, list value
  174. template <typename T, typename DT, enable_if_vector_type_valid_t<T> = 0,
  175. typename std::enable_if<!std::is_same<DT, GeTensorPtr>::value, int>::type = 0>
  176. graphStatus GetValue(std::vector<DT> &val) const {
  177. T valGet;
  178. val.clear();
  179. auto status = GetValue(valGet);
  180. if (status != GRAPH_SUCCESS) {
  181. return status;
  182. }
  183. for (auto item : valGet) {
  184. val.push_back(item);
  185. }
  186. return GRAPH_SUCCESS;
  187. }
  188. // GetValue, not list type
  189. template <typename T, typename DT, enable_if_one_type_valid_t<T> = 0,
  190. typename std::enable_if<!std::is_same<DT, GeTensorPtr>::value, int>::type = 0>
  191. graphStatus GetValue(DT &val) const {
  192. T valGet;
  193. auto status = GetValue(valGet);
  194. if (status != GRAPH_SUCCESS) {
  195. return status;
  196. }
  197. val = DT(valGet);
  198. return GRAPH_SUCCESS;
  199. }
  200. // GE_SERIALIZABLE
  201. template <typename T, enable_if_seriliable_type_valid_t<T> = 0>
  202. graphStatus GetValue(T &t) {
  203. return t.Load(*this);
  204. }
  205. template <typename T, enable_if_seriliable_type_valid_t<T> = 0>
  206. graphStatus GetValue(vector<T> &t) {
  207. graphStatus status;
  208. t.clear();
  209. vector<NamedAttrs> attrs;
  210. status = this->GetValue(attrs);
  211. if (status != GRAPH_SUCCESS) {
  212. return status;
  213. }
  214. for (auto &attr : attrs) {
  215. T item;
  216. GeAttrValue val;
  217. (void)val.SetValue(attr);
  218. status = item.Load(val);
  219. if (status != GRAPH_SUCCESS) {
  220. return status;
  221. }
  222. t.push_back(item);
  223. }
  224. return GRAPH_SUCCESS;
  225. }
  226. template <typename T, typename DT, enable_if_type_valid_t<T> = 0>
  227. static GeAttrValue CreateFrom(DT &&val) {
  228. GeAttrValue valRet;
  229. (void)valRet.SetValue<T>(std::forward<DT>(val));
  230. return valRet;
  231. }
  232. template <typename T, typename DT, enable_if_vector_type_valid_t<T> = 0>
  233. static GeAttrValue CreateFrom(std::initializer_list<DT> &&val) {
  234. GeAttrValue valRet;
  235. (void)valRet.SetValue<T>(std::move(val));
  236. return valRet;
  237. }
  238. template <typename T, enable_if_seriliable_type_valid_t<T> = 0>
  239. static GeAttrValue CreateFrom(const T &val) {
  240. GeAttrValue valRet;
  241. (void)valRet.SetValue(val);
  242. return valRet;
  243. }
  244. template <typename T, enable_if_seriliable_type_valid_t<T> = 0>
  245. static GeAttrValue CreateFrom(const vector<T> &val) {
  246. GeAttrValue valRet;
  247. (void)valRet.SetValue(val);
  248. return valRet;
  249. }
  250. ValueType GetValueType() const;
  251. bool IsEmpty() const;
  252. GeAttrValue Copy() const;
  253. // For map key
  254. bool operator==(const GeAttrValue &other) const { return value_ == other.value_; }
  255. graphStatus MutableTensor(GeTensorPtr &tensor);
  256. graphStatus MutableListTensor(vector<GeTensorPtr> &list_tensor);
  257. private:
  258. #define VALUE_SET_GET_DEC(DT) \
  259. graphStatus SetValue(const DT &val); \
  260. graphStatus GetValue(DT &val) const;
  261. VALUE_SET_GET_DEC(GeAttrValue::STR)
  262. VALUE_SET_GET_DEC(GeAttrValue::INT)
  263. VALUE_SET_GET_DEC(GeAttrValue::FLOAT)
  264. VALUE_SET_GET_DEC(GeAttrValue::BOOL)
  265. VALUE_SET_GET_DEC(GeTensorDesc)
  266. VALUE_SET_GET_DEC(GeAttrValue::TENSOR)
  267. VALUE_SET_GET_DEC(GeAttrValue::GRAPH)
  268. VALUE_SET_GET_DEC(BYTES)
  269. VALUE_SET_GET_DEC(NamedAttrs)
  270. VALUE_SET_GET_DEC(ge::DataType)
  271. VALUE_SET_GET_DEC(vector<GeAttrValue::STR>)
  272. VALUE_SET_GET_DEC(vector<GeAttrValue::INT>)
  273. VALUE_SET_GET_DEC(vector<GeAttrValue::FLOAT>)
  274. VALUE_SET_GET_DEC(vector<GeAttrValue::BOOL>)
  275. VALUE_SET_GET_DEC(vector<GeTensorDesc>)
  276. VALUE_SET_GET_DEC(vector<GeAttrValue::TENSOR>)
  277. VALUE_SET_GET_DEC(vector<GeAttrValue::GRAPH>)
  278. VALUE_SET_GET_DEC(vector<GeAttrValue::BYTES>)
  279. VALUE_SET_GET_DEC(vector<NamedAttrs>)
  280. VALUE_SET_GET_DEC(vector<vector<int64_t>>)
  281. VALUE_SET_GET_DEC(vector<ge::DataType>)
  282. #undef VALUE_SET_GET_DEC
  283. GeIrProtoHelper<proto::AttrDef> value_;
  284. GeAttrValue(const ProtoMsgOwner &proto_owner, ge::proto::AttrDef *val);
  285. friend class AttrHolder;
  286. friend class ModelSerializeImp;
  287. friend class OnnxUtils;
  288. };
  289. class AttrValueImpl {
  290. public:
  291. AttrValueImpl() = default;
  292. ~AttrValueImpl() = default;
  293. GeAttrValue geAttrValue_;
  294. };
  295. } // namespace ge
  296. #endif // INC_GRAPH_GE_ATTR_VALUE_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示

Contributors (1)