You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ge_tensor.cc 37 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/ge_tensor.h"
  17. #include <cstdlib>
  18. #include <cstring>
  19. #include <iostream>
  20. #include <map>
  21. #include "debug/ge_attr_define.h"
  22. #include "debug/ge_util.h"
  23. #include "framework/common/debug/ge_log.h"
  24. #include "graph/ge_attr_value.h"
  25. #include "graph/model_serialize.h"
  26. #include "proto/ge_ir.pb.h"
  27. #include "utils/attr_utils.h"
  28. #include "utils/ge_ir_utils.h"
  29. #include "utils/tensor_utils.h"
  30. #include "utils/type_utils.h"
  31. namespace ge {
  32. namespace{
  33. const char *const kKeyDataTypeSelfDefined = "__tensor_desc_data_type__";
  34. const std::map<DataType, ::ge::proto::DataType> kDataTypeMap = {
  35. {DT_UNDEFINED, proto::DT_UNDEFINED},
  36. {DT_FLOAT, proto::DT_FLOAT},
  37. {DT_FLOAT16, proto::DT_FLOAT16},
  38. {DT_INT8, proto::DT_INT8},
  39. {DT_UINT8, proto::DT_UINT8},
  40. {DT_INT16, proto::DT_INT16},
  41. {DT_UINT16, proto::DT_UINT16},
  42. {DT_INT32, proto::DT_INT32},
  43. {DT_INT64, proto::DT_INT64},
  44. {DT_UINT32, proto::DT_UINT32},
  45. {DT_UINT64, proto::DT_UINT64},
  46. {DT_BOOL, proto::DT_BOOL},
  47. {DT_DOUBLE, proto::DT_DOUBLE},
  48. {DT_DUAL, proto::DT_DUAL},
  49. {DT_DUAL_SUB_INT8, proto::DT_DUAL_SUB_INT8},
  50. {DT_DUAL_SUB_UINT8, proto::DT_DUAL_SUB_UINT8},
  51. {DT_COMPLEX64, proto::DT_COMPLEX64},
  52. {DT_COMPLEX128, proto::DT_COMPLEX128},
  53. {DT_QINT8, proto::DT_QINT8},
  54. {DT_QINT16, proto::DT_QINT16},
  55. {DT_QINT32, proto::DT_QINT32},
  56. {DT_QUINT8, proto::DT_QUINT8},
  57. {DT_QUINT16, proto::DT_QUINT16},
  58. {DT_RESOURCE, proto::DT_RESOURCE},
  59. {DT_STRING_REF, proto::DT_STRING_REF},
  60. {DT_STRING, proto::DT_STRING},
  61. };
  62. const std::map<DataType, int> kDataTypeSelfDefinedMap = {
  63. {DT_DUAL, 13}, {DT_DUAL_SUB_INT8, 14}, {DT_DUAL_SUB_UINT8, 15}, {DT_COMPLEX64, 16}, {DT_COMPLEX128, 17},
  64. {DT_QINT8, 18}, {DT_QINT16, 19}, {DT_QINT32, 20}, {DT_QUINT8, 21}, {DT_QUINT16, 22},
  65. };
  66. }
  67. GeShape::GeShape() { shape_def_.InitDefault(); }
  68. // Default
  69. GeShape::GeShape(std::vector<int64_t> s) : GeShape() {
  70. auto proto_msg = shape_def_.GetProtoMsg();
  71. if (proto_msg != nullptr) {
  72. for (auto i : s) {
  73. proto_msg->add_dim(i);
  74. }
  75. }
  76. }
  77. size_t GeShape::GetDimNum() const {
  78. auto proto_msg = shape_def_.GetProtoMsg();
  79. if (proto_msg != nullptr) {
  80. if (proto_msg->dim_size() >= 0) {
  81. // check whether contain -2, if true, return -1
  82. for (auto i : proto_msg->dim()) {
  83. if (i == UNKNOWN_DIM_NUM) {
  84. return 0;
  85. }
  86. }
  87. return proto_msg->dim_size();
  88. } else {
  89. return 0;
  90. }
  91. }
  92. return 0;
  93. }
  94. int64_t GeShape::GetDim(size_t idx) const {
  95. auto proto_msg = shape_def_.GetProtoMsg();
  96. if (proto_msg != nullptr) {
  97. if (proto_msg->dim_size() > static_cast<int>(idx)) {
  98. return proto_msg->dim(static_cast<int>(idx));
  99. }
  100. }
  101. return 0;
  102. }
  103. graphStatus GeShape::SetDim(size_t idx, int64_t value) {
  104. auto proto_msg = shape_def_.GetProtoMsg();
  105. if (proto_msg != nullptr) {
  106. auto dims = proto_msg->mutable_dim();
  107. GE_CHECK_NOTNULL(dims);
  108. if (dims->empty()) {
  109. GELOGE(GRAPH_FAILED, "shape is empty");
  110. return GRAPH_FAILED;
  111. }
  112. if (static_cast<int>(idx) >= dims->size()) {
  113. GELOGE(GRAPH_FAILED, "idx is out of range");
  114. return GRAPH_FAILED;
  115. }
  116. proto_msg->set_dim(static_cast<int>(idx), value);
  117. }
  118. return GRAPH_SUCCESS;
  119. }
  120. std::vector<int64_t> GeShape::GetDims() const {
  121. vector<int64_t> dims;
  122. auto proto_msg = shape_def_.GetProtoMsg();
  123. if (proto_msg != nullptr) {
  124. for (auto i : proto_msg->dim()) {
  125. dims.push_back(i);
  126. }
  127. }
  128. return dims;
  129. }
  130. std::string GeShape::ToString() const {
  131. auto proto_msg = shape_def_.GetProtoMsg();
  132. if (proto_msg == nullptr) {
  133. return "";
  134. }
  135. std::stringstream ss;
  136. bool first = true;
  137. for (auto i : proto_msg->dim()) {
  138. if (first) {
  139. first = false;
  140. } else {
  141. ss << ",";
  142. }
  143. ss << i;
  144. }
  145. return ss.str();
  146. }
  147. int64_t GeShape::GetShapeSize() const {
  148. int64_t res = 1;
  149. auto proto_msg = shape_def_.GetProtoMsg();
  150. if (proto_msg != nullptr) {
  151. if (proto_msg->dim().empty()) {
  152. return 0;
  153. }
  154. for (auto i : proto_msg->dim()) {
  155. // if unknown shape, return -1
  156. if (i == UNKNOWN_DIM || i == UNKNOWN_DIM_NUM) {
  157. return UNKNOWN_DIM;
  158. }
  159. res *= i;
  160. }
  161. }
  162. return res;
  163. }
  164. ///
  165. /// @brief Check is unknown shape
  166. /// @return bool
  167. /// ///
  168. bool GeShape::IsUnknownShape() const {
  169. auto proto_msg = shape_def_.GetProtoMsg();
  170. if (proto_msg != nullptr) {
  171. for (auto i : proto_msg->dim()) {
  172. if (i < 0) {
  173. return true;
  174. }
  175. }
  176. }
  177. return false;
  178. }
  179. ///
  180. /// @brief Check is a scalar
  181. /// @return bool
  182. ///
  183. bool GeShape::IsScalar() const {
  184. auto proto_msg = shape_def_.GetProtoMsg();
  185. if (proto_msg != nullptr) {
  186. return proto_msg->dim().empty();
  187. }
  188. return false;
  189. }
  190. const string TENSOR_UTILS_SIZE = "size";
  191. const string TENSOR_UTILS_WEIGHT_SIZE = "weight_size";
  192. const string TENSOR_UTILS_REUSE_INPUT = "reuse_input";
  193. const string TENSOR_UTILS_OUTPUT_TENSOR = "output_tensor";
  194. const string TENSOR_UTILS_DEVICE_TYPE = "device_type";
  195. const string TENSOR_UTILS_INPUT_TENSOR = "input_tensor";
  196. const string TENSOR_UTILS_REAL_DIM_CNT = "real_dim_cnt";
  197. const string TENSOR_UTILS_REUSE_INPUT_INDEX = "reuse_input_index";
  198. const string TENSOR_UTILS_DATA_OFFSET = "data_offset";
  199. const string TENSOR_UTILS_CMPS_SIZE = "cmps_size";
  200. const string TENSOR_UTILS_CMPS_TAB = "cmps_tab";
  201. const string TENSOR_UTILS_CMPS_TAB_OFFSET = "cmps_tab_offset";
  202. const string TENSOR_UTILS_CMPSINFO = "cmps_info";
  203. const string TENSOR_UTILS_ALLOFFSET_QUANTIZE_INFO = "alloffset_quantize_info";
  204. const string TENSOR_UTILS_RC = "rc";
  205. const string TENSOR_UTILS_ORIGIN_SHAPE = "origin_shape";
  206. const string TENSOR_UTILS_ORIGIN_FORMAT = "origin_format";
  207. const string TENSOR_UTILS_ORIGIN_DATA_TYPE = "origin_data_type";
  208. const string TENSOR_UTILS_SHAPE_RANGE = "shape_range";
  209. const string TENSOR_UTILS_REF_PORT_INDEX = "ref_port_index";
  210. GeShape::GeShape(const ProtoMsgOwner &proto_owner, proto::ShapeDef *proto_msg) : shape_def_(proto_owner, proto_msg) {}
  211. GeShape::GeShape(const GeShape &other) : GeShape() { shape_def_.CopyValueFrom(other.shape_def_); }
  212. GeShape::GeShape(GeShape &&other) : GeShape() { shape_def_.MoveValueFrom(std::move(other.shape_def_)); }
  213. GeShape &GeShape::operator=(const GeShape &other) {
  214. if (&other != this) {
  215. shape_def_.CopyValueFrom(other.shape_def_);
  216. }
  217. return *this;
  218. }
  219. GeShape &GeShape::operator=(GeShape &&other) {
  220. if (&other != this) {
  221. shape_def_.CopyValueFrom(std::move(other.shape_def_));
  222. }
  223. return *this;
  224. }
  225. GeTensorDesc::GeTensorDesc() {
  226. tensor_descriptor_.InitDefault();
  227. SetDataType(DT_FLOAT);
  228. Init();
  229. }
  230. // Default
  231. GeTensorDesc::GeTensorDesc(GeShape shape, Format format, DataType dt) : GeTensorDesc() {
  232. SetFormat(format);
  233. SetDataType(dt);
  234. ShapeReference() = std::move(shape);
  235. }
  236. // Default
  237. GeTensorDesc::GeTensorDesc(const GeTensorDesc &desc) : GeTensorDesc() {
  238. tensor_descriptor_.CopyValueFrom(desc.tensor_descriptor_);
  239. }
  240. // Default
  241. GeTensorDesc::GeTensorDesc(GeTensorDesc &&desc) : GeTensorDesc() {
  242. tensor_descriptor_.MoveValueFrom(std::move(desc.tensor_descriptor_));
  243. }
  244. GeTensorDesc::GeTensorDesc(const ProtoMsgOwner &proto_owner, proto::TensorDescriptor *proto_msg)
  245. : tensor_descriptor_(proto_owner, proto_msg) {
  246. if (proto_msg != nullptr && !proto_msg->has_out_attr()) {
  247. proto_msg->set_has_out_attr(true);
  248. int64_t size = 0;
  249. (void)AttrUtils::GetInt(this, TENSOR_UTILS_SIZE, size);
  250. proto_msg->set_size(size);
  251. int64_t weight_size = 0;
  252. (void)AttrUtils::GetInt(this, TENSOR_UTILS_WEIGHT_SIZE, weight_size);
  253. proto_msg->set_weight_size(weight_size);
  254. bool reuse_input = false;
  255. (void)AttrUtils::GetBool(this, TENSOR_UTILS_REUSE_INPUT, reuse_input);
  256. proto_msg->set_reuse_input(reuse_input);
  257. bool output_tensor = false;
  258. (void)AttrUtils::GetBool(this, TENSOR_UTILS_OUTPUT_TENSOR, output_tensor);
  259. proto_msg->set_output_tensor(output_tensor);
  260. string device_type = "NPU";
  261. (void)AttrUtils::GetStr(this, TENSOR_UTILS_DEVICE_TYPE, device_type);
  262. proto_msg->set_device_type(device_type);
  263. bool input_tensor = false;
  264. (void)AttrUtils::GetBool(this, TENSOR_UTILS_INPUT_TENSOR, input_tensor);
  265. proto_msg->set_input_tensor(input_tensor);
  266. int64_t real_dim_cnt = 0;
  267. (void)AttrUtils::GetInt(this, TENSOR_UTILS_REAL_DIM_CNT, real_dim_cnt);
  268. proto_msg->set_real_dim_cnt(real_dim_cnt);
  269. int64_t reuse_input_index = 0;
  270. (void)AttrUtils::GetInt(this, TENSOR_UTILS_REUSE_INPUT_INDEX, reuse_input_index);
  271. proto_msg->set_reuse_input_index(reuse_input_index);
  272. int64_t data_offset = 0;
  273. (void)AttrUtils::GetInt(this, TENSOR_UTILS_DATA_OFFSET, data_offset);
  274. proto_msg->set_data_offset(data_offset);
  275. int64_t cmps_size = 0;
  276. (void)AttrUtils::GetInt(this, TENSOR_UTILS_CMPS_SIZE, cmps_size);
  277. proto_msg->set_cmps_size(cmps_size);
  278. string cmps_tab;
  279. (void)AttrUtils::GetStr(this, TENSOR_UTILS_CMPS_TAB, cmps_tab);
  280. proto_msg->set_cmps_tab(cmps_tab);
  281. int64_t cmps_tab_offset = 0;
  282. (void)AttrUtils::GetInt(this, TENSOR_UTILS_CMPS_TAB_OFFSET, cmps_tab_offset);
  283. proto_msg->set_cmps_tab_offset(cmps_tab_offset);
  284. }
  285. }
  286. bool GeTensorDesc::GeTensorDescAttrsAreEqual(const GeTensorDesc &r_ge_tensor_desc) const {
  287. const auto &tensor_descriptor = this->tensor_descriptor_.GetProtoMsg();
  288. const auto &r_tensor_descriptor = r_ge_tensor_desc.tensor_descriptor_.GetProtoMsg();
  289. if ((tensor_descriptor != nullptr) && (r_tensor_descriptor != nullptr)) {
  290. // Message TensorDescriptor in ge_ir.proto
  291. return (IsEqual(tensor_descriptor->name(), r_tensor_descriptor->name(), "TensorDescriptor.name()") &&
  292. IsEqual(tensor_descriptor->dtype(), r_tensor_descriptor->dtype(), "TensorDescriptor.dtype()") &&
  293. // Message ShapeDef in ge_ir.proto
  294. IsEqual(ToString(tensor_descriptor->shape().dim()), ToString(r_tensor_descriptor->shape().dim()),
  295. "TensorDescriptor.shape().dim()") &&
  296. IsEqual(tensor_descriptor->layout(), r_tensor_descriptor->layout(), "TensorDescriptor.layout()") &&
  297. IsEqual(tensor_descriptor->has_out_attr(), r_tensor_descriptor->has_out_attr(),
  298. "TensorDescriptor.has_out_attr()") &&
  299. IsEqual(tensor_descriptor->size(), r_tensor_descriptor->size(), "TensorDescriptor.size()") &&
  300. IsEqual(tensor_descriptor->weight_size(), r_tensor_descriptor->weight_size(),
  301. "TensorDescriptor.weight_size()") &&
  302. IsEqual(tensor_descriptor->reuse_input(), r_tensor_descriptor->reuse_input(),
  303. "TensorDescriptor.reuse_input()") &&
  304. IsEqual(tensor_descriptor->output_tensor(), r_tensor_descriptor->output_tensor(),
  305. "TensorDescriptor.output_tensor()") &&
  306. IsEqual(tensor_descriptor->device_type(), r_tensor_descriptor->device_type(),
  307. "TensorDescriptor.device_type()") &&
  308. IsEqual(tensor_descriptor->input_tensor(), r_tensor_descriptor->input_tensor(),
  309. "TensorDescriptor.input_tensor()") &&
  310. IsEqual(tensor_descriptor->real_dim_cnt(), r_tensor_descriptor->real_dim_cnt(),
  311. "TensorDescriptor.real_dim_cnt()") &&
  312. IsEqual(tensor_descriptor->reuse_input_index(), r_tensor_descriptor->reuse_input_index(),
  313. "TensorDescriptor.reuse_input_index()") &&
  314. IsEqual(tensor_descriptor->data_offset(), r_tensor_descriptor->data_offset(),
  315. "TensorDescriptor.data_offset()") &&
  316. IsEqual(tensor_descriptor->cmps_size(), r_tensor_descriptor->cmps_size(), "TensorDescriptor.cmps_size()") &&
  317. IsEqual(tensor_descriptor->cmps_tab(), r_tensor_descriptor->cmps_tab(), "TensorDescriptor.cmps_tab()") &&
  318. IsEqual(tensor_descriptor->cmps_tab_offset(), r_tensor_descriptor->cmps_tab_offset(),
  319. "TensorDescriptor.cmps_tab_offset()"));
  320. } else {
  321. return ((tensor_descriptor == nullptr) && (r_tensor_descriptor == nullptr));
  322. }
  323. }
  324. bool GeTensorDesc::operator==(const GeTensorDesc &r_ge_tensor_desc) const {
  325. return GeTensorDescAttrsAreEqual(r_ge_tensor_desc);
  326. }
  327. GeShape &GeTensorDesc::ShapeReference() const {
  328. if (tensor_descriptor_.GetProtoMsg() != nullptr) {
  329. GeShape refShape(tensor_descriptor_.GetProtoOwner(), tensor_descriptor_.GetProtoMsg()->mutable_shape());
  330. __shape_.RefTo(refShape);
  331. } else {
  332. GeShape refShape(tensor_descriptor_.GetProtoOwner(), nullptr);
  333. __shape_.RefTo(refShape);
  334. }
  335. return __shape_;
  336. }
  337. void GeTensorDesc::Init() {
  338. SetFormat(FORMAT_ND);
  339. SetOriginFormat(FORMAT_ND);
  340. TensorUtils::SetDeviceType(*this, DeviceType::NPU);
  341. if (tensor_descriptor_.GetProtoMsg() == nullptr) {
  342. GELOGE(GRAPH_FAILED, "ProtoType nullptr.");
  343. return;
  344. }
  345. tensor_descriptor_.GetProtoMsg()->set_has_out_attr(true);
  346. }
  347. ProtoAttrMapHelper GeTensorDesc::MutableAttrMap() {
  348. if (tensor_descriptor_.GetProtoMsg() != nullptr) {
  349. return ProtoAttrMapHelper(tensor_descriptor_.GetProtoOwner(), tensor_descriptor_.GetProtoMsg()->mutable_attr());
  350. }
  351. return ProtoAttrMapHelper(tensor_descriptor_.GetProtoOwner(), nullptr);
  352. }
  353. ConstProtoAttrMapHelper GeTensorDesc::GetAttrMap() const {
  354. if (tensor_descriptor_.GetProtoMsg() != nullptr) {
  355. return ConstProtoAttrMapHelper(tensor_descriptor_.GetProtoOwner(),
  356. tensor_descriptor_.GetProtoMsg()->mutable_attr());
  357. }
  358. return ConstProtoAttrMapHelper(tensor_descriptor_.GetProtoOwner(), nullptr);
  359. }
  360. void GeTensorDesc::Update(GeShape shape, Format format, DataType dt) {
  361. ShapeReference() = std::move(shape);
  362. SetFormat(format);
  363. SetDataType(dt);
  364. }
  365. GeShape GeTensorDesc::GetShape() const { return ShapeReference(); }
  366. GeShape &GeTensorDesc::MutableShape() { return ShapeReference(); }
  367. void GeTensorDesc::SetShape(GeShape shape) { ShapeReference() = std::move(shape); }
  368. // set shape with -2, it stand for unknown shape
  369. void GeTensorDesc::SetUnknownDimNumShape() { SetShape(GeShape({UNKNOWN_DIM_NUM})); }
  370. // for unknown shape
  371. graphStatus GeTensorDesc::SetShapeRange(const std::vector<std::pair<int64_t, int64_t>> &range) {
  372. std::vector<vector<int64_t>> shape_range;
  373. for (const auto &ele : range) {
  374. shape_range.emplace_back(std::vector<int64_t>({ele.first, ele.second}));
  375. }
  376. auto ret = AttrUtils::SetListListInt(this, TENSOR_UTILS_SHAPE_RANGE, shape_range);
  377. return ret ? GRAPH_SUCCESS : GRAPH_FAILED;
  378. }
  379. graphStatus GeTensorDesc::GetShapeRange(std::vector<std::pair<int64_t, int64_t>> &range) const {
  380. std::vector<vector<int64_t>> shape_range;
  381. (void)AttrUtils::GetListListInt(this, TENSOR_UTILS_SHAPE_RANGE, shape_range);
  382. for (const auto &ele : shape_range) {
  383. // here must be only two elemenet because pair
  384. if (ele.size() != 2) {
  385. GELOGE(GRAPH_FAILED, "shape_range must contain only 2 value but really is %lu", ele.size());
  386. return GRAPH_FAILED;
  387. }
  388. std::pair<int64_t, int64_t> pair({ele[0], ele[1]});
  389. range.emplace_back(pair);
  390. }
  391. return GRAPH_SUCCESS;
  392. }
  393. GeShape GeTensorDesc::GetOriginShape() const {
  394. vector<int64_t> origin_shape;
  395. if (!AttrUtils::GetListInt(this, TENSOR_UTILS_ORIGIN_SHAPE, origin_shape)) {
  396. return GeShape();
  397. }
  398. return GeShape(origin_shape);
  399. }
  400. void GeTensorDesc::SetOriginShape(const GeShape &origin_shape) {
  401. std::vector<int64_t> origin_shape_tmp = origin_shape.GetDims();
  402. (void)AttrUtils::SetListInt(this, TENSOR_UTILS_ORIGIN_SHAPE, origin_shape_tmp);
  403. }
  404. Format GeTensorDesc::GetFormat() const {
  405. auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg();
  406. if (tensor_descriptor_msg != nullptr) {
  407. return TypeUtils::SerialStringToFormat(tensor_descriptor_msg->layout());
  408. }
  409. return FORMAT_RESERVED;
  410. }
  411. void GeTensorDesc::SetFormat(Format format) {
  412. auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg();
  413. if (tensor_descriptor_msg != nullptr) {
  414. tensor_descriptor_msg->set_layout(TypeUtils::FormatToSerialString(format));
  415. }
  416. }
  417. void GeTensorDesc::SetName(const std::string &name) {
  418. auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg();
  419. if (tensor_descriptor_msg != nullptr) {
  420. tensor_descriptor_msg->set_name(name);
  421. return;
  422. }
  423. GELOGW("[SetName]tensor_descriptor_msg is null.");
  424. }
  425. const std::string GeTensorDesc::GetName() const {
  426. auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg();
  427. if (tensor_descriptor_msg != nullptr) {
  428. return tensor_descriptor_msg->name();
  429. }
  430. GELOGW("[GetName]tensor_descriptor_msg is null.");
  431. return "";
  432. }
  433. Format GeTensorDesc::GetOriginFormat() const {
  434. std::string origin_format_str;
  435. if (!AttrUtils::GetStr(this, TENSOR_UTILS_ORIGIN_FORMAT, origin_format_str)) {
  436. // Can not get the certificate and it's not set, return directly
  437. return FORMAT_RESERVED;
  438. }
  439. if (origin_format_str == "RESERVED") {
  440. return FORMAT_RESERVED;
  441. }
  442. return TypeUtils::SerialStringToFormat(origin_format_str);
  443. }
  444. void GeTensorDesc::SetOriginFormat(Format origin_format) {
  445. std::string origin_format_str = "RESERVED";
  446. if (origin_format != FORMAT_RESERVED) {
  447. origin_format_str = TypeUtils::FormatToSerialString(origin_format);
  448. }
  449. (void)AttrUtils::SetStr(this, TENSOR_UTILS_ORIGIN_FORMAT, origin_format_str);
  450. }
  451. DataType GeTensorDesc::GetDataType() const {
  452. auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg();
  453. if (tensor_descriptor_msg == nullptr) {
  454. return DT_UNDEFINED;
  455. }
  456. auto &attr_map = *(tensor_descriptor_msg->mutable_attr());
  457. // Data type
  458. auto it_data_type = attr_map.find(kKeyDataTypeSelfDefined);
  459. if (it_data_type != attr_map.end()) {
  460. int64_t data_type_proto = it_data_type->second.i();
  461. for (auto it : kDataTypeSelfDefinedMap) {
  462. if (it.second == data_type_proto) {
  463. return it.first;
  464. }
  465. }
  466. } else {
  467. auto data_type_proto = tensor_descriptor_msg->dtype();
  468. for (auto it : kDataTypeMap) {
  469. if (it.second == data_type_proto) {
  470. return it.first;
  471. }
  472. }
  473. }
  474. return DT_UNDEFINED;
  475. }
  476. void GeTensorDesc::SetDataType(DataType dataType) {
  477. auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg();
  478. if (tensor_descriptor_msg == nullptr) {
  479. return;
  480. }
  481. auto &attr_maps = *(tensor_descriptor_msg->mutable_attr());
  482. (void)attr_maps.erase(kKeyDataTypeSelfDefined);
  483. // Data type
  484. auto it = kDataTypeMap.find(dataType);
  485. if (it != kDataTypeMap.end()) {
  486. tensor_descriptor_msg->set_dtype(it->second);
  487. return;
  488. }
  489. auto it2 = kDataTypeSelfDefinedMap.find(dataType);
  490. if (it2 != kDataTypeSelfDefinedMap.end()) {
  491. attr_maps[kKeyDataTypeSelfDefined].set_i(it2->second);
  492. }
  493. }
  494. void GeTensorDesc::SetOriginDataType(DataType origin_data_type) {
  495. std::string origin_data_type_str = "RESERVED";
  496. if (origin_data_type != DT_UNDEFINED) {
  497. origin_data_type_str = TypeUtils::DataTypeToSerialString(origin_data_type);
  498. }
  499. (void)AttrUtils::SetStr(this, TENSOR_UTILS_ORIGIN_DATA_TYPE, origin_data_type_str);
  500. }
  501. DataType GeTensorDesc::GetOriginDataType() const {
  502. std::string origin_data_type_str;
  503. if (!AttrUtils::GetStr(this, TENSOR_UTILS_ORIGIN_DATA_TYPE, origin_data_type_str)) {
  504. return DT_UNDEFINED;
  505. }
  506. if (origin_data_type_str == "RESERVED") {
  507. return DT_UNDEFINED;
  508. }
  509. return TypeUtils::SerialStringToDataType(origin_data_type_str);
  510. }
  511. std::vector<uint32_t> GeTensorDesc::GetRefPortIndex() const {
  512. vector<uint32_t> ref_port_index;
  513. (void)AttrUtils::GetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, ref_port_index);
  514. return ref_port_index;
  515. }
  516. void GeTensorDesc::SetRefPortByIndex(const std::vector<uint32_t> &index) {
  517. (void)AttrUtils::SetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, index);
  518. }
  519. graphStatus GeTensorDesc::IsValid() const {
  520. auto dtype = this->GetDataType();
  521. auto format = this->GetFormat();
  522. if (dtype == DT_UNDEFINED && format == FORMAT_RESERVED) {
  523. return GRAPH_PARAM_INVALID;
  524. }
  525. return GRAPH_SUCCESS;
  526. }
  527. GeTensorDesc GeTensorDesc::Clone() const { return *this; }
  528. GeTensorDesc &GeTensorDesc::operator=(const GeTensorDesc &desc) {
  529. if (&desc != this) {
  530. tensor_descriptor_.CopyValueFrom(desc.tensor_descriptor_);
  531. }
  532. return *this;
  533. }
  534. GeTensorDesc &GeTensorDesc::operator=(GeTensorDesc &&desc) {
  535. if (&desc != this) {
  536. tensor_descriptor_.CopyValueFrom(std::move(desc.tensor_descriptor_));
  537. }
  538. return *this;
  539. }
  540. GeTensor::GeTensor::GeTensor() {
  541. tensor_def_.InitDefault();
  542. // Default init desc
  543. DescReference() = GeTensorDesc();
  544. }
  545. GeTensor::GeTensor(const GeTensorDesc &tensor_desc) : GeTensor() { DescReference() = tensor_desc; }
  546. GeTensor::GeTensor(const GeTensorDesc &tensor_desc, const vector<uint8_t> &data) : GeTensor() {
  547. DescReference() = tensor_desc;
  548. auto proto_msg = tensor_def_.GetProtoMsg();
  549. if (proto_msg != nullptr) {
  550. proto_msg->set_data(data.data(), data.size());
  551. }
  552. }
  553. GeTensor::GeTensor(const GeTensorDesc &tensor_desc, const uint8_t *data, size_t size) : GeTensor() {
  554. DescReference() = tensor_desc;
  555. auto proto_msg = tensor_def_.GetProtoMsg();
  556. if (proto_msg != nullptr && data != nullptr) {
  557. proto_msg->set_data(data, size);
  558. }
  559. }
  560. GeTensor::GeTensor(GeTensorDesc &&tensor_desc, vector<uint8_t> &&data) : GeTensor() {
  561. DescReference() = std::move(tensor_desc);
  562. auto proto_msg = tensor_def_.GetProtoMsg();
  563. if (proto_msg != nullptr) {
  564. proto_msg->set_data(data.data(), data.size());
  565. }
  566. }
  567. GeTensor::GeTensor(const GeTensorDesc &tensor_desc, const Buffer &data) : GeTensor() {
  568. DescReference() = tensor_desc;
  569. auto proto_msg = tensor_def_.GetProtoMsg();
  570. if (proto_msg != nullptr) {
  571. if (data.size() == 0) {
  572. GELOGI("GetSize res is 0.");
  573. }
  574. if (data.data() == nullptr) {
  575. GELOGI("data addr is null.");
  576. }
  577. proto_msg->set_data(data.GetData(), data.GetSize());
  578. }
  579. }
  580. GeTensor::GeTensor(const ProtoMsgOwner &proto_owner, proto::TensorDef *proto_msg)
  581. : tensor_def_(proto_owner, proto_msg) {}
  582. GeTensorDesc GeTensor::GetTensorDesc() const { return DescReference(); }
  583. GeTensorDesc &GeTensor::MutableTensorDesc() { return DescReference(); }
  584. GeTensorDesc &GeTensor::DescReference() const {
  585. if (tensor_def_.GetProtoMsg() != nullptr) {
  586. GeTensorDesc tensor_desc(tensor_def_.GetProtoOwner(), tensor_def_.GetProtoMsg()->mutable_desc());
  587. __desc_.RefTo(tensor_desc);
  588. } else {
  589. GeTensorDesc tensor_desc(tensor_def_.GetProtoOwner(), nullptr);
  590. __desc_.RefTo(tensor_desc);
  591. }
  592. return __desc_;
  593. }
  594. void GeTensor::SetTensorDesc(const GeTensorDesc &tensor_desc) { DescReference() = tensor_desc; }
  595. const Buffer GeTensor::GetData() const {
  596. auto proto_msg = tensor_def_.GetProtoMsg();
  597. if (proto_msg != nullptr) {
  598. return Buffer(tensor_def_.GetProtoOwner(), proto_msg->mutable_data());
  599. }
  600. return Buffer();
  601. }
  602. Buffer GeTensor::MutableData() {
  603. auto proto_msg = tensor_def_.GetProtoMsg();
  604. if (proto_msg != nullptr) {
  605. return Buffer(tensor_def_.GetProtoOwner(), proto_msg->mutable_data());
  606. }
  607. return Buffer();
  608. }
  609. graphStatus GeTensor::SetData(vector<uint8_t> &&data) {
  610. auto proto_msg = tensor_def_.GetProtoMsg();
  611. GE_CHECK_NOTNULL(proto_msg);
  612. proto_msg->set_data(data.data(), data.size());
  613. return GRAPH_SUCCESS;
  614. }
  615. graphStatus GeTensor::SetData(const vector<uint8_t> &data) {
  616. auto proto_msg = tensor_def_.GetProtoMsg();
  617. GE_CHECK_NOTNULL(proto_msg);
  618. proto_msg->set_data(data.data(), data.size());
  619. return GRAPH_SUCCESS;
  620. }
  621. graphStatus GeTensor::SetData(const uint8_t *data, size_t size) {
  622. if (size > 0) {
  623. GE_CHECK_NOTNULL(data);
  624. }
  625. auto proto_msg = tensor_def_.GetProtoMsg();
  626. GE_CHECK_NOTNULL(proto_msg);
  627. proto_msg->set_data(data, size);
  628. return GRAPH_SUCCESS;
  629. }
  630. graphStatus GeTensor::SetData(const Buffer &data) {
  631. auto proto_msg = tensor_def_.GetProtoMsg();
  632. GE_CHECK_NOTNULL(proto_msg);
  633. if (data.size() == 0) {
  634. GELOGI("GetSize res is 0.");
  635. }
  636. if (data.data() == nullptr) {
  637. GELOGI("data addr is null.");
  638. }
  639. proto_msg->set_data(data.data(), data.size());
  640. return GRAPH_SUCCESS;
  641. }
  642. GeTensor GeTensor::Clone() const {
  643. GeTensor tensor;
  644. tensor.tensor_def_.CopyValueFrom(tensor_def_);
  645. return tensor;
  646. }
  647. GeTensor::GeTensor(const GeTensor &other) { tensor_def_ = other.tensor_def_; }
  648. GeTensor &GeTensor::operator=(const GeTensor &other) {
  649. if (&other != this) {
  650. tensor_def_ = other.tensor_def_;
  651. }
  652. return *this;
  653. }
  654. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetSize(const GeTensorDesc &tensor_desc,
  655. int64_t &size) {
  656. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  657. GE_CHECK_NOTNULL(tensor_descriptor_msg);
  658. size = static_cast<int64_t>(tensor_descriptor_msg->size());
  659. return GRAPH_SUCCESS;
  660. }
  661. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetSize(GeTensorDesc &tensor_desc, int64_t size) {
  662. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  663. if (tensor_descriptor_msg != nullptr) {
  664. tensor_descriptor_msg->set_size(size);
  665. }
  666. }
  667. uint32_t TensorUtils::GetWeightSize(const GeTensorDesc &tensor_desc) {
  668. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  669. if (tensor_descriptor_msg != nullptr) {
  670. return static_cast<uint32_t>(tensor_descriptor_msg->weight_size());
  671. }
  672. return 0;
  673. }
  674. uint32_t TensorUtils::GetWeightSize(const GeTensor &tensor) { return GetWeightSize(tensor.GetTensorDesc()); }
  675. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t TensorUtils::GetWeightSize(const ConstGeTensorPtr &tensor_ptr) {
  676. if (tensor_ptr == nullptr) {
  677. return 0;
  678. }
  679. return GetWeightSize(*tensor_ptr);
  680. }
  681. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint8_t *TensorUtils::GetWeightAddr(const ConstGeTensorPtr &tensor_ptr,
  682. uint8_t *base) {
  683. if (tensor_ptr == nullptr) {
  684. GELOGE(GRAPH_FAILED, "tensor_ptr is null.");
  685. return nullptr;
  686. }
  687. return GetWeightAddr(*tensor_ptr, base);
  688. }
  689. uint8_t *TensorUtils::GetWeightAddr(const GeTensor &tensor, uint8_t *base) {
  690. if (base == nullptr) {
  691. GELOGE(GRAPH_FAILED, "base is null.");
  692. return nullptr;
  693. }
  694. int64_t weight_data_offset = 0;
  695. if (GetDataOffset(tensor.GetTensorDesc(), weight_data_offset) != GRAPH_SUCCESS) return nullptr;
  696. if (weight_data_offset == 0) {
  697. // The weight of offset 0 is still in const op, still get from ATTR_NAME_WEIGHTS.
  698. return const_cast<uint8_t *>(tensor.GetData().data());
  699. }
  700. return base + weight_data_offset;
  701. }
  702. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetWeightSize(GeTensorDesc &tensor_desc,
  703. uint32_t size) {
  704. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  705. if (tensor_descriptor_msg != nullptr) {
  706. tensor_descriptor_msg->set_weight_size(size);
  707. }
  708. }
  709. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetReuseInput(const GeTensorDesc &tensor_desc,
  710. bool &flag) {
  711. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  712. GE_CHECK_NOTNULL(tensor_descriptor_msg);
  713. flag = tensor_descriptor_msg->reuse_input();
  714. return GRAPH_SUCCESS;
  715. }
  716. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetReuseInput(GeTensorDesc &tensor_desc, bool flag) {
  717. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  718. if (tensor_descriptor_msg != nullptr) {
  719. tensor_descriptor_msg->set_reuse_input(flag);
  720. }
  721. }
  722. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetOutputTensor(const GeTensorDesc &tensor_desc,
  723. bool &flag) {
  724. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  725. GE_CHECK_NOTNULL(tensor_descriptor_msg);
  726. flag = tensor_descriptor_msg->output_tensor();
  727. return GRAPH_SUCCESS;
  728. }
  729. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetOutputTensor(GeTensorDesc &tensor_desc, bool flag) {
  730. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  731. if (tensor_descriptor_msg != nullptr) {
  732. tensor_descriptor_msg->set_output_tensor(flag);
  733. }
  734. }
  735. static map<uint32_t, string> device_to_str_map{
  736. {0, "NPU"}, {1, "CPU"},
  737. };
  738. static map<string, uint32_t> str_to_device_map{
  739. {"NPU", 0}, {"CPU", 1},
  740. };
  741. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetDeviceType(const GeTensorDesc &tensor_desc,
  742. DeviceType &type) {
  743. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  744. GE_CHECK_NOTNULL(tensor_descriptor_msg);
  745. string type_str = tensor_descriptor_msg->device_type();
  746. type = DeviceType(str_to_device_map[type_str]);
  747. return GRAPH_SUCCESS;
  748. }
  749. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetDeviceType(GeTensorDesc &tensor_desc,
  750. DeviceType type) {
  751. auto type_str = device_to_str_map[type];
  752. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  753. if (tensor_descriptor_msg != nullptr) {
  754. tensor_descriptor_msg->set_device_type(type_str);
  755. }
  756. }
  757. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetInputTensor(const GeTensorDesc &tensor_desc,
  758. bool &flag) {
  759. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  760. GE_CHECK_NOTNULL(tensor_descriptor_msg);
  761. flag = tensor_descriptor_msg->input_tensor();
  762. return GRAPH_SUCCESS;
  763. }
  764. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetInputTensor(GeTensorDesc &tensor_desc, bool flag) {
  765. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  766. if (tensor_descriptor_msg != nullptr) {
  767. tensor_descriptor_msg->set_input_tensor(flag);
  768. }
  769. }
  770. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetRealDimCnt(const GeTensorDesc &tensor_desc,
  771. uint32_t &cnt) {
  772. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  773. GE_CHECK_NOTNULL(tensor_descriptor_msg);
  774. cnt = static_cast<uint32_t>(tensor_descriptor_msg->real_dim_cnt());
  775. return GRAPH_SUCCESS;
  776. }
  777. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetRealDimCnt(GeTensorDesc &tensor_desc,
  778. uint32_t cnt) {
  779. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  780. if (tensor_descriptor_msg != nullptr) {
  781. tensor_descriptor_msg->set_real_dim_cnt(cnt);
  782. }
  783. }
  784. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  785. TensorUtils::GetReuseInputIndex(const GeTensorDesc &tensor_desc, uint32_t &idx) {
  786. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  787. GE_CHECK_NOTNULL(tensor_descriptor_msg);
  788. idx = static_cast<uint32_t>(tensor_descriptor_msg->reuse_input_index());
  789. return GRAPH_SUCCESS;
  790. }
  791. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetReuseInputIndex(GeTensorDesc &tensor_desc,
  792. uint32_t idx) {
  793. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  794. if (tensor_descriptor_msg != nullptr) {
  795. tensor_descriptor_msg->set_reuse_input_index(idx);
  796. }
  797. }
  798. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetDataOffset(const GeTensorDesc &tensor_desc,
  799. int64_t &offset) {
  800. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  801. if (tensor_descriptor_msg != nullptr) {
  802. offset = tensor_descriptor_msg->data_offset();
  803. return GRAPH_SUCCESS;
  804. } else {
  805. GELOGW("tensor_descriptor_msg is nullptr.");
  806. return GRAPH_FAILED;
  807. }
  808. }
  809. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetDataOffset(GeTensorDesc &tensor_desc,
  810. int64_t offset) {
  811. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  812. if (tensor_descriptor_msg != nullptr) {
  813. tensor_descriptor_msg->set_data_offset(offset);
  814. }
  815. }
  816. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetCmpsSize(const GeTensorDesc &tensor_desc,
  817. uint32_t &cmp_size) {
  818. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  819. if (tensor_descriptor_msg != nullptr) {
  820. cmp_size = static_cast<uint32_t>(tensor_descriptor_msg->cmps_size());
  821. }
  822. return GRAPH_SUCCESS;
  823. }
  824. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetCmpsSize(GeTensorDesc &tensor_desc,
  825. uint32_t cmp_size) {
  826. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  827. if (tensor_descriptor_msg != nullptr) {
  828. tensor_descriptor_msg->set_cmps_size(cmp_size);
  829. }
  830. }
  831. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetCmpsTab(const GeTensorDesc &tensor_desc,
  832. vector<uint8_t> &vec) {
  833. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  834. if (tensor_descriptor_msg != nullptr) {
  835. string str = tensor_descriptor_msg->cmps_tab();
  836. vec.assign(str.begin(), str.end());
  837. }
  838. return GRAPH_SUCCESS;
  839. }
  840. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetCmpsTab(GeTensorDesc &tensor_desc,
  841. const uint8_t *data, size_t size) {
  842. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  843. if (tensor_descriptor_msg != nullptr) {
  844. GE_CHK_BOOL_EXEC(data != nullptr, return, "data is null.");
  845. string str((const char *)data, size);
  846. tensor_descriptor_msg->set_cmps_tab(str);
  847. }
  848. }
  849. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  850. TensorUtils::GetCmpsTabOffset(const GeTensorDesc &tensor_desc, int64_t &tab_offset) {
  851. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  852. if (tensor_descriptor_msg != nullptr) {
  853. tab_offset = tensor_descriptor_msg->cmps_tab_offset();
  854. }
  855. return GRAPH_SUCCESS;
  856. }
  857. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetCmpsTabOffset(GeTensorDesc &tensor_desc,
  858. int64_t tab_offset) {
  859. auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
  860. if (tensor_descriptor_msg != nullptr) {
  861. tensor_descriptor_msg->set_cmps_tab_offset(tab_offset);
  862. }
  863. }
  864. graphStatus TensorUtils::GetCmpsInfo(const GeTensorDesc &tensor_desc, CompressInfo &info) {
  865. GeAttrValue attr_value;
  866. if (tensor_desc.GetAttr(TENSOR_UTILS_CMPSINFO, attr_value) != GRAPH_SUCCESS) {
  867. return GRAPH_FAILED;
  868. }
  869. return attr_value.GetValue(info);
  870. }
  871. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetCmpsInfo(GeTensorDesc &tensor_desc,
  872. const CompressInfo &info) {
  873. (void)tensor_desc.SetAttr(TENSOR_UTILS_CMPSINFO, GeAttrValue::CreateFrom(info));
  874. }
  875. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool TensorUtils::HasAlloffsetQuantizeInfo(
  876. const GeTensorDesc &tensor_desc) {
  877. return tensor_desc.HasAttr(TENSOR_UTILS_ALLOFFSET_QUANTIZE_INFO);
  878. }
  879. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  880. TensorUtils::GetAlloffsetQuantizeInfo(const GeTensorDesc &tensor_desc, AllOffsetQuantizeInfo &info) {
  881. GeAttrValue attr_value;
  882. if (tensor_desc.GetAttr(TENSOR_UTILS_ALLOFFSET_QUANTIZE_INFO, attr_value) != GRAPH_SUCCESS) {
  883. GELOGW("get attr alloffset_quantize_info fail.");
  884. }
  885. return attr_value.GetValue(info);
  886. }
  887. void TensorUtils::SetAlloffsetQuantizeInfo(GeTensorDesc &tensor_desc, const AllOffsetQuantizeInfo &info) {
  888. (void)tensor_desc.SetAttr(TENSOR_UTILS_ALLOFFSET_QUANTIZE_INFO, GeAttrValue::CreateFrom(info));
  889. }
  890. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetRC(const GeTensorDesc &tensor_desc,
  891. uint32_t &rc) {
  892. return AttrUtils::GetInt(&tensor_desc, TENSOR_UTILS_RC, rc) ? GRAPH_SUCCESS : GRAPH_FAILED;
  893. }
  894. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetRC(GeTensorDesc &tensor_desc, uint32_t rc) {
  895. (void)AttrUtils::SetInt(&tensor_desc, TENSOR_UTILS_RC, rc);
  896. }
  897. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示