/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /*! * \file array_ops.cpp * \brief */ #include "array_ops.h" #include #include #include #include "./util/op_log.h" #include "./util/common_shape_fns.h" #include "./util/array_ops_shape_fns.h" #include "graph/utils/tensor_adapter.h" #include "graph/utils/node_utils.h" #include "./util/error_util.h" #include "util/util.h" namespace ge { const char* const kShape = "shape"; const char* const kShapeDtype = "shape dtype"; const char* const kAttrShape = "attr shape"; const char* const kAttrDtype = "attr dtype"; const char* const kAttrAxis = "attr axis"; const char* const kAttrNumAxes = "attr num_axes"; const char* const kPreOpInputShapeRange = "_pre_op_in_range"; const int64_t kMaxDimNum = 8; IMPLEMT_INFERFUNC(Unique, UniqueInfer) { OpDescPtr op_desc = OpDescUtils::GetOpDescFromOperator(op); GeTensorDescPtr x_input = op_desc->MutableInputDesc(0); GeShape x_shape; if (WithRank(x_input, 1, x_shape) != GRAPH_SUCCESS) { ShapeErrReport(0, op.GetName(), DebugString(x_input->GetShape().GetDims()), "1D"); OP_LOGE(op.GetName().c_str(), "input x must be 1-D"); return GRAPH_FAILED; } DataType idx_type; if (op.GetAttr("out_idx", idx_type) != GRAPH_SUCCESS) { OP_LOGE(op.GetName().c_str(), "Op get attr out_idx failed"); return GRAPH_FAILED; } GeTensorDescPtr idx_desc = op_desc->MutableOutputDesc(1); idx_desc->SetShape(x_shape); idx_desc->SetOriginShape(x_shape); idx_desc->SetDataType(idx_type); GeTensorDescPtr y_desc = op_desc->MutableOutputDesc(0); y_desc->SetShape(GeShape({UNKNOWN_DIM})); y_desc->SetOriginShape(GeShape({UNKNOWN_DIM})); y_desc->SetDataType(x_input->GetDataType()); if (x_shape.GetShapeSize() == UNKNOWN_DIM) { return GRAPH_SUCCESS; } else { std::vector> range; int64_t max_dim = x_shape.GetDim(0); range.emplace_back(std::make_pair(1, max_dim)); y_desc->SetShapeRange(range); return GRAPH_SUCCESS; } } INFER_FUNC_REG(Unique, UniqueInfer); IMPLEMT_INFERFUNC(Const, ConstInfer) { auto value = op.get_attr_value(); auto valDesc = value.GetTensorDesc(); auto dims = valDesc.GetShape().GetDims(); auto attrDtype = valDesc.GetDataType(); TensorDesc outDesc = op.get_output_desc_y(); outDesc.SetDataType(ge::DataType(attrDtype)); outDesc.SetShape(Shape(dims)); (void)op.update_output_desc_y(outDesc); return GRAPH_SUCCESS; } INFER_FUNC_REG(Const, ConstInfer); IMPLEMT_INFERFUNC(Constant, ConstantInfer) { auto value = op.get_attr_value(); auto valDesc = value.GetTensorDesc(); auto dims = valDesc.GetShape().GetDims(); auto attrDtype = valDesc.GetDataType(); TensorDesc outDesc = op.get_output_desc_y(); outDesc.SetDataType(ge::DataType(attrDtype)); outDesc.SetShape(Shape(dims)); (void)op.update_output_desc_y(outDesc); return GRAPH_SUCCESS; } INFER_FUNC_REG(Constant, ConstantInfer); graphStatus ConstAndConstantInferFormat(ge::Operator& op) { OP_LOGI(op.GetName().c_str(), "Const infer format start"); auto op_desc = OpDescUtils::GetOpDescFromOperator(op); auto format = op_desc->MutableOutputDesc(0)->GetOriginFormat(); ConstGeTensorPtr tensor_value; if (!AttrUtils::GetTensor(op_desc, "value", tensor_value)) { OP_LOGE(op.GetName().c_str(), "Get attr value failed!"); return GRAPH_FAILED; } if (!tensor_value) { OP_LOGE(op.GetName().c_str(), "attr tensor is not exist!"); return GRAPH_FAILED; } auto tensor_ptr = const_cast(tensor_value.get()); tensor_ptr->MutableTensorDesc().SetOriginFormat(format); tensor_ptr->MutableTensorDesc().SetFormat(format); return GRAPH_SUCCESS; } IMPLEMT_INFERFORMAT_FUNC(Const, ConstInferFormat) { return ConstAndConstantInferFormat(op); } INFER_FORMAT_FUNC_REG(Const, ConstInferFormat); IMPLEMT_INFERFUNC(Snapshot, SnapshotInferFunc) { OP_LOGI(op.GetName().c_str(), "Snapshot infershape start"); auto op_desc = OpDescUtils::GetOpDescFromOperator(op); auto input_desc_x = op_desc->MutableInputDesc("x"); auto output_desc_y = op_desc->MutableOutputDesc("y"); auto x_dims = input_desc_x->MutableShape().GetDims(); auto x_type = input_desc_x->GetDataType(); std::vector> x_range; input_desc_x->GetShapeRange(x_range); output_desc_y->SetShape(GeShape(x_dims)); output_desc_y->SetOriginShape(GeShape(x_dims)); output_desc_y->SetShapeRange(x_range); output_desc_y->SetDataType(x_type); OP_LOGI(op.GetName().c_str(), "Snapshot infershape end"); return GRAPH_SUCCESS; } INFER_FUNC_REG(Snapshot, SnapshotInferFunc); IMPLEMT_INFERFUNC(GuaranteeConst, GuaranteeConstInfer) { TensorDesc tensorDesc = op.GetInputDesc("x"); (void)op.UpdateOutputDesc("y", tensorDesc); return GRAPH_SUCCESS; } INFER_FUNC_REG(GuaranteeConst, GuaranteeConstInfer); IMPLEMT_INFERFUNC(BroadcastArgs, BroadcastArgsInferFunc) { auto op_desc = OpDescUtils::GetOpDescFromOperator(op); auto x1_desc = op_desc->MutableInputDesc("x1"); auto x2_desc = op_desc->MutableInputDesc("x2"); auto y_desc = op_desc->MutableOutputDesc("y"); auto x1_dims = x1_desc->GetShape().GetDims(); auto x2_dims = x2_desc->GetShape().GetDims(); auto data_type = x1_desc->GetDataType(); std::vector> x1_range; std::vector> x2_range; std::vector> out_range; x1_desc->GetShapeRange(x1_range); x2_desc->GetShapeRange(x2_range); bool data_type_check = ((x1_desc->GetDataType() != DT_INT32 && x1_desc->GetDataType() != DT_INT64) || (x2_desc->GetDataType() != DT_INT32 && x2_desc->GetDataType() != DT_INT64)); if (data_type_check) { string reason = "x1[" + std::to_string(x1_desc->GetDataType()) + "] + and + x2[" + std::to_string(x1_desc->GetDataType()) + "] must DT_INT32 or DT_INT64"; GeInfershapeErrReport(op.GetName(), op.GetOpType(), "dtype", reason); GE_OP_LOGE(op.GetName().c_str(), "Data type check fail. x1[%u] and x2[%u] must DT_INT32 or DT_INT64", x1_desc->GetDataType(), x2_desc->GetDataType()); return GRAPH_PARAM_INVALID; } if (x1_dims.size() > 1 || x2_dims.size() > 1) { string reason = "x1[" + std::to_string(x1_dims.size()) + "] + and + x2[" + std::to_string(x2_dims.size()) + "] must be less than or equal to 1"; GeInfershapeErrReport(op.GetName(), op.GetOpType(), "dims", reason); GE_OP_LOGE(op.GetName().c_str(), "Size check fail. x1[%u] and x2[%u] must be less than or equal to 1", x1_dims.size(), x2_dims.size()); return GRAPH_PARAM_INVALID; } if (x1_dims == UNKNOWN_RANK || x2_dims == UNKNOWN_RANK) { GE_OP_LOGD(op.GetName().c_str(), "all two inputs are unknown rank!"); y_desc->SetShape(GeShape(UNKNOWN_SHAPE)); y_desc->SetOriginShape(GeShape(UNKNOWN_SHAPE)); y_desc->SetDataType(data_type); return GRAPH_SUCCESS; } if (x1_dims == UNKNOWN_SHAPE && x2_dims == UNKNOWN_SHAPE) { GE_OP_LOGD(op.GetName().c_str(), "all two inputs are unknown shape!"); y_desc->SetShape(GeShape(UNKNOWN_SHAPE)); y_desc->SetOriginShape(GeShape(UNKNOWN_SHAPE)); y_desc->SetDataType(data_type); y_desc->SetShapeRange(x1_range); return GRAPH_SUCCESS; } else if (x1_dims == UNKNOWN_SHAPE) { GE_OP_LOGD(op.GetName().c_str(), "x1 is unknown shape!"); int64_t range_max = x2_dims.size(); std::pair pair({1, range_max}); out_range.emplace_back(pair); y_desc->SetShape(GeShape(UNKNOWN_SHAPE)); y_desc->SetOriginShape(GeShape(UNKNOWN_SHAPE)); y_desc->SetDataType(data_type); y_desc->SetShapeRange(out_range); return GRAPH_SUCCESS; } else if (x2_dims == UNKNOWN_SHAPE) { GE_OP_LOGD(op.GetName().c_str(), "x2 is unknown shape!"); int64_t range_max = x2_dims.size(); std::pair pair({1, range_max}); out_range.emplace_back(pair); y_desc->SetShape(GeShape(UNKNOWN_SHAPE)); y_desc->SetOriginShape(GeShape(UNKNOWN_SHAPE)); y_desc->SetDataType(data_type); y_desc->SetShapeRange(out_range); return GRAPH_SUCCESS; } if (x1_dims.empty()) { y_desc->SetShape(GeShape(x2_dims)); } else if (x2_dims.empty()) { y_desc->SetShape(GeShape(x1_dims)); } else { auto dims = x1_dims[0] > x2_dims[0] ? x1_dims : x2_dims; y_desc->SetShape(GeShape(dims)); } int64_t range_max = x1_dims.size() > x2_dims.size() ? x1_dims.size() : x2_dims.size(); std::pair pair({1, range_max}); out_range.emplace_back(pair); y_desc->SetShapeRange(out_range); y_desc->SetDataType(x1_desc->GetDataType()); return GRAPH_SUCCESS; } INFER_FUNC_REG(BroadcastArgs, BroadcastArgsInferFunc); IMPLEMT_INFERFUNC(BroadcastGradientArgs, BroadcastGradientArgsInfer) { auto op_desc = OpDescUtils::GetOpDescFromOperator(op); auto input_desc_x1 = op_desc->MutableInputDesc("x1"); auto input_desc_x2 = op_desc->MutableInputDesc("x2"); auto output_desc_y1 = op_desc->MutableOutputDesc("y1"); auto output_desc_y2 = op_desc->MutableOutputDesc("y2"); auto dims_x1 = input_desc_x1->MutableShape().GetDims(); auto dims_x2 = input_desc_x2->MutableShape().GetDims(); auto x1_type = input_desc_x1->GetDataType(); auto x2_type = input_desc_x2->GetDataType(); std::vector> x1_range; std::vector> x2_range; std::vector> out_range; input_desc_x1->GetShapeRange(x1_range); input_desc_x2->GetShapeRange(x2_range); if (dims_x1 == UNKNOWN_RANK || dims_x2 == UNKNOWN_RANK) { GE_OP_LOGD(op.GetName().c_str(), "all two inputs are unknown rank!"); output_desc_y1->SetShape(GeShape(UNKNOWN_SHAPE)); output_desc_y1->SetOriginShape(GeShape(UNKNOWN_SHAPE)); output_desc_y1->SetDataType(x1_type); output_desc_y2->SetShape(GeShape(UNKNOWN_SHAPE)); output_desc_y2->SetOriginShape(GeShape(UNKNOWN_SHAPE)); output_desc_y2->SetDataType(x2_type); return GRAPH_SUCCESS; } // Input Dim Num must be equal or smaller than 1 if (dims_x1 == UNKNOWN_SHAPE && dims_x2 == UNKNOWN_SHAPE) { GE_OP_LOGD(op.GetName().c_str(), "all two inputs are unknown shape!"); output_desc_y1->SetShape(GeShape(UNKNOWN_SHAPE)); output_desc_y1->SetOriginShape(GeShape(UNKNOWN_SHAPE)); output_desc_y1->SetDataType(x1_type); output_desc_y1->SetShapeRange(x1_range); output_desc_y2->SetShape(GeShape(UNKNOWN_SHAPE)); output_desc_y2->SetOriginShape(GeShape(UNKNOWN_SHAPE)); output_desc_y2->SetDataType(x2_type); output_desc_y2->SetShapeRange(x2_range); return GRAPH_SUCCESS; } else if (dims_x1 == UNKNOWN_SHAPE) { GE_OP_LOGD(op.GetName().c_str(), "x1 is unknown shape!"); int64_t range_max = dims_x2.size(); std::pair pair({1, range_max}); out_range.emplace_back(pair); output_desc_y1->SetShape(GeShape(UNKNOWN_SHAPE)); output_desc_y1->SetOriginShape(GeShape(UNKNOWN_SHAPE)); output_desc_y1->SetDataType(x1_type); output_desc_y1->SetShapeRange(out_range); output_desc_y2->SetShape(GeShape(UNKNOWN_SHAPE)); output_desc_y2->SetOriginShape(GeShape(UNKNOWN_SHAPE)); output_desc_y2->SetDataType(x2_type); output_desc_y2->SetShapeRange(out_range); return GRAPH_SUCCESS; } else if (dims_x2 == UNKNOWN_SHAPE) { GE_OP_LOGD(op.GetName().c_str(), "x2 is unknown shape!"); int64_t range_max = dims_x1.size(); std::pair pair({1, range_max}); out_range.emplace_back(pair); output_desc_y1->SetShape(GeShape(UNKNOWN_SHAPE)); output_desc_y1->SetOriginShape(GeShape(UNKNOWN_SHAPE)); output_desc_y1->SetDataType(x1_type); output_desc_y1->SetShapeRange(out_range); output_desc_y2->SetShape(GeShape(UNKNOWN_SHAPE)); output_desc_y2->SetOriginShape(GeShape(UNKNOWN_SHAPE)); output_desc_y2->SetDataType(x2_type); output_desc_y2->SetShapeRange(out_range); return GRAPH_SUCCESS; } GE_OP_LOGD(op.GetName().c_str(), "all two inputs are known shape!"); int64_t range_max = dims_x1.size() == 0 ? 1 : dims_x1.size(); std::pair pair({1, range_max}); out_range.emplace_back(pair); output_desc_y1->SetDataType(x1_type); output_desc_y2->SetDataType(x2_type); output_desc_y1->SetShape(GeShape(UNKNOWN_SHAPE)); output_desc_y1->SetOriginShape(GeShape(UNKNOWN_SHAPE)); output_desc_y2->SetShape(GeShape(UNKNOWN_SHAPE)); output_desc_y2->SetOriginShape(GeShape(UNKNOWN_SHAPE)); output_desc_y1->SetShapeRange(out_range); output_desc_y2->SetShapeRange(out_range); return GRAPH_SUCCESS; } INFER_FUNC_REG(BroadcastGradientArgs, BroadcastGradientArgsInfer); IMPLEMT_INFERFUNC(PreventGradient, PreventGradientInferFunc) { OP_LOGI(op.GetName().c_str(), "PreventGradient infershape start"); auto op_desc = OpDescUtils::GetOpDescFromOperator(op); auto input_desc_x = op_desc->MutableInputDesc("x"); auto output_desc_y = op_desc->MutableOutputDesc("y"); auto x_dims = input_desc_x->MutableShape().GetDims(); auto x_type = input_desc_x->GetDataType(); std::vector> x_range; input_desc_x->GetShapeRange(x_range); output_desc_y->SetShape(GeShape(x_dims)); output_desc_y->SetOriginShape(GeShape(x_dims)); output_desc_y->SetShapeRange(x_range); output_desc_y->SetDataType(x_type); OP_LOGI(op.GetName().c_str(), "PreventGradient infershape end"); return GRAPH_SUCCESS; } INFER_FUNC_REG(PreventGradient, PreventGradientInferFunc); IMPLEMT_INFERFUNC(StopGradient, StopGradientInferFunc) { OP_LOGI(op.GetName().c_str(), "StopGradient infershape start"); auto op_desc = OpDescUtils::GetOpDescFromOperator(op); auto input_desc_x = op_desc->MutableInputDesc("x"); auto output_desc_y = op_desc->MutableOutputDesc("y"); auto x_dims = input_desc_x->MutableShape().GetDims(); auto x_type = input_desc_x->GetDataType(); std::vector> x_range; input_desc_x->GetShapeRange(x_range); output_desc_y->SetShape(GeShape(x_dims)); output_desc_y->SetOriginShape(GeShape(x_dims)); output_desc_y->SetShapeRange(x_range); output_desc_y->SetShapeRange(x_range); output_desc_y->SetDataType(x_type); OP_LOGI(op.GetName().c_str(), "StopGradient infershape end"); return GRAPH_SUCCESS; } INFER_FUNC_REG(StopGradient, StopGradientInferFunc); IMPLEMT_INFERFUNC(ExpandDims, ExpandDimsInfer) { std::vector dep_inputs = {"axis"}; auto op_desc = OpDescUtils::GetOpDescFromOperator(op); auto node = NodeUtils::GetNodeFromOperator(op); if (node == nullptr) { GE_OP_LOGE(op.GetName().c_str(), "get null node ptr"); return GRAPH_FAILED; } auto x_desc = op_desc->MutableInputDesc("x"); auto axis_desc = op_desc->MutableInputDesc("axis"); auto y_desc = op_desc->MutableOutputDesc("y"); op_desc->SetOpInferDepends(dep_inputs); auto axis_type = axis_desc->GetDataType(); auto x_type = x_desc->GetDataType(); if (axis_type != DT_INT32 && axis_type != DT_INT64) { string reason = "axis dtype[" + std::to_string(axis_type) + "] must int32 or int64"; GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrDtype, reason); GE_OP_LOGE(op.GetName().c_str(), "axis dtype[%d] must int32 or int64", axis_type); return GRAPH_PARAM_INVALID; } bool is_x_unknonwn_rank = x_desc->MutableShape().GetDims() == UNKNOWN_RANK ? true : false; if (is_x_unknonwn_rank) { GE_OP_LOGD("input x shape is unknown rank!"); y_desc->SetUnknownDimNumShape(); y_desc->SetDataType(x_type); y_desc->SetOriginDataType(x_type); return GRAPH_SUCCESS; } int64_t axis_nums = axis_desc->MutableShape().GetShapeSize(); if (axis_nums != 1) { // Shape::GetDims().size() == 0, means it's a scalar, its shape is []. if (!(axis_nums == 0 && axis_desc->MutableShape().GetDims().size() == 0)) { string reason = "axis input must be a tensor with a single value, but [" + std::to_string(axis_nums) + "] nums"; GeInfershapeErrReport(op.GetName(), op.GetOpType(), "axis", reason); GE_OP_LOGE(op.GetName().c_str(), "'axis' input must be a tensor with a single value, but %d nums", axis_nums); return GRAPH_PARAM_INVALID; } } GeTensorPtr tensor_axis = nullptr; graphStatus status = NodeUtils::GetInputConstData(node, "axis", tensor_axis); if (status != GRAPH_SUCCESS) { GE_OP_LOGI(op.GetName().c_str(), "Op get input const data of axis failed"); auto x_shape_size = x_desc->MutableShape().GetDims().size(); std::vector out_dims(x_shape_size + 1, UNKNOWN_DIM); y_desc->SetShape(GeShape(out_dims)); y_desc->SetOriginShape(GeShape(out_dims)); y_desc->SetDataType(x_type); y_desc->SetOriginDataType(x_type); // infer shape range std::vector> x_range; (void)x_desc->GetShapeRange(x_range); if (x_range.empty()) { GE_OP_LOGD(op.GetName().c_str(), "last op does not set shape range!"); return GRAPH_SUCCESS; } if (x_range.size() != x_shape_size) { GE_OP_LOGE(op.GetName().c_str(), "input range size num[%zu] should be same with input shape size[%zu]", x_range.size(), x_shape_size); return GRAPH_FAILED; } int64_t max_range_value = 1; for (const auto &ele : x_range) { if (ele.second > max_range_value) { max_range_value = ele.second; } } std::vector> y_range(x_shape_size + 1, std::pair({1, max_range_value})); y_desc->SetShapeRange(y_range); return GRAPH_SUCCESS; } auto pbuff = tensor_axis->GetData().GetData(); if (pbuff == nullptr) { GE_OP_LOGE(op.GetName().c_str(), "no const data when get data from tensor!"); return GRAPH_FAILED; } int64_t axis; if (axis_type == DT_INT32) { axis = *const_cast(reinterpret_cast(pbuff)); } else if (axis_type == DT_INT64) { axis = *const_cast(reinterpret_cast(pbuff)); } std::vector vec_dim; int32_t dim_num = x_desc->MutableShape().GetDimNum(); if (axis < -1 - dim_num || axis > dim_num) { string reason = "axis[" + std::to_string(axis) + "] is not in [" + std::to_string(-1 - dim_num) + " , " + std::to_string(dim_num) + "]"; GeInfershapeErrReport(op.GetName(), op.GetOpType(), "axis", reason); GE_OP_LOGE(op.GetName().c_str(), "axis[%d] is not in [%d, %d]", axis, -1 - dim_num, dim_num); return GRAPH_PARAM_INVALID; } if (axis < 0) { axis += dim_num + 1; } for (int i = 0; i < dim_num; i++) { vec_dim.push_back(x_desc->MutableShape().GetDim(i)); } vec_dim.emplace(vec_dim.begin() + axis, 1); y_desc->SetShape(GeShape(vec_dim)); y_desc->SetOriginShape(GeShape(vec_dim)); y_desc->SetDataType(x_type); y_desc->SetOriginDataType(x_type); // infer shape range auto x_shape_size = x_desc->MutableShape().GetDims().size(); std::vector> x_range; (void)x_desc->GetShapeRange(x_range); if (x_range.empty()) { GE_OP_LOGD(op.GetName().c_str(), "last op does not set shape range, so break!"); return GRAPH_SUCCESS; } if (x_range.size() != x_shape_size) { GE_OP_LOGE(op.GetName().c_str(), "input range size num[%zu] should be same with input shape size[%zu]", x_range.size(), x_shape_size); return GRAPH_FAILED; } x_range.emplace(x_range.begin() + axis, std::pair{1, 1}); y_desc->SetShapeRange(x_range); return GRAPH_SUCCESS; } INFER_FUNC_REG(ExpandDims, ExpandDimsInfer); template static graphStatus ValidateShape(const GeTensorPtr& tenosr, int64_t& product, int& unknow_index, GeShape& output, Operator& op) { int64_t dim_num = tenosr->MutableTensorDesc().MutableShape().GetDim(0); T* shape_data = const_cast(reinterpret_cast(tenosr->GetData().GetData())); std::vector out_dims = output.GetDims(); if (shape_data == nullptr) { GE_OP_LOGE(op.GetName().c_str(), "truth shape data is invalid"); return GRAPH_PARAM_INVALID; } for (int64_t i = 0; i < dim_num; i++) { if (shape_data[i] == -1) { if (unknow_index != -1) { string reason = "only one dim may be -1, not both dim[ " + std::to_string(unknow_index) + "] and dim[" + std::to_string(i) + "]"; GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShape, reason); GE_OP_LOGE(op.GetName().c_str(), "Only one dim may be -1, not both dim[%lld] and dim[%lld]", unknow_index, i); return GRAPH_PARAM_INVALID; } unknow_index = i; out_dims.push_back(1); } else if (shape_data[i] < 0) { string reason = "Size[" + std::to_string(i) + "] must be non-negative"; GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShape, reason); GE_OP_LOGE(op.GetName().c_str(), "Size[%lld] must be non-negative", i); return GRAPH_PARAM_INVALID; } else { if (shape_data[i] != 0 && product > (INT64_MAX / shape_data[i])) { string reason = "Mul overflow of int64, product[" + std::to_string(product) + "] shape_data[" + std::to_string((int64_t)shape_data[i]) + "]"; GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShape, reason); GE_OP_LOGE(op.GetName().c_str(), "Mul overflow of int64, product[%lld] shape_data[%lld]", product, (int64_t)shape_data[i]); return GRAPH_PARAM_INVALID; } out_dims.push_back(shape_data[i]); product *= shape_data[i]; } } output = GeShape(out_dims); return GRAPH_SUCCESS; } static graphStatus CaffeReshapeInferShape(const vector& dims, const int64_t& axis, const int64_t& num_axes, Operator& op) { GE_OP_LOGI(op.GetName().c_str(), "Reshape infer shape start"); auto op_desc = OpDescUtils::GetOpDescFromOperator(op); auto x_desc = op_desc->MutableInputDesc("x"); auto shape_desc = op_desc->MutableInputDesc("shape"); auto y_desc = op_desc->MutableOutputDesc("y"); auto x_dims = x_desc->GetShape().GetDims(); auto data_type = x_desc->GetDataType(); if (x_dims == UNKNOWN_RANK || dims == UNKNOWN_RANK) { GE_OP_LOGD("Input data is unknown_rank"); y_desc->SetShape(GeShape(UNKNOWN_RANK)); y_desc->SetOriginShape(GeShape(UNKNOWN_RANK)); y_desc->SetDataType(data_type); return GRAPH_SUCCESS; } if (x_dims == UNKNOWN_SHAPE) { GE_OP_LOGD("Input data is unknown_shape."); y_desc->SetShape(GeShape(UNKNOWN_SHAPE)); y_desc->SetOriginShape(GeShape(UNKNOWN_SHAPE)); y_desc->SetDataType(data_type); return GRAPH_SUCCESS; } int64_t inferred_axis = -1; int64_t constant_count = 1; vector copy_axes; // parsing dims for (size_t i = 0; i < dims.size(); ++i) { const int64_t shape_dim_i = dims[i]; if (shape_dim_i == 0) { copy_axes.push_back(i); } else if (shape_dim_i == -1) { if (inferred_axis != -1) { string reason = "only one dim may be -1, not both dim[ " + std::to_string(inferred_axis) + "] and dim[" + std::to_string(i) + "]"; GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrShape, reason); GE_OP_LOGE(op.GetName().c_str(), "Only one dim may be -1, not both dim[%ld] and dim[%zu]", inferred_axis, i); return GRAPH_PARAM_INVALID; } inferred_axis = i; } else { constant_count *= shape_dim_i; } } // parsing start axis and end axis Shape bottom_shape = op.GetInputDesc("x").GetShape(); const int64_t bottom_shape_size = bottom_shape.GetDims().size(); int64_t start_axis = 0; if (axis >= 0) { start_axis = axis; } else { start_axis = axis + bottom_shape_size + 1; } if (start_axis < 0 || start_axis > bottom_shape_size) { int64_t range = -1 - bottom_shape_size; // if axis >=0 , axis range [0, bottom_shape_size], else axis < 0, axis range [-1 - bottom_shape_size, -1] // axis range [-1 - bottom_shape_size, bottom_shape_size] string reason = "axis's range is not in [" + std::to_string(range) + ", " + std::to_string(bottom_shape_size) + "]"; GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrAxis, reason); GE_OP_LOGE(op.GetName().c_str(), "reshape param axis is invalid, axis's range is not in [%ld, %ld]", range, bottom_shape_size); return GRAPH_PARAM_INVALID; } int64_t end_axis = 0; if (num_axes < -1) { GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrNumAxes, "it must be greater than or equal to -1"); GE_OP_LOGE(op.GetName().c_str(), "reshape param num_axes is invalid, it must be greater than or equal to -1"); return GRAPH_PARAM_INVALID; } else if (num_axes == -1) { end_axis = bottom_shape_size; } else { end_axis = start_axis + num_axes; } if (end_axis > bottom_shape_size) { GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrNumAxes, "num_axes must be less than or equal to " + std::to_string((bottom_shape_size - start_axis))); GE_OP_LOGE(op.GetName().c_str(), "reshape param num_axes is invalid, it must be less than or equal to %ld", bottom_shape_size - start_axis); return GRAPH_PARAM_INVALID; } // construct top shape vector bottom_dims = bottom_shape.GetDims(); const int64_t num_axes_replaced = end_axis - start_axis; const int64_t num_axes_retained = bottom_shape_size - num_axes_replaced; const int64_t num_new_axes = dims.size(); vector top_shape(num_axes_retained + num_new_axes); size_t top_shape_index = 0; for (int64_t i = 0; i < start_axis; ++i) { top_shape[top_shape_index] = bottom_dims[i]; top_shape_index++; } for (int64_t i = 0; i < num_new_axes; ++i) { top_shape[top_shape_index] = dims[i]; top_shape_index++; } for (int64_t i = end_axis; i < bottom_shape_size; ++i) { top_shape[top_shape_index] = bottom_dims[i]; top_shape_index++; } if (top_shape_index != top_shape.size()) { GeInfershapeErrReport(op.GetName(), op.GetOpType(), "infer shape size", "top_shape_index not equal to top_shape size"); GE_OP_LOGE(op.GetName().c_str(), "reshape infer shape faied, top_shape_index not equal to top_shape size"); return GRAPH_FAILED; } // product of [0,start_axis) + [end_axis, bottom_shape_size) int64_t explicit_count = constant_count; int64_t bottom_count_all = 1; for (int i = 0; i < bottom_shape_size; ++i) { bottom_count_all *= bottom_dims[i]; if (i < start_axis || i >= end_axis) { explicit_count *= bottom_dims[i]; } } // parsing dim 0 and -1 for (size_t i = 0; i < copy_axes.size(); ++i) { const int64_t copy_axis_index = copy_axes[i]; if ((start_axis + copy_axis_index) >= bottom_shape_size) { GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrShape, "there was no corresponding bottom axis for dim 0"); GE_OP_LOGE(op.GetName().c_str(), "there was no corresponding bottom axis for dim 0."); return GRAPH_FAILED; } top_shape[start_axis + copy_axis_index] = bottom_dims[start_axis + copy_axis_index]; explicit_count *= bottom_dims[start_axis + copy_axis_index]; } if (inferred_axis >= 0) { if (bottom_count_all % explicit_count != 0) { string reason = "The shape of the input cannot be divisible by the product " "of the specified dimensions, the product is [" + std::to_string(explicit_count) + "]"; GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrShape, reason); GE_OP_LOGE( op.GetName().c_str(), "The shape of the input cannot be divisible by the product of the specified dimensions, the product is %ld", explicit_count); return GRAPH_FAILED; } const int64_t inferred_dim = bottom_count_all / explicit_count; top_shape[start_axis + inferred_axis] = inferred_dim; } int64_t top_count_all = 1; for (size_t i = 0; i < top_shape.size(); ++i) { top_count_all *= top_shape[i]; } if (top_count_all != bottom_count_all) { string reason = "output tensor count [ " + std::to_string(top_count_all) + "] does not match input tensor count [" + std::to_string(bottom_count_all) + "]."; GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrShape, reason); GE_OP_LOGE(op.GetName().c_str(), "output tensor count %lld does not match input tensor count %ld.", top_count_all, bottom_count_all); return GRAPH_FAILED; } // updata output shape info TensorDesc td = op.GetOutputDesc("y"); td.SetShape(Shape(top_shape)); td.SetDataType(op.GetInputDesc("x").GetDataType()); (void)op.UpdateOutputDesc("y", td); return GRAPH_SUCCESS; } bool IsEmptyTensor(GeTensorDescPtr tensor_desc) { bool is_empty = false; for (const auto &dim : tensor_desc->MutableShape().GetDims()) { if (dim == 0) { is_empty = true; break; } } return is_empty; } template graphStatus GetOutShapeFromTensor(OpDescPtr op_desc, GeTensorPtr tensor, std::vector &v_out) { auto shape_desc = tensor->MutableTensorDesc(); T* shape_data = const_cast(reinterpret_cast(tensor->GetData().GetData())); if (shape_data == nullptr) { GE_OP_LOGE(op_desc->GetName().c_str(), "const shape data is invalid"); return GRAPH_PARAM_INVALID; } for (int i = 0; i < shape_desc.MutableShape().GetDim(0); i++) { v_out.emplace_back(shape_data[i]); } return GRAPH_SUCCESS; } graphStatus EmptyTensorProcess(const Operator &op, const GeTensorDesc &x_desc, const GeTensorPtr &shape_tensor, GeTensorDesc &out_desc) { GE_OP_LOGD("Start empty-tensor preprocess!"); auto op_desc = OpDescUtils::GetOpDescFromOperator(op); auto shape_type = op_desc->MutableInputDesc("shape")->GetDataType(); std::vector shape_shape; graphStatus ret = GRAPH_SUCCESS; if (shape_type == DT_INT32) { ret = GetOutShapeFromTensor(op_desc, shape_tensor, shape_shape); } else if (shape_type == DT_INT64) { ret = GetOutShapeFromTensor(op_desc, shape_tensor, shape_shape); } else { GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShapeDtype, "Dim type must be DT_INT32 or DT_INT64."); GE_OP_LOGE(op.GetName().c_str(), "Dim type must be DT_INT32 or DT_INT64."); return GRAPH_PARAM_INVALID; } if (ret != GRAPH_SUCCESS) { return ret; } GE_OP_LOGD(op.GetName().c_str(), "x shape: %s shape shape: %s", x_desc.GetShape().ToString().c_str(), GeShape(shape_shape).ToString().c_str()); int64_t num_of_neg_1 = 0; int64_t product = 1; for (auto &dim : shape_shape) { if (dim == -1) { // -1 stand for highest dim here num_of_neg_1++; dim = 0; } product *= dim; } // check valid if ((num_of_neg_1 == 0 && product == 0) || (num_of_neg_1 == 1)) { out_desc.SetShape(GeShape(shape_shape)); out_desc.SetOriginShape(GeShape(shape_shape)); out_desc.SetDataType(x_desc.GetDataType()); out_desc.SetOriginDataType(x_desc.GetDataType()); return GRAPH_SUCCESS; } GE_OP_LOGE(op.GetName().c_str(), "Param is invalid!.Please check!Input shape contains -1 num is %ld, product is %ld", num_of_neg_1, product); return GRAPH_FAILED; } IMPLEMT_INFERFUNC(Reshape, ReshapeInfer) { bool zero_flag = false; vector attr_dims; if (op.GetAttr("shape", attr_dims) == GRAPH_SUCCESS) { for (size_t i = 0; i < attr_dims.size(); ++i) { if (attr_dims[i] == 0) { zero_flag = true; break; } } } std::vector dep_inputs = {"shape"}; auto op_desc = OpDescUtils::GetOpDescFromOperator(op); op_desc->SetOpInferDepends(dep_inputs); auto x_desc = op_desc->MutableInputDesc("x"); auto y_desc = op_desc->MutableOutputDesc("y"); int64_t attr_axis = 0; op.GetAttr("axis", attr_axis); int64_t attr_num_axes = -1; op.GetAttr("num_axes", attr_num_axes); if (attr_axis != 0 || attr_num_axes != -1 || zero_flag) { GE_OP_LOGI(op.GetName().c_str(), "Get reshape_param successfully, shape size is %u, axis is %ld, num_axes is %ld", attr_dims.size(), attr_axis, attr_num_axes); graphStatus caffe_reshape_ret = CaffeReshapeInferShape(attr_dims, attr_axis, attr_num_axes, op); return caffe_reshape_ret; } GE_OP_LOGI(op.GetName().c_str(), "Reshape infer shape start"); GeTensorPtr tensor = nullptr; auto node = NodeUtils::GetNodeFromOperator(op); if (node == nullptr) { OP_LOGE(op.GetName().c_str(), "get null node ptr!"); return GRAPH_PARAM_INVALID; } graphStatus state = NodeUtils::GetInputConstData(node, "shape", tensor); if (state != GRAPH_SUCCESS) { GE_OP_LOGW(op.GetName().c_str(), "Op get input const data of shape failed"); auto input_shape = op_desc->MutableInputDesc("x")->MutableShape(); auto shape_input_desc = op_desc->MutableInputDesc("shape"); auto shape_shape = shape_input_desc->MutableShape(); // because shape's value stand for output shape, so it should be smaller than 1 dim auto shape_rank = shape_shape.GetDims().size(); if (shape_rank > 1) { string reason = "shape dim[" + std::to_string(shape_shape.GetDims().size()) + "] should be smaller or equal than 1"; GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShape, reason); GE_OP_LOGE(op.GetName().c_str(), "shape dim[%zu] should be smaller or equal than 1", shape_shape.GetDims().size()); return GRAPH_PARAM_INVALID; } if (shape_shape.GetDims() != UNKNOWN_RANK && shape_shape.GetDims() != UNKNOWN_SHAPE) { auto x_type = op_desc->MutableInputDesc("x")->GetDataType(); auto td = op_desc->MutableOutputDesc("y"); int64_t rank = (shape_rank == 0) ? 0 : shape_shape.GetDims().at(0); td->SetShape(GeShape(std::vector(rank, UNKNOWN_DIM))); td->SetOriginShape(GeShape(std::vector(rank, UNKNOWN_DIM))); td->SetDataType(x_type); // calc shape range if (input_shape.GetDims() == UNKNOWN_RANK) { GE_OP_LOGD("input x is unknown rank!no way to set shape range!"); return GRAPH_SUCCESS; } auto input_shape_size = input_shape.GetShapeSize(); int64_t range_max = 1; if (input_shape_size <= 0) { // unknown dim , by input shape range calc output range std::vector> x_range; (void)op_desc->MutableInputDesc("x")->GetShapeRange(x_range); if (x_range.empty()) { return GRAPH_SUCCESS; } ge::array_ops::ReshapeRangeInfer(op, x_range, range_max); } else { // known dim, shape size as range_max range_max = input_shape_size; } range_max = (range_max > INT32_MAX) ? INT32_MAX : range_max; std::vector> y_range(rank, {1, range_max}); td->SetShapeRange(y_range); return GRAPH_SUCCESS; } auto x_type = op_desc->MutableInputDesc("x")->GetDataType(); auto td = op_desc->MutableOutputDesc("y"); td->SetShape(GeShape({-2})); td->SetOriginShape(GeShape({-2})); td->SetDataType(x_type); return GRAPH_SUCCESS; } if (IsEmptyTensor(x_desc)) { return EmptyTensorProcess(op, *x_desc, tensor, *y_desc); } std::vector> x_range; std::vector> y_range; op_desc->MutableInputDesc("x")->GetShapeRange(x_range); int64_t product = 1; int unknow_index = -1; GeShape output_shape; DataType shape_type = op_desc->MutableInputDesc("shape")->GetDataType(); int64_t shape_size = op_desc->MutableInputDesc("shape")->MutableShape().GetShapeSize(); graphStatus ret = GRAPH_SUCCESS; if (shape_type == DT_INT32) { ret = ValidateShape(tensor, product, unknow_index, output_shape, op); } else if (shape_type == DT_INT64) { ret = ValidateShape(tensor, product, unknow_index, output_shape, op); } else if (shape_size > 0) { GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShapeDtype, "Dim type must be DT_INT32 or DT_INT64."); GE_OP_LOGE(op.GetName().c_str(), "Dim type must be DT_INT32 or DT_INT64."); return GRAPH_PARAM_INVALID; } if (ret != GRAPH_SUCCESS) { GE_OP_LOGE(op.GetName().c_str(), "ValidateShape failed, ret: %d", ret); return ret; } auto input_shape = op_desc->MutableInputDesc("x")->MutableShape(); int64_t input_size = input_shape.GetShapeSize(); // If input tensor is scalar,then input_size will return 0, assign to 1, which means convert scalar to vector. if (input_size == 0 && output_shape.GetShapeSize() == 1) { input_size = 1; } if (unknow_index != -1) { if (product <= 0) { GE_OP_LOGE(op.GetName().c_str(), "Reshape Op can't infer an empty tensor"); return GRAPH_PARAM_INVALID; } if (input_shape.GetShapeSize() < 0) { GE_OP_LOGI("input x and input shape is all unknown!"); auto td = op_desc->MutableOutputDesc("y"); output_shape.SetDim(unknow_index, -1); td->SetOriginDataType(op_desc->MutableInputDesc("x")->GetDataType()); td->SetShape(output_shape); td->SetOriginShape(output_shape); td->SetDataType(op_desc->MutableInputDesc("x")->GetDataType()); auto max_input_dims = 1; // If last op does not set shape range ,do not set shape range if (x_range.empty()) { GE_OP_LOGI(op.GetName().c_str(), "input x doesnot have shape range!"); } else { // If last op have already set shape range, try best to infer shape range ge::array_ops::ReshapeRangeInfer(op, x_range, y_range, output_shape); } td->SetShapeRange(y_range); return GRAPH_SUCCESS; } int64_t missing = input_size / product; if (product * missing != input_size) { string reason = "The shape of the input cannot be divisible from [" + std::to_string(product) + "]"; GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShape, reason); GE_OP_LOGE(op.GetName().c_str(), "The shape of the input cannot be divisible from %lld", product); return GRAPH_PARAM_INVALID; } output_shape.SetDim(unknow_index, missing); } auto dims = input_shape.GetDims(); bool is_exist_unknown_shape = false; for (auto ele : dims) { is_exist_unknown_shape = (ele == -1) ? true : false; if (!is_exist_unknown_shape) { continue; } } if (SetScalarOutputDesc(string("x"), string("y"), op_desc, output_shape)) { return GRAPH_SUCCESS; } // Shape_size is 0, means shape tensor value is [], implying convert vector/scalar to scalar bool convert_to_scalar = (shape_size == 0 && (input_size == 1 || (input_size == 0 && input_shape.GetDims().size() == 0))); // Output_shape.GetShapeSize() > 0 and input_size <= 0 for dynamic shape bool shape_check_ok = ((input_size == output_shape.GetShapeSize()) || ((output_shape.GetShapeSize() > 0) && (input_size <= 0)) || (is_exist_unknown_shape && (output_shape.GetShapeSize() > 0))); if (!shape_check_ok && !convert_to_scalar) { string reason = "Shape size is [" + std::to_string(shape_size) + "], input tensor with [" + std::to_string(input_size) + "] values, is input dynamic shape [" + std::to_string(is_exist_unknown_shape) + "], but requested shape has [" + std::to_string(output_shape.GetShapeSize()) + "] values"; GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShape, reason); GE_OP_LOGE(op.GetName().c_str(), "Shape size is %lld, input tensor with %lld values, is input dynamic shape :%d, but \ requested shape has %lld values", shape_size, input_size, is_exist_unknown_shape, output_shape.GetShapeSize()); return GRAPH_PARAM_INVALID; } auto td = op_desc->MutableOutputDesc("y"); td->SetShape(output_shape); td->SetOriginShape(output_shape); td->SetDataType(op_desc->MutableInputDesc("x")->GetDataType()); td->SetOriginDataType(op_desc->MutableInputDesc("x")->GetDataType()); return GRAPH_SUCCESS; } INFER_FUNC_REG(Reshape, ReshapeInfer); IMPLEMT_INFERFORMAT_FUNC(Reshape, ReshapeInferFormat) { GE_OP_LOGI(op.GetName().c_str(), "Reshape infer format start"); auto op_desc = OpDescUtils::GetOpDescFromOperator(op); auto input_descs = op_desc->GetAllInputsDescPtr(); auto output_descs = op_desc->GetAllOutputsDescPtr(); for (const auto& input_desc : input_descs) { if (input_desc->GetShape().GetDimNum() < 4) { input_desc->SetOriginFormat(FORMAT_ND); input_desc->SetFormat(FORMAT_ND); } } for (const auto& output_desc : output_descs) { if (output_desc->GetShape().GetDimNum() < 4) { output_desc->SetOriginFormat(FORMAT_ND); output_desc->SetFormat(FORMAT_ND); } } (void)op_desc->DefaultInferFormat(); for (const auto& input_desc : input_descs) { if (input_desc->GetShape().GetDimNum() < 4) { input_desc->SetOriginFormat(FORMAT_ND); input_desc->SetFormat(FORMAT_ND); } } for (const auto& output_desc : output_descs) { if (output_desc->GetShape().GetDimNum() < 4) { output_desc->SetOriginFormat(FORMAT_ND); output_desc->SetFormat(FORMAT_ND); } } return GRAPH_SUCCESS; } INFER_FORMAT_FUNC_REG(Reshape, ReshapeInferFormat); IMPLEMT_VERIFIER(Squeeze, SqueezeVerify) { GE_OP_LOGD("Enter SqueezeVerify"); auto op_desc = OpDescUtils::GetOpDescFromOperator(op); auto axis = op.get_attr_axis(); auto input_desc_x = op_desc->MutableInputDesc("x"); auto xShape = input_desc_x->MutableShape().GetDims(); std::vector> x_range; input_desc_x->GetShapeRange(x_range); if ((xShape != UNKNOWN_RANK) && (!x_range.empty()) && (x_range.size() != xShape.size())) { // if it has set shape range, it should be same with input dim num GE_OP_LOGE("x_shape_range num [%zu] does not match x dims_num [%zu]", x_range.size(), xShape.size()); return GRAPH_FAILED; } auto node = NodeUtils::GetNodeFromOperator(op); if (node == nullptr) { GE_OP_LOGE("node pointer is nullptr"); return GRAPH_FAILED; } bool is_unknow = false; auto status = NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknow); if (status != GRAPH_SUCCESS) { GE_OP_LOGE("Get node unknown shape status failed!"); return GRAPH_FAILED; } if (is_unknow) { // when input is unknown , no way to check param "axis" whether valid. Do check when running return GRAPH_SUCCESS; } if (axis.size() > 0) { for (unsigned i = 0; i < axis.size(); i++) { if (axis[i] < 0) axis[i] += xShape.size(); bool flag = (0 <= axis[i]) && (axis[i] < static_cast(xShape.size())); if (!flag) { GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrAxis, "axis value is out of range of [-rank(input), rank(input))."); GE_OP_LOGE(op.GetName().c_str(), "axis value is out of range of [-rank(input), rank(input))."); return GRAPH_FAILED; } if (!(xShape[axis[i]] == 1)) { GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShape, "input shape has dim not equal to 1."); GE_OP_LOGE(op.GetName().c_str(), "input shape has dim not equal to 1."); return GRAPH_FAILED; } } } GE_OP_LOGD("SqueezeVerify Success!"); return GRAPH_SUCCESS; } VERIFY_FUNC_REG(Squeeze, SqueezeVerify); IMPLEMT_INFERFUNC(Squeeze, SqueezeInfer) { GE_OP_LOGD("Enter Squeeze Infershape!"); auto op_desc = OpDescUtils::GetOpDescFromOperator(op); auto axis = op.get_attr_axis(); auto input_desc_x = op_desc->MutableInputDesc("x"); auto output_desc_y = op_desc->MutableOutputDesc("y"); auto input_shape = input_desc_x->MutableShape(); int64_t dim_size = input_shape.GetDimNum(); auto x_data_type = input_desc_x->GetDataType(); int32_t axis_num = axis.size(); // process -2(UnknownRank) if (input_shape.GetDims() == UNKNOWN_RANK) { GE_OP_LOGD("Input x shape is -2!"); output_desc_y->SetShape(GeShape(UNKNOWN_RANK)); output_desc_y->SetOriginShape(GeShape(UNKNOWN_RANK)); output_desc_y->SetDataType(x_data_type); return GRAPH_SUCCESS; } std::vector> x_range; std::vector> y_range; input_desc_x->GetShapeRange(x_range); std::unordered_set squeeze_dims; for (int32_t i = 0; i < axis_num; ++i) { int32_t dim = axis[i]; if (dim < -dim_size || dim >= dim_size) { string reason = "Tried to squeeze dim index[" + std::to_string(dim) + "] for tensor with [" + std::to_string(dim_size) + "] dimensions"; GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrAxis, reason); GE_OP_LOGE(op.GetName().c_str(), "Tried to squeeze dim index[%d] for tensor with [%lld] dimensions", dim, dim_size); return GRAPH_FAILED; } if (dim < 0) { dim = dim_size + dim; } squeeze_dims.insert(dim); } vector out_shape; for (int i = 0; i < dim_size; i++) { auto exist_dim = input_shape.GetDim(i); // If squeeze_set is non-empty, only squeeze those dimensions. if (!squeeze_dims.empty()) { if (squeeze_dims.count(i) > 0) { // If dim is -1 and been pointed by axis , do think -1 is 1.because no method to do verify if (exist_dim != 1 && exist_dim != UNKNOWN_DIM) { string reason = "Can not squeeze dim[" + std::to_string(i) + "], expected a dimension of 1, got [" + std::to_string(exist_dim) + "]"; GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShape, reason); GE_OP_LOGE(op.GetName().c_str(), "Can not squeeze dim[%d], expected a dimension of 1, got %lld", i, exist_dim); return GRAPH_FAILED; } } else { out_shape.emplace_back(exist_dim); // after verified, it has ensure x_range ele num is same with dims num if (!x_range.empty()) { y_range.emplace_back(x_range[i]); } } } else { // Copy over all non-1-length dimensions. // here no methed to ensure which -1 is 1, so do warning if (exist_dim != 1) { if (exist_dim == -1) { GE_OP_LOGW("the [%d] dim is -1, it will not execute squeeze on it! maybe influence result", exist_dim); } out_shape.emplace_back(exist_dim); // after verified, it has ensure x_range ele num is same with dims num if (!x_range.empty()) { y_range.emplace_back(x_range[i]); } } } } output_desc_y->SetShape(GeShape(out_shape)); output_desc_y->SetOriginShape(GeShape(out_shape)); output_desc_y->SetDataType(x_data_type); if (!y_range.empty()) { output_desc_y->SetShapeRange(y_range); } return GRAPH_SUCCESS; } INFER_FUNC_REG(Squeeze, SqueezeInfer); IMPLEMT_INFERFUNC(Unsqueeze, UnsqueezeInfer) { auto axis_arr = op.get_attr_axes(); auto axis_nums = axis_arr.size(); if (axis_nums <= 0) { string reason = "Axis_nums[" + std::to_string(axis_nums) + "] must be greater than 0"; GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrAxis, reason); GE_OP_LOGE(op.GetName().c_str(), "Axis_nums[%zu] must be greater than 0", axis_nums); return GRAPH_PARAM_INVALID; } std::unordered_set values(axis_arr.begin(), axis_arr.end()); if (values.size() != axis_arr.size()) { string reason = "Axis attribute must not contain any duplicates."; GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrAxis, reason); GE_OP_LOGE(op.GetName().c_str(), "Axis attribute must not contain any duplicates."); return GRAPH_PARAM_INVALID; } Shape input_shape = op.get_input_desc_x().GetShape(); int64_t dim_num = input_shape.GetDimNum() + axis_nums; std::vector vec_dim(dim_num, 0); for (size_t i = 0; i < axis_nums; i++) { int64_t axis = axis_arr[i]; if ((axis < -dim_num) || (axis > (dim_num - 1))) { string reason = "axis[" + std::to_string(axis_nums) + "]'s range is not in [" + std::to_string(-dim_num) + ", " + std::to_string(dim_num - 1) + "]"; GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrAxis, reason); GE_OP_LOGE(op.GetName().c_str(), "Axis %ld not in [%ld, %ld]", axis, -dim_num, dim_num); return GRAPH_PARAM_INVALID; } if (axis < 0) { axis += dim_num; } vec_dim.at(axis) = 1; } int64_t index = 0; for (int64_t i = 0; i < dim_num; i++) { if (vec_dim.at(i) != 1) { vec_dim.at(i) = input_shape.GetDim(index); index++; } } TensorDesc td = op.get_output_desc_y(); td.SetShape(Shape(vec_dim)); td.SetDataType(op.get_input_desc_x().GetDataType()); (void)op.update_output_desc_y(td); return GRAPH_SUCCESS; } INFER_FUNC_REG(Unsqueeze, UnsqueezeInfer); IMPLEMT_INFERFUNC(Rank, RankInfer) { OP_LOGI(op.GetName().c_str(), "Rank infershape start"); auto op_desc = OpDescUtils::GetOpDescFromOperator(op); auto output_desc_y = op_desc->MutableOutputDesc("y"); std::vector oShapeVector; output_desc_y->SetShape(GeShape(oShapeVector)); output_desc_y->SetOriginShape(GeShape(oShapeVector)); output_desc_y->SetDataType(DT_INT32); OP_LOGI(op.GetName().c_str(), "Rank infershape end"); return GRAPH_SUCCESS; } INFER_FUNC_REG(Rank, RankInfer); IMPLEMT_INFERFUNC(Size, SizeInfer) { OP_LOGI(op.GetName().c_str(), "Size infershape start"); auto op_desc = OpDescUtils::GetOpDescFromOperator(op); auto output_desc_y = op_desc->MutableOutputDesc("y"); std::vector oShapeVector; output_desc_y->SetShape(GeShape(oShapeVector)); DataType out_type = DT_INT32; GeAttrValue out_type_value; op_desc->GetAttr("dtype", out_type_value); out_type_value.GetValue(out_type); output_desc_y->SetDataType(out_type); OP_LOGI(op.GetName().c_str(), "Size infershape end"); return GRAPH_SUCCESS; } INFER_FUNC_REG(Size, SizeInfer); COMMON_INFER_FUNC_REG(Data, ELMTWISE_INFER_SHAPEANDTYPE("x", "y")); COMMON_INFER_FUNC_REG(PlaceHolder, ELMTWISE_INFER_SHAPEANDTYPE("x", "y")); COMMON_INFER_FUNC_REG(End, ELMTWISE_INFER_SHAPEANDTYPE("x", "y")); IMPLEMT_INFERFUNC(PlaceholderWithDefault, PlaceholderWithDefaultInfer) { TensorDesc input_desc = op.GetInputDesc("x"); auto dims = input_desc.GetShape().GetDims(); auto data_type = input_desc.GetDataType(); TensorDesc output_desc = op.GetOutputDesc("y"); output_desc.SetDataType(ge::DataType(data_type)); output_desc.SetShape(Shape(dims)); (void)op.UpdateOutputDesc("y", output_desc); return GRAPH_SUCCESS; } INFER_FUNC_REG(PlaceholderWithDefault, PlaceholderWithDefaultInfer); IMPLEMT_INFERFUNC(Shape, ShapeInfer) { auto op_desc = OpDescUtils::GetOpDescFromOperator(op); auto td = op_desc->MutableOutputDesc("y"); auto input_dims = op_desc->MutableInputDesc("x")->MutableShape().GetDims(); if (input_dims == UNKNOWN_RANK) { td->SetShape(ge::GeShape(UNKNOWN_SHAPE)); td->SetOriginShape(ge::GeShape(UNKNOWN_SHAPE)); td->SetShapeRange(std::vector>{{1,kMaxDimNum}}); } else { int64_t size = static_cast(input_dims.size()); std::vector size_v{size}; td->SetShape(ge::GeShape(size_v)); td->SetOriginShape(ge::GeShape(size_v)); } uint32_t out_type = DT_INT32; (void)op.GetAttr("dtype", out_type); td->SetDataType((DataType)out_type); std::vector> inRange; op_desc->MutableInputDesc("x")->GetShapeRange(inRange); if (!inRange.empty()) { std::vector pre_op_range; pre_op_range.resize(2*inRange.size()); for (int i = 0; i < pre_op_range.size(); i = i + 2) { pre_op_range[i] = inRange[i/2].first; pre_op_range[i + 1] = inRange[i/2].second; } ge::AttrUtils::SetListInt(*td, kPreOpInputShapeRange, pre_op_range); OP_LOGD(op.GetName().c_str(), "Shape op set pre_op_range success"); } return GRAPH_SUCCESS; } INFER_FUNC_REG(Shape, ShapeInfer); IMPLEMT_INFERFUNC(ShapeN, ShapeNInfer) { auto op_desc = OpDescUtils::GetOpDescFromOperator(op); for (size_t i = 0; i < op.GetInputsSize(); i++) { auto td = op_desc->MutableOutputDesc(i); auto input_dims = op_desc->MutableInputDesc(i)->MutableShape().GetDims(); if (input_dims == UNKNOWN_RANK) { td->SetShape(ge::GeShape(UNKNOWN_SHAPE)); td->SetOriginShape(ge::GeShape(UNKNOWN_SHAPE)); td->SetShapeRange(std::vector>{{1,kMaxDimNum}}); } else { int64_t size = static_cast(input_dims.size()); GE_OP_LOGD(op.GetName().c_str(), "output value %ld", size); std::vector size_v{size}; td->SetShape(ge::GeShape(size_v)); td->SetOriginShape(ge::GeShape(size_v)); } uint32_t out_type = DT_INT32; (void)op.GetAttr("dtype", out_type); td->SetDataType((DataType)out_type); std::vector> inRange; op_desc->MutableInputDesc(i)->GetShapeRange(inRange); if (!inRange.empty()) { std::vector pre_op_range; pre_op_range.resize(2*inRange.size()); for (int i = 0; i < pre_op_range.size(); i = i + 2) { pre_op_range[i] = inRange[i/2].first; pre_op_range[i + 1] = inRange[i/2].second; } ge::AttrUtils::SetListInt(*td, kPreOpInputShapeRange, pre_op_range); OP_LOGD(op.GetName().c_str(), "ShapeN op set pre_op_range success"); } } return GRAPH_SUCCESS; } INFER_FUNC_REG(ShapeN, ShapeNInfer); IMPLEMT_INFERFUNC(IdentityN, IdentityNInfer) { OP_LOGI(op.GetName().c_str(), "IdentityN infershape start"); auto op_desc = OpDescUtils::GetOpDescFromOperator(op); for (size_t i = 0; i < op.GetInputsSize(); i++) { auto input_desc = op_desc->MutableInputDesc(i); auto input_dims = input_desc->MutableShape().GetDims(); auto output_desc = op_desc->MutableOutputDesc(i); auto intput_dtype = input_desc->GetDataType(); std::vector> input_range; input_desc->GetShapeRange(input_range); output_desc->SetShape(GeShape(input_dims)); output_desc->SetOriginShape(GeShape(input_dims)); output_desc->SetDataType(intput_dtype); output_desc->SetShapeRange(input_range); } OP_LOGI(op.GetName().c_str(), "IdentityN infershape end"); return GRAPH_SUCCESS; } INFER_FUNC_REG(IdentityN, IdentityNInfer); IMPLEMT_INFERFUNC(Identity, IdentityInfer) { OP_LOGI(op.GetName().c_str(), "Identity infershape start"); auto op_desc = OpDescUtils::GetOpDescFromOperator(op); auto input_desc_x = op_desc->MutableInputDesc("x"); auto output_desc_y = op_desc->MutableOutputDesc("y"); std::vector vec_dim; vec_dim = input_desc_x->MutableShape().GetDims(); std::vector> x_range; input_desc_x->GetShapeRange(x_range); DataType data_type = input_desc_x->GetDataType(); output_desc_y->SetDataType(data_type); output_desc_y->SetShape(GeShape(vec_dim)); output_desc_y->SetOriginShape(GeShape(vec_dim)); output_desc_y->SetShapeRange(x_range); OP_LOGI(op.GetName().c_str(), "Identity infershape end"); return GRAPH_SUCCESS; } INFER_FUNC_REG(Identity, IdentityInfer); IMPLEMT_INFERFUNC(ReadVariableOp, ReadVariableOpInfer) { TensorDesc input_desc = op.GetInputDesc("x"); (void)op.UpdateOutputDesc("y", input_desc); return GRAPH_SUCCESS; } INFER_FUNC_REG(ReadVariableOp, ReadVariableOpInfer); template static void CaclDims(const Tensor& data, std::vector& vec_dim) { int32_t size = data.GetSize() / sizeof(T); for (int32_t i = 0; i < size; i++) { T dim = *((T*)data.GetData() + i); if (dim != 0) { vec_dim.push_back(dim); } else { vec_dim.clear(); break; } } } template static void CaclDims(const GeTensorPtr& data, std::vector& vec_dim) { int32_t size = data->GetData().GetSize() / sizeof(T); for (int32_t i = 0; i < size; i++) { void* data_ptr = (void*)data->GetData().GetData(); if (data_ptr == nullptr) { return; } T dim = *((T*)data_ptr + i); if (dim != 0) { vec_dim.push_back(dim); } else { vec_dim.clear(); break; } } } IMPLEMT_INFERFUNC(Empty, EmptyInfer) { OP_LOGI(op.GetName().c_str(), "Empty infershape start"); auto op_desc = OpDescUtils::GetOpDescFromOperator(op); std::vector dep_inputs = {"shape"}; op_desc->SetOpInferDepends(dep_inputs); auto input_desc_shape = op_desc->MutableInputDesc("shape"); auto output_desc_y = op_desc->MutableOutputDesc("y"); auto dtype = op.get_attr_dtype(); std::vector> shape_range; std::vector> y_range; input_desc_shape->GetShapeRange(shape_range); DataType data_type = input_desc_shape->GetDataType(); std::vector vec_dim; if (data_type == DT_INT32) { vec_dim = input_desc_shape->MutableShape().GetDims(); } else { GeInfershapeErrReport(op.GetName(), op.GetOpType(), "dtype", "Empty only support shape type 'DT_INT32'"); GE_OP_LOGE(op.GetName().c_str(), "Empty only support shape type 'DT_INT32'"); return GRAPH_PARAM_INVALID; } if (vec_dim == UNKNOWN_RANK) { GE_OP_LOGD(op.GetName().c_str(), "all inputs are unknown rank!"); output_desc_y->SetShape(GeShape(UNKNOWN_SHAPE)); output_desc_y->SetOriginShape(GeShape(UNKNOWN_SHAPE)); output_desc_y->SetDataType((DataType)dtype); return GRAPH_SUCCESS; } if (vec_dim == UNKNOWN_SHAPE) { GE_OP_LOGD(op.GetName().c_str(), "shape is unknown shape!"); std::pair pair({1, shape_range.size()}); y_range.emplace_back(pair); output_desc_y->SetShape(GeShape(UNKNOWN_SHAPE)); output_desc_y->SetOriginShape(GeShape(UNKNOWN_SHAPE)); output_desc_y->SetDataType((DataType)dtype); output_desc_y->SetShapeRange(y_range); return GRAPH_SUCCESS; } auto node = NodeUtils::GetNodeFromOperator(op); if (node == nullptr) { OP_LOGE(op.GetName().c_str(), "Get null node ptr."); return GRAPH_PARAM_INVALID; } GeTensorPtr shape_data; std::vector shape_dims; auto result = NodeUtils::GetInputConstData(node, "shape", shape_data); if(result == GRAPH_SUCCESS) { DataType data_type = shape_data->GetTensorDesc().GetDataType(); if (data_type == DT_INT32) { CaclDims(shape_data,shape_dims); } else if (data_type == DT_INT64) { CaclDims(shape_data, shape_dims); } OP_LOGD(op.GetName().c_str(), "Get input const data success."); std::pair pair({1,shape_range.size()}); y_range.emplace_back(pair); output_desc_y->SetShape(GeShape(shape_dims)); output_desc_y->SetOriginShape(GeShape(shape_dims)); output_desc_y->SetDataType((DataType)dtype); output_desc_y->SetShapeRange(y_range); return GRAPH_SUCCESS; } else { OP_LOGD(op.GetName().c_str(), "Get input const data failed!"); std::pair pair({1,shape_range.size()}); y_range.emplace_back(pair); output_desc_y->SetShape(GeShape(UNKNOWN_SHAPE)); output_desc_y->SetOriginShape(GeShape(UNKNOWN_SHAPE)); output_desc_y->SetDataType((DataType)dtype); output_desc_y->SetShapeRange(y_range); return GRAPH_SUCCESS; } output_desc_y->SetShape(GeShape(vec_dim)); output_desc_y->SetOriginShape(GeShape(vec_dim)); output_desc_y->SetDataType((DataType)dtype); OP_LOGD(op.GetName().c_str(), "Empty infershape end"); return GRAPH_SUCCESS; } INFER_FUNC_REG(Empty, EmptyInfer); IMPLEMT_INFERFUNC(Where, WhereInfer) { OpDescPtr op_desc = OpDescUtils::GetOpDescFromOperator(op); GeTensorDescPtr x_desc = op_desc->MutableInputDesc(0); GeShape x_shape; if (WithRankAtLeast(x_desc, 1, x_shape) != GRAPH_SUCCESS) { OP_LOGE(op.GetName().c_str(), "input x must be at least 1D."); return GRAPH_FAILED; } if (WithRankAtMost(x_desc, 5, x_shape) != GRAPH_SUCCESS) { OP_LOGE(op.GetName().c_str(), "input x must be at most 5D."); return GRAPH_FAILED; } GeTensorDescPtr y_desc = op_desc->MutableOutputDesc(0); y_desc->SetDataType(DT_INT64); vector y_shape; auto input_dims = x_shape.GetDims(); int64_t input_shape_size = x_shape.GetShapeSize(); if (input_shape_size != UNKNOWN_DIM) { // input shape: known y_shape.push_back(UNKNOWN_DIM); y_shape.push_back(input_dims.size()); std::vector> range; int64_t dims_num = x_shape.GetDimNum(); range.emplace_back(std::make_pair(1, input_shape_size)); range.emplace_back(std::make_pair(dims_num, dims_num)); y_desc->SetShapeRange(range); } else { if (input_dims == UNKNOWN_RANK) { // input shape: unknown rank y_shape.push_back(UNKNOWN_DIM); y_shape.push_back(UNKNOWN_DIM); } else { // input shape: unknown dims y_shape.push_back(UNKNOWN_DIM); y_shape.push_back(input_dims.size()); } } y_desc->SetShape(GeShape(y_shape)); y_desc->SetOriginShape(GeShape(y_shape)); return GRAPH_SUCCESS; } INFER_FUNC_REG(Where, WhereInfer); IMPLEMT_INFERFUNC(TransShape, TransShapeInfer) { TensorDesc y_desc = op.GetOutputDesc("y"); vector output_shape; auto ret = op.GetAttr("outShape", output_shape); if (ret != GRAPH_SUCCESS) { OP_LOGE(op.GetName().c_str(), "Failed to get attribute value."); return GRAPH_SUCCESS; } y_desc.SetShape(Shape(output_shape)); if (op.UpdateOutputDesc("y", y_desc) != GRAPH_SUCCESS) { return GRAPH_FAILED; } return GRAPH_SUCCESS; } INFER_FUNC_REG(TransShape, TransShapeInfer); // ----------------SortV2 Begin------------------- IMPLEMT_INFERFUNC(SortV2, SortV2InferShape) { TensorDesc tensordesc_input = op.GetInputDesc("x"); Shape input_shape = tensordesc_input.GetShape(); DataType input_dtype = tensordesc_input.GetDataType(); std::vector dims_input = input_shape.GetDims(); TensorDesc tensordesc_output1 = op.GetOutputDesc("y"); tensordesc_output1.SetShape(ge::Shape(dims_input)); tensordesc_output1.SetDataType(input_dtype); (void)op.UpdateOutputDesc("y", tensordesc_output1); return GRAPH_SUCCESS; } IMPLEMT_VERIFIER(SortV2, SortV2Verify) { return GRAPH_SUCCESS; } INFER_FUNC_REG(SortV2, SortV2InferShape); VERIFY_FUNC_REG(SortV2, SortV2Verify); // ----------------SortV2 END--------------------- // ----------------Expand Begin------------------- template static bool ExpandCalDim(const Tensor &data, std::vector &vec_dim, std::vector &vec_x) { uint32_t size_shape = data.GetSize() / sizeof(T); uint32_t size_x = vec_x.size(); if (size_shape < size_x) { uint32_t diff = size_x - size_shape; for (int32_t i = 0; i < size_x; i++) { if (i < diff) { vec_dim.push_back(vec_x[i]); } else { T dim = *((T *)data.GetData() + (i - diff)); if ((vec_x[i] != dim) && (vec_x[i] != 1) && (dim != 1)) { return false; } if (vec_x[i] > dim) { vec_dim.push_back(vec_x[i]); } else { vec_dim.push_back(dim); } } } } else { uint32_t diff = size_shape - size_x; for (int32_t i = 0; i < size_shape; i++) { T dim = *((T *)data.GetData() + i); if (i < diff) { vec_dim.push_back(dim); } else { if ((vec_x[i - diff] != dim) && (vec_x[i-diff] != 1) && (dim != 1)) { return false; } if (vec_x[i - diff] > dim) { vec_dim.push_back(vec_x[i - diff]); } else { vec_dim.push_back(dim); } } } } return true; } IMPLEMT_COMMON_INFERFUNC(ExpandInferShape) { Shape x_shape = op.GetInputDesc("x").GetShape(); DataType x_dtype = op.GetInputDesc("x").GetDataType(); std::vector dims_x = x_shape.GetDims(); Tensor data; std::vector vec_dim; TensorDesc td = op.GetOutputDesc("y"); if (op.GetInputConstData("shape", data) != GRAPH_SUCCESS) { OP_LOGE(op.GetName().c_str(), "Get constValue failed of [shape]"); return GRAPH_FAILED; } else { DataType data_type = data.GetTensorDesc().GetDataType(); std::vector vec_dim; if (data_type == DT_INT32) { if (!ExpandCalDim (data, vec_dim, dims_x)) { OP_LOGE(op.GetName().c_str(), "Data shape are not compatible!"); return GRAPH_FAILED; } } else if (data_type == DT_INT64) { if (!ExpandCalDim (data, vec_dim, dims_x)) { OP_LOGE(op.GetName().c_str(), "Data shape are not compatible!"); return GRAPH_FAILED; } } else { OP_LOGE(op.GetName().c_str(), "Data type not supported!"); return GRAPH_PARAM_INVALID; } td.SetShape(ge::Shape(vec_dim)); td.SetDataType(x_dtype); (void)op.UpdateOutputDesc("y", td); return GRAPH_SUCCESS; } } COMMON_INFER_FUNC_REG(Expand, ExpandInferShape); // ----------------Expand END--------------------- // ----------------ExpandD Begin------------------- IMPLEMT_COMMON_INFERFUNC(ExpandDInferShape) { Shape x_shape = op.GetInputDesc("x").GetShape(); DataType x_dtype = op.GetInputDesc("x").GetDataType(); std::vector shape; op.GetAttr("shape", shape); std::vector dims_x = x_shape.GetDims(); TensorDesc td = op.GetOutputDesc("y"); std::vector dim_vec; if (shape.size() < dims_x.size()) { std::vector dims_tmp = shape; shape = dims_x; dims_x = dims_tmp; } if (shape.size() != dims_x.size()) { int dec = shape.size() - dims_x.size(); for (int i = 0; i < dec; i++) { dims_x.insert(dims_x.begin(), (int64_t)1); } } for (size_t i = 0; i < shape.size(); i++) { if ((shape[i] != dims_x[i]) && (shape[i] != 1) && (dims_x[i] != 1)) { OP_LOGE(op.GetName().c_str(), "The input shape and attr shape are not compatible."); return GRAPH_FAILED; } if (shape[i] > dims_x[i]) { dim_vec.push_back(shape[i]); } else { dim_vec.push_back(dims_x[i]); } } td.SetShape(ge::Shape(dim_vec)); td.SetDataType(x_dtype); (void)op.UpdateOutputDesc("y", td); return GRAPH_SUCCESS; } COMMON_INFER_FUNC_REG(ExpandD, ExpandDInferShape); // ----------------Expand END--------------------- } // namespace ge