diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index 4847daf9..2dcf1cdb 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -162,6 +162,7 @@ set(TRAIN_SRC_LIST "host_kernels/broadcast_args_kernel.cc" "host_kernels/broadcast_gradient_args_kernel.cc" "host_kernels/cast_kernel.cc" + "host_kernels/constant_of_shape_kernel.cc" "host_kernels/concat_offset_kernel.cc" "host_kernels/concat_v2_kernel.cc" "host_kernels/dynamic_stitch_kernel.cc" @@ -514,6 +515,7 @@ set(INFER_SRC_LIST "host_kernels/rank_kernel.cc" "host_kernels/broadcast_args_kernel.cc" "host_kernels/fill_kernel.cc" + "host_kernels/constant_of_shape_kernel.cc" "host_kernels/empty_kernel.cc" "host_kernels/expanddims_kernel.cc" "host_kernels/reshape_kernel.cc" diff --git a/ge/common/types.cc b/ge/common/types.cc index de293d34..c6ba69f1 100644 --- a/ge/common/types.cc +++ b/ge/common/types.cc @@ -293,6 +293,7 @@ REGISTER_OPTYPE_DEFINE(GETNEXT, "GetNext"); REGISTER_OPTYPE_DEFINE(INITDATA, "InitData"); REGISTER_OPTYPE_DEFINE(REFIDENTITY, "RefIdentity"); REGISTER_OPTYPE_DEFINE(BITCAST, "Bitcast"); +REGISTER_OPTYPE_DEFINE(CONSTANTOFSHAPE, "ConstantOfShape"); /***************Ann special operator*************************/ REGISTER_OPTYPE_DEFINE(ANN_MEAN, "AnnMean"); diff --git a/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc b/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc index badca5a3..2affe903 100644 --- a/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc +++ b/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc @@ -62,5 +62,6 @@ REGISTER_OP_CREATOR(RefMerge, GeDeletedOp); REGISTER_OP_CREATOR(RefSwitch, GeDeletedOp); REGISTER_OP_CREATOR(TransShape, GeDeletedOp); REGISTER_OP_CREATOR(Bitcast, GeDeletedOp); +REGISTER_OP_CREATOR(ConstantOfShape, GeDeletedOp); } // namespace ge_local } // namespace ge diff --git a/ge/host_kernels/constant_of_shape_kernel.cc b/ge/host_kernels/constant_of_shape_kernel.cc new file mode 100644 index 00000000..95faf482 --- /dev/null +++ b/ge/host_kernels/constant_of_shape_kernel.cc @@ -0,0 +1,140 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "host_kernels/constant_of_shape_kernel.h" + +#include +#include + +#include "common/debug/log.h" +#include "common/fp16_t.h" +#include "common/op/ge_op_utils.h" +#include "framework/common/debug/ge_log.h" +#include "host_kernels/kernel_utils.h" +#include "graph/utils/type_utils.h" +#include "inc/kernel_factory.h" +#include "framework/common/types.h" + +namespace ge { +namespace { +const size_t kAddInputSize = 1; +const size_t kAddOutputSize = 1; +} // namespace + +Status ConstantOfShapeKernel::CheckParam( + const OpDescPtr &op_desc_ptr, const std::vector &input) { + + GE_CHECK_NOTNULL(op_desc_ptr); + // check how many inputs + if ((input.size() != kAddInputSize) + || (op_desc_ptr->GetOutputsSize() != kAddOutputSize)) { + GELOGE(PARAM_INVALID, + "Input number must be [%zu], output number must be [%zu].", + kAddInputSize, kAddOutputSize); + return PARAM_INVALID; + } + GE_CHECK_NOTNULL(input[0]); + // input vector elements must be 1-D or empty tensor + auto input_shape = input[0]->GetTensorDesc().GetShape(); + if (input_shape.GetShapeSize() == 0) { + GELOGD("Input 0 is empty tensor."); + } else { + size_t dim_num = input_shape.GetDimNum(); + if (dim_num != 1) { + GELOGE(PARAM_INVALID, "Shape input must be a 1-D tensor, but got [%zu]", + dim_num); + return PARAM_INVALID; + } + } + return SUCCESS; +} + +Status ConstantOfShapeKernel::Compute(const ge::OpDescPtr op_desc_ptr, + const vector &input, + vector &v_output) { + GELOGI("ConstantOfShapeKernel in."); + if (CheckParam(op_desc_ptr, input) != SUCCESS) { + return NOT_CHANGED; + } + + GeShape out_shape; + int64_t fill_size = 1; + Status ret = PARAM_INVALID; + + ConstGeTensorPtr value = MakeShared(); + GE_CHECK_NOTNULL(value); + GeTensorPtr output_ptr = MakeShared(op_desc_ptr->GetOutputDesc(0)); + GE_CHECK_NOTNULL(output_ptr); + + ConstGeTensorPtr dims = input.at(0); + if (dims->GetData().size() == 0) { + GELOGI("Input 0 is empty tensor, then output 0 is a scalar."); + out_shape = GeShape(); + } else { + std::vector vec_dim; + GE_RETURN_IF_ERROR(KernelUtils::CalcDims(dims, vec_dim, fill_size)); + out_shape = GeShape(vec_dim); + } + + DataType data_type; + if (!AttrUtils::GetTensor(op_desc_ptr, "value", value)) { + GELOGE(FAILED, "Get Attr value failed."); + return FAILED; + } + data_type = value->GetTensorDesc().GetDataType(); + + ret = PARAM_INVALID; + switch (data_type) { +#define CASE(dtype, type) \ + case dtype: \ + ret = KernelUtils::GenData( \ + fill_size, \ + *reinterpret_cast(value->GetData().data()), \ + output_ptr); \ + break; + CASE(DT_FLOAT, float) + CASE(DT_FLOAT16, fp16_t) + CASE(DT_INT8, int8_t) + CASE(DT_INT16, int16_t) + CASE(DT_UINT16, uint16_t) + CASE(DT_UINT8, uint8_t) + CASE(DT_INT32, int32_t) + CASE(DT_INT64, int64_t) + CASE(DT_UINT32, uint32_t) + CASE(DT_UINT64, uint64_t) + CASE(DT_BOOL, bool) + CASE(DT_DOUBLE, double) +#undef CASE + default: + GELOGE(PARAM_INVALID, "Invalid data type: [%s]", + TypeUtils::DataTypeToSerialString(data_type).c_str()); + break; + } + if (ret != SUCCESS) { + GELOGE(ret, "GenData failed, data_type: [%s]", + TypeUtils::DataTypeToSerialString(data_type).c_str()); + return ret; + } + output_ptr->MutableTensorDesc().SetShape(out_shape); + output_ptr->MutableTensorDesc().SetDataType(data_type); + v_output.push_back(output_ptr); + + GELOGI("ConstantOfShapeKernel success."); + return SUCCESS; +} + +REGISTER_KERNEL(CONSTANTOFSHAPE, ConstantOfShapeKernel); +} // namespace ge diff --git a/ge/host_kernels/constant_of_shape_kernel.h b/ge/host_kernels/constant_of_shape_kernel.h new file mode 100644 index 00000000..fccc6789 --- /dev/null +++ b/ge/host_kernels/constant_of_shape_kernel.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_FOLDING_KERNELCONSTANT_OF_SHAPE_KERNEL_H_ +#define GE_GRAPH_PASSES_FOLDING_KERNELCONSTANT_OF_SHAPE_KERNEL_H_ + +#include + +#include "inc/kernel.h" + +namespace ge { +class ConstantOfShapeKernel : public Kernel { + public: + Status Compute(const OpDescPtr op_desc_ptr, const std::vector &input, + std::vector &v_output) override; + + private: + Status CheckParam(const OpDescPtr &op_desc_ptr, + const std::vector &input); +}; +} // namespace ge + +#endif // GE_GRAPH_PASSES_FOLDING_KERNELCONSTANT_OF_SHAPE_KERNEL_H_ diff --git a/inc/framework/common/types.h b/inc/framework/common/types.h index ad284d07..6467648a 100644 --- a/inc/framework/common/types.h +++ b/inc/framework/common/types.h @@ -340,6 +340,7 @@ REGISTER_OPTYPE_DECLARE(INITDATA, "InitData"); REGISTER_OPTYPE_DECLARE(TRANSSHAPE, "TransShape") REGISTER_OPTYPE_DECLARE(REFIDENTITY, "RefIdentity"); REGISTER_OPTYPE_DECLARE(BITCAST, "Bitcast"); +REGISTER_OPTYPE_DECLARE(CONSTANTOFSHAPE, "ConstantOfShape"); // ANN dedicated operator REGISTER_OPTYPE_DECLARE(ANN_MEAN, "AnnMean");