You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ge_tensor.cc 36 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/ge_tensor.h"
  17. #include <cstdlib>
  18. #include <cstring>
  19. #include <iostream>
  20. #include <map>
  21. #include "debug/ge_attr_define.h"
  22. #include "debug/ge_util.h"
  23. #include "framework/common/debug/ge_log.h"
  24. #include "graph/ge_attr_value.h"
  25. #include "graph/model_serialize.h"
  26. #include "proto/ge_ir.pb.h"
  27. #include "utils/attr_utils.h"
  28. #include "utils/ge_ir_utils.h"
  29. #include "utils/tensor_utils.h"
  30. #include "utils/type_utils.h"
  31. namespace ge {
  32. static const char *const kKeyDataTypeSelfDefined = "__tensor_desc_data_type__";
  33. static const std::map<DataType, ::ge::proto::DataType> kDataTypeMap = {
  34. {DT_UNDEFINED, proto::DT_UNDEFINED},
  35. {DT_FLOAT, proto::DT_FLOAT},
  36. {DT_FLOAT16, proto::DT_FLOAT16},
  37. {DT_INT8, proto::DT_INT8},
  38. {DT_UINT8, proto::DT_UINT8},
  39. {DT_INT16, proto::DT_INT16},
  40. {DT_UINT16, proto::DT_UINT16},
  41. {DT_INT32, proto::DT_INT32},
  42. {DT_INT64, proto::DT_INT64},
  43. {DT_UINT32, proto::DT_UINT32},
  44. {DT_UINT64, proto::DT_UINT64},
  45. {DT_BOOL, proto::DT_BOOL},
  46. {DT_DOUBLE, proto::DT_DOUBLE},
  47. {DT_DUAL, proto::DT_DUAL},
  48. {DT_DUAL_SUB_INT8, proto::DT_DUAL_SUB_INT8},
  49. {DT_DUAL_SUB_UINT8, proto::DT_DUAL_SUB_UINT8},
  50. {DT_COMPLEX64, proto::DT_COMPLEX64},
  51. {DT_COMPLEX128, proto::DT_COMPLEX128},
  52. {DT_QINT8, proto::DT_QINT8},
  53. {DT_QINT16, proto::DT_QINT16},
  54. {DT_QINT32, proto::DT_QINT32},
  55. {DT_QUINT8, proto::DT_QUINT8},
  56. {DT_QUINT16, proto::DT_QUINT16},
  57. {DT_RESOURCE, proto::DT_RESOURCE},
  58. {DT_STRING_REF, proto::DT_STRING_REF},
  59. {DT_STRING, proto::DT_STRING},
  60. };
  61. static const std::map<DataType, int> kDataTypeSelfDefinedMap = {
  62. {DT_DUAL, 13}, {DT_DUAL_SUB_INT8, 14}, {DT_DUAL_SUB_UINT8, 15}, {DT_COMPLEX64, 16}, {DT_COMPLEX128, 17},
  63. {DT_QINT8, 18}, {DT_QINT16, 19}, {DT_QINT32, 20}, {DT_QUINT8, 21}, {DT_QUINT16, 22},
  64. };
  65. GeShape::GeShape() { shape_def_.InitDefault(); }
  66. // Default
  67. GeShape::GeShape(std::vector<int64_t> s) : GeShape() {
  68. auto proto_msg = shape_def_.GetProtoMsg();
  69. if (proto_msg != nullptr) {
  70. for (auto i : s) {
  71. proto_msg->add_dim(i);
  72. }
  73. }
  74. }
  75. size_t GeShape::GetDimNum() const {
  76. auto proto_msg = shape_def_.GetProtoMsg();
  77. if (proto_msg != nullptr) {
  78. if (proto_msg->dim_size() >= 0) {
  79. // check whether contain -2, if true, return -1
  80. for (auto i : proto_msg->dim()) {
  81. if (i == UNKNOWN_DIM_NUM) {
  82. return 0;
  83. }
  84. }
  85. return proto_msg->dim_size();
  86. } else {
  87. return 0;
  88. }
  89. }
  90. return 0;
  91. }
  92. int64_t GeShape::GetDim(size_t idx) const {
  93. auto proto_msg = shape_def_.GetProtoMsg();
  94. if (proto_msg != nullptr) {
  95. if (proto_msg->dim_size() > static_cast<int>(idx)) {
  96. return proto_msg->dim(static_cast<int>(idx));
  97. }
  98. }
  99. return 0;
  100. }
  101. graphStatus GeShape::SetDim(size_t idx, int64_t value) {
  102. auto proto_msg = shape_def_.GetProtoMsg();
  103. if (proto_msg != nullptr) {
  104. auto dims = proto_msg->mutable_dim();
  105. GE_CHECK_NOTNULL(dims);
  106. if (dims->empty()) {
  107. GELOGE(GRAPH_FAILED, "shape is empty");
  108. return GRAPH_FAILED;
  109. }
  110. if (static_cast<int>(idx) >= dims->size()) {
  111. GELOGE(GRAPH_FAILED, "idx is out of range");
  112. return GRAPH_FAILED;
  113. }
  114. proto_msg->set_dim(static_cast<int>(idx), value);
  115. }
  116. return GRAPH_SUCCESS;
  117. }
  118. std::vector<int64_t> GeShape::GetDims() const {
  119. vector<int64_t> dims;
  120. auto proto_msg = shape_def_.GetProtoMsg();
  121. if (proto_msg != nullptr) {
  122. for (auto i : proto_msg->dim()) {
  123. dims.push_back(i);
  124. }
  125. }
  126. return dims;
  127. }
  128. std::string GeShape::ToString() const {
  129. auto proto_msg = shape_def_.GetProtoMsg();
  130. if (proto_msg == nullptr) {
  131. return "";
  132. }
  133. std::stringstream ss;
  134. bool first = true;
  135. for (auto i : proto_msg->dim()) {
  136. if (first) {
  137. first = false;
  138. } else {
  139. ss << ",";
  140. }
  141. ss << i;
  142. }
  143. return ss.str();
  144. }
  145. int64_t GeShape::GetShapeSize() const {
  146. int64_t res = 1;
  147. auto proto_msg = shape_def_.GetProtoMsg();
  148. if (proto_msg != nullptr) {
  149. if (proto_msg->dim().empty()) {
  150. return 0;
  151. }
  152. for (auto i : proto_msg->dim()) {
  153. // if unknown shape, return -1
  154. if (i == UNKNOWN_DIM || i == UNKNOWN_DIM_NUM) {
  155. return UNKNOWN_DIM;
  156. }
  157. res *= i;
  158. }
  159. }
  160. return res;
  161. }
  162. ///
  163. /// @brief Check is unknown shape
  164. /// @return bool
  165. /// ///
  166. bool GeShape::IsUnknownShape() const {
  167. auto proto_msg = shape_def_.GetProtoMsg();
  168. if (proto_msg != nullptr) {
  169. for (auto i : proto_msg->dim()) {
  170. if (i < 0) {
  171. return true;
  172. }
  173. }
  174. }
  175. return false;
  176. }
  177. ///
  178. /// @brief Check is a scalar
  179. /// @return bool
  180. ///
  181. bool GeShape::IsScalar() const {
  182. auto proto_msg = shape_def_.GetProtoMsg();
  183. if (proto_msg != nullptr) {
  184. return proto_msg->dim().empty();
  185. }
  186. return false;
  187. }
  188. const string TENSOR_UTILS_SIZE = "size";
  189. const string TENSOR_UTILS_WEIGHT_SIZE = "weight_size";
  190. const string TENSOR_UTILS_REUSE_INPUT = "reuse_input";
  191. const string TENSOR_UTILS_OUTPUT_TENSOR = "output_tensor";
  192. const string TENSOR_UTILS_DEVICE_TYPE = "device_type";
  193. const string TENSOR_UTILS_INPUT_TENSOR = "input_tensor";
  194. const string TENSOR_UTILS_REAL_DIM_CNT = "real_dim_cnt";
  195. const string TENSOR_UTILS_REUSE_INPUT_INDEX = "reuse_input_index";
  196. const string TENSOR_UTILS_DATA_OFFSET = "data_offset";
  197. const string TENSOR_UTILS_CMPS_SIZE = "cmps_size";
  198. const string TENSOR_UTILS_CMPS_TAB = "cmps_tab";
  199. const string TENSOR_UTILS_CMPS_TAB_OFFSET = "cmps_tab_offset";
  200. const string TENSOR_UTILS_CMPSINFO = "cmps_info";
  201. const string TENSOR_UTILS_ALLOFFSET_QUANTIZE_INFO = "alloffset_quantize_info";
  202. const string TENSOR_UTILS_RC = "rc";
  203. const string TENSOR_UTILS_ORIGIN_SHAPE = "origin_shape";
  204. const string TENSOR_UTILS_ORIGIN_FORMAT = "origin_format";
  205. const string TENSOR_UTILS_ORIGIN_DATA_TYPE = "origin_data_type";
  206. const string TENSOR_UTILS_SHAPE_RANGE = "shape_range";
  207. GeShape::GeShape(const ProtoMsgOwner &proto_owner, proto::ShapeDef *proto_msg) : shape_def_(proto_owner, proto_msg) {}
  208. GeShape::GeShape(const GeShape &other) : GeShape() { shape_def_.CopyValueFrom(other.shape_def_); }
  209. GeShape::GeShape(GeShape &&other) : GeShape() { shape_def_.MoveValueFrom(std::move(other.shape_def_)); }
  210. GeShape &GeShape::operator=(const GeShape &other) {
  211. if (&other != this) {
  212. shape_def_.CopyValueFrom(other.shape_def_);
  213. }
  214. return *this;
  215. }
  216. GeShape &GeShape::operator=(GeShape &&other) {
  217. if (&other != this) {
  218. shape_def_.CopyValueFrom(std::move(other.shape_def_));
  219. }
  220. return *this;
  221. }
  222. GeTensorDesc::GeTensorDesc() {
  223. tensor_descriptor_.InitDefault();
  224. SetDataType(DT_FLOAT);
  225. Init();
  226. }
  227. // Default
  228. GeTensorDesc::GeTensorDesc(GeShape shape, Format format, DataType dt) : GeTensorDesc() {
  229. SetFormat(format);
  230. SetDataType(dt);
  231. ShapeReference() = std::move(shape);
  232. }
  233. // Default
  234. GeTensorDesc::GeTensorDesc(const GeTensorDesc &desc) : GeTensorDesc() {
  235. tensor_descriptor_.CopyValueFrom(desc.tensor_descriptor_);
  236. }
  237. // Default
  238. GeTensorDesc::GeTensorDesc(GeTensorDesc &&desc) : GeTensorDesc() {
  239. tensor_descriptor_.MoveValueFrom(std::move(desc.tensor_descriptor_));
  240. }
  241. GeTensorDesc::GeTensorDesc(const ProtoMsgOwner &proto_owner, proto::TensorDescriptor *proto_msg)
  242. : tensor_descriptor_(proto_owner, proto_msg) {
  243. if (proto_msg != nullptr && !proto_msg->has_out_attr()) {
  244. proto_msg->set_has_out_attr(true);
  245. int64_t size = 0;
  246. (void)AttrUtils::GetInt(this, TENSOR_UTILS_SIZE, size);
  247. proto_msg->set_size(size);
  248. int64_t weight_size = 0;
  249. (void)AttrUtils::GetInt(this, TENSOR_UTILS_WEIGHT_SIZE, weight_size);
  250. proto_msg->set_weight_size(weight_size);
  251. bool reuse_input = false;
  252. (void)AttrUtils::GetBool(this, TENSOR_UTILS_REUSE_INPUT, reuse_input);
  253. proto_msg->set_reuse_input(reuse_input);
  254. bool output_tensor = false;
  255. (void)AttrUtils::GetBool(this, TENSOR_UTILS_OUTPUT_TENSOR, output_tensor);
  256. proto_msg->set_output_tensor(output_tensor);
  257. string device_type = "NPU";
  258. (void)AttrUtils::GetStr(this, TENSOR_UTILS_DEVICE_TYPE, device_type);
  259. proto_msg->set_device_type(device_type);
  260. bool input_tensor = false;
  261. (void)AttrUtils::GetBool(this, TENSOR_UTILS_INPUT_TENSOR, input_tensor);
  262. proto_msg->set_input_tensor(input_tensor);
  263. int64_t real_dim_cnt = 0;
  264. (void)AttrUtils::GetInt(this, TENSOR_UTILS_REAL_DIM_CNT, real_dim_cnt);
  265. proto_msg->set_real_dim_cnt(real_dim_cnt);
  266. int64_t reuse_input_index = 0;
  267. (void)AttrUtils::GetInt(this, TENSOR_UTILS_REUSE_INPUT_INDEX, reuse_input_index);
  268. proto_msg->set_reuse_input_index(reuse_input_index);
  269. int64_t data_offset = 0;
  270. (void)AttrUtils::GetInt(this, TENSOR_UTILS_DATA_OFFSET, data_offset);
  271. proto_msg->set_data_offset(data_offset);
  272. int64_t cmps_size = 0;
  273. (void)AttrUtils::GetInt(this, TENSOR_UTILS_CMPS_SIZE, cmps_size);
  274. proto_msg->set_cmps_size(cmps_size);
  275. string cmps_tab;
  276. (void)AttrUtils::GetStr(this, TENSOR_UTILS_CMPS_TAB, cmps_tab);
  277. proto_msg->set_cmps_tab(cmps_tab);
  278. int64_t cmps_tab_offset = 0;
  279. (void)AttrUtils::GetInt(this, TENSOR_UTILS_CMPS_TAB_OFFSET, cmps_tab_offset);
  280. proto_msg->set_cmps_tab_offset(cmps_tab_offset);
  281. }
  282. }
  283. bool GeTensorDesc::GeTensorDescAttrsAreEqual(const GeTensorDesc &r_ge_tensor_desc) const {
  284. const auto &tensor_descriptor = this->tensor_descriptor_.GetProtoMsg();
  285. const auto &r_tensor_descriptor = r_ge_tensor_desc.tensor_descriptor_.GetProtoMsg();
  286. if ((tensor_descriptor != nullptr) && (r_tensor_descriptor != nullptr)) {
  287. // Message TensorDescriptor in ge_ir.proto
  288. return (
  289. IsEqual(tensor_descriptor->name(), r_tensor_descriptor->name(), "TensorDescriptor.name()") &&
  290. IsEqual(tensor_descriptor->dtype(), r_tensor_descriptor->dtype(), "TensorDescriptor.dtype()") &&
  291. // Message ShapeDef in ge_ir.proto
  292. IsEqual(ToString(tensor_descriptor->shape().dim()), ToString(r_tensor_descriptor->shape().dim()),
  293. "TensorDescriptor.shape().dim()") &&
  294. IsEqual(tensor_descriptor->layout(), r_tensor_descriptor->layout(), "TensorDescriptor.layout()") &&
  295. IsEqual(tensor_descriptor->has_out_attr(), r_tensor_descriptor->has_out_attr(),
  296. "TensorDescriptor.has_out_attr()") &&
  297. IsEqual(tensor_descriptor->size(), r_tensor_descriptor->size(), "TensorDescriptor.size()") &&
  298. IsEqual(tensor_descriptor->weight_size(), r_tensor_descriptor->weight_size(), "TensorDescriptor.weight_size()") &&
  299. IsEqual(tensor_descriptor->reuse_input(), r_tensor_descriptor->reuse_input(), "TensorDescriptor.reuse_input()") &&
  300. IsEqual(tensor_descriptor->output_tensor(), r_tensor_descriptor->output_tensor(),
  301. "TensorDescriptor.output_tensor()") &&
  302. IsEqual(tensor_descriptor->device_type(), r_tensor_descriptor->device_type(), "TensorDescriptor.device_type()") &&
  303. IsEqual(tensor_descriptor->input_tensor(), r_tensor_descriptor->input_tensor(),
  304. "TensorDescriptor.input_tensor()") &&
  305. IsEqual(tensor_descriptor->real_dim_cnt(), r_tensor_descriptor->real_dim_cnt(),
  306. "TensorDescriptor.real_dim_cnt()") &&
  307. IsEqual(tensor_descriptor->reuse_input_index(), r_tensor_descriptor->reuse_input_index(),
  308. "TensorDescriptor.reuse_input_index()") &&
  309. IsEqual(tensor_descriptor->data_offset(), r_tensor_descriptor->data_offset(), "TensorDescriptor.data_offset()") &&
  310. IsEqual(tensor_descriptor->cmps_size(), r_tensor_descriptor->cmps_size(), "TensorDescriptor.cmps_size()") &&
  311. IsEqual(tensor_descriptor->cmps_tab(), r_tensor_descriptor->cmps_tab(), "TensorDescriptor.cmps_tab()") &&
  312. IsEqual(tensor_descriptor->cmps_tab_offset(), r_tensor_descriptor->cmps_tab_offset(),
  313. "TensorDescriptor.cmps_tab_offset()"));
  314. } else {
  315. return ((tensor_descriptor == nullptr) && (r_tensor_descriptor == nullptr));
  316. }
  317. }
  318. bool GeTensorDesc::operator==(const GeTensorDesc &r_ge_tensor_desc) const {
  319. return GeTensorDescAttrsAreEqual(r_ge_tensor_desc);
  320. }
  321. GeShape &GeTensorDesc::ShapeReference() const {
  322. if (tensor_descriptor_.GetProtoMsg() != nullptr) {
  323. GeShape refShape(tensor_descriptor_.GetProtoOwner(), tensor_descriptor_.GetProtoMsg()->mutable_shape());
  324. __shape_.RefTo(refShape);
  325. } else {
  326. GeShape refShape(tensor_descriptor_.GetProtoOwner(), nullptr);
  327. __shape_.RefTo(refShape);
  328. }
  329. return __shape_;
  330. }
  331. void GeTensorDesc::Init() {
  332. SetFormat(FORMAT_ND);
  333. SetOriginFormat(FORMAT_ND);
  334. TensorUtils::SetDeviceType(*this, DeviceType::NPU);
  335. if (tensor_descriptor_.GetProtoMsg() == nullptr) {
  336. GELOGE(GRAPH_FAILED, "ProtoType nullptr.");
  337. return;
  338. }
  339. tensor_descriptor_.GetProtoMsg()->set_has_out_attr(true);
  340. }
  341. ProtoAttrMapHelper GeTensorDesc::MutableAttrMap() {
  342. if (tensor_descriptor_.GetProtoMsg() != nullptr) {
  343. return ProtoAttrMapHelper(tensor_descriptor_.GetProtoOwner(), tensor_descriptor_.GetProtoMsg()->mutable_attr());
  344. }
  345. return ProtoAttrMapHelper(tensor_descriptor_.GetProtoOwner(), nullptr);
  346. }
  347. ConstProtoAttrMapHelper GeTensorDesc::GetAttrMap() const {
  348. if (tensor_descriptor_.GetProtoMsg() != nullptr) {
  349. return ConstProtoAttrMapHelper(tensor_descriptor_.GetProtoOwner(),
  350. tensor_descriptor_.GetProtoMsg()->mutable_attr());
  351. }
  352. return ConstProtoAttrMapHelper(tensor_descriptor_.GetProtoOwner(), nullptr);
  353. }
  354. void GeTensorDesc::Update(GeShape shape, Format format, DataType dt) {
  355. ShapeReference() = std::move(shape);
  356. SetFormat(format);
  357. SetDataType(dt);
  358. }
  359. GeShape GeTensorDesc::GetShape() const { return ShapeReference(); }
  360. GeShape &GeTensorDesc::MutableShape() { return ShapeReference(); }
  361. void GeTensorDesc::SetShape(GeShape shape) { ShapeReference() = std::move(shape); }
  362. // set shape with -2, it stand for unknown shape
  363. void GeTensorDesc::SetUnknownDimNumShape() { SetShape(GeShape({UNKNOWN_DIM_NUM})); }
  364. // for unknown shape
  365. graphStatus GeTensorDesc::SetShapeRange(const std::vector<std::pair<int64_t, int64_t>> &range) {
  366. std::vector<vector<int64_t>> shape_range;
  367. for (const auto &ele : range) {
  368. shape_range.emplace_back(std::vector<int64_t>({ele.first, ele.second}));
  369. }
  370. auto ret = AttrUtils::SetListListInt(this, TENSOR_UTILS_SHAPE_RANGE, shape_range);
  371. return ret ? GRAPH_SUCCESS : GRAPH_FAILED;
  372. }
  373. graphStatus GeTensorDesc::GetShapeRange(std::vector<std::pair<int64_t, int64_t>> &range) const {
  374. std::vector<vector<int64_t>> shape_range;
  375. (void)AttrUtils::GetListListInt(this, TENSOR_UTILS_SHAPE_RANGE, shape_range);
  376. for (const auto &ele : shape_range) {
  377. // here must be only two elemenet because pair
  378. if (ele.size() != 2) {
  379. GELOGE(GRAPH_FAILED, "shape_range must contain only 2 value but really is %lu", ele.size());
  380. return GRAPH_FAILED;
  381. }
  382. std::pair<int64_t, int64_t> pair({ele[0], ele[1]});
  383. range.push_back(pair);
  384. }
  385. return GRAPH_SUCCESS;
  386. }
  387. GeShape GeTensorDesc::GetOriginShape() const {
  388. vector<int64_t> origin_shape;
  389. if (!AttrUtils::GetListInt(this, TENSOR_UTILS_ORIGIN_SHAPE, origin_shape)) {
  390. return GeShape();
  391. }
  392. return GeShape(origin_shape);
  393. }
  394. void GeTensorDesc::SetOriginShape(const GeShape &origin_shape) {
  395. std::vector<int64_t> origin_shape_tmp = origin_shape.GetDims();
  396. (void)AttrUtils::SetListInt(this, TENSOR_UTILS_ORIGIN_SHAPE, origin_shape_tmp);
  397. }
  398. Format GeTensorDesc::GetFormat() const {
  399. auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg();
  400. if (tensor_descriptor_msg != nullptr) {
  401. return TypeUtils::SerialStringToFormat(tensor_descriptor_msg->layout());
  402. }
  403. return FORMAT_RESERVED;
  404. }
  405. void GeTensorDesc::SetFormat(Format format) {
  406. auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg();
  407. if (tensor_descriptor_msg != nullptr) {
  408. tensor_descriptor_msg->set_layout(TypeUtils::FormatToSerialString(format));
  409. }
  410. }
  411. Format GeTensorDesc::GetOriginFormat() const {
  412. std::string origin_format_str;
  413. if (!AttrUtils::GetStr(this, TENSOR_UTILS_ORIGIN_FORMAT, origin_format_str)) {
  414. // Can not get the certificate and it's not set, return directly
  415. return FORMAT_RESERVED;
  416. }
  417. if (origin_format_str == "RESERVED") {
  418. return FORMAT_RESERVED;
  419. }
  420. return TypeUtils::SerialStringToFormat(origin_format_str);
  421. }
  422. void GeTensorDesc::SetOriginFormat(Format origin_format) {
  423. std::string origin_format_str = "RESERVED";
  424. if (origin_format != FORMAT_RESERVED) {
  425. origin_format_str = TypeUtils::FormatToSerialString(origin_format);
  426. }
  427. (void)AttrUtils::SetStr(this, TENSOR_UTILS_ORIGIN_FORMAT, origin_format_str);
  428. }
  429. DataType GeTensorDesc::GetDataType() const {
  430. auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg();
  431. if (tensor_descriptor_msg == nullptr) {
  432. return DT_UNDEFINED;
  433. }
  434. auto &attr_map = *(tensor_descriptor_msg->mutable_attr());
  435. // Data type
  436. auto it_data_type = attr_map.find(kKeyDataTypeSelfDefined);
  437. if (it_data_type != attr_map.end()) {
  438. int64_t data_type_proto = it_data_type->second.i();
  439. for (auto it : kDataTypeSelfDefinedMap) {
  440. if (it.second == data_type_proto) {
  441. return it.first;
  442. }
  443. }
  444. } else {
  445. auto data_type_proto = tensor_descriptor_msg->dtype();
  446. for (auto it : kDataTypeMap) {
  447. if (it.second == data_type_proto) {
  448. return it.first;
  449. }
  450. }
  451. }
  452. return DT_UNDEFINED;
  453. }
  454. void GeTensorDesc::SetDataType(DataType dataType) {
  455. auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg();
  456. if (tensor_descriptor_msg == nullptr) {
  457. return;
  458. }
  459. auto &attr_maps = *(tensor_descriptor_msg->mutable_attr());
  460. (void)attr_maps.erase(kKeyDataTypeSelfDefined);
  461. // Data type
  462. auto it = kDataTypeMap.find(dataType);
  463. if (it != kDataTypeMap.end()) {
  464. tensor_descriptor_msg->set_dtype(it->second);
  465. return;
  466. }
  467. auto it2 = kDataTypeSelfDefinedMap.find(dataType);
  468. if (it2 != kDataTypeSelfDefinedMap.end()) {
  469. attr_maps[kKeyDataTypeSelfDefined].set_i(it2->second);
  470. }
  471. }
  472. void GeTensorDesc::SetOriginDataType(DataType origin_data_type) {
  473. std::string origin_data_type_str = "RESERVED";
  474. if (origin_data_type != DT_UNDEFINED) {
  475. origin_data_type_str = TypeUtils::DataTypeToSerialString(origin_data_type);
  476. }
  477. (void)AttrUtils::SetStr(this, TENSOR_UTILS_ORIGIN_DATA_TYPE, origin_data_type_str);
  478. }
  479. DataType GeTensorDesc::GetOriginDataType() const {
  480. std::string origin_data_type_str;
  481. if (!AttrUtils::GetStr(this, TENSOR_UTILS_ORIGIN_DATA_TYPE, origin_data_type_str)) {
  482. return DT_UNDEFINED;
  483. }
  484. if (origin_data_type_str == "RESERVED") {
  485. return DT_UNDEFINED;
  486. }
  487. return TypeUtils::SerialStringToDataType(origin_data_type_str);
  488. }
  489. graphStatus GeTensorDesc::IsValid() const {
  490. auto dtype = this->GetDataType();
  491. auto format = this->GetFormat();
  492. if (dtype == DT_UNDEFINED && format == FORMAT_RESERVED) {
  493. return GRAPH_PARAM_INVALID;
  494. }
  495. return GRAPH_SUCCESS;
  496. }
  497. GeTensorDesc GeTensorDesc::Clone() const { return *this; }
  498. GeTensorDesc &GeTensorDesc::operator=(const GeTensorDesc &desc) {
  499. if (&desc != this) {
  500. tensor_descriptor_.CopyValueFrom(desc.tensor_descriptor_);
  501. }
  502. return *this;
  503. }
  504. GeTensorDesc &GeTensorDesc::operator=(GeTensorDesc &&desc) {
  505. if (&desc != this) {
  506. tensor_descriptor_.CopyValueFrom(std::move(desc.tensor_descriptor_));
  507. }
  508. return *this;
  509. }
  510. GeTensor::GeTensor::GeTensor() {
  511. tensor_def_.InitDefault();
  512. // Default init desc
  513. DescReference() = GeTensorDesc();
  514. }
  515. GeTensor::GeTensor(const GeTensorDesc &tensor_desc) : GeTensor() { DescReference() = tensor_desc; }
  516. GeTensor::GeTensor(const GeTensorDesc &tensor_desc, const vector<uint8_t> &data) : GeTensor() {
  517. DescReference() = tensor_desc;
  518. auto proto_msg = tensor_def_.GetProtoMsg();
  519. if (proto_msg != nullptr) {
  520. proto_msg->set_data(data.data(), data.size());
  521. }
  522. }
  523. GeTensor::GeTensor(const GeTensorDesc &tensor_desc, const uint8_t *data, size_t size) : GeTensor() {
  524. DescReference() = tensor_desc;
  525. auto proto_msg = tensor_def_.GetProtoMsg();
  526. if (proto_msg != nullptr && data != nullptr) {
  527. proto_msg->set_data(data, size);
  528. }
  529. }
  530. GeTensor::GeTensor(GeTensorDesc &&tensor_desc, vector<uint8_t> &&data) : GeTensor() {
  531. DescReference() = std::move(tensor_desc);
  532. auto proto_msg = tensor_def_.GetProtoMsg();
  533. if (proto_msg != nullptr) {
  534. proto_msg->set_data(data.data(), data.size());
  535. }
  536. }
  537. GeTensor::GeTensor(const GeTensorDesc &tensor_desc, const Buffer &data) : GeTensor() {
  538. DescReference() = tensor_desc;
  539. auto proto_msg = tensor_def_.GetProtoMsg();
  540. if (proto_msg != nullptr) {
  541. if (data.size() == 0) {
  542. GELOGI("GetSize res is 0.");
  543. }
  544. if (data.data() == nullptr) {
  545. GELOGI("data addr is null.");
  546. }
  547. proto_msg->set_data(data.GetData(), data.GetSize());
  548. }
  549. }
  550. GeTensor::GeTensor(const ProtoMsgOwner &proto_owner, proto::TensorDef *proto_msg)
  551. : tensor_def_(proto_owner, proto_msg) {}
  552. GeTensorDesc GeTensor::GetTensorDesc() const { return DescReference(); }
  553. GeTensorDesc &GeTensor::MutableTensorDesc() { return DescReference(); }
  554. GeTensorDesc &GeTensor::DescReference() const {
  555. if (tensor_def_.GetProtoMsg() != nullptr) {
  556. GeTensorDesc tensor_desc(tensor_def_.GetProtoOwner(), tensor_def_.GetProtoMsg()->mutable_desc());
  557. __desc_.RefTo(tensor_desc);
  558. } else {
  559. GeTensorDesc tensor_desc(tensor_def_.GetProtoOwner(), nullptr);
  560. __desc_.RefTo(tensor_desc);
  561. }
  562. return __desc_;
  563. }
  564. void GeTensor::SetTensorDesc(const GeTensorDesc &tensor_desc) { DescReference() = tensor_desc; }
  565. const Buffer GeTensor::GetData() const {
  566. auto proto_msg = tensor_def_.GetProtoMsg();
  567. if (proto_msg != nullptr) {
  568. return Buffer(tensor_def_.GetProtoOwner(), proto_msg->mutable_data());
  569. }
  570. return Buffer();
  571. }
  572. Buffer GeTensor::MutableData() {
  573. auto proto_msg = tensor_def_.GetProtoMsg();
  574. if (proto_msg != nullptr) {
  575. return Buffer(tensor_def_.GetProtoOwner(), proto_msg->mutable_data());
  576. }
  577. return Buffer();
  578. }
  579. graphStatus GeTensor::SetData(vector<uint8_t> &&data) {
  580. auto proto_msg = tensor_def_.GetProtoMsg();
  581. GE_CHECK_NOTNULL(proto_msg);
  582. proto_msg->set_data(data.data(), data.size());
  583. return GRAPH_SUCCESS;
  584. }
  585. graphStatus GeTensor::SetData(const vector<uint8_t> &data) {
  586. auto proto_msg = tensor_def_.GetProtoMsg();
  587. GE_CHECK_NOTNULL(proto_msg);
  588. proto_msg->set_data(data.data(), data.size());
  589. return GRAPH_SUCCESS;
  590. }
  591. graphStatus GeTensor::SetData(const uint8_t *data, size_t size) {
  592. GE_CHECK_NOTNULL(data);
  593. auto proto_msg = tensor_def_.GetProtoMsg();
  594. GE_CHECK_NOTNULL(proto_msg);
  595. proto_msg->set_data(data, size);
  596. return GRAPH_SUCCESS;
  597. }
  598. graphStatus GeTensor::SetData(const Buffer &data) {
  599. auto proto_msg = tensor_def_.GetProtoMsg();
  600. GE_CHECK_NOTNULL(proto_msg);
  601. if (data.size() == 0) {
  602. GELOGI("GetSize res is 0.");
  603. }
  604. if (data.data() == nullptr) {
  605. GELOGI("data addr is null.");
  606. }
  607. proto_msg->set_data(data.data(), data.size());
  608. return GRAPH_SUCCESS;
  609. }
  610. GeTensor GeTensor::Clone() const {
  611. GeTensor tensor;
  612. tensor.tensor_def_.CopyValueFrom(tensor_def_);
  613. return tensor;
  614. }
  615. GeTensor::GeTensor(const GeTensor &other) { tensor_def_ = other.tensor_def_; }
  616. GeTensor &GeTensor::operator=(const GeTensor &other) {
  617. if (&other != this) {
  618. tensor_def_ = other.tensor_def_;
  619. }
  620. return *this;
  621. }
  622. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetSize(const GeTensorDesc &tensor_desc,
  623. int64_t &size) {
  624. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  625. GE_CHECK_NOTNULL(tensor_descriptor_msg);
  626. size = static_cast<int64_t>(tensor_descriptor_msg->size());
  627. return GRAPH_SUCCESS;
  628. }
  629. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetSize(GeTensorDesc &tensor_desc, int64_t size) {
  630. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  631. if (tensor_descriptor_msg != nullptr) {
  632. tensor_descriptor_msg->set_size(size);
  633. }
  634. }
  635. uint32_t TensorUtils::GetWeightSize(const GeTensorDesc &tensor_desc) {
  636. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  637. if (tensor_descriptor_msg != nullptr) {
  638. return static_cast<uint32_t>(tensor_descriptor_msg->weight_size());
  639. }
  640. return 0;
  641. }
  642. uint32_t TensorUtils::GetWeightSize(const GeTensor &tensor) { return GetWeightSize(tensor.GetTensorDesc()); }
  643. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t TensorUtils::GetWeightSize(const ConstGeTensorPtr &tensor_ptr) {
  644. if (tensor_ptr == nullptr) {
  645. return 0;
  646. }
  647. return GetWeightSize(*tensor_ptr);
  648. }
  649. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint8_t *TensorUtils::GetWeightAddr(const ConstGeTensorPtr &tensor_ptr,
  650. uint8_t *base) {
  651. if (tensor_ptr == nullptr) {
  652. GELOGE(GRAPH_FAILED, "tensor_ptr is null.");
  653. return nullptr;
  654. }
  655. return GetWeightAddr(*tensor_ptr, base);
  656. }
  657. uint8_t *TensorUtils::GetWeightAddr(const GeTensor &tensor, uint8_t *base) {
  658. if (base == nullptr) {
  659. GELOGE(GRAPH_FAILED, "base is null.");
  660. return nullptr;
  661. }
  662. int64_t weight_data_offset = 0;
  663. if (GetDataOffset(tensor.GetTensorDesc(), weight_data_offset) != GRAPH_SUCCESS) return nullptr;
  664. if (weight_data_offset == 0) {
  665. // The weight of offset 0 is still in const op, still get from ATTR_NAME_WEIGHTS.
  666. return const_cast<uint8_t *>(tensor.GetData().data());
  667. }
  668. return base + weight_data_offset;
  669. }
  670. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetWeightSize(GeTensorDesc &tensor_desc,
  671. uint32_t size) {
  672. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  673. if (tensor_descriptor_msg != nullptr) {
  674. tensor_descriptor_msg->set_weight_size(size);
  675. }
  676. }
  677. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetReuseInput(const GeTensorDesc &tensor_desc,
  678. bool &flag) {
  679. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  680. GE_CHECK_NOTNULL(tensor_descriptor_msg);
  681. flag = tensor_descriptor_msg->reuse_input();
  682. return GRAPH_SUCCESS;
  683. }
  684. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetReuseInput(GeTensorDesc &tensor_desc, bool flag) {
  685. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  686. if (tensor_descriptor_msg != nullptr) {
  687. tensor_descriptor_msg->set_reuse_input(flag);
  688. }
  689. }
  690. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetOutputTensor(const GeTensorDesc &tensor_desc,
  691. bool &flag) {
  692. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  693. GE_CHECK_NOTNULL(tensor_descriptor_msg);
  694. flag = tensor_descriptor_msg->output_tensor();
  695. return GRAPH_SUCCESS;
  696. }
  697. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetOutputTensor(GeTensorDesc &tensor_desc, bool flag) {
  698. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  699. if (tensor_descriptor_msg != nullptr) {
  700. tensor_descriptor_msg->set_output_tensor(flag);
  701. }
  702. }
  703. static map<uint32_t, string> device_to_str_map{
  704. {0, "NPU"},
  705. {1, "CPU"},
  706. };
  707. static map<string, uint32_t> str_to_device_map{
  708. {"NPU", 0},
  709. {"CPU", 1},
  710. };
  711. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetDeviceType(const GeTensorDesc &tensor_desc,
  712. DeviceType &type) {
  713. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  714. GE_CHECK_NOTNULL(tensor_descriptor_msg);
  715. string type_str = tensor_descriptor_msg->device_type();
  716. type = DeviceType(str_to_device_map[type_str]);
  717. return GRAPH_SUCCESS;
  718. }
  719. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetDeviceType(GeTensorDesc &tensor_desc,
  720. DeviceType type) {
  721. auto type_str = device_to_str_map[type];
  722. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  723. if (tensor_descriptor_msg != nullptr) {
  724. tensor_descriptor_msg->set_device_type(type_str);
  725. }
  726. }
  727. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetInputTensor(const GeTensorDesc &tensor_desc,
  728. bool &flag) {
  729. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  730. GE_CHECK_NOTNULL(tensor_descriptor_msg);
  731. flag = tensor_descriptor_msg->input_tensor();
  732. return GRAPH_SUCCESS;
  733. }
  734. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetInputTensor(GeTensorDesc &tensor_desc, bool flag) {
  735. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  736. if (tensor_descriptor_msg != nullptr) {
  737. tensor_descriptor_msg->set_input_tensor(flag);
  738. }
  739. }
  740. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetRealDimCnt(const GeTensorDesc &tensor_desc,
  741. uint32_t &cnt) {
  742. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  743. GE_CHECK_NOTNULL(tensor_descriptor_msg);
  744. cnt = static_cast<uint32_t>(tensor_descriptor_msg->real_dim_cnt());
  745. return GRAPH_SUCCESS;
  746. }
  747. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetRealDimCnt(GeTensorDesc &tensor_desc,
  748. uint32_t cnt) {
  749. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  750. if (tensor_descriptor_msg != nullptr) {
  751. tensor_descriptor_msg->set_real_dim_cnt(cnt);
  752. }
  753. }
  754. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  755. TensorUtils::GetReuseInputIndex(const GeTensorDesc &tensor_desc, uint32_t &idx) {
  756. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  757. GE_CHECK_NOTNULL(tensor_descriptor_msg);
  758. idx = static_cast<uint32_t>(tensor_descriptor_msg->reuse_input_index());
  759. return GRAPH_SUCCESS;
  760. }
  761. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetReuseInputIndex(GeTensorDesc &tensor_desc,
  762. uint32_t idx) {
  763. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  764. if (tensor_descriptor_msg != nullptr) {
  765. tensor_descriptor_msg->set_reuse_input_index(idx);
  766. }
  767. }
  768. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetDataOffset(const GeTensorDesc &tensor_desc,
  769. int64_t &offset) {
  770. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  771. if (tensor_descriptor_msg != nullptr) {
  772. offset = tensor_descriptor_msg->data_offset();
  773. return GRAPH_SUCCESS;
  774. } else {
  775. GELOGW("tensor_descriptor_msg is nullptr.");
  776. return GRAPH_FAILED;
  777. }
  778. }
  779. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetDataOffset(GeTensorDesc &tensor_desc,
  780. int64_t offset) {
  781. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  782. if (tensor_descriptor_msg != nullptr) {
  783. tensor_descriptor_msg->set_data_offset(offset);
  784. }
  785. }
  786. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetCmpsSize(const GeTensorDesc &tensor_desc,
  787. uint32_t &cmp_size) {
  788. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  789. if (tensor_descriptor_msg != nullptr) {
  790. cmp_size = static_cast<uint32_t>(tensor_descriptor_msg->cmps_size());
  791. }
  792. return GRAPH_SUCCESS;
  793. }
  794. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetCmpsSize(GeTensorDesc &tensor_desc,
  795. uint32_t cmp_size) {
  796. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  797. if (tensor_descriptor_msg != nullptr) {
  798. tensor_descriptor_msg->set_cmps_size(cmp_size);
  799. }
  800. }
  801. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetCmpsTab(const GeTensorDesc &tensor_desc,
  802. vector<uint8_t> &vec) {
  803. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  804. if (tensor_descriptor_msg != nullptr) {
  805. string str = tensor_descriptor_msg->cmps_tab();
  806. vec.assign(str.begin(), str.end());
  807. }
  808. return GRAPH_SUCCESS;
  809. }
  810. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetCmpsTab(GeTensorDesc &tensor_desc,
  811. const uint8_t *data, size_t size) {
  812. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  813. if (tensor_descriptor_msg != nullptr) {
  814. GE_CHK_BOOL_EXEC(data != nullptr, return, "data is null.");
  815. string str((const char *)data, size);
  816. tensor_descriptor_msg->set_cmps_tab(str);
  817. }
  818. }
  819. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  820. TensorUtils::GetCmpsTabOffset(const GeTensorDesc &tensor_desc, int64_t &tab_offset) {
  821. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  822. if (tensor_descriptor_msg != nullptr) {
  823. tab_offset = tensor_descriptor_msg->cmps_tab_offset();
  824. }
  825. return GRAPH_SUCCESS;
  826. }
  827. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetCmpsTabOffset(GeTensorDesc &tensor_desc,
  828. int64_t tab_offset) {
  829. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  830. if (tensor_descriptor_msg != nullptr) {
  831. tensor_descriptor_msg->set_cmps_tab_offset(tab_offset);
  832. }
  833. }
  834. graphStatus TensorUtils::GetCmpsInfo(const GeTensorDesc &tensor_desc, CompressInfo &info) {
  835. GeAttrValue attr_value;
  836. if (tensor_desc.GetAttr(TENSOR_UTILS_CMPSINFO, attr_value) != GRAPH_SUCCESS) {
  837. return GRAPH_FAILED;
  838. }
  839. return attr_value.GetValue(info);
  840. }
  841. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetCmpsInfo(GeTensorDesc &tensor_desc,
  842. const CompressInfo &info) {
  843. (void)tensor_desc.SetAttr(TENSOR_UTILS_CMPSINFO, GeAttrValue::CreateFrom(info));
  844. }
  845. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool TensorUtils::HasAlloffsetQuantizeInfo(
  846. const GeTensorDesc &tensor_desc) {
  847. return tensor_desc.HasAttr(TENSOR_UTILS_ALLOFFSET_QUANTIZE_INFO);
  848. }
  849. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  850. TensorUtils::GetAlloffsetQuantizeInfo(const GeTensorDesc &tensor_desc, AllOffsetQuantizeInfo &info) {
  851. GeAttrValue attr_value;
  852. if (tensor_desc.GetAttr(TENSOR_UTILS_ALLOFFSET_QUANTIZE_INFO, attr_value) != GRAPH_SUCCESS) {
  853. GELOGW("get attr alloffset_quantize_info fail.");
  854. }
  855. return attr_value.GetValue(info);
  856. }
  857. void TensorUtils::SetAlloffsetQuantizeInfo(GeTensorDesc &tensor_desc, const AllOffsetQuantizeInfo &info) {
  858. (void)tensor_desc.SetAttr(TENSOR_UTILS_ALLOFFSET_QUANTIZE_INFO, GeAttrValue::CreateFrom(info));
  859. }
  860. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetRC(const GeTensorDesc &tensor_desc,
  861. uint32_t &rc) {
  862. return AttrUtils::GetInt(&tensor_desc, TENSOR_UTILS_RC, rc) ? GRAPH_SUCCESS : GRAPH_FAILED;
  863. }
  864. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetRC(GeTensorDesc &tensor_desc, uint32_t rc) {
  865. (void)AttrUtils::SetInt(&tensor_desc, TENSOR_UTILS_RC, rc);
  866. }
  867. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示