/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "graph/passes/infer_value_range_pass.h" #include "common/util/error_manager/error_manager.h" #include "framework/common/debug/ge_log.h" #include "graph/debug/ge_attr_define.h" #include "graph/operator_factory_impl.h" #include "graph/passes/folding_pass.h" #include "common/ge/ge_util.h" #include "init/gelib.h" using std::unique_ptr; namespace ge { namespace { #define GET_DATA_BY_DTYPE(DTYPE, TYPE) \ case (DTYPE): \ ConstructValueRange(lower_tensor, higher_tensor, output_tensor_value_range); \ break; Status RunCpuKernelForValueRange(NodePtr &node, const vector &inputs, std::vector &outputs) { // should use RunOpKernelWithCheck, RunOpKernel for ut test auto ret = FoldingPass::RunOpKernel(node, inputs, outputs); if (ret != SUCCESS) { auto op_kernel = folding_pass::GetKernelByType(node); if (op_kernel == nullptr) { GELOGE(PARAM_INVALID, "Calculate value range failed, no op kernel for node %s type %s", node->GetName().c_str(), node->GetType().c_str()); return PARAM_INVALID; } ret = op_kernel->Compute(node->GetOpDesc(), inputs, outputs); if (ret != SUCCESS) { REPORT_INNER_ERROR("E19999", "Calculate for node %s(%s) failed", node->GetName().c_str(), node->GetType().c_str()); GELOGE(INTERNAL_ERROR, "Calculate for node %s failed in constant folding", node->GetName().c_str()); return ret; } } GELOGI("Node %s type %s, run cpu kernel success.", node->GetName().c_str(), node->GetType().c_str()); return SUCCESS; } } // namespace graphStatus InferValueRangePass::Infer(NodePtr &node) { PrintInOutTensorShape(node, "before_infer_value_range"); auto infer_value_range_param = OperatorFactoryImpl::GetInferValueRangePara(node->GetType()); // Use registered func to calculate value range if (!infer_value_range_param.use_cpu_kernel) { if (infer_value_range_param.infer_value_func == nullptr) { GELOGE(GRAPH_PARAM_INVALID, "The registered func to infer value range is nullptr."); return GRAPH_PARAM_INVALID; } Operator op = OpDescUtils::CreateOperatorFromNode(node); auto ret = node->GetOpDesc()->CallInferValueRangeFunc(op); if (ret != GRAPH_SUCCESS) { REPORT_CALL_ERROR("E19999", "Node %s call infer value range function failed.", node->GetName().c_str()); GELOGE(GRAPH_FAILED, "[Call][InferFunction] failed, node: %s.", node->GetName().c_str()); return GRAPH_FAILED; } return GRAPH_SUCCESS; } // Use CPU kernel func to calculate value range return ConstructInputAndInferValueRange(node); } bool InferValueRangePass::NeedInfer(const NodePtr &node) { auto infer_value_range_param = OperatorFactoryImpl::GetInferValueRangePara(node->GetType()); if (!infer_value_range_param.is_initialized) { GELOGD("Node %s does not register func to infer value range, skip infer_value_range_pass.", node->GetName().c_str()); return false; } if (infer_value_range_param.when_call == INPUT_IS_DYNAMIC) { // Only do infer for node that all inputs are dynamic, such as shape if (InputIsDynamic(node)) { return true; } GELOGD("Node %s register func to infer value range and when_call is INPUT_IS_DYNAMIC, but check input failed.", node->GetName().c_str()); } else if (infer_value_range_param.when_call == INPUT_HAS_VALUE_RANGE) { // Only do infer for node that all inputs have value_range or node type of inputs is constant/const if (InputIsConstOrHasValueRange(node)) { return true; } GELOGD("Node %s register func to infer value range and when_call is INPUT_HAS_VALUE_RANGE, but check input failed.", node->GetName().c_str()); } GELOGD("Node %s does not need to infer value range, skip infer_value_range_pass.", node->GetName().c_str()); return false; } bool InferValueRangePass::TensorDescChanged(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) { bool changed = false; std::vector> src_value_range; std::vector> dst_value_range; (void)src->GetValueRange(src_value_range); (void)dst->GetValueRange(dst_value_range); if (src_value_range != dst_value_range) { changed = true; } return changed; } graphStatus InferValueRangePass::UpdateInputDescAttr(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { changed = false; std::vector> src_value_range; std::vector> dst_value_range; (void)src->GetValueRange(src_value_range); (void)dst->GetValueRange(dst_value_range); if (src_value_range != dst_value_range) { changed = true; } dst->SetValueRange(src_value_range); return GRAPH_SUCCESS; } void InferValueRangePass::AnalyzeFailedInfo(const NodePtr &node) { REPORT_CALL_ERROR("E19999", "Infer value range for node:%s(%s) failed.", node->GetName().c_str(), node->GetType().c_str()); GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "infer value range failed. node: %s", node->GetName().c_str()); } bool InferValueRangePass::InputIsDynamic(const NodePtr &node) { bool input_is_dynamic = false; auto cur_op_desc = node->GetOpDesc(); for (const auto &input_desc : cur_op_desc->GetAllInputsDescPtr()) { auto dims = input_desc->GetShape().GetDims(); for (auto dim : dims) { if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) { input_is_dynamic = true; break; } } } return input_is_dynamic; } bool InferValueRangePass::InputIsConstOrHasValueRange(const NodePtr &node) { bool input_is_const_or_has_value_range = true; auto cur_op_desc = node->GetOpDesc(); auto in_data_anchors = node->GetAllInDataAnchors(); for (auto i = 0; i < in_data_anchors.size(); ++i) { auto peer_out_anchor = in_data_anchors.at(i)->GetPeerOutAnchor(); if (peer_out_anchor == nullptr) { continue; } auto peer_node = peer_out_anchor->GetOwnerNode(); if (peer_node == nullptr || peer_node->GetOpDesc() == nullptr) { continue; } if ((peer_node->GetType() == CONSTANT) || (peer_node->GetType() == CONSTANTOP)) { continue; } const auto &input_desc = cur_op_desc->GetInputDesc(i); std::vector> value_range; (void)input_desc.GetValueRange(value_range); if (value_range.empty()) { int peer_out_idx = peer_out_anchor->GetIdx(); auto peer_out_desc = peer_node->GetOpDesc()->MutableOutputDesc(static_cast(peer_out_idx)); (void)peer_out_desc->GetValueRange(value_range); if (value_range.empty()) { input_is_const_or_has_value_range = false; break; } } } return input_is_const_or_has_value_range; } template graphStatus InferValueRangePass::ConstructData(const GeTensorDesc &tensor_desc, bool use_floor_value, GeTensorPtr &output_ptr) { std::vector> value_range; (void)tensor_desc.GetValueRange(value_range); if (value_range.size() != tensor_desc.GetShape().GetShapeSize()) { REPORT_INNER_ERROR("E19999", "Value range of input %s is invalid.", tensor_desc.GetName().c_str()); GELOGE(GRAPH_PARAM_INVALID, "Value range of input %s is invalid.", tensor_desc.GetName().c_str()); return GRAPH_PARAM_INVALID; } auto value_range_data_num = value_range.size(); unique_ptr buf(new (std::nothrow) T[value_range_data_num]()); if (buf == nullptr) { REPORT_INNER_ERROR("E19999", "New buf failed"); GELOGE(MEMALLOC_FAILED, "new buf failed"); return GRAPH_FAILED; } for (auto j = 0; j < value_range_data_num; ++j) { auto value_range_j = use_floor_value ? value_range[j].first : value_range[j].second; buf[j] = static_cast(value_range_j); } if (output_ptr->SetData(reinterpret_cast(buf.get()), value_range_data_num * sizeof(T)) != GRAPH_SUCCESS) { GELOGE(GRAPH_FAILED, "set data failed"); return GRAPH_FAILED; } return GRAPH_SUCCESS; } graphStatus InferValueRangePass::ConstructDataByType(const GeTensorDesc &tensor_desc, bool use_floor_value, GeTensorPtr &output_ptr) { graphStatus ret = GRAPH_SUCCESS; auto data_type = tensor_desc.GetDataType(); output_ptr->MutableTensorDesc().SetDataType(data_type); switch (data_type) { case DT_FLOAT: ret = ConstructData(tensor_desc, use_floor_value, output_ptr); break; case DT_DOUBLE: ret = ConstructData(tensor_desc, use_floor_value, output_ptr); break; case DT_UINT8: ret = ConstructData(tensor_desc, use_floor_value, output_ptr); break; case DT_INT8: ret = ConstructData(tensor_desc, use_floor_value, output_ptr); break; case DT_UINT16: ret = ConstructData(tensor_desc, use_floor_value, output_ptr); break; case DT_INT16: ret = ConstructData(tensor_desc, use_floor_value, output_ptr); break; case DT_INT32: ret = ConstructData(tensor_desc, use_floor_value, output_ptr); break; case DT_INT64: ret = ConstructData(tensor_desc, use_floor_value, output_ptr); break; default: GELOGW("Data type:%s is not supported.", TypeUtils::DataTypeToSerialString(data_type).c_str()); ret = GRAPH_FAILED; } return ret; } vector InferValueRangePass::ConstructInputTensors(const NodePtr &node, bool use_floor_value) { vector input_tensors; auto cur_op_desc = node->GetOpDesc(); auto in_data_anchors = node->GetAllInDataAnchors(); for (auto i = 0; i < in_data_anchors.size(); ++i) { auto peer_out_anchor = in_data_anchors.at(i)->GetPeerOutAnchor(); if (peer_out_anchor == nullptr) { continue; } auto peer_node = peer_out_anchor->GetOwnerNode(); if (peer_node == nullptr) { continue; } // construct input tensor by constant node if ((peer_node->GetType() == CONSTANT) || (peer_node->GetType() == CONSTANTOP)) { vector const_weight = OpDescUtils::MutableWeights(peer_node); if (const_weight.empty()) { REPORT_INNER_ERROR("E19999", "MutableWeights failed, weight is empty, node: %s(%s)", peer_node->GetName().c_str(), peer_node->GetType().c_str()); GELOGE(INTERNAL_ERROR, "MutableWeights failed, weight is empty, node: %s(%s)", peer_node->GetName().c_str(), peer_node->GetType().c_str()); return vector(); } // const/constant op has only one weight if (const_weight.at(0) == nullptr) { REPORT_INNER_ERROR("E19999", "MutableWeights failed, weight of constant is null, node: %s(%s)", peer_node->GetName().c_str(), peer_node->GetType().c_str()); GELOGE(INTERNAL_ERROR, "MutableWeights failed, weight of constant is null, node name: %s(%s)", peer_node->GetName().c_str(), peer_node->GetType().c_str()); return vector(); } input_tensors.push_back(const_weight.at(0)); continue; } // construct input tensor by boundary of value range const auto &input_tensor_desc = cur_op_desc->GetInputDesc(i); GeTensorPtr tmp_tensor_ptr = MakeShared(input_tensor_desc); if (tmp_tensor_ptr == nullptr) { REPORT_INNER_ERROR("E19999", "Make shared failed"); GELOGE(MEMALLOC_FAILED, "Make shared failed"); return vector(); } auto ret = ConstructDataByType(input_tensor_desc, use_floor_value, tmp_tensor_ptr); if (ret != GRAPH_SUCCESS) { REPORT_INNER_ERROR("E19999", "Input %s construct input tensor by boundary of value range failed.", input_tensor_desc.GetName().c_str()); GELOGE(GRAPH_PARAM_INVALID, "Input %s construct input tensor by boundary of value range failed.", input_tensor_desc.GetName().c_str()); return vector(); } input_tensors.push_back(tmp_tensor_ptr); } return input_tensors; } graphStatus InferValueRangePass::ConstructInputAndInferValueRange(NodePtr &node) { auto inputs = ConstructInputTensors(node, true); if (inputs.empty()) { return GRAPH_PARAM_INVALID; } vector outputs_lower; auto ret = RunCpuKernelForValueRange(node, inputs, outputs_lower); if (ret != SUCCESS) { REPORT_INNER_ERROR("E19999", "Calculate for node %s(%s) failed", node->GetName().c_str(), node->GetType().c_str()); GELOGE(GRAPH_FAILED, "Calculate for node %s failed in constant folding", node->GetName().c_str()); return GRAPH_FAILED; } inputs = ConstructInputTensors(node, false); if (inputs.empty()) { return GRAPH_PARAM_INVALID; } vector outputs_higher; ret = RunCpuKernelForValueRange(node, inputs, outputs_higher); if (ret != SUCCESS) { REPORT_INNER_ERROR("E19999", "Calculate for node %s(%s) failed", node->GetName().c_str(), node->GetType().c_str()); GELOGE(GRAPH_FAILED, "Calculate for node %s failed in constant folding", node->GetName().c_str()); return GRAPH_FAILED; } // construct value range from output tensor OpDescPtr node_desc = node->GetOpDesc(); std::vector> output_tensor_value_range; size_t node_output_desc_size = node_desc->GetOutputsSize(); for (size_t i = 0; i < node_output_desc_size; ++i) { output_tensor_value_range.clear(); auto lower_tensor = outputs_lower[i]; auto lower_tensor_shape_size = lower_tensor->GetTensorDesc().GetShape().GetShapeSize(); auto higher_tensor = outputs_higher[i]; auto higher_tensor_shape_size = higher_tensor->GetTensorDesc().GetShape().GetShapeSize(); auto output_tensor_desc = node_desc->MutableOutputDesc(i); auto output_tensor_shape_size = output_tensor_desc->GetShape().GetShapeSize(); if (output_tensor_shape_size != lower_tensor_shape_size || output_tensor_shape_size != higher_tensor_shape_size) { GELOGE(GRAPH_PARAM_INVALID, "Value range of output %s is invalid.", output_tensor_desc->GetName().c_str()); } auto data_type = output_tensor_desc->GetDataType(); switch (data_type) { GET_DATA_BY_DTYPE(DT_INT8, int8_t) GET_DATA_BY_DTYPE(DT_INT16, int16_t) GET_DATA_BY_DTYPE(DT_INT32, int32_t) GET_DATA_BY_DTYPE(DT_INT64, int64_t) GET_DATA_BY_DTYPE(DT_UINT8, uint8_t) GET_DATA_BY_DTYPE(DT_UINT16, uint16_t) GET_DATA_BY_DTYPE(DT_UINT32, uint32_t) GET_DATA_BY_DTYPE(DT_UINT64, uint64_t) GET_DATA_BY_DTYPE(DT_FLOAT, float) GET_DATA_BY_DTYPE(DT_DOUBLE, double) default: GELOGW("Data type:%s is not supported.", TypeUtils::DataTypeToSerialString(data_type).c_str()); return GRAPH_FAILED; } output_tensor_desc->SetValueRange(output_tensor_value_range); } return GRAPH_SUCCESS; } template void InferValueRangePass::ConstructValueRange(const GeTensorPtr &left_tensor, const GeTensorPtr &right_tensor, std::vector> &value_range) { auto x = reinterpret_cast(left_tensor->GetData().GetData()); auto y = reinterpret_cast(right_tensor->GetData().GetData()); for (auto j = 0; j < left_tensor->GetTensorDesc().GetShape().GetShapeSize(); ++j) { auto left = static_cast(*(x + j)); auto right = static_cast(*(y + j)); value_range.emplace_back(std::make_pair(left, right)); } } } // namespace ge