You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

attr_utils.h 9.0 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef INC_GRAPH_UTILS_ATTR_UTILS_H_
  17. #define INC_GRAPH_UTILS_ATTR_UTILS_H_
  18. #include <memory>
  19. #include <string>
  20. #include <vector>
  21. #include "graph/detail/attributes_holder.h"
  22. #include "graph/ge_attr_value.h"
  23. #include "graph/types.h"
  24. namespace ge {
  25. class OpDesc;
  26. using OpDescPtr = std::shared_ptr<OpDesc>;
  27. using ConstOpDescPtr = std::shared_ptr<const OpDesc>;
  28. class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils {
  29. public:
  30. class ConstAttrHolderAdapter;
  31. class AttrHolderAdapter;
  32. // Set
  33. static bool HasAttr(ConstAttrHolderAdapter &&obj, const string &name);
  34. static bool SetInt(AttrHolderAdapter &&obj, const string &name, const int64_t &value);
  35. static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector<int64_t> &value);
  36. static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector<uint32_t> &value);
  37. static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector<int32_t> &value);
  38. static bool SetListInt(AttrHolderAdapter &&obj, const string &name, std::initializer_list<int64_t> &&value);
  39. static bool SetFloat(AttrHolderAdapter &&obj, const string &name, const float &value);
  40. static bool SetListFloat(AttrHolderAdapter &&obj, const string &name, const vector<float> &value);
  41. static bool SetBool(AttrHolderAdapter &&obj, const string &name, const bool &value);
  42. static bool SetListBool(AttrHolderAdapter &&obj, const string &name, const vector<bool> &value);
  43. static bool SetStr(AttrHolderAdapter &&obj, const string &name, const string &value);
  44. static bool SetListStr(AttrHolderAdapter &&obj, const string &name, const vector<string> &value);
  45. static bool SetTensorDesc(AttrHolderAdapter &&obj, const string &name, const GeTensorDesc &value);
  46. static bool SetListTensorDesc(AttrHolderAdapter &&obj, const string &name, const vector<GeTensorDesc> &value);
  47. static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const GeTensorPtr &value);
  48. static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const ConstGeTensorPtr &value);
  49. static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const GeTensor &value);
  50. static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector<GeTensorPtr> &value);
  51. static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector<ConstGeTensorPtr> &value);
  52. static bool SetListTensor(AttrHolderAdapter &&obj, const string &name,
  53. std::initializer_list<ConstGeTensorPtr> &&value);
  54. static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector<GeTensor> &value);
  55. static bool SetGraph(AttrHolderAdapter &&obj, const string &name, const ComputeGraphPtr &value);
  56. static bool SetListGraph(AttrHolderAdapter &&obj, const string &name, const vector<ComputeGraphPtr> &value);
  57. static bool SetBytes(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::BYTES &value);
  58. static bool SetListBytes(AttrHolderAdapter &&obj, const string &name, const vector<GeAttrValue::BYTES> &value);
  59. static bool SetNamedAttrs(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::NAMED_ATTRS &value);
  60. static bool SetListNamedAttrs(AttrHolderAdapter &&obj, const string &name,
  61. const vector<GeAttrValue::NAMED_ATTRS> &value);
  62. static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector<ConstOpDescPtr> &value);
  63. static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector<OpDescPtr> &value);
  64. // Get
  65. static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, int64_t &value);
  66. static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, int32_t &value);
  67. static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, uint32_t &value);
  68. static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector<int64_t> &value);
  69. static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector<int32_t> &value);
  70. static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector<uint32_t> &value);
  71. static bool GetFloat(ConstAttrHolderAdapter &&obj, const string &name, float &value);
  72. static bool GetListFloat(ConstAttrHolderAdapter &&obj, const string &name, vector<float> &value);
  73. static bool GetBool(ConstAttrHolderAdapter &&obj, const string &name, bool &value);
  74. static bool GetListBool(ConstAttrHolderAdapter &&obj, const string &name, vector<bool> &value);
  75. static bool GetStr(ConstAttrHolderAdapter &&obj, const string &name, string &value);
  76. static bool GetListStr(ConstAttrHolderAdapter &&obj, const string &name, vector<string> &value);
  77. static bool GetTensorDesc(ConstAttrHolderAdapter &&obj, const string &name, GeTensorDesc &value);
  78. static bool GetListTensorDesc(ConstAttrHolderAdapter &&obj, const string &name, vector<GeTensorDesc> &value);
  79. static bool GetTensor(ConstAttrHolderAdapter &&obj, const string &name, ConstGeTensorPtr &value);
  80. static bool MutableTensor(AttrHolderAdapter &&obj, const string &name, GeTensorPtr &value);
  81. static bool GetListTensor(ConstAttrHolderAdapter &&obj, const string &name, vector<ConstGeTensorPtr> &value);
  82. static bool MutableListTensor(AttrHolderAdapter &&obj, const string &name, vector<GeTensorPtr> &value);
  83. static bool GetGraph(ConstAttrHolderAdapter &&obj, const string &name, ComputeGraphPtr &value);
  84. static bool GetListGraph(ConstAttrHolderAdapter &&obj, const string &name, vector<ComputeGraphPtr> &value);
  85. static bool GetBytes(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::BYTES &value);
  86. static bool GetListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector<GeAttrValue::BYTES> &value);
  87. static bool GetNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::NAMED_ATTRS &value);
  88. static bool GetListNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name,
  89. vector<GeAttrValue::NAMED_ATTRS> &value);
  90. static bool GetListOpDesc(ConstAttrHolderAdapter &&obj, const string &name, vector<OpDescPtr> &value);
  91. // Value will be moved
  92. static bool SetZeroCopyBytes(AttrHolderAdapter &&obj, const string &name, Buffer &&buffer);
  93. static bool GetZeroCopyBytes(ConstAttrHolderAdapter &&obj, const string &name, Buffer &buffer);
  94. // Value will be moved
  95. static bool SetZeroCopyListBytes(AttrHolderAdapter &&obj, const string &name, vector<Buffer> &listBuffer);
  96. static bool GetZeroCopyListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector<Buffer> &listBuffer);
  97. static bool SetListListInt(AttrHolderAdapter &&obj, const string &name, const vector<vector<int64_t>> &value);
  98. static bool GetListListInt(ConstAttrHolderAdapter &&obj, const string &name, vector<vector<int64_t>> &value);
  99. static bool SetListDataType(AttrHolderAdapter &&obj, const string &name, const vector<ge::DataType> &value);
  100. static bool GetListDataType(ConstAttrHolderAdapter &&obj, const string &name, vector<ge::DataType> &value);
  101. static bool SetDataType(AttrHolderAdapter &&obj, const string &name, const ge::DataType &value);
  102. static bool GetDataType(ConstAttrHolderAdapter &&obj, const string &name, ge::DataType &value);
  103. static OpDescPtr CloneOpDesc(const ConstOpDescPtr &orgOpDesc);
  104. static OpDescPtr CopyOpDesc(const ConstOpDescPtr &orgOpDesc);
  105. static std::string GetAllAttrsStr(ConstAttrHolderAdapter &&obj);
  106. class AttrHolderAdapter {
  107. public:
  108. AttrHolderAdapter(AttrHolder *obj) : obj_(obj) {}
  109. ~AttrHolderAdapter() {}
  110. template <class T>
  111. AttrHolderAdapter(const std::shared_ptr<T> &obj) : obj_(obj.get()) {}
  112. AttrHolderAdapter(AttrHolder &obj) : obj_(&obj) {}
  113. operator bool() const { return obj_ != nullptr; }
  114. AttrHolder *operator->() { return obj_; }
  115. AttrHolder *get() { return obj_; }
  116. AttrHolder *obj_;
  117. };
  118. class ConstAttrHolderAdapter {
  119. public:
  120. ConstAttrHolderAdapter(const AttrHolder *obj) : obj_(obj) {}
  121. ~ConstAttrHolderAdapter() {}
  122. template <class T>
  123. ConstAttrHolderAdapter(const std::shared_ptr<T> obj) : obj_(obj.get()) {}
  124. ConstAttrHolderAdapter(const AttrHolder &obj) : obj_(&obj) {}
  125. operator bool() const { return obj_ != nullptr; }
  126. const AttrHolder *operator->() const { return obj_; }
  127. const AttrHolder *get() const { return obj_; }
  128. private:
  129. const AttrHolder *obj_;
  130. };
  131. };
  132. } // namespace ge
  133. #endif // INC_GRAPH_UTILS_ATTR_UTILS_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示