Browse Source

add op constantofshape

pull/553/head
lilei 4 years ago
parent
commit
eaa36c48d7
6 changed files with 181 additions and 0 deletions
  1. +2
    -0
      ge/CMakeLists.txt
  2. +1
    -0
      ge/common/types.cc
  3. +1
    -0
      ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc
  4. +140
    -0
      ge/host_kernels/constant_of_shape_kernel.cc
  5. +36
    -0
      ge/host_kernels/constant_of_shape_kernel.h
  6. +1
    -0
      inc/framework/common/types.h

+ 2
- 0
ge/CMakeLists.txt View File

@@ -162,6 +162,7 @@ set(TRAIN_SRC_LIST
"host_kernels/broadcast_args_kernel.cc"
"host_kernels/broadcast_gradient_args_kernel.cc"
"host_kernels/cast_kernel.cc"
"host_kernels/constant_of_shape_kernel.cc"
"host_kernels/concat_offset_kernel.cc"
"host_kernels/concat_v2_kernel.cc"
"host_kernels/dynamic_stitch_kernel.cc"
@@ -514,6 +515,7 @@ set(INFER_SRC_LIST
"host_kernels/rank_kernel.cc"
"host_kernels/broadcast_args_kernel.cc"
"host_kernels/fill_kernel.cc"
"host_kernels/constant_of_shape_kernel.cc"
"host_kernels/empty_kernel.cc"
"host_kernels/expanddims_kernel.cc"
"host_kernels/reshape_kernel.cc"


+ 1
- 0
ge/common/types.cc View File

@@ -293,6 +293,7 @@ REGISTER_OPTYPE_DEFINE(GETNEXT, "GetNext");
REGISTER_OPTYPE_DEFINE(INITDATA, "InitData");
REGISTER_OPTYPE_DEFINE(REFIDENTITY, "RefIdentity");
REGISTER_OPTYPE_DEFINE(BITCAST, "Bitcast");
REGISTER_OPTYPE_DEFINE(CONSTANTOFSHAPE, "ConstantOfShape");

/***************Ann special operator*************************/
REGISTER_OPTYPE_DEFINE(ANN_MEAN, "AnnMean");


+ 1
- 0
ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc View File

@@ -62,5 +62,6 @@ REGISTER_OP_CREATOR(RefMerge, GeDeletedOp);
REGISTER_OP_CREATOR(RefSwitch, GeDeletedOp);
REGISTER_OP_CREATOR(TransShape, GeDeletedOp);
REGISTER_OP_CREATOR(Bitcast, GeDeletedOp);
REGISTER_OP_CREATOR(ConstantOfShape, GeDeletedOp);
} // namespace ge_local
} // namespace ge

+ 140
- 0
ge/host_kernels/constant_of_shape_kernel.cc View File

@@ -0,0 +1,140 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "host_kernels/constant_of_shape_kernel.h"

#include <memory>
#include <set>

#include "common/debug/log.h"
#include "common/fp16_t.h"
#include "common/op/ge_op_utils.h"
#include "framework/common/debug/ge_log.h"
#include "host_kernels/kernel_utils.h"
#include "graph/utils/type_utils.h"
#include "inc/kernel_factory.h"
#include "framework/common/types.h"

namespace ge {
namespace {
const size_t kAddInputSize = 1;
const size_t kAddOutputSize = 1;
} // namespace

Status ConstantOfShapeKernel::CheckParam(
const OpDescPtr &op_desc_ptr, const std::vector<ConstGeTensorPtr> &input) {

GE_CHECK_NOTNULL(op_desc_ptr);
// check how many inputs
if ((input.size() != kAddInputSize)
|| (op_desc_ptr->GetOutputsSize() != kAddOutputSize)) {
GELOGE(PARAM_INVALID,
"Input number must be [%zu], output number must be [%zu].",
kAddInputSize, kAddOutputSize);
return PARAM_INVALID;
}
GE_CHECK_NOTNULL(input[0]);
// input vector elements must be 1-D or empty tensor
auto input_shape = input[0]->GetTensorDesc().GetShape();
if (input_shape.GetShapeSize() == 0) {
GELOGD("Input 0 is empty tensor.");
} else {
size_t dim_num = input_shape.GetDimNum();
if (dim_num != 1) {
GELOGE(PARAM_INVALID, "Shape input must be a 1-D tensor, but got [%zu]",
dim_num);
return PARAM_INVALID;
}
}
return SUCCESS;
}

Status ConstantOfShapeKernel::Compute(const ge::OpDescPtr op_desc_ptr,
const vector<ge::ConstGeTensorPtr> &input,
vector<ge::GeTensorPtr> &v_output) {
GELOGI("ConstantOfShapeKernel in.");
if (CheckParam(op_desc_ptr, input) != SUCCESS) {
return NOT_CHANGED;
}

GeShape out_shape;
int64_t fill_size = 1;
Status ret = PARAM_INVALID;

ConstGeTensorPtr value = MakeShared<const GeTensor>();
GE_CHECK_NOTNULL(value);
GeTensorPtr output_ptr = MakeShared<GeTensor>(op_desc_ptr->GetOutputDesc(0));
GE_CHECK_NOTNULL(output_ptr);

ConstGeTensorPtr dims = input.at(0);
if (dims->GetData().size() == 0) {
GELOGI("Input 0 is empty tensor, then output 0 is a scalar.");
out_shape = GeShape();
} else {
std::vector<int64_t> vec_dim;
GE_RETURN_IF_ERROR(KernelUtils::CalcDims<int64_t>(dims, vec_dim, fill_size));
out_shape = GeShape(vec_dim);
}

DataType data_type;
if (!AttrUtils::GetTensor(op_desc_ptr, "value", value)) {
GELOGE(FAILED, "Get Attr value failed.");
return FAILED;
}
data_type = value->GetTensorDesc().GetDataType();

ret = PARAM_INVALID;
switch (data_type) {
#define CASE(dtype, type) \
case dtype: \
ret = KernelUtils::GenData( \
fill_size, \
*reinterpret_cast<const type *>(value->GetData().data()), \
output_ptr); \
break;
CASE(DT_FLOAT, float)
CASE(DT_FLOAT16, fp16_t)
CASE(DT_INT8, int8_t)
CASE(DT_INT16, int16_t)
CASE(DT_UINT16, uint16_t)
CASE(DT_UINT8, uint8_t)
CASE(DT_INT32, int32_t)
CASE(DT_INT64, int64_t)
CASE(DT_UINT32, uint32_t)
CASE(DT_UINT64, uint64_t)
CASE(DT_BOOL, bool)
CASE(DT_DOUBLE, double)
#undef CASE
default:
GELOGE(PARAM_INVALID, "Invalid data type: [%s]",
TypeUtils::DataTypeToSerialString(data_type).c_str());
break;
}
if (ret != SUCCESS) {
GELOGE(ret, "GenData failed, data_type: [%s]",
TypeUtils::DataTypeToSerialString(data_type).c_str());
return ret;
}
output_ptr->MutableTensorDesc().SetShape(out_shape);
output_ptr->MutableTensorDesc().SetDataType(data_type);
v_output.push_back(output_ptr);

GELOGI("ConstantOfShapeKernel success.");
return SUCCESS;
}

REGISTER_KERNEL(CONSTANTOFSHAPE, ConstantOfShapeKernel);
} // namespace ge

+ 36
- 0
ge/host_kernels/constant_of_shape_kernel.h View File

@@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GRAPH_PASSES_FOLDING_KERNELCONSTANT_OF_SHAPE_KERNEL_H_
#define GE_GRAPH_PASSES_FOLDING_KERNELCONSTANT_OF_SHAPE_KERNEL_H_

#include <vector>

#include "inc/kernel.h"

namespace ge {
class ConstantOfShapeKernel : public Kernel {
public:
Status Compute(const OpDescPtr op_desc_ptr, const std::vector<ConstGeTensorPtr> &input,
std::vector<GeTensorPtr> &v_output) override;

private:
Status CheckParam(const OpDescPtr &op_desc_ptr,
const std::vector<ConstGeTensorPtr> &input);
};
} // namespace ge

#endif // GE_GRAPH_PASSES_FOLDING_KERNELCONSTANT_OF_SHAPE_KERNEL_H_

+ 1
- 0
inc/framework/common/types.h View File

@@ -340,6 +340,7 @@ REGISTER_OPTYPE_DECLARE(INITDATA, "InitData");
REGISTER_OPTYPE_DECLARE(TRANSSHAPE, "TransShape")
REGISTER_OPTYPE_DECLARE(REFIDENTITY, "RefIdentity");
REGISTER_OPTYPE_DECLARE(BITCAST, "Bitcast");
REGISTER_OPTYPE_DECLARE(CONSTANTOFSHAPE, "ConstantOfShape");

// ANN dedicated operator
REGISTER_OPTYPE_DECLARE(ANN_MEAN, "AnnMean");


Loading…
Cancel
Save