You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ge_tensor.h 5.3 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef INC_GRAPH_GE_TENSOR_H_
  17. #define INC_GRAPH_GE_TENSOR_H_
  18. #include <atomic>
  19. #include <memory>
  20. #include <string>
  21. #include <vector>
  22. #include "detail/attributes_holder.h"
  23. #include "graph/buffer.h"
  24. #include "graph/ge_error_codes.h"
  25. #include "graph/types.h"
  26. namespace ge {
  27. class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape {
  28. public:
  29. GeShape();
  30. ~GeShape() = default;
  31. explicit GeShape(std::vector<int64_t> s);
  32. size_t GetDimNum() const;
  33. // If the idx is invalid, return 0
  34. int64_t GetDim(size_t idx) const;
  35. graphStatus SetDim(size_t idx, int64_t value);
  36. std::vector<int64_t> GetDims() const;
  37. int64_t GetShapeSize() const;
  38. std::string ToString() const;
  39. GeShape(const GeShape &other);
  40. GeShape(GeShape &&other);
  41. GeShape &operator=(const GeShape &other);
  42. GeShape &operator=(GeShape &&other);
  43. private:
  44. GeIrProtoHelper<proto::ShapeDef> shape_def_;
  45. friend class GeTensorDesc;
  46. // Create geshape from proto obj
  47. GeShape(const ProtoMsgOwner &protoOnwer, proto::ShapeDef *protoMsg);
  48. void RefTo(const GeShape &shape) { shape_def_ = shape.shape_def_; }
  49. };
  50. class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc : public AttrHolder {
  51. friend class TensorUtils;
  52. friend class GeAttrValue;
  53. friend class ModelSerialize;
  54. public:
  55. GeTensorDesc();
  56. explicit GeTensorDesc(GeShape shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT);
  57. GeTensorDesc(const GeTensorDesc &desc);
  58. GeTensorDesc(GeTensorDesc &&desc);
  59. ~GeTensorDesc() = default;
  60. bool operator==(const GeTensorDesc &r_ge_tensor_desc) const;
  61. void Update(GeShape shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT);
  62. GeShape GetShape() const;
  63. GeShape &MutableShape();
  64. void SetShape(GeShape shape);
  65. GeShape GetOriginShape() const;
  66. void SetOriginShape(const GeShape &originShape);
  67. Format GetFormat() const;
  68. void SetFormat(Format format);
  69. Format GetOriginFormat() const;
  70. void SetOriginFormat(Format originFormat);
  71. DataType GetDataType() const;
  72. void SetDataType(DataType dt);
  73. void SetOriginDataType(DataType originDataType);
  74. DataType GetOriginDataType() const;
  75. GeTensorDesc Clone() const;
  76. GeTensorDesc &operator=(const GeTensorDesc &desc);
  77. GeTensorDesc &operator=(GeTensorDesc &&desc);
  78. graphStatus IsValid() const;
  79. protected:
  80. ProtoAttrMapHelper MutableAttrMap() override;
  81. ConstProtoAttrMapHelper GetAttrMap() const override;
  82. private:
  83. bool GeTensorDescAttrsAreEqual(const GeTensorDesc &r_ge_tensor_desc) const;
  84. using AttrHolder::DelAttr;
  85. using AttrHolder::GetAllAttrs;
  86. using AttrHolder::GetAttr;
  87. using AttrHolder::HasAttr;
  88. using AttrHolder::SetAttr;
  89. void Init();
  90. // Create getensordesc from proto obj
  91. GeTensorDesc(const ProtoMsgOwner &protoOnwer, proto::TensorDescriptor *protoMsg);
  92. friend class GeTensor;
  93. friend class GeAttrValueImp;
  94. friend class ModelSerializeImp;
  95. friend class OnnxUtils;
  96. GeIrProtoHelper<proto::TensorDescriptor> tensor_descriptor_;
  97. // Reference from tensorDescriptor_, do not direct use
  98. mutable GeShape __shape_;
  99. void RefTo(const GeTensorDesc &tensorDesc) { tensor_descriptor_ = tensorDesc.tensor_descriptor_; }
  100. GeShape &ShapeReference() const;
  101. };
  102. class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensor {
  103. public:
  104. GeTensor();
  105. explicit GeTensor(const GeTensorDesc &tensorDesc);
  106. explicit GeTensor(const GeTensorDesc &tensorDesc, const std::vector<uint8_t> &data);
  107. explicit GeTensor(const GeTensorDesc &tensorDesc, const Buffer &data);
  108. explicit GeTensor(const GeTensorDesc &tensorDesc, const uint8_t *data, size_t size);
  109. explicit GeTensor(GeTensorDesc &&tensorDesc, std::vector<uint8_t> &&data);
  110. ~GeTensor() = default;
  111. GeTensorDesc GetTensorDesc() const;
  112. GeTensorDesc &MutableTensorDesc();
  113. void SetTensorDesc(const GeTensorDesc &tensorDesc);
  114. const Buffer GetData() const;
  115. Buffer MutableData();
  116. graphStatus SetData(std::vector<uint8_t> &&data);
  117. graphStatus SetData(const std::vector<uint8_t> &data);
  118. graphStatus SetData(const Buffer &data);
  119. graphStatus SetData(const uint8_t *data, size_t size);
  120. GeTensor Clone() const;
  121. // Share value
  122. GeTensor(const GeTensor &other);
  123. // Share value
  124. GeTensor &operator=(const GeTensor &other);
  125. private:
  126. friend class GeAttrValueImp;
  127. friend class ModelSerializeImp;
  128. friend class OnnxUtils;
  129. // Create getensor from proto obj
  130. GeTensor(const ProtoMsgOwner &protoOnwer, proto::TensorDef *protoMsg);
  131. GeIrProtoHelper<proto::TensorDef> tensor_def_;
  132. // Reference from tensorDef_, cab not use it directly
  133. mutable GeTensorDesc __desc_;
  134. GeTensorDesc &DescReference() const;
  135. };
  136. } // namespace ge
  137. #endif // INC_GRAPH_GE_TENSOR_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示

Contributors (1)