You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

infer_value_range_pass.cc 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/infer_value_range_pass.h"
  17. #include "common/formats/utils/formats_trans_utils.h"
  18. #include "common/util/error_manager/error_manager.h"
  19. #include "framework/common/debug/ge_log.h"
  20. #include "graph/debug/ge_attr_define.h"
  21. #include "graph/operator_factory_impl.h"
  22. #include "graph/passes/constant_folding_pass.h"
  23. #include "graph/utils/type_utils.h"
  24. #include "common/ge/ge_util.h"
  25. using std::unique_ptr;
  26. namespace ge {
  27. namespace {
  28. #define GET_DATA_BY_DTYPE(DTYPE, TYPE) \
  29. case (DTYPE): \
  30. ConstructValueRange<TYPE>(lower_boundary_tensor, upper_boundary_tensor, output_tensor_value_range); \
  31. break;
  32. void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) {
  33. std::vector<std::pair<int64_t, int64_t>> shape_range;
  34. (void)desc->GetShapeRange(shape_range);
  35. desc_str += formats::RangeToString(shape_range);
  36. shape_range.clear();
  37. (void)desc->GetOriginShapeRange(shape_range);
  38. desc_str += ",";
  39. desc_str += formats::RangeToString(shape_range);
  40. shape_range.clear();
  41. }
  42. Status RunCpuKernelForValueRange(NodePtr &node, const vector<ConstGeTensorPtr> &inputs,
  43. std::vector<GeTensorPtr> &outputs) {
  44. // RunOpKernelWithCheck, RunOpKernel for test
  45. auto ret = ConstantFoldingPass::RunOpKernel(node, inputs, outputs);
  46. if (ret != SUCCESS) {
  47. auto op_kernel = folding_pass::GetKernelByType(node);
  48. if (op_kernel == nullptr) {
  49. GELOGW("Calculate value range failed, no op kernel for node %s type %s", node->GetName().c_str(),
  50. node->GetType().c_str());
  51. return NOT_CHANGED;
  52. }
  53. ret = op_kernel->Compute(node->GetOpDesc(), inputs, outputs);
  54. if (ret != SUCCESS) {
  55. GELOGW("Calculate value range failed, node %s run cpu kernel failed.", node->GetName().c_str());
  56. return NOT_CHANGED;
  57. }
  58. }
  59. GELOGI("Node %s type %s, run cpu kernel success.", node->GetName().c_str(), node->GetType().c_str());
  60. return SUCCESS;
  61. }
  62. } // namespace
  63. graphStatus InferValueRangePass::Infer(NodePtr &node) {
  64. auto infer_value_range_param = OperatorFactoryImpl::GetInferValueRangePara(node->GetType());
  65. // Use registered func to calculate value range
  66. if (!infer_value_range_param.use_cpu_kernel) {
  67. if (infer_value_range_param.infer_value_func == nullptr) {
  68. GELOGW("The registered func of node %s to infer value range is nullptr.", node->GetName().c_str());
  69. return GRAPH_NOT_CHANGED;
  70. }
  71. Operator op = OpDescUtils::CreateOperatorFromNode(node);
  72. auto ret = node->GetOpDesc()->CallInferValueRangeFunc(op);
  73. if (ret != GRAPH_SUCCESS) {
  74. GELOGW("Node %s call infer value range func failed, ret: %u.", node->GetName().c_str(), ret);
  75. return GRAPH_NOT_CHANGED;
  76. }
  77. GELOGD("Node %s infer value range func succeed by registered func.", node->GetName().c_str());
  78. return GRAPH_SUCCESS;
  79. }
  80. // Deal with scenes with unknown value range
  81. bool has_unknown_value_range = false;
  82. bool has_zero_in_value_range = false;
  83. CheckInputValueRange(node, has_unknown_value_range, has_zero_in_value_range);
  84. if (has_unknown_value_range) {
  85. if (has_zero_in_value_range) {
  86. // When there is zero in input value range, it is unreasonable to always set output value range {1:-1}.
  87. GELOGW("Node %s has -1 and 0 in value range, skip setting value range.", node->GetName().c_str());
  88. return GRAPH_NOT_CHANGED;
  89. }
  90. GELOGI("Node %s has unknown value range in input tensors, set value range {1:-1}, and skip cpu kernel.",
  91. node->GetName().c_str());
  92. return GenerateWorstValueRange(node);
  93. }
  94. // Use CPU kernel func to calculate value range
  95. auto ret = ConstructInputAndInferValueRange(node);
  96. if (ret != GRAPH_SUCCESS) {
  97. GELOGW("Use CPU kernel to calculate value range failed. node: %s, ret: %u", node->GetName().c_str(), ret);
  98. return GRAPH_NOT_CHANGED;
  99. }
  100. GELOGD("Node %s infer value range func succeed by running cpu kernel.", node->GetName().c_str());
  101. return GRAPH_SUCCESS;
  102. }
  103. std::string InferValueRangePass::SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const {
  104. std::stringstream ss;
  105. ss << "[";
  106. ss << "(shape:[" << tensor_desc->MutableShape().ToString() << "]),";
  107. string range_str;
  108. SerialShapeRange(tensor_desc, range_str);
  109. ss << "(shape_range:" << range_str << "),";
  110. std::vector<std::pair<int64_t, int64_t>> value_range;
  111. (void)tensor_desc->GetValueRange(value_range);
  112. string value_range_str = formats::RangeToString(value_range);
  113. ss << "(value_range:" << value_range_str << ")]";
  114. return ss.str();
  115. }
  116. bool InferValueRangePass::NeedInfer(const NodePtr &node) const {
  117. auto infer_value_range_param = OperatorFactoryImpl::GetInferValueRangePara(node->GetType());
  118. if (!infer_value_range_param.is_initialized) {
  119. GELOGD("Node %s does not register func to infer value range, skip infer_value_range_pass.",
  120. node->GetName().c_str());
  121. return false;
  122. }
  123. if (infer_value_range_param.when_call == INPUT_IS_DYNAMIC) {
  124. // Only do infer for node that all inputs are dynamic, such as shape
  125. if (InputIsDynamic(node)) {
  126. return true;
  127. }
  128. GELOGD("Node %s register func to infer value range and when_call is INPUT_IS_DYNAMIC, but check input failed.",
  129. node->GetName().c_str());
  130. } else if (infer_value_range_param.when_call == INPUT_HAS_VALUE_RANGE) {
  131. // Only do infer for node that all inputs have value_range or node type of inputs is constant/const
  132. if (InputIsConstOrHasValueRange(node)) {
  133. return true;
  134. }
  135. GELOGD("Node %s register func to infer value range and when_call is INPUT_HAS_VALUE_RANGE, but check input failed.",
  136. node->GetName().c_str());
  137. }
  138. GELOGD("Node %s does not need to infer value range, skip infer_value_range_pass.", node->GetName().c_str());
  139. return false;
  140. }
  141. bool InferValueRangePass::InputIsDynamic(const NodePtr &node) const{
  142. bool input_is_dynamic = false;
  143. auto cur_op_desc = node->GetOpDesc();
  144. for (const auto &input_desc : cur_op_desc->GetAllInputsDescPtr()) {
  145. auto dims = input_desc->GetShape().GetDims();
  146. for (auto dim : dims) {
  147. if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) {
  148. input_is_dynamic = true;
  149. break;
  150. }
  151. }
  152. }
  153. return input_is_dynamic;
  154. }
  155. bool InferValueRangePass::InputIsConstOrHasValueRange(const NodePtr &node) const {
  156. bool input_is_const_or_has_value_range = true;
  157. auto cur_op_desc = node->GetOpDesc();
  158. auto in_data_anchors = node->GetAllInDataAnchors();
  159. for (size_t i = 0; i < in_data_anchors.size(); ++i) {
  160. auto peer_out_anchor = in_data_anchors.at(i)->GetPeerOutAnchor();
  161. if (peer_out_anchor == nullptr) {
  162. continue;
  163. }
  164. auto peer_node = peer_out_anchor->GetOwnerNode();
  165. if (peer_node == nullptr || peer_node->GetOpDesc() == nullptr) {
  166. continue;
  167. }
  168. if ((peer_node->GetType() == CONSTANT) || (peer_node->GetType() == CONSTANTOP)) {
  169. continue;
  170. }
  171. const auto &input_desc = cur_op_desc->GetInputDesc(i);
  172. std::vector<std::pair<int64_t, int64_t>> value_range;
  173. (void)input_desc.GetValueRange(value_range);
  174. if (value_range.empty()) {
  175. GELOGD("Node %s input %zu does not have value range, skip infer_value_range_pass for current node.",
  176. node->GetName().c_str(), i);
  177. input_is_const_or_has_value_range = false;
  178. break;
  179. }
  180. }
  181. return input_is_const_or_has_value_range;
  182. }
  183. void InferValueRangePass::CheckInputValueRange(const NodePtr &node, bool &has_unknown_value_range,
  184. bool &has_zero_in_value_range) const {
  185. has_unknown_value_range = false;
  186. has_zero_in_value_range = false;
  187. auto cur_op_desc = node->GetOpDesc();
  188. for (const auto &input_desc : cur_op_desc->GetAllInputsDescPtr()) {
  189. std::vector<std::pair<int64_t, int64_t>> input_desc_value_range;
  190. input_desc->GetValueRange(input_desc_value_range);
  191. if (!input_desc_value_range.empty()) {
  192. for (const auto &range : input_desc_value_range) {
  193. if (range.first == 0 || range.second == 0) {
  194. GELOGD("Node %s input tensors have zero in value range %s.", node->GetName().c_str(),
  195. formats::RangeToString(input_desc_value_range).c_str());
  196. has_zero_in_value_range = true;
  197. }
  198. if (range.first == -1 || range.second == -1) {
  199. GELOGD("Node %s input tensors have unknown value range, value range is %s.", node->GetName().c_str(),
  200. formats::RangeToString(input_desc_value_range).c_str());
  201. has_unknown_value_range = true;
  202. }
  203. }
  204. }
  205. }
  206. }
  207. graphStatus InferValueRangePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) {
  208. if (src == nullptr || dst == nullptr) {
  209. REPORT_CALL_ERROR("E19999", "While updating tensor desc, input desc is null.");
  210. GELOGE(GRAPH_FAILED, "[Param][check] While updating tensor desc, input desc is null.");
  211. return GRAPH_FAILED;
  212. }
  213. changed = false;
  214. std::vector<std::pair<int64_t, int64_t>> src_value_range;
  215. std::vector<std::pair<int64_t, int64_t>> dst_value_range;
  216. (void)src->GetValueRange(src_value_range);
  217. (void)dst->GetValueRange(dst_value_range);
  218. if (src_value_range != dst_value_range) {
  219. GELOGD("While updating tensor desc, value range has been changed, src value range: %s, dst value range: %s.",
  220. formats::RangeToString(src_value_range).c_str(), formats::RangeToString(dst_value_range).c_str());
  221. changed = true;
  222. }
  223. dst->SetValueRange(src_value_range);
  224. return GRAPH_SUCCESS;
  225. }
  226. graphStatus InferValueRangePass::UpdateOutputFromSubgraphs(const std::vector<GeTensorDescPtr> &src,
  227. GeTensorDescPtr &dst) {
  228. std::vector<std::pair<int64_t, int64_t>> ref_out_tensor_value_range;
  229. auto ref_out_tensor = src.at(0);
  230. (void)ref_out_tensor->GetValueRange(ref_out_tensor_value_range);
  231. for (auto &ref_tensor : src) {
  232. std::vector<std::pair<int64_t, int64_t>> ref_tensor_value_range;
  233. (void)ref_tensor->GetValueRange(ref_tensor_value_range);
  234. if (ref_tensor_value_range.size() != ref_out_tensor_value_range.size()) {
  235. GELOGD("Update TensorDesc %s failed, rank of value ranges %s and %s are not the same, skip value range refresh.",
  236. dst->GetName().c_str(), formats::RangeToString(ref_out_tensor_value_range).c_str(),
  237. formats::RangeToString(ref_tensor_value_range).c_str());
  238. return GRAPH_SUCCESS;
  239. }
  240. for (size_t j = 0; j < ref_out_tensor_value_range.size(); j++) {
  241. if ((ref_out_tensor_value_range.at(j).first != ref_tensor_value_range.at(j).first) ||
  242. (ref_out_tensor_value_range.at(j).second != ref_tensor_value_range.at(j).second)) {
  243. ref_out_tensor_value_range[j] = std::make_pair(1, -1);
  244. }
  245. }
  246. }
  247. GELOGD("While updating output desc from subgraphs, set parent node desc value range %s.",
  248. formats::RangeToString(ref_out_tensor_value_range).c_str());
  249. dst->SetValueRange(ref_out_tensor_value_range);
  250. return GRAPH_SUCCESS;
  251. }
  252. graphStatus InferValueRangePass::UpdateOutputFromSubgraphsForMultiDims(const std::vector<GeTensorDescPtr> &src,
  253. GeTensorDescPtr &dst) {
  254. REPORT_INNER_ERROR("E19999",
  255. "Update TensorDesc %s failed. In dynamic multi-dims size scene, there should be no value range.",
  256. dst->GetName().c_str());
  257. GELOGE(GRAPH_FAILED,
  258. "[Update][TensorDesc] %s failed. In dynamic multi-dims size scene, there should be no value range.",
  259. dst->GetName().c_str());
  260. return GRAPH_FAILED;
  261. }
  262. graphStatus InferValueRangePass::GenerateWorstValueRange(NodePtr &node) {
  263. GELOGI("Node %s does not run cpu kernel, because input value range has -1.", node->GetName().c_str());
  264. OpDescPtr op_desc = node->GetOpDesc();
  265. for (size_t i = 0; i < op_desc->GetOutputsSize(); ++i) {
  266. auto output_desc = op_desc->MutableOutputDesc(i);
  267. if (output_desc == nullptr) {
  268. continue;
  269. }
  270. auto output_i_shape = output_desc->GetShape();
  271. auto output_i_shape_size = output_i_shape.GetShapeSize();
  272. if (output_i_shape_size < 0) {
  273. GELOGD("Node %s output shape is unknown, cannot infer value range, shape is %s.", node->GetName().c_str(),
  274. formats::ShapeToString(output_i_shape).c_str());
  275. return GRAPH_NOT_CHANGED;
  276. }
  277. std::vector<std::pair<int64_t, int64_t>> output_i_value_range(output_i_shape_size, {1, -1});
  278. if (output_i_shape.IsScalar()) {
  279. output_i_value_range.emplace_back(1, -1);
  280. }
  281. output_desc->SetValueRange(output_i_value_range);
  282. GELOGD("Node %s output %zu shape is %s, the generated worst value range is %s.", node->GetName().c_str(), i,
  283. formats::ShapeToString(output_i_shape).c_str(), formats::RangeToString(output_i_value_range).c_str());
  284. }
  285. return GRAPH_SUCCESS;
  286. }
  287. template <typename T>
  288. graphStatus InferValueRangePass::ConstructData(const GeTensorDesc &tensor_desc, bool use_floor_value,
  289. GeTensorPtr &output_ptr) {
  290. std::vector<std::pair<int64_t, int64_t>> value_range;
  291. (void)tensor_desc.GetValueRange(value_range);
  292. size_t value_range_data_num = value_range.size();
  293. auto tensor_shape = tensor_desc.GetShape();
  294. bool value_range_and_tensor_shape_matched = true;
  295. if (tensor_shape.IsScalar()){
  296. // scalar tensor has only one value_range pair
  297. if (value_range_data_num != 1) {
  298. value_range_and_tensor_shape_matched = false;
  299. }
  300. } else {
  301. // normal tensor, value_range size is equal to tensor shape size.
  302. if (static_cast<int64_t>(value_range_data_num) != tensor_shape.GetShapeSize()) {
  303. value_range_and_tensor_shape_matched = false;
  304. }
  305. }
  306. if (!value_range_and_tensor_shape_matched) {
  307. GELOGW("Input %s value range and tensor shape do not match. Value range size is %zu, tensor shape is %s.",
  308. tensor_desc.GetName().c_str(), value_range_data_num, formats::ShapeToString(tensor_shape).c_str());
  309. return GRAPH_PARAM_INVALID;
  310. }
  311. unique_ptr<T[]> buf(new (std::nothrow) T[value_range_data_num]());
  312. if (buf == nullptr) {
  313. REPORT_INNER_ERROR("E19999", "New buf failed");
  314. GELOGE(MEMALLOC_FAILED, "New buf failed");
  315. return GRAPH_FAILED;
  316. }
  317. for (size_t j = 0; j < value_range_data_num; ++j) {
  318. auto value_range_j = use_floor_value ? value_range[j].first : value_range[j].second;
  319. buf[j] = static_cast<T>(value_range_j);
  320. }
  321. if (output_ptr->SetData(reinterpret_cast<uint8_t *>(buf.get()), value_range_data_num * sizeof(T)) != GRAPH_SUCCESS) {
  322. GELOGW("Set data failed while constructing value range input tensor.");
  323. return GRAPH_NOT_CHANGED;
  324. }
  325. return GRAPH_SUCCESS;
  326. }
  327. graphStatus InferValueRangePass::ConstructDataByType(const GeTensorDesc &tensor_desc, bool use_floor_value,
  328. GeTensorPtr &output_ptr) {
  329. graphStatus ret = GRAPH_SUCCESS;
  330. auto data_type = tensor_desc.GetDataType();
  331. output_ptr->MutableTensorDesc().SetDataType(data_type);
  332. switch (data_type) {
  333. case DT_FLOAT:
  334. ret = ConstructData<float>(tensor_desc, use_floor_value, output_ptr);
  335. break;
  336. case DT_DOUBLE:
  337. ret = ConstructData<double>(tensor_desc, use_floor_value, output_ptr);
  338. break;
  339. case DT_UINT8:
  340. ret = ConstructData<uint8_t>(tensor_desc, use_floor_value, output_ptr);
  341. break;
  342. case DT_INT8:
  343. ret = ConstructData<int8_t>(tensor_desc, use_floor_value, output_ptr);
  344. break;
  345. case DT_UINT16:
  346. ret = ConstructData<uint16_t>(tensor_desc, use_floor_value, output_ptr);
  347. break;
  348. case DT_INT16:
  349. ret = ConstructData<int16_t>(tensor_desc, use_floor_value, output_ptr);
  350. break;
  351. case DT_INT32:
  352. ret = ConstructData<int32_t>(tensor_desc, use_floor_value, output_ptr);
  353. break;
  354. case DT_INT64:
  355. ret = ConstructData<int64_t>(tensor_desc, use_floor_value, output_ptr);
  356. break;
  357. default:
  358. GELOGW("Data type:%s is not supported.", TypeUtils::DataTypeToSerialString(data_type).c_str());
  359. ret = GRAPH_PARAM_INVALID;
  360. }
  361. return ret;
  362. }
  363. vector<ConstGeTensorPtr> InferValueRangePass::ConstructInputTensors(const NodePtr &node, bool use_floor_value) {
  364. vector<ConstGeTensorPtr> input_tensors;
  365. auto cur_op_desc = node->GetOpDesc();
  366. auto in_data_anchors = node->GetAllInDataAnchors();
  367. for (size_t i = 0; i < in_data_anchors.size(); ++i) {
  368. auto peer_out_anchor = in_data_anchors.at(i)->GetPeerOutAnchor();
  369. if (peer_out_anchor == nullptr) {
  370. continue;
  371. }
  372. auto peer_node = peer_out_anchor->GetOwnerNode();
  373. if (peer_node == nullptr) {
  374. continue;
  375. }
  376. // construct input tensor by constant node
  377. if ((peer_node->GetType() == CONSTANT) || (peer_node->GetType() == CONSTANTOP)) {
  378. vector<GeTensorPtr> const_weight = OpDescUtils::MutableWeights(peer_node);
  379. if (const_weight.empty()) {
  380. GELOGW("MutableWeights failed, weight is empty, node: %s(%s)", peer_node->GetName().c_str(),
  381. peer_node->GetType().c_str());
  382. return vector<ConstGeTensorPtr>();
  383. }
  384. // const/constant op has only one weight
  385. if (const_weight.at(0) == nullptr) {
  386. GELOGW("MutableWeights failed, weight of constant is null, node name: %s(%s)",
  387. peer_node->GetName().c_str(), peer_node->GetType().c_str());
  388. return vector<ConstGeTensorPtr>();
  389. }
  390. input_tensors.push_back(const_weight.at(0));
  391. GELOGD("Node %s construct input tensor %zu by constant node.", node->GetName().c_str(), input_tensors.size());
  392. continue;
  393. }
  394. // construct input tensor by boundary of value range
  395. const auto &input_tensor_desc = cur_op_desc->GetInputDesc(i);
  396. GeTensorPtr tmp_tensor_ptr = MakeShared<GeTensor>(input_tensor_desc);
  397. if (tmp_tensor_ptr == nullptr) {
  398. REPORT_INNER_ERROR("E19999", "Make shared failed");
  399. GELOGE(MEMALLOC_FAILED, "Make shared failed");
  400. return vector<ConstGeTensorPtr>();
  401. }
  402. auto ret = ConstructDataByType(input_tensor_desc, use_floor_value, tmp_tensor_ptr);
  403. if (ret != GRAPH_SUCCESS) {
  404. GELOGW("Construct input tensor by boundary of value range failed for input %s.",
  405. input_tensor_desc.GetName().c_str());
  406. return vector<ConstGeTensorPtr>();
  407. }
  408. input_tensors.push_back(tmp_tensor_ptr);
  409. GELOGD("Node %s construct input tensor %zu by input desc value range.", node->GetName().c_str(),
  410. input_tensors.size());
  411. }
  412. return input_tensors;
  413. }
  414. graphStatus InferValueRangePass::ConstructInputAndInferValueRange(NodePtr &node) {
  415. auto inputs = ConstructInputTensors(node, true);
  416. if (inputs.empty()) {
  417. return GRAPH_PARAM_INVALID;
  418. }
  419. vector<GeTensorPtr> lower_boundary_outputs;
  420. auto ret = RunCpuKernelForValueRange(node, inputs, lower_boundary_outputs);
  421. if (ret != SUCCESS) {
  422. GELOGW("Node %s run cpu kernel failed while calculating value range.", node->GetName().c_str());
  423. return GRAPH_PARAM_INVALID;
  424. }
  425. inputs = ConstructInputTensors(node, false);
  426. if (inputs.empty()) {
  427. return GRAPH_PARAM_INVALID;
  428. }
  429. vector<GeTensorPtr> upper_boundary_outputs;
  430. ret = RunCpuKernelForValueRange(node, inputs, upper_boundary_outputs);
  431. if (ret != SUCCESS) {
  432. GELOGW("Node %s run cpu kernel failed while calculating value range.", node->GetName().c_str());
  433. return GRAPH_PARAM_INVALID;
  434. }
  435. // construct value range from output tensor
  436. OpDescPtr node_desc = node->GetOpDesc();
  437. std::vector<std::pair<int64_t, int64_t>> output_tensor_value_range;
  438. size_t node_output_desc_size = node_desc->GetOutputsSize();
  439. for (size_t i = 0; i < node_output_desc_size; ++i) {
  440. output_tensor_value_range.clear();
  441. auto output_tensor_desc = node_desc->MutableOutputDesc(i);
  442. auto output_shape_size = output_tensor_desc->GetShape().GetShapeSize();
  443. auto lower_boundary_tensor = lower_boundary_outputs[i];
  444. auto lower_boundary_shape = lower_boundary_tensor->GetTensorDesc().GetShape();
  445. auto upper_boundary_tensor = upper_boundary_outputs[i];
  446. auto upper_boundary_shape = upper_boundary_tensor->GetTensorDesc().GetShape();
  447. if (lower_boundary_shape.GetShapeSize() != output_shape_size ||
  448. upper_boundary_shape.GetShapeSize() != output_shape_size) {
  449. GELOGD(
  450. "Cpu kernel result shapes %s, %s and output shape %s do not match, can not infer value range for output %s.",
  451. formats::ShapeToString(lower_boundary_shape).c_str(), formats::ShapeToString(upper_boundary_shape).c_str(),
  452. formats::ShapeToString(output_tensor_desc->GetShape()).c_str(), output_tensor_desc->GetName().c_str());
  453. return GRAPH_PARAM_INVALID;
  454. }
  455. auto data_type = output_tensor_desc->GetDataType();
  456. switch (data_type) {
  457. GET_DATA_BY_DTYPE(DT_INT8, int8_t)
  458. GET_DATA_BY_DTYPE(DT_INT16, int16_t)
  459. GET_DATA_BY_DTYPE(DT_INT32, int32_t)
  460. GET_DATA_BY_DTYPE(DT_INT64, int64_t)
  461. GET_DATA_BY_DTYPE(DT_UINT8, uint8_t)
  462. GET_DATA_BY_DTYPE(DT_UINT16, uint16_t)
  463. GET_DATA_BY_DTYPE(DT_UINT32, uint32_t)
  464. GET_DATA_BY_DTYPE(DT_UINT64, uint64_t)
  465. GET_DATA_BY_DTYPE(DT_FLOAT, float)
  466. GET_DATA_BY_DTYPE(DT_DOUBLE, double)
  467. default:
  468. GELOGW("Data type:%s is not supported.", TypeUtils::DataTypeToSerialString(data_type).c_str());
  469. return GRAPH_PARAM_INVALID;
  470. }
  471. output_tensor_desc->SetValueRange(output_tensor_value_range);
  472. GELOGD("Node %s calculates output %zu value range %s by running cpu kernel.", node->GetName().c_str(), i,
  473. formats::RangeToString(output_tensor_value_range).c_str());
  474. }
  475. return GRAPH_SUCCESS;
  476. }
  477. template <typename T>
  478. void InferValueRangePass::ConstructValueRange(const GeTensorPtr &left_tensor, const GeTensorPtr &right_tensor,
  479. std::vector<std::pair<int64_t, int64_t>> &value_range) {
  480. auto x = reinterpret_cast<const T *>(left_tensor->GetData().GetData());
  481. auto y = reinterpret_cast<const T *>(right_tensor->GetData().GetData());
  482. if (x == nullptr || y == nullptr) {
  483. GELOGI("Output tensor of cpu kernel does not have data, no way to set value range.");
  484. return;
  485. }
  486. auto left_tensor_shape = left_tensor->GetTensorDesc().GetShape();
  487. for (auto j = 0; j < left_tensor_shape.GetShapeSize(); ++j) {
  488. auto left = static_cast<int64_t>(*(x + j));
  489. auto right = static_cast<int64_t>(*(y + j));
  490. value_range.emplace_back(left, right);
  491. }
  492. if (left_tensor_shape.IsScalar()) {
  493. GELOGD("When inferring value range, output tensors of cpu kernel are scalar tensors.");
  494. value_range.emplace_back(static_cast<int64_t>(*x), static_cast<int64_t>(*y));
  495. }
  496. }
  497. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示