You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ge_tensor.h 5.7 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef INC_GRAPH_GE_TENSOR_H_
  17. #define INC_GRAPH_GE_TENSOR_H_
  18. #include <atomic>
  19. #include <memory>
  20. #include <string>
  21. #include <vector>
  22. #include "detail/attributes_holder.h"
  23. #include "graph/buffer.h"
  24. #include "graph/ge_error_codes.h"
  25. #include "graph/types.h"
  26. namespace ge {
  27. class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape {
  28. public:
  29. GeShape();
  30. ~GeShape() = default;
  31. explicit GeShape(std::vector<int64_t> s);
  32. size_t GetDimNum() const;
  33. // If the idx is invalid, return 0
  34. int64_t GetDim(size_t idx) const;
  35. graphStatus SetDim(size_t idx, int64_t value);
  36. std::vector<int64_t> GetDims() const;
  37. int64_t GetShapeSize() const;
  38. std::string ToString() const;
  39. ///
  40. /// @brief Check is unknown shape
  41. /// @return bool
  42. ///
  43. bool IsUnknownShape() const;
  44. ///
  45. /// @brief Check is a scalar
  46. /// @return bool
  47. ///
  48. bool IsScalar() const;
  49. GeShape(const GeShape &other);
  50. GeShape(GeShape &&other);
  51. GeShape &operator=(const GeShape &other);
  52. GeShape &operator=(GeShape &&other);
  53. private:
  54. GeIrProtoHelper<proto::ShapeDef> shape_def_;
  55. friend class GeTensorDesc;
  56. // Create from proto obj
  57. GeShape(const ProtoMsgOwner &protoOnwer, proto::ShapeDef *protoMsg);
  58. void RefTo(const GeShape &shape) { shape_def_ = shape.shape_def_; }
  59. };
  60. class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc : public AttrHolder {
  61. friend class TensorUtils;
  62. friend class GeAttrValue;
  63. friend class ModelSerialize;
  64. public:
  65. GeTensorDesc();
  66. explicit GeTensorDesc(GeShape shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT);
  67. GeTensorDesc(const GeTensorDesc &desc);
  68. GeTensorDesc(GeTensorDesc &&desc);
  69. ~GeTensorDesc() = default;
  70. bool operator==(const GeTensorDesc &r_ge_tensor_desc) const;
  71. void Update(GeShape shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT);
  72. GeShape GetShape() const;
  73. GeShape &MutableShape();
  74. void SetShape(GeShape shape);
  75. // set shape with -2, it stand for unknown shape
  76. void SetUnknownDimNumShape();
  77. // for unknown shape
  78. graphStatus SetShapeRange(const std::vector<std::pair<int64_t, int64_t>> &range);
  79. graphStatus GetShapeRange(std::vector<std::pair<int64_t, int64_t>> &range) const;
  80. GeShape GetOriginShape() const;
  81. void SetOriginShape(const GeShape &originShape);
  82. Format GetFormat() const;
  83. void SetFormat(Format format);
  84. Format GetOriginFormat() const;
  85. void SetOriginFormat(Format originFormat);
  86. DataType GetDataType() const;
  87. void SetDataType(DataType dt);
  88. void SetOriginDataType(DataType originDataType);
  89. DataType GetOriginDataType() const;
  90. GeTensorDesc Clone() const;
  91. GeTensorDesc &operator=(const GeTensorDesc &desc);
  92. GeTensorDesc &operator=(GeTensorDesc &&desc);
  93. graphStatus IsValid() const;
  94. protected:
  95. ProtoAttrMapHelper MutableAttrMap() override;
  96. ConstProtoAttrMapHelper GetAttrMap() const override;
  97. private:
  98. bool GeTensorDescAttrsAreEqual(const GeTensorDesc &r_ge_tensor_desc) const;
  99. using AttrHolder::DelAttr;
  100. using AttrHolder::GetAllAttrs;
  101. using AttrHolder::GetAttr;
  102. using AttrHolder::HasAttr;
  103. using AttrHolder::SetAttr;
  104. void Init();
  105. // Create from proto obj
  106. GeTensorDesc(const ProtoMsgOwner &protoOnwer, proto::TensorDescriptor *protoMsg);
  107. friend class GeTensor;
  108. friend class GeAttrValueImp;
  109. friend class ModelSerializeImp;
  110. friend class OnnxUtils;
  111. GeIrProtoHelper<proto::TensorDescriptor> tensor_descriptor_;
  112. // Reference from tensorDescriptor_, do not direct use
  113. mutable GeShape __shape_;
  114. void RefTo(const GeTensorDesc &tensorDesc) { tensor_descriptor_ = tensorDesc.tensor_descriptor_; }
  115. GeShape &ShapeReference() const;
  116. };
  117. class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensor {
  118. public:
  119. GeTensor();
  120. explicit GeTensor(const GeTensorDesc &tensorDesc);
  121. explicit GeTensor(const GeTensorDesc &tensorDesc, const std::vector<uint8_t> &data);
  122. explicit GeTensor(const GeTensorDesc &tensorDesc, const Buffer &data);
  123. explicit GeTensor(const GeTensorDesc &tensorDesc, const uint8_t *data, size_t size);
  124. explicit GeTensor(GeTensorDesc &&tensorDesc, std::vector<uint8_t> &&data);
  125. ~GeTensor() = default;
  126. GeTensorDesc GetTensorDesc() const;
  127. GeTensorDesc &MutableTensorDesc();
  128. void SetTensorDesc(const GeTensorDesc &tensorDesc);
  129. const Buffer GetData() const;
  130. Buffer MutableData();
  131. graphStatus SetData(std::vector<uint8_t> &&data);
  132. graphStatus SetData(const std::vector<uint8_t> &data);
  133. graphStatus SetData(const Buffer &data);
  134. graphStatus SetData(const uint8_t *data, size_t size);
  135. GeTensor Clone() const;
  136. // Share value
  137. GeTensor(const GeTensor &other);
  138. // Share value
  139. GeTensor &operator=(const GeTensor &other);
  140. private:
  141. friend class GeAttrValueImp;
  142. friend class ModelSerializeImp;
  143. friend class OnnxUtils;
  144. // Create from proto obj
  145. GeTensor(const ProtoMsgOwner &protoOnwer, proto::TensorDef *protoMsg);
  146. GeIrProtoHelper<proto::TensorDef> tensor_def_;
  147. // Reference from tensorDef_, do not direct use
  148. mutable GeTensorDesc __desc_;
  149. GeTensorDesc &DescReference() const;
  150. };
  151. } // namespace ge
  152. #endif // INC_GRAPH_GE_TENSOR_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示