You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

infer_value_range_pass.cc 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/infer_value_range_pass.h"
  17. #include "common/util/error_manager/error_manager.h"
  18. #include "framework/common/debug/ge_log.h"
  19. #include "graph/debug/ge_attr_define.h"
  20. #include "graph/operator_factory_impl.h"
  21. #include "graph/passes/folding_pass.h"
  22. #include "common/ge/ge_util.h"
  23. #include "init/gelib.h"
  24. using std::unique_ptr;
  25. namespace ge {
  26. namespace {
  27. #define GET_DATA_BY_DTYPE(DTYPE, TYPE) \
  28. case (DTYPE): \
  29. ConstructValueRange<TYPE>(lower_tensor, higher_tensor, output_tensor_value_range); \
  30. break;
  31. Status RunCpuKernelForValueRange(NodePtr &node, const vector<ConstGeTensorPtr> &inputs,
  32. std::vector<GeTensorPtr> &outputs) {
  33. // should use RunOpKernelWithCheck, RunOpKernel for ut test
  34. auto ret = FoldingPass::RunOpKernel(node, inputs, outputs);
  35. if (ret != SUCCESS) {
  36. auto op_kernel = folding_pass::GetKernelByType(node);
  37. if (op_kernel == nullptr) {
  38. GELOGE(PARAM_INVALID, "Calculate value range failed, no op kernel for node %s type %s", node->GetName().c_str(),
  39. node->GetType().c_str());
  40. return PARAM_INVALID;
  41. }
  42. ret = op_kernel->Compute(node->GetOpDesc(), inputs, outputs);
  43. if (ret != SUCCESS) {
  44. REPORT_INNER_ERROR("E19999", "Calculate for node %s(%s) failed", node->GetName().c_str(),
  45. node->GetType().c_str());
  46. GELOGE(INTERNAL_ERROR, "Calculate for node %s failed in constant folding", node->GetName().c_str());
  47. return ret;
  48. }
  49. }
  50. GELOGI("Node %s type %s, run cpu kernel success.", node->GetName().c_str(), node->GetType().c_str());
  51. return SUCCESS;
  52. }
  53. } // namespace
  54. graphStatus InferValueRangePass::Infer(NodePtr &node) {
  55. PrintInOutTensorShape(node, "before_infer_value_range");
  56. auto infer_value_range_param = OperatorFactoryImpl::GetInferValueRangePara(node->GetType());
  57. // Use registered func to calculate value range
  58. if (!infer_value_range_param.use_cpu_kernel) {
  59. if (infer_value_range_param.infer_value_func == nullptr) {
  60. GELOGE(GRAPH_PARAM_INVALID, "The registered func to infer value range is nullptr.");
  61. return GRAPH_PARAM_INVALID;
  62. }
  63. Operator op = OpDescUtils::CreateOperatorFromNode(node);
  64. auto ret = node->GetOpDesc()->CallInferValueRangeFunc(op);
  65. if (ret != GRAPH_SUCCESS) {
  66. REPORT_CALL_ERROR("E19999", "Node %s call infer value range function failed.", node->GetName().c_str());
  67. GELOGE(GRAPH_FAILED, "[Call][InferFunction] failed, node: %s.", node->GetName().c_str());
  68. return GRAPH_FAILED;
  69. }
  70. return GRAPH_SUCCESS;
  71. }
  72. // Use CPU kernel func to calculate value range
  73. return ConstructInputAndInferValueRange(node);
  74. }
  75. bool InferValueRangePass::NeedInfer(const NodePtr &node) {
  76. auto infer_value_range_param = OperatorFactoryImpl::GetInferValueRangePara(node->GetType());
  77. if (!infer_value_range_param.is_initialized) {
  78. GELOGD("Node %s does not register func to infer value range, skip infer_value_range_pass.",
  79. node->GetName().c_str());
  80. return false;
  81. }
  82. if (infer_value_range_param.when_call == INPUT_IS_DYNAMIC) {
  83. // Only do infer for node that all inputs are dynamic, such as shape
  84. if (InputIsDynamic(node)) {
  85. return true;
  86. }
  87. GELOGD("Node %s register func to infer value range and when_call is INPUT_IS_DYNAMIC, but check input failed.",
  88. node->GetName().c_str());
  89. } else if (infer_value_range_param.when_call == INPUT_HAS_VALUE_RANGE) {
  90. // Only do infer for node that all inputs have value_range or node type of inputs is constant/const
  91. if (InputIsConstOrHasValueRange(node)) {
  92. return true;
  93. }
  94. GELOGD("Node %s register func to infer value range and when_call is INPUT_HAS_VALUE_RANGE, but check input failed.",
  95. node->GetName().c_str());
  96. }
  97. GELOGD("Node %s does not need to infer value range, skip infer_value_range_pass.", node->GetName().c_str());
  98. return false;
  99. }
  100. bool InferValueRangePass::TensorDescChanged(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) {
  101. bool changed = false;
  102. std::vector<std::pair<int64_t, int64_t>> src_value_range;
  103. std::vector<std::pair<int64_t, int64_t>> dst_value_range;
  104. (void)src->GetValueRange(src_value_range);
  105. (void)dst->GetValueRange(dst_value_range);
  106. if (src_value_range != dst_value_range) {
  107. changed = true;
  108. }
  109. return changed;
  110. }
  111. graphStatus InferValueRangePass::UpdateInputDescAttr(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) {
  112. changed = false;
  113. std::vector<std::pair<int64_t, int64_t>> src_value_range;
  114. std::vector<std::pair<int64_t, int64_t>> dst_value_range;
  115. (void)src->GetValueRange(src_value_range);
  116. (void)dst->GetValueRange(dst_value_range);
  117. if (src_value_range != dst_value_range) {
  118. changed = true;
  119. }
  120. dst->SetValueRange(src_value_range);
  121. return GRAPH_SUCCESS;
  122. }
  123. void InferValueRangePass::AnalyzeFailedInfo(const NodePtr &node) {
  124. REPORT_CALL_ERROR("E19999", "Infer value range for node:%s(%s) failed.", node->GetName().c_str(),
  125. node->GetType().c_str());
  126. GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "infer value range failed. node: %s", node->GetName().c_str());
  127. }
  128. bool InferValueRangePass::InputIsDynamic(const NodePtr &node) {
  129. bool input_is_dynamic = false;
  130. auto cur_op_desc = node->GetOpDesc();
  131. for (const auto &input_desc : cur_op_desc->GetAllInputsDescPtr()) {
  132. auto dims = input_desc->GetShape().GetDims();
  133. for (auto dim : dims) {
  134. if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) {
  135. input_is_dynamic = true;
  136. break;
  137. }
  138. }
  139. }
  140. return input_is_dynamic;
  141. }
  142. bool InferValueRangePass::InputIsConstOrHasValueRange(const NodePtr &node) {
  143. bool input_is_const_or_has_value_range = true;
  144. auto cur_op_desc = node->GetOpDesc();
  145. auto in_data_anchors = node->GetAllInDataAnchors();
  146. for (auto i = 0; i < in_data_anchors.size(); ++i) {
  147. auto peer_out_anchor = in_data_anchors.at(i)->GetPeerOutAnchor();
  148. if (peer_out_anchor == nullptr) {
  149. continue;
  150. }
  151. auto peer_node = peer_out_anchor->GetOwnerNode();
  152. if (peer_node == nullptr || peer_node->GetOpDesc() == nullptr) {
  153. continue;
  154. }
  155. if ((peer_node->GetType() == CONSTANT) || (peer_node->GetType() == CONSTANTOP)) {
  156. continue;
  157. }
  158. const auto &input_desc = cur_op_desc->GetInputDesc(i);
  159. std::vector<std::pair<int64_t, int64_t>> value_range;
  160. (void)input_desc.GetValueRange(value_range);
  161. if (value_range.empty()) {
  162. int peer_out_idx = peer_out_anchor->GetIdx();
  163. auto peer_out_desc = peer_node->GetOpDesc()->MutableOutputDesc(static_cast<uint32_t>(peer_out_idx));
  164. (void)peer_out_desc->GetValueRange(value_range);
  165. if (value_range.empty()) {
  166. input_is_const_or_has_value_range = false;
  167. break;
  168. }
  169. }
  170. }
  171. return input_is_const_or_has_value_range;
  172. }
  173. template <typename T>
  174. graphStatus InferValueRangePass::ConstructData(const GeTensorDesc &tensor_desc, bool use_floor_value, GeTensorPtr &output_ptr) {
  175. std::vector<std::pair<int64_t, int64_t>> value_range;
  176. (void)tensor_desc.GetValueRange(value_range);
  177. if (value_range.size() != tensor_desc.GetShape().GetShapeSize()) {
  178. REPORT_INNER_ERROR("E19999", "Value range of input %s is invalid.", tensor_desc.GetName().c_str());
  179. GELOGE(GRAPH_PARAM_INVALID, "Value range of input %s is invalid.", tensor_desc.GetName().c_str());
  180. return GRAPH_PARAM_INVALID;
  181. }
  182. auto value_range_data_num = value_range.size();
  183. unique_ptr<T[]> buf(new (std::nothrow) T[value_range_data_num]());
  184. if (buf == nullptr) {
  185. REPORT_INNER_ERROR("E19999", "New buf failed");
  186. GELOGE(MEMALLOC_FAILED, "new buf failed");
  187. return GRAPH_FAILED;
  188. }
  189. for (auto j = 0; j < value_range_data_num; ++j) {
  190. auto value_range_j = use_floor_value ? value_range[j].first : value_range[j].second;
  191. buf[j] = static_cast<T>(value_range_j);
  192. }
  193. if (output_ptr->SetData(reinterpret_cast<uint8_t *>(buf.get()), value_range_data_num * sizeof(T)) != GRAPH_SUCCESS) {
  194. GELOGE(GRAPH_FAILED, "set data failed");
  195. return GRAPH_FAILED;
  196. }
  197. return GRAPH_SUCCESS;
  198. }
  199. graphStatus InferValueRangePass::ConstructDataByType(const GeTensorDesc &tensor_desc, bool use_floor_value, GeTensorPtr &output_ptr) {
  200. graphStatus ret = GRAPH_SUCCESS;
  201. auto data_type = tensor_desc.GetDataType();
  202. output_ptr->MutableTensorDesc().SetDataType(data_type);
  203. switch (data_type) {
  204. case DT_FLOAT:
  205. ret = ConstructData<float>(tensor_desc, use_floor_value, output_ptr);
  206. break;
  207. case DT_DOUBLE:
  208. ret = ConstructData<double>(tensor_desc, use_floor_value, output_ptr);
  209. break;
  210. case DT_UINT8:
  211. ret = ConstructData<uint8_t>(tensor_desc, use_floor_value, output_ptr);
  212. break;
  213. case DT_INT8:
  214. ret = ConstructData<int8_t>(tensor_desc, use_floor_value, output_ptr);
  215. break;
  216. case DT_UINT16:
  217. ret = ConstructData<uint16_t>(tensor_desc, use_floor_value, output_ptr);
  218. break;
  219. case DT_INT16:
  220. ret = ConstructData<int16_t>(tensor_desc, use_floor_value, output_ptr);
  221. break;
  222. case DT_INT32:
  223. ret = ConstructData<int32_t>(tensor_desc, use_floor_value, output_ptr);
  224. break;
  225. case DT_INT64:
  226. ret = ConstructData<int64_t>(tensor_desc, use_floor_value, output_ptr);
  227. break;
  228. default:
  229. GELOGW("Data type:%s is not supported.", TypeUtils::DataTypeToSerialString(data_type).c_str());
  230. ret = GRAPH_FAILED;
  231. }
  232. return ret;
  233. }
  234. vector<ConstGeTensorPtr> InferValueRangePass::ConstructInputTensors(const NodePtr &node, bool use_floor_value) {
  235. vector<ConstGeTensorPtr> input_tensors;
  236. auto cur_op_desc = node->GetOpDesc();
  237. auto in_data_anchors = node->GetAllInDataAnchors();
  238. for (auto i = 0; i < in_data_anchors.size(); ++i) {
  239. auto peer_out_anchor = in_data_anchors.at(i)->GetPeerOutAnchor();
  240. if (peer_out_anchor == nullptr) {
  241. continue;
  242. }
  243. auto peer_node = peer_out_anchor->GetOwnerNode();
  244. if (peer_node == nullptr) {
  245. continue;
  246. }
  247. // construct input tensor by constant node
  248. if ((peer_node->GetType() == CONSTANT) || (peer_node->GetType() == CONSTANTOP)) {
  249. vector<GeTensorPtr> const_weight = OpDescUtils::MutableWeights(peer_node);
  250. if (const_weight.empty()) {
  251. REPORT_INNER_ERROR("E19999", "MutableWeights failed, weight is empty, node: %s(%s)",
  252. peer_node->GetName().c_str(), peer_node->GetType().c_str());
  253. GELOGE(INTERNAL_ERROR, "MutableWeights failed, weight is empty, node: %s(%s)", peer_node->GetName().c_str(),
  254. peer_node->GetType().c_str());
  255. return vector<ConstGeTensorPtr>();
  256. }
  257. // const/constant op has only one weight
  258. if (const_weight.at(0) == nullptr) {
  259. REPORT_INNER_ERROR("E19999", "MutableWeights failed, weight of constant is null, node: %s(%s)",
  260. peer_node->GetName().c_str(), peer_node->GetType().c_str());
  261. GELOGE(INTERNAL_ERROR, "MutableWeights failed, weight of constant is null, node name: %s(%s)",
  262. peer_node->GetName().c_str(), peer_node->GetType().c_str());
  263. return vector<ConstGeTensorPtr>();
  264. }
  265. input_tensors.push_back(const_weight.at(0));
  266. continue;
  267. }
  268. // construct input tensor by boundary of value range
  269. const auto &input_tensor_desc = cur_op_desc->GetInputDesc(i);
  270. GeTensorPtr tmp_tensor_ptr = MakeShared<GeTensor>(input_tensor_desc);
  271. if (tmp_tensor_ptr == nullptr) {
  272. REPORT_INNER_ERROR("E19999", "Make shared failed");
  273. GELOGE(MEMALLOC_FAILED, "Make shared failed");
  274. return vector<ConstGeTensorPtr>();
  275. }
  276. auto ret = ConstructDataByType(input_tensor_desc, use_floor_value, tmp_tensor_ptr);
  277. if (ret != GRAPH_SUCCESS) {
  278. REPORT_INNER_ERROR("E19999", "Input %s construct input tensor by boundary of value range failed.",
  279. input_tensor_desc.GetName().c_str());
  280. GELOGE(GRAPH_PARAM_INVALID, "Input %s construct input tensor by boundary of value range failed.",
  281. input_tensor_desc.GetName().c_str());
  282. return vector<ConstGeTensorPtr>();
  283. }
  284. input_tensors.push_back(tmp_tensor_ptr);
  285. }
  286. return input_tensors;
  287. }
  288. graphStatus InferValueRangePass::ConstructInputAndInferValueRange(NodePtr &node) {
  289. auto inputs = ConstructInputTensors(node, true);
  290. if (inputs.empty()) {
  291. return GRAPH_PARAM_INVALID;
  292. }
  293. vector<GeTensorPtr> outputs_lower;
  294. auto ret = RunCpuKernelForValueRange(node, inputs, outputs_lower);
  295. if (ret != SUCCESS) {
  296. REPORT_INNER_ERROR("E19999", "Calculate for node %s(%s) failed", node->GetName().c_str(), node->GetType().c_str());
  297. GELOGE(GRAPH_FAILED, "Calculate for node %s failed in constant folding", node->GetName().c_str());
  298. return GRAPH_FAILED;
  299. }
  300. inputs = ConstructInputTensors(node, false);
  301. if (inputs.empty()) {
  302. return GRAPH_PARAM_INVALID;
  303. }
  304. vector<GeTensorPtr> outputs_higher;
  305. ret = RunCpuKernelForValueRange(node, inputs, outputs_higher);
  306. if (ret != SUCCESS) {
  307. REPORT_INNER_ERROR("E19999", "Calculate for node %s(%s) failed", node->GetName().c_str(), node->GetType().c_str());
  308. GELOGE(GRAPH_FAILED, "Calculate for node %s failed in constant folding", node->GetName().c_str());
  309. return GRAPH_FAILED;
  310. }
  311. // construct value range from output tensor
  312. OpDescPtr node_desc = node->GetOpDesc();
  313. std::vector<std::pair<int64_t, int64_t>> output_tensor_value_range;
  314. size_t node_output_desc_size = node_desc->GetOutputsSize();
  315. for (size_t i = 0; i < node_output_desc_size; ++i) {
  316. output_tensor_value_range.clear();
  317. auto lower_tensor = outputs_lower[i];
  318. auto lower_tensor_shape_size = lower_tensor->GetTensorDesc().GetShape().GetShapeSize();
  319. auto higher_tensor = outputs_higher[i];
  320. auto higher_tensor_shape_size = higher_tensor->GetTensorDesc().GetShape().GetShapeSize();
  321. auto output_tensor_desc = node_desc->MutableOutputDesc(i);
  322. auto output_tensor_shape_size = output_tensor_desc->GetShape().GetShapeSize();
  323. if (output_tensor_shape_size != lower_tensor_shape_size || output_tensor_shape_size != higher_tensor_shape_size) {
  324. GELOGE(GRAPH_PARAM_INVALID, "Value range of output %s is invalid.", output_tensor_desc->GetName().c_str());
  325. }
  326. auto data_type = output_tensor_desc->GetDataType();
  327. switch (data_type) {
  328. GET_DATA_BY_DTYPE(DT_INT8, int8_t)
  329. GET_DATA_BY_DTYPE(DT_INT16, int16_t)
  330. GET_DATA_BY_DTYPE(DT_INT32, int32_t)
  331. GET_DATA_BY_DTYPE(DT_INT64, int64_t)
  332. GET_DATA_BY_DTYPE(DT_UINT8, uint8_t)
  333. GET_DATA_BY_DTYPE(DT_UINT16, uint16_t)
  334. GET_DATA_BY_DTYPE(DT_UINT32, uint32_t)
  335. GET_DATA_BY_DTYPE(DT_UINT64, uint64_t)
  336. GET_DATA_BY_DTYPE(DT_FLOAT, float)
  337. GET_DATA_BY_DTYPE(DT_DOUBLE, double)
  338. default:
  339. GELOGW("Data type:%s is not supported.", TypeUtils::DataTypeToSerialString(data_type).c_str());
  340. return GRAPH_FAILED;
  341. }
  342. output_tensor_desc->SetValueRange(output_tensor_value_range);
  343. }
  344. return GRAPH_SUCCESS;
  345. }
  346. template <typename T>
  347. void InferValueRangePass::ConstructValueRange(const GeTensorPtr &left_tensor, const GeTensorPtr &right_tensor,
  348. std::vector<std::pair<int64_t, int64_t>> &value_range) {
  349. auto x = reinterpret_cast<const T *>(left_tensor->GetData().GetData());
  350. auto y = reinterpret_cast<const T *>(right_tensor->GetData().GetData());
  351. for (auto j = 0; j < left_tensor->GetTensorDesc().GetShape().GetShapeSize(); ++j) {
  352. auto left = static_cast<int64_t>(*(x + j));
  353. auto right = static_cast<int64_t>(*(y + j));
  354. value_range.emplace_back(std::make_pair(left, right));
  355. }
  356. }
  357. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示